In [None]:
import torch
import torchvision

import numpy as np
import tqdm

import matplotlib.pyplot as plt

from utils.edm_score import input_gradient

from torchvision.models import resnet18

from collections import OrderedDict

In [None]:
# avoid type-3 fonts
import matplotlib
import seaborn as sns
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

sns.set_style("white")
sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=2)

In [None]:
device = 'cuda'

## Load the data set

In [None]:
val_images, val_labels = torch.load('../datasets/imagenet-64x64-val.pth')
valset = torch.utils.data.TensorDataset(val_images, val_labels)
valloader = torch.utils.data.DataLoader(valset, batch_size=512, shuffle=True, num_workers=8)
valloader_single_images = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)

## Load networks

In [None]:
path = '../saved_models/imagenet_robust/imagenet64x64/'
models = {}
model_names = [#'resnet18_l2_eps0.0.pth',
               'resnet18_l2_eps0.01.pth',
               'resnet18_l2_eps0.1.pth',
               'resnet18_l2_eps5.0.pth',
               'resnet18_l2_eps10.0.pth',
               'resnet18_l2_eps20.0.pth',
               'resnet18_l2_eps50.0.pth',
               'resnet18_l2_eps100.0.pth',
               'resnet18_l2_eps200.0.pth',
               'resnet18_l2_eps500.0.pth',
               'resnet18_l2_eps2500.0.pth',
               'resnet18_l2_eps5000.0.pth']

for file in model_names:
    pos = file.find('eps') 
    pos = pos + len('eps')
    eps = file[pos:pos+file[pos:].find('.pth')]
    eps = float(eps)
    model = resnet18()
    state_dict = torch.load(os.path.join(path, file))
    # remove `module.` from distributed training
    if 'module.conv1.weight' in state_dict.keys():
        cleaned_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]
            cleaned_state_dict[name] = v
        state_dict = cleaned_state_dict
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.eval()
    models[file] = model

## Accuracies

In [None]:
epsilons =  [0.01, 0.1, 5, 10, 20, 50, 100, 200, 500, 2500, 5000]

In [None]:
accuracies = {}
for model_name, model in models.items():
    model.to(device)
    val_loss = 0
    val_zero_one_loss = 0
    for img, label in tqdm.tqdm(valloader):
        img = img / 255
        img, label = img.to(device), label.to(device)
        pred = model(img)
        val_zero_one_loss += (pred.softmax(dim=1).argmax(dim=1) != label).sum().item()
    accuracies[model_name] = 1-val_zero_one_loss / len(valloader.dataset)
    print(f'{model_name} Val Acc. {accuracies[model_name]}')
    model.to('cpu')

In [None]:
plt.plot(epsilons, accuracies.values(), 'o--')
plt.xscale('log')

## The diffusion model

In [None]:
import diffusion_model

sigma = 1.2
diffusion = diffusion_model.Diffusion(f'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl')

## Compute Input Gradients and Scores

In [None]:
N_images = 1000

images = []
scores = []
input_gradients = {k : [] for k, _ in models.items()}

for idx, (img, label) in tqdm.tqdm(enumerate(valloader_single_images)):
    img = img / 255
    images.append(img.detach().cpu())
    img, label = img.to(device), label.to(device)
    # input gradient, for all models
    for model_name, model in models.items():
        model.to(device)
        assert img.grad is None
        ig = input_gradient(model, img).detach().cpu()
        input_gradients[model_name].append(ig)
        model.to('cpu')
    # score
    diffusion.to(device)
    score = diffusion.get_score(img.to(device), sigma, class_labels=label)
    scores.append(score.detach().cpu())
    diffusion.to('cpu')
    if idx >= N_images:
        break
images = torch.vstack(images)
input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}
scores = torch.vstack(scores)        

# scale the lenght of score and input gradients so that they lie in [-1,1]
for model_name, _ in models.items():
    for idx in range(input_gradients[model_name].shape[0]):
        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()
for idx in range(scores.shape[0]): # score to [-1, 1]
    scores[idx] = scores[idx] / scores[idx].abs().max()

## Supplement Figure

In [None]:
nrows = 2+len(models)
__, axs = plt.subplots(nrows=nrows, ncols=10, figsize=(20, 26))

for idx in range(10):
    img = (images[idx, :, :, :]*255).clip(0, 255).to(torch.uint8)
    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
    
    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

    # different models
    for model_idx, (model_name, _) in enumerate(models.items()):
        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

for ax in axs:
    for idx in range(10):
        ax[idx].axis('off')

plt.tight_layout(pad=0.5)
plt.savefig('../figures/imagenet64x64-gradients-big.pdf')
plt.show()

## LPIPS metric

In [None]:
import lpips

# init lpips 
torch.hub.set_dir("../tmp/.cache/torchhub") # set hub to writeable directory
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores

In [None]:
lpips_results = []
for model_name, _ in models.items():
    print(model_name)
    distances = []
    for img_idx in range(N_images): # for all images
        score = scores[img_idx]
        ig = input_gradients[model_name][img_idx]
        # bilinear downsize for the input gradient to have the same size as the score 
        ig = torchvision.transforms.Resize(size=(64,64))(ig) 
        distance = loss_fn_alex(ig, score)
        distances.append(distance.item())        
    print(np.mean(distances))
    lpips_results.append(np.mean(distances))

## Figure 4

In [None]:
plt.figure(figsize=(10,6))
ax1 = plt.gca()
ax2 = ax1.twinx()

ax1.plot(epsilons, accuracies.values(), 'o--', ms=10, color='#1f77b4', label='Accuracy')
ax1.set_ylabel('Accuracy', color='#1f77b4')
ax1.tick_params(axis='y', colors='#1f77b4')

ax2.plot(epsilons, [1-x for x in lpips_results], 'o--', ms=10, color='#ff7f0e')
ax2.set_ylabel('1-LPIPS', color='#ff7f0e')
ax2.tick_params(axis='y', colors='#ff7f0e')

plt.title('ImageNet-64x64')
plt.xscale('log')
ax1.set_xlabel('Adversarial Perturbation Budget (Epsilon)')
plt.savefig('../figures/imagenet64x64_lpips.pdf')

plt.show()

## Figure 6

In [None]:
plot_models = {k: v for k, v in models.items()}
plot_models.pop('resnet18_l2_eps10.0.pth')
plot_models.pop('resnet18_l2_eps100.0.pth')

sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=1)

for img_idx in [10]:
    nrows = 1
    ncols = 11
    __, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols, 2))

    # image
    img = (images[img_idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    axs[0].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # score
    ig = torch.clone(scores[img_idx, :, :, :])
    std = ig.std()
    ig = ig.clip(-3*std, 3*std) / (3*std)
    img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # models
    for model_idx, (model_name, _) in enumerate(plot_models.items()):
        ig = torch.clone(input_gradients[model_name][img_idx, :, :, :])
        std = ig.std()
        ig = ig.clip(-3*std, 3*std) / (3*std)
        img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)
        axs[2+model_idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
        
        # lpips metric
        ig = input_gradients[model_name][img_idx, :, :, :]
        score = scores[img_idx, :, :, :]
        distance = loss_fn_alex(ig, score)
        #axs[2+model_idx].set_title(f'{1-distance.item():.2}')
        print(f'{1-distance.item():.2}')

    for idx in range(ncols):
        axs[idx].axis('off')
    
    plt.tight_layout(pad=.45)
    plt.savefig('../figures/imagenet-example.png', dpi=600)
    plt.show()
