## MNIST with a distractor

In [None]:
import torch
import torch.nn.functional as F

import torchvision

from models.resnet import ResNet18

from utils.datasets import get_dataloaders

import matplotlib.pyplot as plt

import numpy as np
import tqdm

import utils
from utils.edm_score import input_gradient

In [None]:
# avoid type-3 fonts
import matplotlib
import seaborn as sns

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

sns.set_style("white")
sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=2)

### Create an MNIST data set with a single distractor

In [None]:
device = 'cuda'

In [None]:
mnist_trainloader, mnist_testloader = get_dataloaders("mnist", batch_size=32)

In [None]:
eval_images = iter(mnist_trainloader).__next__()[0].to(device)

In [None]:
from PIL import Image, ImageDraw, ImageFont
import random

def draw_text_simple(img, text=None):
    """ draws the letter A on an image"""
    img = img.copy()
    draw = ImageDraw.Draw(img)
    fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", random.choice([35]))
    draw.text((9, 2), "A", font=fnt, 
            stroke_width=1,
            fill="white",
            stroke_fill="white")
    return img

# create 28x28 black and white image with pillow
img = Image.new('L', (40, 40), color=0)
img = draw_text_simple(img, "AB")

img = np.array(img)
plt.imshow(img, cmap='gray')

In [None]:
def create_distractor(text):
    img = Image.new('L', (40, 40), color=0)
    img = draw_text_simple(img, text)
    img = img.resize((28, 28))
    return np.array(img)

img = create_distractor("AB")
plt.imshow(img, cmap='gray')

In [None]:
import random, string

def create_distractor_image(mnist_image, text=None):
    if text is None:
        text = ''.join(random.choice(string.ascii_letters) for _ in range(2))
    distractor = torch.Tensor(create_distractor(text)) / 255
    image = torch.zeros((56, 28))
    on_top = False
    if np.random.random() < 0.5: # number on top block
        image[:28, :] = mnist_image
        image[28:, :] = distractor
    else:
        image[:28, :] = distractor
        image[28:, :] = mnist_image
        on_top = True
    return image, on_top


img2, on_top = create_distractor_image(eval_images[0][0])
plt.imshow(img2, cmap='gray')

In [None]:
trainset = torchvision.datasets.MNIST('data/', train=True, transform=torchvision.transforms.ToTensor())
testset = torchvision.datasets.MNIST('data/', train=False, transform=torchvision.transforms.ToTensor())

In [None]:
word_distractor_mnist_x = torch.zeros((60000, 1, 56, 28))
word_distractor_mnist_y = torch.zeros((60000,))
dataloader = torch.utils.data.DataLoader(trainset, batch_size=1)
for idx, (img, label) in enumerate(iter(dataloader)):
    word_distractor_mnist_x[idx] = create_distractor_image(img[0][0])[0]
    word_distractor_mnist_y[idx] = label

In [None]:
word_distractor_mnist_test_x = torch.zeros((10000, 1, 56, 28))
word_distractor_mnist_test_y = torch.zeros((10000,))
dataloader = torch.utils.data.DataLoader(testset, batch_size=1)
for idx, (img, label) in enumerate(iter(dataloader)):
    word_distractor_mnist_test_x[idx] = create_distractor_image(img[0][0])[0]
    word_distractor_mnist_test_y[idx] = label

In [None]:
torch.save((word_distractor_mnist_x, word_distractor_mnist_y, word_distractor_mnist_test_x, word_distractor_mnist_test_y), '../datasets/simple_word_distractor_mnist.pt')

#### (training with pgd in a separate notebook)

### Load and compare the two models

In [None]:
device = 'cuda'

model_files = ['../saved_models/simple_word_distractor_mnist/resnet18_reg=none_simple_word_distractor_mnist.pt',
               '../saved_models/simple_word_distractor_mnist/mnist_simple_word_distractor_adv_robust_l2.pth']

models = {}
for model_file in model_files:
    model = ResNet18(in_channel=1)
    model.load_state_dict(torch.load(model_file))
    model.eval()
    model.to('cpu')
    models[os.path.basename(model_file)] = model    

