# Final Project - Introduction to Machine Learning (CS680)
## Evaluating the Use of Fast Adversarial Training in Defending Against Adversarial Patch Attacks
### Pedro Maia de Sampaio Ferraz


### Clone FAST-BAT repo and train models

In [None]:
!git clone https://github.com/NormalUhr/FastBAT.git

In [None]:
%cd /content/FastBAT

In [None]:
# Regular model
!python train.py \
    --mode pgd \
    --dataset CIFAR10 \
    --attack_step 0 \
    --lr_scheduler multistep \
    --lr_max 0.1 \
    --dataset_val_ratio 0.01

# FAST-BAT model
!python train.py \
    --mode fast_bat \
    --dataset CIFAR10 \
    --attack_eps 8 \
    --attack_step_test 10 \
    --dataset_val_ratio 0.01

# # FAST-AT model
!python train.py \
    --mode fast_at \
    --dataset CIFAR10 \
    --attack_eps 8 \
    --attack_step_test 10 \
    --dataset_val_ratio 0.01

### Define methods for training adversarial patch

In [None]:
import math
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm as tqdm
from torch import autograd
from torchvision import transforms
from datasets import *
from model_zoo import *
from torch.nn.modules.utils import _pair, _quadruple

class PatchTransformer(nn.Module):
    def __init__(self):
        super(PatchTransformer, self).__init__()

    def forward(self, adv_patch, img_size, batch_size, type="random"):
        # Determine size of padding
        pad_size = (img_size - adv_patch.size(-1))

        # Expand patch to create batch_size patches
        adv_patch = adv_patch.expand(batch_size, *adv_patch.shape)
        padded_adv_patch_ext = torch.zeros(batch_size, 3, img_size, img_size).cuda()

        # Clamp patch
        adv_patch = torch.clamp(adv_patch, 0.000001, 0.999999)

        # Pad to get image size
        pad_dims = torch.randint(pad_size, (batch_size, 2,))

        if type == "corner":
            pad_func = nn.ConstantPad2d((0, pad_size, 0, pad_size), 0)
            adv_patch = pad_func(adv_patch)
            return adv_patch
        elif type == "random":
            for i in range(pad_dims.shape[0]):
                pad_func = nn.ConstantPad2d((pad_dims[i][0],pad_size-pad_dims[i][0],pad_dims[i][1],pad_size-pad_dims[i][1]), 0)
                padded_adv_patch_ext[i] = pad_func(adv_patch[i])
            return padded_adv_patch_ext
        
        return adv_patch

class PatchApplier(nn.Module):
    def __init__(self):
        super(PatchApplier, self).__init__()

    def forward(self, img, adv_patch):
        img = torch.where((adv_patch == 0), img, adv_patch)
        return img

In [None]:
from datetime import datetime

class PatchTrainer(object):
    def __init__(self, patch_size, model_path, device):
        train_dl, val_dl, test_dl, norm_layer, num_classes = cifar10_dataloader(data_dir="./data/",
                                                                                batch_size=200,
                                                                                val_ratio=0.2)
        self.patch_size = patch_size
        self.batch_size = 200
        self.img_size = 32
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.test_dl = test_dl
        self.epoch_length = len(train_dl)
        self.model = PreActResNet18(num_classes=num_classes, activation_fn=nn.ReLU)
        self.model.normalize = norm_layer
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
        self.model = self.model.eval().cuda()
        self.patch_applier = PatchApplier().cuda()
        self.patch_transformer = PatchTransformer().cuda()

    def generate_patch(self, type):
        if type == 'gray':
            adv_patch_cpu = torch.full((3, self.patch_size, self.patch_size), 0.5)
        elif type == 'random':
            adv_patch_cpu = torch.rand((3, self.patch_size, self.patch_size))
        elif type == 'transparent':
            adv_patch_cpu = torch.full((3, self.patch_size, self.patch_size), 0.0)

        return adv_patch_cpu

    def train(self):
        n_epochs = 1

        adv_patch_cpu = self.generate_patch("gray")
        adv_patch_cpu.requires_grad_(True)
        optimizer = optim.Adam([adv_patch_cpu], lr=0.05, amsgrad=True)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=50)
        criterion = nn.CrossEntropyLoss()

        no_patch = self.generate_patch("transparent")
        initial_accuracy = self.validate(no_patch)
        print(f"Baseline accuracy: {initial_accuracy}")

        loss_list = []
        val_acc_list = [initial_accuracy]
        for epoch in range(n_epochs):
            et0 = time.time()
            ep_loss = 0
            for i_batch, (img_batch, lab_batch) in enumerate(self.train_dl):
                img_batch = img_batch.cuda()
                lab_batch = lab_batch.cuda()
                adv_patch = adv_patch_cpu.cuda()

                adv_batch_t = self.patch_transformer(adv_patch, self.img_size, self.batch_size)
                p_img_batch = self.patch_applier(img_batch, adv_batch_t)

                output = self.model(p_img_batch)
                loss = -criterion(output, lab_batch)
                ep_loss += loss.detach().cpu().numpy()
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                adv_patch_cpu.data.clamp_(0,1) # keep patch in image range

                if i_batch % 20 == 0:
                    iteration = self.epoch_length * epoch + i_batch
                    print(f"Epoch {epoch}, batch {i_batch}, iteration {iteration}, loss: {loss.detach().cpu().numpy()}")

            plt.imshow(transforms.ToPILImage()(adv_patch_cpu))
            plt.show()

            loss_list.append(ep_loss)

            val_acc = self.validate(adv_patch_cpu)
            val_acc_list.append(val_acc)

            scheduler.step(ep_loss)
            et1 = time.time()
            print('  EPOCH NR: ', epoch),
            print('EPOCH LOSS: ', ep_loss)
            print('EPOCH TIME: ', et1-et0)
            im = transforms.ToPILImage('RGB')(adv_patch_cpu)
            plt.imshow(im)
            plt.show()
            now = datetime.now()
            im.save(f"saved_patches/patch_{now.strftime('%H:%M:%S')}.jpg")
            del adv_batch_t, output, p_img_batch, loss
            torch.cuda.empty_cache()
        
        return adv_patch_cpu, loss_list, val_acc_list

    def validate(self, adv_patch):
        total = 0
        correct = 0
        for i_batch, (img_batch, lab_batch) in enumerate(self.val_dl):
            img_batch = img_batch.cuda()
            lab_batch = lab_batch.cuda()
            adv_patch = adv_patch.cuda()
            adv_batch_t = self.patch_transformer(adv_patch, self.img_size, self.batch_size)
            p_img_batch = self.patch_applier(img_batch, adv_batch_t)

            adv_output = self.model(p_img_batch)
            _, predicted = torch.max(adv_output.data, 1)

            total += lab_batch.size(0)
            correct += (predicted == lab_batch).sum()
            
        val_acc = correct / len(self.val_dl.dataset)
        print(f"VALIDATION ACCURACY: {val_acc}")
        return val_acc

