In [1]:
%matplotlib inline

In [2]:
from tqdm.notebook import tqdm

In [3]:
import torch
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18
import torch.nn.functional as F
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

import torchmetrics

import lightly

import matplotlib.pyplot as plt
import numpy as np

  rank_zero_deprecation(


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# DATA

In [5]:
# data params
num_workers = 6
batch_size = 1

In [6]:
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

In [7]:
# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
])

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

In [8]:
val_dataloader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

@TODO: make data for classifier and for moco

... will probably have to make a custom loader that makes both something for moco and something for classifier

# MOCO data

In [9]:
# ive made this file because the standard BaseCollateFunction from lightly
# only supports returning x0 and x1 augmentations for some image x. I also want x returned
# to send to the classifier and to add the reconstruction gradient to
from myCollate import BetterSimCLRCollateFunction

In [10]:
# MoCo v2 uses SimCLR augmentations, additionally, disable blur
# collate_fn = lightly.data.SimCLRCollateFunction(
#     input_size=32,
#     gaussian_blur=0.,
# )

collate_fn = BetterSimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.
)

In [11]:
custom_tfm = torchvision.transforms.Compose([transforms.ToPILImage(), collate_fn.transform,])

In [12]:
custom_tfm

Compose(
    ToPILImage()
    Compose(
    RandomResizedCrop(size=(32, 32), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
    <lightly.transforms.rotation.RandomRotate object at 0x7efb98c8afd0>
    RandomHorizontalFlip(p=0.5)
    RandomVerticalFlip(p=0.0)
    RandomApply(
    p=0.8
    ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=[-0.1, 0.1])
)
    RandomGrayscale(p=0.2)
    <lightly.transforms.gaussian_blur.GaussianBlur object at 0x7efb98c8aee0>
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
)

In [13]:
# We use the moco augmentations for training moco
dataset_train_moco = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    # transform = og_img_transforms
)

In [14]:
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

# Classifier

In [15]:
# model params
n_classes=10
pretrained_model = "./saved_models/resnet_80/epoch=73-val_loss=0.64-val_acc=0.80.ckpt"

In [16]:
from plr18 import plr18

In [17]:
model = plr18().load_from_checkpoint(pretrained_model);

In [18]:
model.to(device);
model.eval();

# Defense model

In [19]:
memory_bank_size = 4096

In [20]:
class MocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)
        backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco based on ResNet
        self.resnet_moco = \
            lightly.models.MoCo(backbone, num_ftrs=512, m=0.99, batch_shuffle=True)

        # create our loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(
            temperature=0.1,
            memory_bank_size=memory_bank_size)

    def forward(self, x):
        self.resnet_moco(x)
        
    def contrastive_loss(self, x0, x1):
        # calculate the contrastive loss for some transformed x -> x0, x1
        # also return grad for each of these
        self.zero_grad()
        x0.requires_grad = True
        x1.requires_grad = True
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        loss.backward()
        return x0.grad, x1.grad, loss
        

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        self.log('train_loss_ssl', loss)
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()


    def configure_optimizers(self):
        optim = torch.optim.SGD(self.resnet_moco.parameters(), lr=6e-2,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]


In [21]:
moco = MocoModel()
moco.load_from_checkpoint('./saved_models/resnet_moco/epoch=142-train_loss_ssl=2.46.ckpt')
moco.eval();

# Experiment

In [22]:
seed = 1
max_epochs = 150
batch_limit = 500

In [23]:
epsilons = [0, 0.001, 0.002, 0.003, 0.004, 0.008, 0.01, .05, .1, .15, .2,]

In [24]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

In [25]:
def test( classifier, defender, test_loader, epsilon, batch_limit=100):

    # Accuracy counter
    correct = 0
    adv_examples = []
    og_contrastive_loss_avg = 0
    perturbed_contrastive_loss_avg = 0
    
    classifier = classifier.to(device)
    defender = defender.to(device)


    # Loop over all examples in test set
    for sample in enumerate(tqdm(dataloader_train_moco)):

        # Send the data and label to the device
        # data, target = data.to(device), target.to(device)
        (idx, (data, tfms, target, _)) = sample
        
        #         x0, x1 = tfms # transformed image pairs for moco
        #         x0 = x0.clone().to(device)
        #         x1 = x1.clone().to(device)
        

        make_tensor = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])
        data = make_tensor(data[0])
        data = data.to(device).unsqueeze(0) # make batch dim
        target = target[0].to(device).unsqueeze(0)
        
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = classifier(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = F.nll_loss(output, target)

        # Zero all existing gradients
        classifier.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        
        # get perturbed data pairs for moco
        perturbed_x0 = custom_tfm(perturbed_data.squeeze(0)).to(device)
        perturbed_x1 = custom_tfm(perturbed_data.squeeze(0)).to(device)
        
        # Re-classify the perturbed image
        output = classifier(perturbed_data)
        
        # fix perturbed image --------------------------------------------------------------
        x0_grad, x1_grad, og_contrastive_loss = defender.contrastive_loss(x0, x1)
        perturbed_x0_grad, perturbed_x1_grad, perturbed_contrastive_loss = defender.contrastive_loss(perturbed_x0.unsqueeze(0), perturbed_x1.unsqueeze(0))
        og_contrastive_loss_avg += og_contrastive_loss
        perturbed_contrastive_loss_avg += perturbed_contrastive_loss
        
        
        # also produce x0 and x1 for perturbed data
        # perturbed = collate_fn.transform(data.detach().cpu().numpy())        

        # Check for success
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
            # Special case for saving 0 epsilon examples
            if (epsilon == 0) and (len(adv_examples) < 5):
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
        else:
            # Save some adv examples for visualization later
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
        
        if idx > batch_limit:
            break

    # Calculate final accuracy for this epsilon
    div = min(float(len(test_loader)), float(batch_limit))
    final_acc = correct/div
    og_contrastive_loss_avg = og_contrastive_loss_avg / div
    perturbed_contrastive_loss_avg = perturbed_contrastive_loss_avg / div
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}\t og_cont_loss = {} | pert_cont_loss = {}"\
          .format(epsilon, correct, div, final_acc, og_contrastive_loss_avg, perturbed_contrastive_loss_avg))

    # Return the accuracy and an adversarial example
    return final_acc, adv_examples, og_contrastive_loss_avg, perturbed_contrastive_loss_avg

In [26]:
accuracies = []
examples = []
og_contrastive_losses = []
perturbed_contrastive_losses = []

# Run test for each epsilon
for eps in epsilons:
    acc, ex, og_contrastive_loss, perturbed_contrastive_loss = test(model, moco, val_dataloader, eps)
    accuracies.append(acc)
    examples.append(ex)
    og_contrastive_losses.append(og_contrastive_loss)
    perturbed_contrastive_losses.append(perturbed_contrastive_loss)

  0%|          | 0/10000 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


NameError: name 'x0' is not defined

In [None]:
## SANITY CHECK CODE in case the above doesnt work as expected ##

# correct = 0
# for d, t, _ in tqdm(test_loader):
#     data, target = d.to(device), t.to(device)
#     out = model(data).max(1, keepdim=True)[1]
# #     print('out:', out)
# #     print('target:', target)
    
#     if out.item() == target.item():
#         correct += 1
# print('CORRECT:', correct)


In [None]:
og_contrastive_losses = [t.cpu().detach().numpy() for t in og_contrastive_losses]
perturbed_contrastive_losses = [t.cpu().detach().numpy() for t in perturbed_contrastive_losses]

In [None]:
plt.figure(figsize=(5,5))
plt.plot(epsilons, og_contrastive_losses, "*-", label='Original')
plt.plot(epsilons, perturbed_contrastive_losses, "*-", label='Perturbed')
# plt.yticks(np.arange(0, 1.1, step=0.1))
# plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Contrastive Loss vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Contrastive Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(5,5))
plt.plot(epsilons, og_contrastive_losses, "*-", label='Original')
plt.plot(epsilons, perturbed_contrastive_losses, "*-", label='Perturbed')
# plt.yticks(np.arange(0, 1.1, step=0.1))
# plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Contrastive Loss vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Contrastive Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(5,5))
plt.plot(epsilons, og_contrastive_losses, "*-", label='Original')
plt.plot(epsilons, perturbed_contrastive_losses, "*-", label='Perturbed')
# plt.yticks(np.arange(0, 1.1, step=0.1))
# plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(5,5))
plt.plot(epsilons, accuracies, "*-")
# plt.yticks(np.arange(0, 1.1, step=0.1))
# plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()

Sample Adversarial Examples
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Remember the idea of no free lunch? In this case, as epsilon increases
the test accuracy decreases **BUT** the perturbations become more easily
perceptible. In reality, there is a tradeoff between accuracy
degredation and perceptibility that an attacker must consider. Here, we
show some examples of successful adversarial examples at each epsilon
value. Each row of the plot shows a different epsilon value. The first
row is the $\epsilon=0$ examples which represent the original
“clean” images with no perturbation. The title of each image shows the
“original classification -> adversarial classification.” Notice, the
perturbations start to become evident at $\epsilon=0.15$ and are
quite evident at $\epsilon=0.3$. However, in all cases humans are
still capable of identifying the correct class despite the added noise.




In [None]:
# Plot several examples of adversarial samples at each epsilon
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0]),cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel("{}eps".format(epsilons[i]), fontsize=14) # eps 
        orig,adv,ex = examples[i][j]
        ex = np.moveaxis(ex, 0, -1)
        plt.title("{} -> {}".format(orig, adv))
        plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()