In [None]:
trainloader, testloader = get_dataloaders("simple_word_distractor_mnist", batch_size=32)

In [None]:
for model_name, model in models.items():
    model.to('cuda')
    print(model_name, utils.datasets.test(model, testloader, device))
    model.to('cpu')

In [None]:
eval_images = []
for img, _ in tqdm.tqdm(testloader):
    eval_images.append(img.detach().cpu())
eval_images = torch.vstack(eval_images)

In [None]:
input_gradients = {k : [] for k, _ in models.items()}
for model_name, model in models.items():
    for img, _ in tqdm.tqdm(testloader):
        img = img.to(device)
        gradient = input_gradient(model, img).detach().cpu()
        input_gradients[model_name].append(gradient)
input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}

In [None]:
# scale the lenght of input gradients so that they lie in [-1,1]
for model_name, _ in models.items():
    for idx in range(input_gradients[model_name].shape[0]):
        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()

In [None]:
on_top = [1, 0, 0, 1, 1, 0, 0 ,0 ,0 , 1]
__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))

for idx in range(10):
    img = eval_images[idx, :, :, :].clone().detach().squeeze()
    noise = 0.25*torch.randn_like(img) 
    if on_top[idx]:
        img[:28, :] = img[:28, :] + noise[:28, :]
    else:
        img[28:, :] = img[28:, :] + noise[28:, :]
    img = (img * 255).clip(0, 255).to(torch.uint8)
    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)
    axs[idx].axis('off')

In [None]:
on_top = [1, 0, 0, 1, 1, 0, 0 ,0 ,0 , 1]
__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))

for idx in range(10):
    img = eval_images[idx, :, :, :].clone().detach().squeeze()
    noise = 0.25*torch.randn_like(img) 
    if not on_top[idx]:
        img[:28, :] = img[:28, :] + noise[:28, :]
    else:
        img[28:, :] = img[28:, :] + noise[28:, :]
    img = (img * 255).clip(0, 255).to(torch.uint8)
    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)
    axs[idx].axis('off')

In [None]:
model_name = 'resnet18_reg=none_simple_word_distractor_mnist.pt'
__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))

for idx in range(10):
    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)
    axs[idx].axis('off')

In [None]:
model_name = 'mnist_simple_word_distractor_adv_robust_l2.pth'
__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))

for idx in range(10):
    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)
    axs[idx].axis('off')

### Figure 5: Compare robustness on signal versus distractor part

In [None]:
def softmax_l1(x, y):
    return (F.softmax(x, dim=1) - F.softmax(y, dim=1)).abs().sum(axis=1)

In [None]:
mnist_trainloader, mnist_testloader = get_dataloaders("mnist", batch_size=1)

# create a new evaluation data set where we remember the position of the distractor
eval_images = []
eval_images_on_top = []
for idx, (mnist_images, label) in enumerate(mnist_testloader):
    img, on_top = create_distractor_image(mnist_images[0][0])
    img = img
    eval_images.append(img)
    eval_images_on_top.append(on_top)

In [None]:
data_noise_std = np.logspace(-2, 1, num=30)
data_noise_std = np.insert(data_noise_std, 0, 0)
model_results = {}

