## 1. Eval the performance of I-BAU

In [1]:
from models import PreActResNet18
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from meta_sift import *
import numpy as np
import math

device = 'cuda'
torch.cuda.set_device(2)
set_seed(0)

In [2]:
# Load a poisoned Model
poisoned_model = VGG('small_VGG16')
poisoned_model.load_state_dict(torch.load('./checkpoints/gtsrb_tar38_badnets.pth', map_location="cuda"))
poisoned_model = poisoned_model.cuda()

# Load the ASR evaluation testset
data_transform = transforms.Compose([transforms.ToTensor(),])
train_path = './dataset/gtsrb_dataset.h5'
testset = h5_dataset(train_path, False, None)
asr_test = posion_image_all2one(testset, list(np.where(np.array(testset.targets)!=38)[0]), 38, data_transform)

In [3]:
# Get ASR
print('Original Poison Model ASR: %.3f%%' % (get_results(poisoned_model, asr_test)))

Original Poison Model ASR: 96.793%


In [4]:
# Load the poisoned train dataset
trainset = h5_dataset(train_path, True, None)
train_poi_set, poi_idx = poi_dataset(trainset, poi_methond="badnets", transform=data_transform, poi_rates=0.33,random_seed=0, tar_lab=38)
clean_validset = get_validset(train_poi_set, poi_idx)

In [5]:
from ibau import IBAU
# Import IBAU to clean the model
acctest = h5_dataset(train_path, False, data_transform)
cleaned_model = IBAU(copy.deepcopy(poisoned_model), clean_validset)
print('Original Poison Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model), asr_test)))
print('Original Poison Model ACC: %.3f%%' % (get_results(copy.deepcopy(cleaned_model), acctest)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
Original Poison Model ASR: 11.964%
Original Poison Model ACC: 92.518%


In [6]:
# Ibau with poisoned base set
poi_base_set = get_validset(train_poi_set, poi_idx, 1000, 8)
cleaned_model_dirty = IBAU(copy.deepcopy(poisoned_model), poi_base_set)

print('Original Poison Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), asr_test)))
print('Original Poison Model ACC: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), acctest)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
Original Poison Model ASR: 77.371%
Original Poison Model ACC: 94.054%


## 2.DC cannot remove all poison

In [7]:
from dc import DC
dc_idx = DC(train_poi_set, 1000)

In [8]:
cleaned_model_dirty = IBAU(copy.deepcopy(poisoned_model), torch.utils.data.Subset(train_poi_set, dc_idx))


print('Original Poison Model ASR: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), asr_test)))
print('Original Poison Model ACC: %.3f%%' % (get_results(copy.deepcopy(cleaned_model_dirty), acctest)))

Round: 0
Round: 1
Round: 2
Round: 3
Round: 4
Original Poison Model ASR: 73.911%
Original Poison Model ACC: 90.594%


In [9]:
print('NCR for DC is: %.3f%%' % get_NCR(train_poi_set, poi_idx, dc_idx))

NCR for DC is: 33.885%


## 3. Meta-sift result

In [10]:
class Args:
    num_classes = 43
    tar_lab = 38
    repeat_rounds = 5
    res_epochs = 1
    warmup_epochs = 1
    batch_size = 128
    num_workers = 16
    v_lr = 0.0005
    meta_lr = 0.1
    top_k = 15
    go_lr = 1e-1
    num_act = 4
    momentum = 0.9
    nesterov = True
    random_seed = 0
args=Args()

In [11]:
new_idx = meta_sift(args, train_poi_set)

-----------Start sift round: 0-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.15it/s]


-----------Start sift round: 1-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.20it/s]


-----------Start sift round: 2-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.17it/s]


-----------Start sift round: 3-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.14it/s]


-----------Start sift round: 4-----------
Warmup Epoch 0 


100%|██████████| 307/307 [00:33<00:00,  9.16it/s]
307it [00:02, 120.73it/s]                         
307it [00:02, 119.78it/s]                         
307it [00:02, 120.08it/s]                         
307it [00:02, 119.54it/s]                         
307it [00:02, 121.14it/s]                         


In [12]:
# NCR for meta_sift
print('NCR for Meta Sift is: %.3f%%' % get_NCR(train_poi_set, poi_idx, new_idx))

NCR for Meta Sift is: 0.000%
