In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets

from cifar10_poi_dataset import CIFAR10Poi
import common


transform = common.transform_cifar

trainset_poisoned = CIFAR10Poi(
    common.DATA_ROOT,
    download=True,
    train=True,
    transform=transform,
    poison_rate=common.poison_rate_train,
    poison_size=common.poison_size,
)

trainloader_poisoned = DataLoader(
    trainset_poisoned, batch_size=common.batch_size, shuffle=True
)


model = common.CIFAR10Net(3, 10)
optimizer = common.get_optimizer(model)

testset_clean = datasets.CIFAR10(common.DATA_ROOT, train=False, transform=transform)
testloader_clean = torch.utils.data.DataLoader(
    testset_clean, batch_size=common.batch_size, shuffle=False
)
poisoned_model = common.train_model(
    model,
    trainloader_poisoned,
    optimizer,
    nr_epochs=common.nr_epochs_cifar10,
    testloader=testloader_clean,
)

print("Test result for badnet and clean testset:")
common.test_model(model, testloader_clean)

Files already downloaded and verified
training in epoch 0/20


196it [00:44,  4.45it/s]


Accuracies (label 0-9): 0.57 0.68 0.45 0.29 0.38 0.37 0.59 0.64 0.85 0.50 
training in epoch 1/20


196it [00:43,  4.51it/s]


Accuracies (label 0-9): 0.73 0.77 0.49 0.52 0.33 0.42 0.87 0.71 0.72 0.72 
training in epoch 2/20


196it [00:43,  4.51it/s]


Accuracies (label 0-9): 0.72 0.81 0.51 0.49 0.74 0.53 0.68 0.70 0.83 0.71 
training in epoch 3/20


196it [00:44,  4.45it/s]


Accuracies (label 0-9): 0.52 0.78 0.43 0.52 0.67 0.68 0.79 0.69 0.86 0.84 
training in epoch 4/20


196it [00:44,  4.40it/s]


Accuracies (label 0-9): 0.71 0.88 0.45 0.40 0.70 0.69 0.85 0.75 0.84 0.80 
training in epoch 5/20


196it [00:43,  4.53it/s]


Accuracies (label 0-9): 0.75 0.80 0.56 0.53 0.74 0.57 0.80 0.80 0.73 0.87 
training in epoch 6/20


196it [00:44,  4.38it/s]


Accuracies (label 0-9): 0.82 0.86 0.70 0.44 0.43 0.72 0.70 0.77 0.74 0.79 
training in epoch 7/20


196it [00:44,  4.45it/s]


Accuracies (label 0-9): 0.74 0.81 0.59 0.60 0.71 0.44 0.78 0.81 0.80 0.87 
training in epoch 8/20


196it [00:43,  4.53it/s]


Accuracies (label 0-9): 0.77 0.86 0.65 0.44 0.65 0.60 0.87 0.74 0.83 0.75 
training in epoch 9/20


196it [00:43,  4.55it/s]


Accuracies (label 0-9): 0.77 0.77 0.55 0.60 0.68 0.66 0.79 0.73 0.83 0.81 
training in epoch 10/20


196it [00:43,  4.52it/s]


Accuracies (label 0-9): 0.77 0.83 0.56 0.52 0.68 0.58 0.85 0.76 0.87 0.73 
training in epoch 11/20


196it [00:43,  4.50it/s]


Accuracies (label 0-9): 0.83 0.82 0.57 0.51 0.66 0.65 0.81 0.74 0.81 0.76 
training in epoch 12/20


196it [00:43,  4.48it/s]


Accuracies (label 0-9): 0.75 0.89 0.62 0.56 0.75 0.55 0.79 0.69 0.81 0.76 
training in epoch 13/20


196it [00:42,  4.62it/s]


Accuracies (label 0-9): 0.74 0.85 0.58 0.56 0.67 0.62 0.79 0.77 0.85 0.78 
training in epoch 14/20


196it [00:42,  4.59it/s]


Accuracies (label 0-9): 0.78 0.81 0.61 0.49 0.68 0.69 0.77 0.78 0.80 0.79 
training in epoch 15/20


196it [00:43,  4.54it/s]


Accuracies (label 0-9): 0.80 0.88 0.69 0.43 0.60 0.61 0.73 0.70 0.83 0.66 
training in epoch 16/20


196it [00:43,  4.55it/s]


Accuracies (label 0-9): 0.80 0.85 0.66 0.48 0.66 0.56 0.78 0.82 0.80 0.75 
training in epoch 17/20


196it [00:42,  4.58it/s]


Accuracies (label 0-9): 0.74 0.84 0.63 0.53 0.53 0.66 0.79 0.77 0.83 0.76 
training in epoch 18/20


196it [00:42,  4.56it/s]


Accuracies (label 0-9): 0.71 0.85 0.64 0.54 0.65 0.57 0.84 0.71 0.77 0.77 
training in epoch 19/20


196it [00:43,  4.52it/s]


Accuracies (label 0-9): 0.73 0.84 0.57 0.51 0.69 0.60 0.83 0.83 0.77 0.78 
Test result for badnet and clean testset:
Accuracies (label 0-9): 0.73 0.84 0.57 0.51 0.69 0.60 0.83 0.83 0.77 0.78 


In [2]:
testset_poi = CIFAR10Poi(
    common.DATA_ROOT,
    download=True,
    train=False,
    transform=transform,
    poison_rate=1,
    poison_size=common.poison_size,
)
testloader_poi = torch.utils.data.DataLoader(
    testset_poi, batch_size=common.batch_size, shuffle=False
)
print("Test result for badnet and backdoored testset (modified label):")
common.test_model(model, testloader_poi)

Files already downloaded and verified
Test result for badnet and backdoored testset (modified label):
Accuracies (label 0-9): 0.69 0.65 0.70 0.52 0.36 0.59 0.55 0.72 0.70 0.64 
