In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim

from torchvision import datasets, transforms, models

from backdoor.poisons import NarcissusPoison

In [2]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [3]:
# The argumention use for surrogate model training stage
transform_surrogate_train = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# The argumention use for all training set
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# The argumention use for all testing set
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
trainset = datasets.CIFAR10(root='/data/', train=True, download=False, transform=transform_train)
testset = datasets.CIFAR10(root='/data/', train=False, download=False, transform=transform_test)
pood_trainset = datasets.ImageFolder(root='/data/tiny-imagenet-200/train/', transform=transform_surrogate_train)

In [5]:
surrogate_model = models.resnet18(num_classes=201).to(device)
warmup_model = models.resnet18(num_classes=201).to(device)

In [6]:
attack = NarcissusPoison(device, pood_trainset, trainset, surrogate_model, warmup_model)

In [7]:
# sur_epochs = 200

# sur_criterion = nn.CrossEntropyLoss()
# sur_optimizer = optim.SGD
# sur_scheduler = optim.lr_scheduler.CosineAnnealingLR

# attack.train_surrogate(sur_epochs, sur_criterion, sur_optimizer, sur_scheduler)

# surrogate already trained and is stored in './surrogate_model.pth'
surrogate_model.load_state_dict(torch.load('./surrogate_model.pth'))
attack.load_surrogate(surrogate_model)

In [9]:
warmup_epochs = 5
warmup_criterion = nn.CrossEntropyLoss()
warmup_optim = optim.RAdam

model = attack.poi_warmup(warmup_epochs, warmup_criterion, warmup_optim)

Epoch:0, Loss: 5.183448e-01
Epoch:1, Loss: 0.000000e+00
Epoch:2, Loss: 0.000000e+00
Epoch:3, Loss: 0.000000e+00
Epoch:4, Loss: 0.000000e+00


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.RAdam

attack.generate_trigger(1000, criterion, optimizer)

In [None]:
noise = attack.noise

In [None]:
def plot_img(img, title=''):
    if len(img.shape) > 3:
        img = img[0]
    img = np.moveaxis(img, 0, -1)
    img = np.clip(img, 0, 1)
    
    plt.imshow(img)
    plt.title(title)
    plt.show()

In [None]:
idx = 23

noised_img = trainset[idx][0].numpy() + noise[0]
print(noised_img.shape)

plot_img(noise, 'Noise')
plot_img(trainset[idx][0].numpy(), 'Normal Image')
plot_img(noised_img, 'Noised Image')

In [None]:
np.save('./noise.npy', noise)

In [None]:
poison_amount = 25

noise_testing_model = models.resnet18(num_classes=10)
noise_testing_model = noise_testing_model.to(device)

train_epochs = 200
train_lr = 0.1
test_batch_size = 150

multi_test = 3
random_seed = 65

In [None]:
np.random.seed(random_seed)
random.seed(random_seed)
torch.manual_seed(random_seed)
model = noise_testing_model

optimizer = optim.SGD(params=model.parameters(), lr=train_lr, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_epochs)

In [None]:
target_class = 0
train_target_list = list(np.where(np.array(trainset.targets) == target_class)[0])

transform_after_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),
])

In [None]:
# Get poisoned dataset for training the model
class PoisonedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, noise, transform):
        self.dataset = dataset
        self.indices = indices
        self.noise = noise
        self.transform = transform
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if idx in self.indices:
            image += self.noise
        if self.transform is not None:
            image = self.transform(image)
        return (image, label)

    def __len__(self):
        return len(self.dataset)
    
random_poison_idx = random.sample(train_target_list, poison_amount)
poison_train_target = PoisonedDataset(trainset, random_poison_idx, noise[0], transform_after_train)

In [None]:
print('Traing dataset size is:',len(poison_train_target)," Poison numbers is:",len(random_poison_idx))
clean_train_loader = torch.utils.data.DataLoader(poison_train_target, batch_size=test_batch_size, shuffle=True, num_workers=5)
clean_test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=5)

In [None]:
test_non_target = list(np.where(np.array(testset.targets)!= target_class)[0])

class AsrDataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset, indices, taget_class, noise, magnify):
        self.dataset = torch.utils.data.Subset(dataset, indices)
        self.target_class = target_class
        self.noise = noise
        if len(noise.shape) > 3:
            self.noise = noise[0]
        self.magnify = magnify
    
    def __getitem__(self, idx):
        img = self.dataset[idx][0]
        img += self.noise * self.magnify
        return (img, self.target_class)
    
    def __len__(self):
        return len(self.dataset)
    
test_target_indices = list(np.where(np.array(testset.targets) == target_class)[0])
test_target_dataset = torch.utils.data.Subset(testset, test_target_indices)

asr_testset = AsrDataset(testset, test_non_target, 0, noise, 3)

asr_loader = torch.utils.data.DataLoader(asr_testset, batch_size=test_batch_size, shuffle=False, num_workers=5)
target_test_loader = torch.utils.data.DataLoader(test_target_dataset, batch_size=test_batch_size, shuffle=False, num_workers=5)

In [None]:
for epoch in range(train_epochs):
    
    print(f'Epoch: [{epoch+1}/{train_epochs}]')
    
    model.train()
    train_losses = []
    for images, labels in clean_train_loader:
        images, labels = images.to(device), labels.to(device)
        model.zero_grad()
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        train_losses.append(loss)
        loss.backward()
        optimizer.step()
        
    print(f'Train Loss: {sum(train_losses)/len(train_losses)}')
    
    model.eval()
    
    # Get clean test accuracy
    correct_clean, total_clean = 0, 0
    for i, (images, labels) in enumerate(clean_test_loader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            logits = model(images)
            _, predicted = torch.max(logits.data, 1)
            total_clean += labels.size(0)
            correct_clean += (predicted == labels).sum().item()
    acc_clean = correct_clean / total_clean
    print('Clean Test Accuracy %.2f' % (acc_clean))
    
    # Get target clean accuracy
    correct_tar, total_tar = 0, 0
    for i, (images, labels) in enumerate(target_test_loader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            logits = model(images)
            _, predicted = torch.max(logits.data, 1)
            total_tar += labels.size(0)
            correct_tar += (predicted == labels).sum().item()
    acc_tar = correct_tar / total_tar
    print('\nTarget Test Clean Accuracy %.2f' % (acc_tar))

    # Get Attack Success Rate
    correct, total = 0, 0
    for i, (images, labels) in enumerate(asr_loader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            logits = model(images)
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct/total
    print(f'Attack Success Rate: {acc}')
    
    print()

In [None]:
idx = 80
magnify = 1

img = testset[idx][0]

plot_img(img.numpy())
plot_img(img.numpy() + noise[0] * magnify)


In [None]:
def output(model, idx, noise=None, magnify=magnify):
    img = testset[idx][0]
    img = img.unsqueeze(dim=0)
    if noise is not None:
        img += noise
    logits = model(img.to(device))
    _, predicted = torch.max(logits, 1)
    return predicted.item(), testset[idx][1]

output(model, idx, noise, magnify)