## Takeaway #1. Defense performance is sensitive to the purity of the base set

In [1]:
from models import PreActResNet18
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from meta_sift import *
import numpy as np
import math

device = 'cuda'
torch.cuda.set_device(0)
set_seed(0)

In [2]:
# Load a poisoned Model
poisoned_model = VGG('small_VGG16')
poisoned_model.load_state_dict(torch.load('./checkpoints/gtsrb_tar38_badnets.pth', map_location="cuda"))
poisoned_model = poisoned_model.cuda()

# Load the ASR evaluation testset
data_transform = transforms.Compose([transforms.ToTensor(),])
train_path = './dataset/gtsrb_dataset.h5'
testset = h5_dataset(train_path, False, None)
asr_test = posion_image_all2one(testset, list(np.where(np.array(testset.targets)!=38)[0]), 38, data_transform)

In [3]:
# Get ASR
print('Original Poison Model ASR: %.3f%%' % (get_results(poisoned_model, asr_test)))

Original Poison Model ASR: 96.793%


In [4]:
# Load the poisoned train dataset
trainset = h5_dataset(train_path, True, None)
train_poi_set, poi_idx = poi_dataset(trainset, poi_methond="badnets", transform=data_transform, poi_rates=0.33,random_seed=0, tar_lab=38)
clean_validset = get_validset(train_poi_set, poi_idx)

In [5]:
from ibau import IBAU
# Import IBAU to clean the model
acctest = h5_dataset(train_path, False, data_transform)
cleaned_model = IBAU(copy.deepcopy(poisoned_model), clean_validset)
print('I-BAU (clean base set) Cleaned Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model), asr_test)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
I-BAU (clean base set) Cleaned Model ASR: 11.536%


In [6]:
# Ibau with poisoned base set, only 8 poisoned sample in 1000 base set.
poi_base_set = get_validset(train_poi_set, poi_idx, 1000, 8)
cleaned_model_dirty = IBAU(copy.deepcopy(poisoned_model), poi_base_set)

print('I-BAU (poisoned base set) Cleaned Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), asr_test)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
I-BAU (poisoned base set) Cleaned Model ASR: 73.903%


## Takeaway #2.Existing automated methods fail to identify a clean subset with high enough precision

In [7]:
from dc import DCM
dc_idx = DCM(train_poi_set, 1000)

In [8]:
cleaned_model_dirty = IBAU(copy.deepcopy(poisoned_model), torch.utils.data.Subset(train_poi_set, dc_idx))


print('I-BAU (DCM base set) Cleaned Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), asr_test)))


Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
I-BAU (DC base set) Cleaned Model ASR: 71.781%


In [9]:
print('NCR for DCM is: %.3f%%' % get_NCR(train_poi_set, poi_idx, dc_idx))

NCR for DC is: 33.885%


## Takeaway #3. Meta-Sift can obtain a clean subset under poison situation

In [10]:
class Args:
    num_classes = 43
    tar_lab = 38
    repeat_rounds = 5
    res_epochs = 1
    warmup_epochs = 1
    batch_size = 128
    num_workers = 16
    v_lr = 0.0005
    meta_lr = 0.1
    top_k = 15
    go_lr = 1e-1
    num_act = 4
    momentum = 0.9
    nesterov = True
    random_seed = 0
args=Args()

In [11]:
meta_sift_idx = meta_sift(args, train_poi_set)

-----------Training sifter number: 0-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:32<00:00,  9.31it/s]


-----------Training sifter number: 1-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.22it/s]


-----------Training sifter number: 2-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.23it/s]


-----------Training sifter number: 3-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.21it/s]


-----------Training sifter number: 4-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.17it/s]
307it [00:02, 121.86it/s]                         
307it [00:02, 121.56it/s]                         
307it [00:02, 120.50it/s]                         
307it [00:02, 120.78it/s]                         
307it [00:02, 120.80it/s]                         


In [12]:
cleaned_model_meta_sitf = IBAU(copy.deepcopy(poisoned_model), torch.utils.data.Subset(train_poi_set, meta_sift_idx))

print('I-BAU (meta sift base set) Cleaned Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_meta_sitf), asr_test)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
I-BAU (meta sift base set) Cleaned Model ASR: 8.955%


In [13]:
# NCR for meta_sift
print('NCR for Meta Sift is: %.3f%%' % get_NCR(train_poi_set, poi_idx, meta_sift_idx))

NCR for Meta Sift is: 0.000%