### Train patch on regular model

In [None]:
import matplotlib.pyplot as plt
!mkdir saved_patches

In [None]:
model_path = "results/checkpoints/CIFAR10_PGD_PreActResNet-18_Eps8_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    adv_patch, loss_list, val_acc_list = trainer.train()
    acc_iter.append(val_acc_list[-1])
regular_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
model_path = "results/checkpoints/CIFAR10_FAST_AT_PreActResNet-18_Eps8.0_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    adv_patch, loss_list, val_acc_list = trainer.train()
    acc_iter.append(val_acc_list[-1])
fast_at_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
model_path = "results/checkpoints/CIFAR10_FAST_BAT_PreActResNet-18_Eps8.0_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    adv_patch, loss_list, val_acc_list = trainer.train()
    acc_iter.append(val_acc_list[-1])
fast_bat_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
!mkdir images

In [None]:
plt.figure()
plt.plot(range(2, 11, 1), regular_val_acc, label="Regular model")
plt.plot(range(2, 11, 1), fast_at_val_acc, label="FAST-AT trained model")
plt.plot(range(2, 11, 1), fast_bat_val_acc, label="FAST-BAT trained model")
plt.legend()
plt.xlabel('size of adversarial patch')
plt.ylabel('validation accuracy')
plt.title('Comparison of adversarial patch effectiveness')
plt.xticks(np.arange(2, 11, 1))
plt.savefig('images/adversarial_patch_comparison.png')

### Evaluate results on random patches

In [None]:
model_path = "results/checkpoints/CIFAR10_PGD_PreActResNet-18_Eps8_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    random_adv_patch = trainer.generate_patch("random")
    val_acc = trainer.validate(random_adv_patch)
    acc_iter.append(val_acc)
regular_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
model_path = "results/checkpoints/CIFAR10_FAST_AT_PreActResNet-18_Eps8.0_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    random_adv_patch = trainer.generate_patch("random")
    val_acc = trainer.validate(random_adv_patch)
    acc_iter.append(val_acc)
fast_at_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
model_path = "results/checkpoints/CIFAR10_FAST_BAT_PreActResNet-18_Eps8.0_.pth"
acc_iter = []
for i in range(2, 11):
    trainer = PatchTrainer(i, model_path, "cuda")
    random_adv_patch = trainer.generate_patch("random")
    val_acc = trainer.validate(random_adv_patch)
    acc_iter.append(val_acc)
fast_bat_val_acc = list(map(lambda x: x.detach().cpu().item(), acc_iter))

In [None]:
plt.figure()
plt.plot(range(2, 11, 1), regular_val_acc, label="Regular model")
plt.plot(range(2, 11, 1), fast_at_val_acc, label="FAST-AT trained model")
plt.plot(range(2, 11, 1), fast_bat_val_acc, label="FAST-BAT trained model")
plt.legend()
plt.xlabel('size of adversarial patch')
plt.ylabel('validation accuracy')
plt.title('Comparison of random patch effectiveness')
plt.xticks(np.arange(2, 11, 1))
plt.savefig('images/adversarial_patch_comparison_random.png')