In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch.nn as nn
import torch
import numpy as np
from matplotlib import pyplot as plt


In [None]:
from metrics import Metric

In [None]:
def get_adv(model, benign_examples, target_labels, metric: Metric, c):
    step_size = 1e-2
    adversarial_examples = torch.zeros(benign_examples.shape)
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    for _ in range(100):
        adversarial_examples.requires_grad = True
        if adversarial_examples.grad is not None:
            adversarial_examples.grad.zero_()
        loss = c * metric(adversarial_examples, benign_examples).sum() \
            + loss_fn(model(adversarial_examples), target_labels)
        loss.backward()
        adversarial_examples = (adversarial_examples - step_size * adversarial_examples.grad.apply_(lambda x: 1 if x >= 0 else -1)).detach()
    return adversarial_examples


def lbfgs_batch(model, benign_examples, labels, metric: Metric):
    batch = len(benign_examples)
    all_adversarial_examples = torch.zeros(batch, 9, 28, 28)
    target_labels = torch.tensor([[i for i in range(10) if i != label] for label in labels])
    for i in range(9):
        print(f'--- {i} ---')
        successful_indexes = []
        unsuccessful_indexes = [i for i in range(batch)]
        c = 100
        while unsuccessful_indexes:
            still_benign_examples = torch.tensor([benign_examples[j].tolist() for j in unsuccessful_indexes])
            still_target_labels = torch.tensor([target_labels[j, i] for j in unsuccessful_indexes])
            adversarial_examples = get_adv(model, still_benign_examples, still_target_labels, metric, c)
            adversarial_preds = torch.argmax(model(adversarial_examples), dim=1)
            indexes_to_delete = []
            for j in range(len(adversarial_examples)):
                # print(j)
                if adversarial_preds[j] != labels[unsuccessful_indexes[j]] or c <= 0.01:
                    all_adversarial_examples[unsuccessful_indexes[j], i, :, :] = adversarial_examples[j, :, :, :]
                    successful_indexes.append(unsuccessful_indexes[j])
                    indexes_to_delete.append(unsuccessful_indexes[j])
            for j in indexes_to_delete:
                unsuccessful_indexes.remove(j)
            c *= 0.1
    expanded_examples = benign_examples.expand(batch, 9, 28, 28)
    norms = torch.zeros(batch, 9)
    for i in range(9):
        norms[:, i] = metric(all_adversarial_examples[:, i, :, :].reshape(batch, 1, 28, 28), expanded_examples.reshape(batch, 1, 28, 28))
        preds = torch.argmax(model(all_adversarial_examples[:,i,:,:].reshape(batch, 1, 28, 28)), dim=1)
        norms[:, i] += torch.tensor([torch.inf if preds[j] == labels[j] else 0 for j in range(batch)])
    selected_adversarial_examples = torch.zeros(benign_examples.shape)
    indexes = torch.argmin(norms, dim=1)
    for i in range(batch):
        selected_adversarial_examples[i, 0, :, :] = all_adversarial_examples[i, indexes[i], :, :]
    return selected_adversarial_examples

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Linear(1024, 200),
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(200, 10),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x, y):
        return self.seq(x, y)

In [None]:
model = Model()

