In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import os

from dataset import Adversarial_Dataset
from util_defense_GAN import adjust_lr, get_z_sets, get_z_star, Resize_Image
from model import CNN
from gan_model import Generator
from torchsummary import summary
import copy

### set parameters

In [None]:
batch_size = 32
in_channel = 3
height = 32
width = 32

display_steps = 20

### Load classification model

In [None]:
# Send the model to GPU
model = CNN()

summary(model, input_size = (in_channel,height,width), device = 'cpu')

In [None]:
device_model = torch.device(0)

In [None]:
model.load_state_dict(torch.load('./checkpoints/cifar10.pth'))

model = model.to(device_model)

### load defense-GAN model

In [None]:
learning_rate = 10.0
rec_iters = [200, 500, 1000]
rec_rrs = [10, 15, 20]
decay_rate = 0.1
global_step = 3.0
generator_input_size = 32

INPUT_LATENT = 64 

In [None]:
device_generator = torch.device(7)

In [None]:
ModelG = Generator()

generator_path = './defensive_models/gen_cifar10_gp_99.pth'

ModelG.load_state_dict(torch.load(generator_path))

summary(ModelG, input_size = (INPUT_LATENT,), device = 'cpu')

In [None]:
ModelG = ModelG.to(device_generator)

In [None]:
loss = nn.MSELoss()

### load test dataset

In [None]:
# adversarial dataset path
root_dir = './adversarial/'

In [None]:
# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

### clean Image

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader

### FGSM

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'FGSM',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader

### Deep Fool

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'DF',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader

### Saliency Map

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'SM',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader