In [None]:
import foolbox
import torch
import torchvision
import numpy as np
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import pickle
import sys
import os

from model import CNN
from torchsummary import summary
import copy

In [None]:
# Number of classes in the dataset
num_classes = 10

### load data

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

In [None]:
model_folder = os.path.abspath('./adversarial')
if not os.path.exists(model_folder):
    os.mkdir(model_folder)

### load classification model

In [None]:
# Send the model to GPU
model = CNN()

model.load_state_dict(torch.load('./checkpoints/cifar10.pth'))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = model.to(device)

model.eval()

### load adversarial generation tool - Foolbox

In [None]:
fmodel = foolbox.models.PyTorchModel(model, bounds = (0, 1), num_classes = num_classes, device = device)

### FGSM

In [None]:
fgsm_adv = []
fgsm_index = []
fgsm_label = []

In [None]:
for idx in range(len(test_dataset)):
    img, label = test_dataset.__getitem__(idx)
    
    image = img.detach().numpy()
    
    # apply attack on source image
    attack = foolbox.attacks.FGSM(fmodel)
    adversarial = attack(image, label, max_epsilon = 0.2)
    
    if adversarial is None:
        continue
    
    #adversarial = normalize(adversarial)
    
    fgsm_adv.append(torch.from_numpy(adversarial))
    fgsm_index.append(idx)
    fgsm_label.append(label)
    
    print('image {} save'.format(idx))

In [None]:
print('length of adversarial images : {}'.format(len(fgsm_adv)))

In [None]:
with open ('./adversarial/FGSM_indexs.pickle', 'wb') as fp:
    pickle.dump(fgsm_index, fp)

with open ('./adversarial/FGSM_adv_images.pickle', 'wb') as fp:
    pickle.dump(fgsm_adv, fp)
    
with open ('./adversarial/FGSM_adv_label.pickle', 'wb') as fp:
    pickle.dump(fgsm_label, fp)

### DeepFool

In [None]:
DF_adv = []
DF_index = []
DF_label = []

In [None]:
for idx in range(len(test_dataset)):
    img, label = test_dataset.__getitem__(idx)
    
    image = img.detach().numpy()
    
    # apply attack on source image
    attack = foolbox.attacks.DeepFoolLinfinityAttack(fmodel)
    adversarial = attack(image, label, steps=50)
    
    if adversarial is None:
        continue
    
    #adversarial = normalize(adversarial)
    
    DF_adv.append(torch.from_numpy(adversarial))
    DF_index.append(idx)
    DF_label.append(label)
    
    print('image {} save'.format(idx))

In [None]:
print('length of adversarial images : {}'.format(len(DF_adv)))

In [None]:
with open ('./adversarial/DF_indexs.pickle', 'wb') as fp:
    pickle.dump(DF_index, fp)

with open ('./adversarial/DF_adv_images.pickle', 'wb') as fp:
    pickle.dump(DF_adv, fp)
    
with open ('./adversarial/DF_adv_label.pickle', 'wb') as fp:
    pickle.dump(DF_label, fp)

### SaliencyMap

In [None]:
SM_adv = []
SM_index = []
SM_label = []

In [None]:
for idx in range(len(test_dataset)):
    img, label = test_dataset.__getitem__(idx)
    
    image = img.detach().numpy()

    # apply attack on source image
    attack = foolbox.attacks.SaliencyMapAttack(fmodel)
    adversarial = attack(image, label)
    
    if adversarial is None:
        continue
    
    #adversarial = normalize(adversarial)
    
    SM_adv.append(torch.from_numpy(adversarial))
    SM_index.append(idx)
    SM_label.append(label)
    
    print('image {} save'.format(idx))

In [None]:
print('length of adversarial images : {}'.format(len(SM_adv)))

In [None]:
with open ('./adversarial/SM_indexs.pickle', 'wb') as fp:
    pickle.dump(SM_index, fp)

with open ('./adversarial/SM_adv_images.pickle', 'wb') as fp:
    pickle.dump(SM_adv, fp)
    
with open ('./adversarial/SM_adv_label.pickle', 'wb') as fp:
    pickle.dump(SM_label, fp)