In [None]:
import torch
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np

from tqdm import tqdm

import utils

from utils.datasets import get_dataloaders
from models.resnet import ResNet18

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent
from easydict import EasyDict

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## Functions for robust training and eval

In [None]:
def pgd_l2(model, X, y, epsilon=[1., 4., 8.], num_iter=10):
    """ Construct FGSM adversarial examples on the examples X"""
    # randomly choose epsilon
    epsilon = np.random.choice(epsilon)
    
    # choose a random starting point with length epsilon / 2
    delta = torch.rand_like(X, requires_grad=True) 
    norm = torch.linalg.norm(delta.flatten())
    delta.data = epsilon * delta.data / norm / 2
    
    alpha = 2.5 * epsilon / num_iter # fixed step size of 2.5*epsilon/100 as in https://arxiv.org/pdf/1706.06083.pdf
    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()(model(X + delta), y)
        loss.backward()
        
        # take a step
        step = delta.grad.detach()
        step = alpha * step / torch.linalg.norm(step.flatten())  
        delta.data = delta.data + step
        
        # project on the epsilon ball around X if necessary
        norm = torch.linalg.norm(delta.flatten())
        if norm > epsilon:
            delta.data = epsilon * delta.data / norm
        
        # next iteration
        delta.grad.zero_()
    return delta.detach()

In [None]:
def cleverhans_eval_l2(net, testloader, eps):
    # Evaluate on clean and adversarial data
    net.eval()
    report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
    for idx, (x, y) in enumerate(testloader):
        x, y = x.to(device), y.to(device)
        x_fgm = fast_gradient_method(net, x, eps, 2)
        x_pgd = projected_gradient_descent(net, x, eps, 2.5*eps/100, 100, 2)
        _, y_pred = net(x).max(1)  # model prediction on clean examples
        _, y_pred_fgm = net(x_fgm).max(
                1
        )  # model prediction on FGM adversarial examples
        _, y_pred_pgd = net(x_pgd).max(
                1
        )  # model prediction on PGD adversarial examples
        report.nb_test += y.size(0)
        report.correct += y_pred.eq(y).sum().item()
        report.correct_fgm += y_pred_fgm.eq(y).sum().item()
        report.correct_pgd += y_pred_pgd.eq(y).sum().item()
        if idx > 9:
            break # 1280 samples only
    print(
        "test acc on clean examples (%): {:.3f}".format(
            report.correct / report.nb_test * 100.0
        )
    )
    print(
        "test acc on FGM adversarial examples (%): {:.3f}".format(
            report.correct_fgm / report.nb_test * 100.0
        )
    )
    print(
        "test acc on PGD adversarial examples (%): {:.3f}".format(
            report.correct_pgd / report.nb_test * 100.0
        )
    )
    return 1 - report.correct_pgd / report.nb_test

## Training Loop

In [None]:
trainloader, testloader = get_dataloaders("simple_word_distractor_mnist", batch_size=128)

In [None]:
model = ResNet18(in_channel=1)
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=1e-1)
scheduler = MultiStepLR(optimizer, milestones=[3], gamma=0.1)

In [None]:
ce_loss = torch.nn.CrossEntropyLoss()

for i_epoch in range(9): #15?    
    model.train()
    train_loss = 0
    train_zero_one_loss = 0
    for img, label in tqdm(trainloader):
        img, label = img.to(device), label.to(device)
        delta = pgd_l2(model, img, label) # adversarial perturbation
        pred = model(img + delta)
        optimizer.zero_grad()
        loss = ce_loss(pred, label)
        loss.backward()
        train_loss += loss.item()  
        train_zero_one_loss += (pred.softmax(dim=1).argmax(dim=1) != label).sum().item()
        optimizer.step()
    average_loss, acc = utils.datasets.test(model, testloader, device)
    print(f'Epoch {i_epoch}. Avg. Loss: {train_loss / len(trainloader.dataset)}. Avg. Val Loss: {average_loss}. Acc.: {1-train_zero_one_loss / len(trainloader.dataset)}.  Val Acc. {acc}')
    scheduler.step()
    
    torch.save(model.state_dict(), f'../saved_models/mnist_simple_word_distractor_adv_robust_l2_epoch_{i_epoch}.pth')

In [None]:
torch.save(model.state_dict(), f'../saved_models/mnist_simple_word_distractor_adv_robust_l2.pth')

## Eval

In [None]:
model.load_state_dict( torch.load(f'../saved_models/mnist_simple_word_distractor_adv_robust_l2.pth'))
model.eval()
model.to(device)

In [None]:
for epsilon in [0.001, 0.5, 1, 2, 4, 8, 10, 20]:
    print(f' --------- {epsilon} --------- ')
    cleverhans_eval_l2(model, testloader, epsilon)