In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import tqdm

import PIL
import matplotlib.pyplot as plt

from utils.datasets import test

from utils.edm_score import input_gradient

from robustness import model_utils, datasets # https://github.com/MadryLab/robustness

In [None]:
device = 'cuda'
batch_size = 32
num_workers = 0

In [None]:
# training and validation loader, inlcuding the standard data transform
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])

imagenet_val_unnormalized = torchvision.datasets.ImageNet('/scratch_local/datasets/ImageNet2012', 
                                                         split='val',            
                                                         transform = transforms.Compose([
                                                            transforms.Resize(256),
                                                            transforms.CenterCrop(224),
                                                            transforms.ToTensor(),
                                                        ]))

val_loader_unnormalized = torch.utils.data.DataLoader(imagenet_val_unnormalized, batch_size=batch_size, shuffle=False, # for madry pre-trained models
                                                      num_workers=num_workers)

## Data

In [None]:
# load the models from https://github.com/microsoft/robust-models-transfer
l2_epsilons = ['0', '0.01', '0.03', '0.05', '0.1', '0.25', '0.5', '1', '3', '5']
l2_robust_models = [f'../saved_models/imagenet_robust/resnet18_l2_eps{eps}.ckpt' for eps in  l2_epsilons]

In [None]:
class Wrapper(torch.nn.Module):
    def __init__(self, wrapped):
        super().__init__()
        self.wrapped = wrapped
    
    def forward(self, x):
        out = self.wrapped(x)
        # insert fancy logic here
        return out[0]
  
    def __getattr__(self, name):
        try:
                return super().__getattr__(name)
        except AttributeError:
            if name == "wrapped":
                raise AttributeError()
            return getattr(self.wrapped, name)

In [None]:
imagenet_val_64x64 = torchvision.datasets.ImageNet('/scratch_local/datasets/ImageNet2012', 
                                             split='val',            
                                             transform = transforms.Compose([
                                                transforms.Resize(256),
                                                transforms.CenterCrop(224),
                                                transforms.Resize(64),
                                                transforms.Resize(224),
                                                transforms.ToTensor(),
                                                normalize,
                                            ]))

val_loader_64x64 = torch.utils.data.DataLoader(imagenet_val_64x64, batch_size=batch_size, shuffle=False,
                                         num_workers=num_workers, pin_memory=True)

test(resnet50_model, val_loader_64x64, device) # 56.69

## Robust models

In [None]:
imgnet_ds = datasets.ImageNet('/scratch_local/datasets/ImageNet2012')
models = {model_file: Wrapper(model_utils.make_and_restore_model(arch='resnet18', dataset=imgnet_ds, resume_path = model_file)[0]) for model_file in l2_robust_models}

In [None]:
for _, model in models.items():
    model.eval()
    model.to('cpu')

In [None]:
input_gradients = {k : [] for k, _ in models.items()}
for model_name, model in models.items():
    model.to(device)
    for img, _ in tqdm.tqdm(val_loader_unnormalized):
        img = img.to(device)
        gradient = input_gradient(model, img).detach().cpu()
        input_gradients[model_name].append(gradient)
        break
    model.to('cpu')
input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}

# scale the input gradients so that they lie in [-1,1]
for model_name, _ in models.items():
    for idx in range(input_gradients[model_name].shape[0]):
        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()

## The diffusion model

In [None]:
import diffusion_model

diffusion = diffusion_model.Diffusion(f'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl')

In [None]:
val_loader_unnormalized_bs_1 = torch.utils.data.DataLoader(imagenet_val_unnormalized, batch_size=1, shuffle=True, num_workers=num_workers)

In [None]:
sigma = 1.2

diffusion.to(device)
images = []
scores = []
for idx, (img, label) in tqdm.tqdm(enumerate(val_loader_unnormalized_bs_1)):
    images.append(img.clone().detach())
    img  = transforms.Resize((64,64), antialias=True)(img)
    score = diffusion.get_score(img.to(device), sigma, class_labels=label)
    scores.append(score.detach().cpu())
    if idx > 12:
        break
images = torch.vstack(images)
scores = torch.vstack(scores)
diffusion.to('cpu')

for idx in range(scores.shape[0]): # score to [-1, 1]
    scores[idx] = scores[idx] / scores[idx].abs().max()

## Calculate Input Gradients and Scores for 1000 images from the test set

In [None]:
import random

torch.manual_seed(0)
random.seed(0)

images = []
scores = []
input_gradients = {k : [] for k, _ in models.items()}
for idx, (img, label) in tqdm.tqdm(enumerate(val_loader_unnormalized_bs_1)):
    # image
    images.append(img.detach().cpu())
    img = img.to(device)
    # input gradient, for all models
    for model_name, model in models.items():
        model.to(device)
        assert img.grad is None
        ig = input_gradient(model, img).detach().cpu()
        input_gradients[model_name].append(ig)
        model.to('cpu')
    # score
    diffusion.to(device)
    img  = transforms.Resize((64,64), antialias=True)(images[-1])
    score = diffusion.get_score(img.to(device), sigma, class_labels=label)
    scores.append(score.detach().cpu())
    diffusion.to('cpu')
    if idx >= 1000:
        break
