In [None]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 8

# the mean and std are the calculated statistics from cifar_10 dataset
cifar_10_mean = (0.491, 0.482, 0.447) # mean for the three channels of cifar_10 images
cifar_10_std = (0.202, 0.199, 0.201) # std for the three channels of cifar_10 images

# convert mean and std to 3-dimensional tensors for future operations
mean = torch.tensor(cifar_10_mean).to(device).view(3, 1, 1)
std = torch.tensor(cifar_10_std).to(device).view(3, 1, 1)

epsilon = 8/255/std
# TODO: iterative fgsm attack
# alpha (step size) can be decided by yourself
alpha = 0.8/255/std

root = './data' # directory for storing benign images
atk_root = './ifgsm'

In [None]:
import os
import glob
import shutil
import numpy as np
from PIL import Image
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_10_mean, cifar_10_std)
])

class AdvDataset(Dataset):
    def __init__(self, data_dir, transform):
        self.images = []
        self.labels = []
        self.names = []
        '''
        data_dir
        ├── class_dir
        │   ├── class1.png
        │   ├── ...
        │   ├── class20.png
        '''
        for i, class_dir in enumerate(sorted(glob.glob(f'{data_dir}/*'))):
            images = sorted(glob.glob(f'{class_dir}/*'))
            self.images += images
            self.labels += ([i] * len(images))
            self.names += [os.path.relpath(imgs, data_dir) for imgs in images]
        self.transform = transform
    def __getitem__(self, idx):
        image = self.transform(Image.open(self.images[idx]))
        label = self.labels[idx]
        return image, label
    def __getname__(self):
        return self.names
    def __len__(self):
        return len(self.images)

adv_set = AdvDataset(root, transform=transform)
adv_names = adv_set.__getname__()
adv_loader = DataLoader(adv_set, batch_size=batch_size, shuffle=False)

print(f'number of images = {adv_set.__len__()}')

In [None]:
# to evaluate the performance of model on benign images
def epoch_benign(model, loader, loss_fn):
    model.eval()
    train_acc, train_loss = 0.0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        yp = model(x)
        loss = loss_fn(yp, y)
        train_acc += (yp.argmax(dim=1) == y).sum().item()
        train_loss += loss.item() * x.shape[0]
    return train_acc / len(loader.dataset), train_loss / len(loader.dataset)

In [None]:
def epoch_benign(model, loader, loss_fn):
    model.eval()
    train_acc, train_loss = 0.0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        yp = model(x)
        loss = loss_fn(yp, y)
        train_acc += (yp.argmax(dim=1) == y).sum().item()
        train_loss += loss.item() * x.shape[0]
    return train_acc / len(loader.dataset), train_loss / len(loader.dataset)

# Attack Algorithm

In [None]:
# perform fgsm attack
def fgsm(model, x, y, loss_fn, epsilon=epsilon):
    x_adv = x.detach().clone() # initialize x_adv as original benign image x
    x_adv.requires_grad = True # need to obtain gradient of x_adv, thus set required grad
    loss = loss_fn(model(x_adv), y) # calculate loss
    loss.backward() # calculate gradient
    # fgsm: use gradient ascent on x_adv to maximize loss
    x_adv = x_adv + epsilon * x_adv.grad.detach().sign()
    return x_adv

# TODO: perform iterative fgsm attack
# set alpha as the step size in Global Settings section
# alpha and num_iter can be decided by yourself
def ifgsm(models, x, y, loss_fn, epsilon=epsilon, alpha=alpha, num_iter=100):
    x_adv = x.detach().clone()
    for i in range(num_iter):
        for model in models:
            model.eval()
            x_adv = fgsm(model,x_adv,y,loss_fn,alpha)
            x_adv = torch.min(torch.max(x_adv, x-epsilon),x+epsilon)
    return x_adv

# Attack and Generate Adversarial Examples

In [None]:
# perform adversarial attack and generate adversarial examples
def gen_adv_examples(models, loader, attack, loss_fn):
    
    adv_names = []
    train_acc, train_loss = 0.0, 0.0
    
    for i, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        x_adv = attack(models, x, y, loss_fn) # obtain adversarial examples
        yp = model(x_adv)
        loss = loss_fn(yp, y)
        train_acc += (yp.argmax(dim=1) == y).sum().item()
        train_loss += loss.item() * x.shape[0]
        # store adversarial examples
        adv_ex = ((x_adv) * std + mean).clamp(0, 1) # to 0-1 scale
        adv_ex = (adv_ex * 255).clamp(0, 255) # 0-255 scale
        adv_ex = adv_ex.detach().cpu().data.numpy().round() # round to remove decimal part
        adv_ex = adv_ex.transpose((0, 2, 3, 1)) # transpose (bs, C, H, W) back to (bs, H, W, C)
        adv_examples = adv_ex if i == 0 else np.r_[adv_examples, adv_ex]
    return adv_examples, train_acc / len(loader.dataset), train_loss / len(loader.dataset)

# create directory which stores adversarial examples
def create_dir(data_dir, adv_dir, adv_examples, adv_names):
    if os.path.exists(adv_dir) is not True:
        _ = shutil.copytree(data_dir, adv_dir)
    for example, name in zip(adv_examples, adv_names):
        im = Image.fromarray(example.astype(np.uint8)) # image pixel value should be unsigned int
        im.save(os.path.join(adv_dir, name))

# Attack Model

In [None]:
from pytorchcv.model_provider import get_model as ptcv_get_model

models_name = ['resnet20_cifar10',
          'resnet56_cifar10',
          'resnet110_cifar10',
          'resnet164bn_cifar10',
          'resnet272bn_cifar10',
          'resnet542bn_cifar10',
          'resnet1001_cifar10',
          'resnet1202_cifar10',
          'preresnet20_cifar10',
          'preresnet56_cifar10',
          'preresnet110_cifar10',
          'preresnet164bn_cifar10',
          'preresnet272bn_cifar10',
          'preresnet542bn_cifar10',
          'preresnet1001_cifar10',
          'preresnet1202_cifar10',
          'seresnet20_cifar10',
         'seresnet56_cifar10',
         'seresnet110_cifar10',
         'seresnet164bn_cifar10',
         'seresnet272bn_cifar10',
         'seresnet542bn_cifar10',
         'sepreresnet20_cifar10']

# models_name = ['resnet110_cifar10',
#                'preresnet110_cifar10',
#                'seresnet110_cifar10']

models = [ptcv_get_model(model_name, pretrained=True).to(device) for model_name in models_name]
loss_fn = nn.CrossEntropyLoss()


for model in models:
    benign_acc, benign_loss = epoch_benign(model, adv_loader, loss_fn)
    print(f'benign_acc = {benign_acc:.5f}, benign_loss = {benign_loss:.5f}')

# I-FGSM (generate ATK dataset)

In [None]:
adv_examples, ifgsm_acc, ifgsm_loss = gen_adv_examples(models, adv_loader, ifgsm, loss_fn)
print(f'ifgsm_acc = {ifgsm_acc:.5f}, ifgsm_loss = {ifgsm_loss:.5f}')

create_dir(root, 'ifgsm', adv_examples, adv_names)

# ATK dataloader

In [None]:
atk_set = AdvDataset(atk_root, transform=transform)
atk_names = atk_set.__getname__()
atk_loader = DataLoader(atk_set, batch_size=batch_size, shuffle=False)
print(f'number of images = {atk_set.__len__()}')

# Test on other model 

In [None]:
for model in models:

    other_acc, other_loss = epoch_benign(model, atk_loader, loss_fn)
    print(f'acc = {other_acc:.5f}, loss = {other_loss:.5f}')

# Visualization

In [None]:
import matplotlib.pyplot as plt

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10, 20))
cnt = 0
for i, cls_name in enumerate(classes):
    path = f'{cls_name}/{cls_name}1.png'
    # benign image
    cnt += 1
    plt.subplot(len(classes), 4, cnt)
    im = Image.open(f'./data/{path}')
    logit = model(transform(im).unsqueeze(0).to(device))[0]
    predict = logit.argmax(-1).item()
    prob = logit.softmax(-1)[predict].item()
    plt.title(f'benign: {cls_name}1.png\n{classes[predict]}: {prob:.2%}')
    plt.axis('off')
    plt.imshow(np.array(im))
    # adversarial image
    cnt += 1
    plt.subplot(len(classes), 4, cnt)
    im = Image.open(f'./ifgsm/{path}')
    logit = model(transform(im).unsqueeze(0).to(device))[0]
    predict = logit.argmax(-1).item()
    prob = logit.softmax(-1)[predict].item()
    plt.title(f'adversarial: {cls_name}1.png\n{classes[predict]}: {prob:.2%}')
    plt.axis('off')
    plt.imshow(np.array(im))
plt.tight_layout()
plt.show()

# output file

In [None]:
%cd ifgsm
!tar zcvf ../ifgsm.tgz *
%cd ..