for model_name, model in models.items():
    # noise on entire image
    all_noise_results = {sigma: [] for sigma in data_noise_std}
    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):
        for sigma in data_noise_std:
            noise = sigma * torch.randn_like(img)
            noisy_image = img.clone().detach() + noise
            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))
            all_noise_results[sigma].append(logits.detach().cpu())
    all_noise_results = {x: torch.vstack(y) for x,y in all_noise_results.items()}

    # noise only on signal
    signal_noise_results = {sigma: [] for sigma in data_noise_std}
    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):
        for sigma in data_noise_std:
            noise = sigma * torch.randn_like(img)
            noisy_image = img.clone().detach()
            if not on_top:
                noisy_image[:28, :] = noisy_image[:28, :] + noise[:28, :]
            else:
                noisy_image[28:, :] = noisy_image[28:, :] + noise[28:, :]
            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))
            signal_noise_results[sigma].append(logits.detach().cpu())
    signal_noise_results = {x: torch.vstack(y) for x,y in signal_noise_results.items()}
    
    # noise only on distractor
    distractor_noise_results = {sigma: [] for sigma in data_noise_std}
    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):
        for sigma in data_noise_std:
            noise = sigma * torch.randn_like(img)
            noisy_image = img.clone().detach()
            if on_top:
                noisy_image[:28, :] = noisy_image[:28, :] + noise[:28, :]
            else:
                noisy_image[28:, :] = noisy_image[28:, :] + noise[28:, :]
            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))
            distractor_noise_results[sigma].append(logits.detach().cpu())
    distractor_noise_results = {x: torch.vstack(y) for x,y in distractor_noise_results.items()}


    # store results
    model_results[model_name] = [all_noise_results, distractor_noise_results, signal_noise_results]

    # plot
    plot_result = {}
    original = all_noise_results[0.0]
    for sigma in data_noise_std: 
        noisy = all_noise_results[sigma]
        estimate = softmax_l1(noisy, original)
        plot_result[sigma] = estimate.mean().item()
    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Entire Image')

    plot_result = {}
    original = distractor_noise_results[0.0]
    for sigma in data_noise_std: 
        noisy =  distractor_noise_results[sigma]
        estimate = softmax_l1(noisy, original)
        plot_result[sigma] = estimate.mean().item()
    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Distractor')

    plot_result = {}
    original = signal_noise_results[0.0]
    for sigma in data_noise_std: 
        noisy =  signal_noise_results[sigma]
        estimate = softmax_l1(noisy, original)
        plot_result[sigma] = estimate.mean().item()
    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Signal')

    plt.xscale('log')
    plt.xlabel("Magnitude of Noise Added to Image")
    plt.ylabel("L1 Deviation with Original Softmax Score")
    plt.legend(fontsize=6)  
    plt.tight_layout()
    plt.legend(loc=4)
    plt.title(model_name)
    plt.show()

In [None]:
torch.save(model_results, '../data/relative_noise_robustness_results.pt')

In [None]:
data_noise_std = np.logspace(-2, 1, num=30)
data_noise_std = np.insert(data_noise_std, 0, 0)
model_results = torch.load('../data/relative_noise_robustness_results.pt')

In [None]:
sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=1.)

#plt.figure(figsize=(8,6))

model_name_1 = 'resnet18_reg=none_simple_word_distractor_mnist.pt'
model_name_2 = 'mnist_simple_word_distractor_adv_robust_l2.pth'

for model_name in [model_name_1, model_name_2]:
    all_noise, distractor_noise, signal_noise = model_results[model_name]
    
    plot_result = {}
    original = all_noise[0.0]
    for sigma in data_noise_std: 
        dn =  distractor_noise[sigma]
        sn = signal_noise[sigma]
        dn_estimate = softmax_l1(dn, original)
        sn_estimate = softmax_l1(sn, original)
        plot_result[sigma] = sn_estimate.mean().item() / (dn_estimate.mean().item()+1e-12)
    plt.plot(data_noise_std, list(plot_result.values()), 'o--', label= 'Resnet18' if model_name == 'resnet18_reg=none_simple_word_distractor_mnist.pt' else 'Robust Resnet18', ms=6, lw=2)

plt.hlines(1, 0, 10, color = 'red', lw=2) # , label="Equally Robustness to Signal and Distractor Noise"
plt.xscale('log')
#plt.yscale('log')
plt.xlabel("Noise Level")
plt.ylabel("Signal/Distractor")
plt.legend(fontsize=6)  
plt.tight_layout()
# Put a legend below current axis
#plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
#          fancybox=True, shadow=True, ncol=1)
plt.legend()
plt.ylim([0,12])
plt.yticks([1,4,7,10])
plt.title("Relative Noise Robustness (MNIST)")
plt.savefig('../figures/relative-noise-robustness.pdf', bbox_inches='tight')
plt.show()