## Figure 1, Figure 6 and Supplement Figures

In [None]:
import os

import torch

from utils.datasets import get_dataloaders
from models.resnet import ResNet18

import matplotlib.pyplot as plt

import numpy as np
import tqdm

from utils.edm_score import cifar10_score, input_gradient_sum

In [None]:
device = 'cuda'

In [None]:
def get_reg_const(model_name, const_name):
    magic_string = f'{const_name}='
    pos = model_name.find(magic_string) 
    pos = pos + len(magic_string)
    reg_const = model_name[pos:pos+model_name[pos:].find('_')]
    return float(reg_const)

### Gnorm regularized models

In [None]:
path = '../saved_models/gnorm_c10/'  

gnorms = []
model_files = {}
for file in os.listdir(path): # load one representative model for each regularization constant
    if file[-3:] != '.pt':
        continue
    
    reg_const = get_reg_const(file, 'gnorm_const')
    if not reg_const in gnorms:
        gnorms.append(reg_const)
        model_files[reg_const] = os.path.join(path, file)

gnorms.sort()
model_files = [model_files[x] for x in gnorms]

### Randomized Smoothing

In [None]:
path = '../saved_models/randomized_smoothing_c10/'  

gnorms = []
model_files = {}
for file in os.listdir(path): # load one representative model for each regularization constant
    if file[-3:] != '.pt':
        continue
    
    reg_const = get_reg_const(file, 'noise_level')
    if not reg_const in gnorms:
        gnorms.append(reg_const)
        model_files[reg_const] = os.path.join(path, file)

gnorms.sort()
model_files = [model_files[x] for x in gnorms]

### Smothness Penalty

In [None]:
path = '../saved_models/smooth_c10/'   
gnorms = []
model_files = {}
for file in os.listdir(path): # load one representative model for each regularization constant
    if file[-3:] != '.pt':
        continue
    
    reg_const = get_reg_const(file, 'gnorm_const')
    if not reg_const in gnorms:
        gnorms.append(reg_const)
        model_files[reg_const] = os.path.join(path, file)

gnorms.sort()
model_files = [model_files[x] for x in gnorms]

### Models for Figure 1

In [None]:
model_files = ['../saved_models/standard_c10/resnet18_reg=none_cifar10.pt',
               '../saved_models/gnorm_c10/full/resnet18_reg=gnorm_const=0.07278953843983153_cifar10.pt',
               '../saved_models/gnorm_c10/full/resnet18_reg=gnorm_const=57.361525104486816_cifar10.pt']

### Load the selected models files

In [None]:
models = {}
for model_file in model_files:
    model = ResNet18()
    model.load_state_dict(torch.load(model_file))
    model.eval()
    model.to('cpu')
    models[os.path.basename(model_file)] = model    

#### Compute input gradients and scores for the entire test set

In [None]:
trainloader, testloader = get_dataloaders("cifar10", batch_size=32)

In [None]:
sigma = 0.5

input_gradients = {k : [] for k, _ in models.items()}
for model_name, model in models.items():
    for img, _ in tqdm.tqdm(testloader):
        img = img.to(device)
        gradient = input_gradient_sum(model, img, device=device).detach().cpu()
        input_gradients[model_name].append(gradient)
input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}

images = []
scores = []
for img, _ in tqdm.tqdm(testloader):
    img = img.to(device)
    images.append(img.detach().cpu())
    scores.append(cifar10_score(img, sigma, device=device).detach().cpu())
images = torch.vstack(images)
scores = torch.vstack(scores)

In [None]:
# scale the lenght of score and input gradients so that they lie in [-1,1]
for idx in range(scores.shape[0]):
    scores[idx] = scores[idx] / scores[idx].abs().max()
for model_name, _ in models.items():
    for idx in range(input_gradients[model_name].shape[0]):
        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()

## Supplement Figures

