# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math

In [None]:
import torch.nn as nn
import torch
import torch.linalg as linalg

In [None]:
from utils import MnistData, Clipper
from models import ModelManager, ModelType
from adversarials import ClassificationAdversarials

In [None]:
import numpy as np
from matplotlib import pyplot as plt

# Settings

In [None]:
linfty_norm_radius = 50 / 255
lone_norm_radius = 28 * 28 * 50 / 255
ltwo_norm_radius = 28 * 50 / 255

In [None]:
model = ModelManager.get_trained(ModelType.MnistCnnB)

In [None]:
batch_size = 10

In [None]:
data = MnistData(True)

# CW functions

In [None]:
def solve_for_linfty(model, benign_image, label, c_lambda, max_norm):
    step_size = 1e-2
    adv = torch.zeros(benign_image.shape)
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    norm_of_diff = lambda x, y: torch.max(torch.abs(x - y))
    adv = adv.unsqueeze(0)
    benign_image = benign_image.unsqueeze(0)
    for _ in range(100):
        adv.requires_grad = True
        if adv.grad is not None:
            adv.grad.zero_()
        loss = norm_of_diff(adv, benign_image) \
            - c_lambda * loss_fn(model(adv), torch.Tensor([label]).type(torch.long))
        loss.backward()
        new_adv = Clipper.clip(
            benign_image,
            (adv - step_size * adv.grad.apply_(lambda x: 1 if x >= 0 else -1)),
            max_norm
        )
        adv = new_adv
    if torch.argmax(model(adv), dim=1)[0] != label or c_lambda > 10:
        return adv.squeeze(0)
    return None

def cw_linfty(model: nn.Module, benign_examples: torch.Tensor, labels: torch.Tensor, max_norm) -> torch.Tensor:
    advs = []
    for i in range(len(benign_examples)):
        print(f'--- {i} ---')
        benign_example, label = benign_examples[i], labels[i]
        adv = None
        c_lambda = 1e-2
        while adv is None:
            adv = solve_for_linfty(model, benign_example, label, c_lambda, max_norm)
            c_lambda *= 1.1
        advs.append(adv)
    return torch.Tensor([adv.tolist() for adv in advs])

In [None]:
def lone_norm(examples: torch.Tensor) -> torch.Tensor:
    return torch.sum(torch.abs(examples))

def solve_for_lone(model, benign_image, label, c_lambda, norm):
    step_size = 1e-2
    adv = torch.zeros(benign_image.shape)
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    adv = adv.unsqueeze(0)
    benign_image = benign_image.unsqueeze(0)
    for _ in range(100):
        adv.requires_grad = True
        if adv.grad is not None:
            adv.grad.zero_()
        loss = lone_norm(adv - benign_image) \
            - c_lambda * loss_fn(model(adv), torch.Tensor([label]).type(torch.long))
        loss.backward()
        new_adv = Clipper.clip_with_custom_norm(
            benign_image,
            (adv - step_size * adv.grad.apply_(lambda x: 1 if x >= 0 else -1)),
            lone_norm,
            norm
        )
        adv = new_adv
    if torch.argmax(model(adv), dim=1)[0] != label or c_lambda > 10:
        return adv.squeeze(0)
    return None

def cw_lone(model: nn.Module, benign_examples: torch.Tensor, labels: torch.Tensor, norm) -> torch.Tensor:
    advs = []
    for i in range(len(benign_examples)):
        print(f'--- {i} ---')
        benign_example, label = benign_examples[i], labels[i]
        adv = None
        c_lambda = 1e-2
        while adv is None:
            adv = solve_for_lone(model, benign_example, label, c_lambda, norm)
            c_lambda *= 1.1
        advs.append(adv)
    return torch.Tensor([adv.tolist() for adv in advs])

In [None]:
def ltwo_norm(examples: torch.Tensor) -> torch.Tensor:
    return torch.sum(torch.abs(examples) ** 2) ** (1 / 2)

