In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import os

from dataset import Adversarial_Dataset
from model import CNN
from torchsummary import summary
import copy

In [None]:
batch_size = 32

### load data

In [None]:
# adversarial dataset path
root_dir = './adversarial/'

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

### load classification model

In [None]:
model = CNN()

summary(model, input_size = (3,32,32), device = 'cpu')

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model.load_state_dict(torch.load('./checkpoints/cifar10.pth'))

model = model.to(device)

### Clean image

In [None]:
model.eval()   # Set model to evaluate mode
    
running_corrects = 0.0

epoch_size = 0.0

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(testloader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)

        # statistics
        running_corrects += torch.sum(preds == labels.data)
        
        epoch_size += inputs.size(0)

running_corrects =  running_corrects.double() / epoch_size

print('Test  Acc: {:.4f}'.format(running_corrects))

### FGSM

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'FGSM',adversarial_transform)

In [None]:
sample_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()   # Set model to evaluate mode
    
running_corrects = 0.0

epoch_size = 0.0

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(sample_loader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)
        
        # statistics
        running_corrects += torch.sum(preds == labels.data)
        
        epoch_size += inputs.size(0)
        
running_corrects =  running_corrects.double() / epoch_size

print('Test  Acc: {:.4f}'.format(running_corrects))

### Deep Fool

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'DF',adversarial_transform)

In [None]:
sample_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()   # Set model to evaluate mode
    
running_corrects = 0.0

epoch_size = 0.0

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(sample_loader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)
        
        # statistics
        running_corrects += torch.sum(preds == labels.data)
        
        epoch_size += inputs.size(0)
        
running_corrects =  running_corrects.double() / epoch_size

print('Test  Acc: {:.4f}'.format(running_corrects))

### Saliency Map

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'SM',adversarial_transform)

In [None]:
sample_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()   # Set model to evaluate mode
    
running_corrects = 0.0

epoch_size = 0.0

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(sample_loader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)
        
        # statistics
        running_corrects += torch.sum(preds == labels.data)
        
        epoch_size += inputs.size(0)
        
running_corrects =  running_corrects.double() / epoch_size

print('Test  Acc: {:.4f}'.format(running_corrects))