In [None]:
# we have 30 different models, make the plot for 15
representative_models = {k: v for k, v in models.items() if get_reg_const(k, 'gnorm_const') in gnorms[0::2]}
#representative_models = {k: v for k, v in models.items() if get_reg_const(k, 'noise_level') in gnorms[0::2]}

In [None]:
nrows = 2+len(representative_models)
ncols = 12
__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 28))

for idx in range(ncols):
    # image
    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # score
    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # models
    for model_idx, (model_name, _) in enumerate(representative_models.items()):
        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
        axs[2+model_idx, 0].set_ylabel(model_name[13:-3])

for ax in axs:
    for idx in range(ncols):
        ax[idx].axis('off')

plt.tight_layout(pad=0.5)
plt.savefig('../figures/cifar10-gradients-big.pdf')
plt.show()

## Figure 6

In [None]:
plot_models = {k: v for k, v in models.items() if get_reg_const(k, 'gnorm_const') in gnorms[0::3]}
plot_models.pop('resnet18_reg=gnorm_const=0.0014873521072935117_cifar10.pt')

img_idx = 0
nrows = 1
ncols = 11
__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols, 2))

# image
img = (images[img_idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
axs[0].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

# score
img = (scores[img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
axs[1].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

# models
for model_idx, (model_name, _) in enumerate(plot_models.items()):
    img = (input_gradients[model_name][img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[2+model_idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

for idx in range(ncols):
    axs[idx].axis('off')

plt.tight_layout(pad=.5)
plt.show()


In [None]:
nrows = 5
ncols = 5
img_offset = 10

for idx in range(ncols):
    # image
    img = (images[img_offset+idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    plt.figure(figsize=(4,4))
    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)
    plt.axis('off')
    plt.tight_layout(pad=0.)
    plt.savefig(f'../figures/fig2-cifar10-img-{idx}.png')
    plt.show()

    # score
    img = (scores[img_offset+idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    plt.figure(figsize=(4,4))
    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)
    plt.axis('off')
    plt.tight_layout(pad=0.)
    plt.savefig(f'../figures/fig2-cifar10-img-{idx}-score.png')
    plt.show()

    # models
    for model_idx, (model_name, _) in enumerate(models.items()):
        plt.figure(figsize=(4,4))
        img = (input_gradients[model_name][img_offset+idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
        plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))
        plt.axis('off')
        plt.tight_layout(pad=0.)
        plt.savefig(f'../figures/fig2-cifar10-img-{idx}-model-{model_idx}.png')
        plt.show()

for ax in axs:
    for idx in range(ncols):
        ax[idx].axis('off')

plt.tight_layout(pad=1.0)
plt.savefig('../figures/cifar10-page-2.pdf')
plt.show()

## Figure 1

In [None]:
import lpips

loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores

In [None]:
distances = []
model_name = 'resnet18_reg=gnorm_const=0.07278953843983153_cifar10.pt'
for img_idx in range(scores.shape[0]): # for all images
    score = scores[img_idx]
    ig = input_gradients[model_name][img_idx]
    distance = loss_fn_alex(ig, score)
    distances.append(distance.item())        

In [None]:
values = []
for i in range(20):
    x = np.argmin(distances)
    values.append(x)
    distances[x] = 1000
print(values)

In [None]:
for idx in values:
    # image
    print(idx)
    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)
    plt.axis('off')
    plt.show()
    plt.imshow((scores[idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))
    plt.axis('off')
    plt.show()
    plt.imshow((input_gradients[model_name][idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))
    plt.axis('off')
    plt.show()

In [None]:
nrows = 3
ncols = 4
__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 6))

for icol, idx in enumerate([7908, 2649, 6748, 821]):
    # image
    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)
    axs[0, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # score
    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[1, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)

    # models
    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)
    axs[2, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))

for ax in axs:
    for idx in range(ncols):
        ax[idx].axis('off')

plt.tight_layout(pad=0.75)
plt.savefig('../figures/cifar10-gradients-big.png')
plt.show()