images = torch.vstack(images)
input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}
scores = torch.vstack(scores)

# scale the lenght of score and input gradients so that they lie in [-1,1]
for model_name, _ in models.items():
    for idx in range(input_gradients[model_name].shape[0]):
        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()
for idx in range(scores.shape[0]): # score to [-1, 1]
    scores[idx] = scores[idx] / scores[idx].abs().max()

In [None]:
torch.save((images, scores, input_gradients), '../datasets/imagenet_resnet18_img_score_gradients.pkl')

## Supplement Figure

In [None]:
nrows = 2+len(models)
__, axs = plt.subplots(nrows=nrows, ncols=10, figsize=(20, 24))

for idx in range(10):
    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
    
    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

    # different models
    for model_idx, (model_name, _) in enumerate(models.items()):
        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
        axs[2+model_idx, 0].set_ylabel(model_name[15:20])

for ax in axs:
    for idx in range(10):
        ax[idx].axis('off')

plt.tight_layout(pad=0.5)
plt.savefig('../figures/imagenet-big-appendix.pdf')
plt.show()

## Figure 4

In [None]:
images, scores, input_gradients = torch.load('../datasets/imagenet_resnet18_img_score_gradients.pkl')

In [None]:
import lpips

# init lpips 
torch.hub.set_dir("../tmp/.cache/torchhub") # set hub to writeable directory
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores

In [None]:
results = []
for model_name, _ in models.items():
    if model_name == '../saved_models/imagenet_robust/resnet18_l2_eps0.ckpt': # cant't plot 0 in log plot
        continue     
    print(model_name)
    distances = []
    for img_idx in range(100): # for all images

        score = scores[img_idx]
        ig = input_gradients[model_name][img_idx]
        # bilinear downsize for the input gradient to have the same size as the score 
        ig = torchvision.transforms.Resize(size=(64,64))(ig) 
        distance = loss_fn_alex(ig, score)
        distances.append(distance.item())        
    print(np.mean(distances))
    results.append(np.mean(distances))

In [None]:
# avoid type-3 fonts
import matplotlib
import seaborn as sns
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

sns.set_style("white")
sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=1.75)

In [None]:
epsilons = ['0.01', '0.03', '0.05', '0.1', '0.25', '0.5', '1', '3', '5']
accuracies = np.array([69.90, 69.24, 69.15, 68.77, 67.43, 65.49, 62.32, 53.12, 45.59]) / 100 # https://github.com/microsoft/robust-models-transfer

In [None]:
plt.figure(figsize=(10,6))
ax1 = plt.gca()
ax2 = ax1.twinx()

ax1.plot(epsilons, accuracies, 'o--', ms=10, color='#1f77b4', label='Accuracy')
ax1.set_ylabel('Accuracy', color='#1f77b4')
ax1.tick_params(axis='y', colors='#1f77b4')


ax2.plot(epsilons, [1-x for x in results], 'o--', ms=10, color='#ff7f0e')
ax2.set_ylabel('1-LPIPS', color='#ff7f0e')
ax2.tick_params(axis='y', colors='#ff7f0e')

plt.title('ImageNet')

#ax1.get_xaxis().get_major_formatter().labelOnlyBase = False
#ax1.set_xscale('log')
#ax1.set_xticks([float(x) for x in epsilons])

ax1.set_xlabel('Adversarial Perturbation Budget (Epsilon)')
plt.savefig('../figures/imagenet_lpips.pdf')

plt.show()

## Figure 1

In [None]:
nrows = 4
ncols = 4
__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 8))

for icol, idx in enumerate([89, 93, 9, 8]):
    # image
    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    axs[0, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # score
    ig = torch.clone(scores[idx, :, :, :])
    std = ig.std()
    ig = ig.clip(-3*std, 3*std) / (3*std)
    img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

    # pag
    #img = (input_gradients['../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt'][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    ig = torch.clone(input_gradients['../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt'][idx, :, :, :])
    std = ig.std()
    ig = ig.clip(-3*std, 3*std) / (3*std)
    img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[2, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
    
    # standard gradient
    ig = torch.clone(input_gradients['../saved_models/imagenet_robust/resnet18_l2_eps0.ckpt'][idx, :, :, :])
    std = ig.std()
    ig = ig.clip(-3*std, 3*std) / (3*std)
    img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[3, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

for ax in axs:
    for idx in range(ncols):
        ax[idx].axis('off')

plt.tight_layout(pad=0.4)
plt.savefig('imagenet-gradients-big.png', dpi=600)
plt.show()