def solve_for_ltwo(model, benign_image, label, c_lambda, norm):
    step_size = 1e-2
    adv = torch.zeros(benign_image.shape)
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    adv = adv.unsqueeze(0)
    benign_image = benign_image.unsqueeze(0)
    for _ in range(100):
        adv.requires_grad = True
        if adv.grad is not None:
            adv.grad.zero_()
        loss = ltwo_norm(adv - benign_image) \
            - c_lambda * loss_fn(model(adv), torch.Tensor([label]).type(torch.long))
        loss.backward()
        new_adv = Clipper.clip_with_custom_norm(
            benign_image,
            (adv - step_size * adv.grad.apply_(lambda x: 1 if x >= 0 else -1)),
            ltwo_norm,
            norm
        )
        adv = new_adv
    if torch.argmax(model(adv), dim=1)[0] != label or c_lambda > 10:
        return adv.squeeze(0)
    return None

def cw_ltwo(model: nn.Module, benign_examples: torch.Tensor, labels: torch.Tensor, norm) -> torch.Tensor:
    advs = []
    for i in range(len(benign_examples)):
        print(f'--- {i} ---')
        benign_example, label = benign_examples[i], labels[i]
        adv = None
        c_lambda = 1e-2
        while adv is None:
            adv = solve_for_ltwo(model, benign_example, label, c_lambda, norm)
            c_lambda *= 1.1
        advs.append(adv)
    return torch.Tensor([adv.tolist() for adv in advs])

# Generation

In [None]:
benign_examples, labels = data.choose_first_well_classified(batch_size, model)

In [None]:
cw_linfty_examples = cw_linfty(model, benign_examples, labels, linfty_norm_radius)
cw_lone_examples = cw_lone(model, benign_examples, labels, lone_norm_radius)
cw_ltwo_examples = cw_ltwo(model, benign_examples, labels, ltwo_norm_radius)

In [None]:
# Save'em all
for i in range(batch_size):
    example = np.array(benign_examples[i].detach()).reshape(28, 28)
    plt.imshow(example, cmap='gray', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"AEXAMPLES\\CW_NORMS\\benign_{i}.png", bbox_inches="tight", pad_inches=0)
    
    example = np.array(cw_linfty_examples[i].detach()).reshape(28, 28)
    plt.imshow(example, cmap='gray', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"AEXAMPLES\\CW_NORMS\\cw_linfty_{i}.png", bbox_inches="tight", pad_inches=0)
    
    example = np.array(cw_lone_examples[i].detach()).reshape(28, 28)
    plt.imshow(example, cmap='gray', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"AEXAMPLES\\CW_NORMS\\cw_lone_{i}.png", bbox_inches="tight", pad_inches=0)
    
    example = np.array(cw_ltwo_examples[i].detach()).reshape(28, 28)
    plt.imshow(example, cmap='gray', vmin=0, vmax=1)
    plt.axis("off")
    plt.savefig(f"AEXAMPLES\\CW_NORMS\\cw_ltwo_{i}.png", bbox_inches="tight", pad_inches=0)

cw_linfty_adversarials = ClassificationAdversarials.get_adversarials(model, benign_examples, labels, cw_linfty_examples)
print(f'cw_linfty: {len(cw_linfty_adversarials)}')

cw_lone_adversarials = ClassificationAdversarials.get_adversarials(model, benign_examples, labels, cw_lone_examples)
print(f'cw_lone: {len(cw_lone_adversarials)}')

cw_ltwo_adversarials = ClassificationAdversarials.get_adversarials(model, benign_examples, labels, cw_ltwo_examples)
print(f'cw_ltwo: {len(cw_ltwo_adversarials)}')

# Ad Hoc experiments

In [None]:
def ltwo_norm_mnist(examples: torch.Tensor) -> torch.Tensor:
    return torch.sum(torch.abs(examples) ** 2, dim=2).sum(dim=2).sum(dim=1) ** (1 / 2)

def phi(model, benign_examples, delta, labels, ltwo_norm_radius, c_lambda):
    loss_fn = nn.L1Loss(reduction='sum')
    delta.requires_grad = True
    if delta.grad is not None:
        delta.grad.zero_()
    c_lambda.requires_grad = True
    if c_lambda.grad is not None:
        c_lambda.grad.zero_()
    dim = 1
    for i in range(1, len(benign_examples.shape)):
        dim *= benign_examples.shape[i]
    one_hot = torch.Tensor([[1 if label == j else 0 for j in range(10)] for label in labels])
    loss = loss_fn(model(benign_examples + delta),one_hot)
    loss.backward(retain_graph=True)
    phi_value = (- 2 * delta.reshape(len(benign_examples), dim) * c_lambda.reshape(len(benign_examples), 1) \
        - delta.grad.reshape(len(benign_examples), dim)).reshape(len(benign_examples), dim)
    phi_n_1 = - (delta ** 2).sum(dim=2).sum(dim=2).sum(dim=1) + ltwo_norm_radius
    return phi_value, phi_n_1

def cw_newton_ltwo(model, benign_examples, labels, ltwo_norm_radius):
    benign_examples.requires_grad = True
    dim = 1
    for i in range(1, len(benign_examples.shape)):
        dim *= benign_examples.shape[i]
    delta = torch.rand(benign_examples.shape) - 0.5
    norms_inverse = (1 / ltwo_norm_mnist(delta)).reshape(len(benign_examples),1)
    delta = (norms_inverse * delta.reshape(len(benign_examples), dim)).reshape(len(benign_examples), 1, 28, 28).detach()
    c_lambda = torch.ones(len(benign_examples), 1)
    c_lambda.requires_grad = True
    while True:
        d_phi_value = torch.zeros(len(benign_examples), dim + 1, dim + 1)
        phi_value, phi_n_1 = phi(model, benign_examples, delta, labels, ltwo_norm_radius, c_lambda)
        for i in range(dim):
            if delta.grad is not None:
                delta.grad.zero_()
            if c_lambda.grad is not None:
                c_lambda.grad.zero_()
            phi_value_i = phi_value[:, i]
            phi_value_i = phi_value_i.sum()
            phi_value_i.backward(retain_graph=True)
            d_phi_value[:, i, 0: dim] = delta.grad.detach().reshape(len(benign_examples), dim)
            d_phi_value[:, i, dim] = c_lambda.grad.detach().reshape(len(benign_examples))
        if delta.grad is not None:
            delta.grad.zero_()
        if c_lambda.grad is not None:
            c_lambda.grad.zero_()
        phi_n_1.sum().backward()
        d_phi_value[:, dim, 0: dim] = delta.grad.detach().reshape(len(benign_examples), dim)
        # d_phi_value_inverse = torch.inverse(d_phi_value).reshape(len(benign_examples), dim+1, dim+1)
        whole_phi = torch.zeros(len(benign_examples), dim + 1, 1)
        whole_phi[:, 0: dim, 0] = phi_value
        whole_phi[:, dim, 0] = phi_n_1
        whole_delta = torch.zeros(len(benign_examples), dim + 1)
        whole_delta[:, 0:dim] = delta.reshape(len(benign_examples), dim)
        whole_delta[:, dim] = c_lambda.reshape(len(benign_examples))
        zbuchlo_to = False
        try:
            product = linalg.solve( d_phi_value, whole_phi)
        except RuntimeError:
            product = whole_phi
            zbuchlo_to = True
        new_whole_delta = whole_delta - product.reshape(len(benign_examples), dim+1)
        if nn.MSELoss(reduction='sum')(whole_delta, new_whole_delta) <= len(benign_examples) * 1e-7 or zbuchlo_to:
            delta, c_lambda = new_whole_delta[:, 0: dim].reshape(len(benign_examples), 1, 28, 28), new_whole_delta[:, dim].reshape(len(benign_examples),1)
            break
        else:
            delta, c_lambda = new_whole_delta[:, 0: dim].reshape(len(benign_examples), 1, 28, 28).detach(), new_whole_delta[:, dim].detach()
    return benign_examples + delta


In [None]:
cw_newton = cw_newton_ltwo(model, benign_examples, labels, ltwo_norm_radius)

In [None]:
for i in range(batch_size):
    MnistData.display(cw_newton[i], scale=True)
advs = ClassificationAdversarials.get_adversarials(model, benign_examples, labels, cw_newton)
len(advs)