# Load data

In [1]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [3]:
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('../', train=True, download=True, transform=transform_train)

# Load test set (using as validation)
val_set = dsets.CIFAR10('../', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


# Import Resnet Models
----

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os.path
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from copy import deepcopy
from resnet import ResNet18

In [5]:
# Select device
# device = torch.device('cuda')
device = torch.device('cpu')

## Import Resnet model

In [6]:
# Check for model
if os.path.isfile('cifar resnet.pt'):
    # Load saved model
    print('Loading saved model')
    model_resnet = torch.load('cifar resnet.pt').to(device)
else:
    print('Model not available')

Loading saved model


## Import surrogate

In [7]:
from fastshap import ImageSurrogate
from fastshap.utils import MaskLayer2d, KLDivLoss, DatasetInputOnly

In [8]:
# Check for model
if os.path.isfile('cifar surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('cifar surrogate.pt').to(device)
    surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)

else:
    print('Model not available')

Loading saved surrogate model


## Import Explainer

In [9]:
from unet import UNet
from fastshap import FastSHAP

In [10]:
# Check for model
if os.path.isfile('cifar explainer.pt'):
    print('Loading saved explainer model')
    explainer = torch.load('cifar explainer.pt').to(device)
    fastshap = FastSHAP(explainer, surrogate, link=nn.LogSoftmax(dim=1))

else:
    print('Model not available')

Loading saved explainer model


# Import Vit Models
-----

## Import Vit Model

In [29]:
from vit import ViT
model_path = 'cifar_c10_vit.pt'

In [30]:
if os.path.isfile(model_path):
    # load saved model
    print("Loading saved model")
    model_vit = torch.load(model_path).to(device)
else:
    print('Model not available')

Loading saved model


In [31]:
torch.save(model.cpu(), model_path)

## Import surrogates

In [15]:
from fastshap import ImageSurrogate
from fastshap.utils import MaskLayer2d, DatasetInputOnly, KLDivLoss

In [16]:
surr_model_path = 'cifar_c10_vit_surrogate_lb20.pt'

In [17]:
# Check for model
if os.path.isfile(surr_model_path):
    print('Loading saved surrogate model')
    surr = torch.load(surr_model_path).to(device)
    surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)
else:
    print('Model not available')

Loading saved surrogate model


## Import explainer

In [18]:
from unet import UNet
from fastshap import FastSHAP

In [19]:
explainer_path = 'cifar_explainer_unet_lb20.pt'

In [20]:
# Check for model
if os.path.isfile(explainer_path):
    print('Loading saved explainer model')
    explainer = torch.load(explainer_path).to(device)
    fastshap = FastSHAP(explainer, surrogate, link=nn.LogSoftmax(dim=1))

else:
    print('Model not available')

Loading saved explainer model


# Initialize dataset

In [112]:
dset = val_set
# dset_size = 10000
dset_size = 10
targets = np.array(dset.targets)
num_classes = targets.max() + 1
classes = ['Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

# Plot individual images and values

In [113]:
import matplotlib.pyplot as plt

In [114]:
def plot_img(x):
    mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]
    im = x.numpy() * std + mean    
    im = im.transpose(1, 2, 0).astype(float)
    im = np.clip(im, a_min=0, a_max=1)
    plt.imshow(im, vmin=0, vmax=1)
    plt.show()

In [115]:
def plot_values(value):
#     plt.imshow(value, cmap='seismic',vmin=-m, vmax=m )
    plt.imshow(value, cmap='seismic')

# Collect performance results Using Resnet

## Create dataframe to store results

In [116]:
import pandas as pd
import numpy as np

In [117]:
columns =  ['y_label','y_A','y_prob_correct_A', 'y_prob_A', 'A','y_B','y_prob_correct_B', 'y_prob_B', 'B']
dset_idxs = np.array(range(dset_size))
results = pd.DataFrame(columns = columns, index = dset_idxs)

In [118]:
results.head()

Unnamed: 0,y_label,y_A,y_prob_correct_A,y_prob_A,A,y_B,y_prob_correct_B,y_prob_B,B
0,,,,,,,,,
1,,,,,,,,,
2,,,,,,,,,
3,,,,,,,,,
4,,,,,,,,,


## Collect data 

In [128]:
model_A = model_resnet
model_B = model_vit

In [134]:
def get_pred(dset, dset_size, idx, model, model_name, x, y):
    y_pred = model(x).softmax(dim=1).detach()
    y_pred_class = np.argmax(y_pred)

    # store values

    results.iloc[idx]['y_' + model_name] = y_pred_class.item()
    results.iloc[idx]['y_prob_correct_' + model_name] = y_pred[0][y].item()
    results.iloc[idx]['y_prob_' + model_name] = y_pred[0][y_pred_class.item()].item()

    if (y == y_pred_class): 
        results.iloc[idx][model_name] = 1
    else:
        results.iloc[idx][model_name] = 0

In [135]:
def collect_data(dset, dset_size, results, model_A, model_B):
    for idx in range(dset_size):
        x, y = zip(*[dset[idx]])
        x = torch.stack(x)
        y = y[0]
        
        results.iloc[idx]['y_label'] = y
        get_pred(dset, dset_size, idx, model_A, 'A', x, y)
        get_pred(dset, dset_size, idx, model_B, 'B', x, y)

In [136]:
def check_results(idx,  results):
    x, y = zip(*[dset[idx]])
    x = x[0]
    y = y[0]
    print(f'Label = {classes[y]}, {y}')
    y_predicted = results.iloc[idx]['y_resnet_class'] 
    print(f'Predicted = {classes[y_predicted]} , {y_predicted}' )
    plot_img(x)

In [137]:
collect_data(dset, dset_size, results, model_A, model_B)

In [138]:
results

Unnamed: 0,y_label,y_A,y_prob_correct_A,y_prob_A,A,y_B,y_prob_correct_B,y_prob_B,B
0,3,3,0.366576,0.366576,1,3,1.0,1.0,1
1,8,8,0.307475,0.307475,1,8,0.999998,0.999998,1
2,8,1,0.180506,0.241541,0,8,0.999936,0.999936,1
3,0,1,0.019308,0.37162,0,0,0.846047,0.846047,1
4,6,1,0.06216,0.321299,0,6,1.0,1.0,1
5,6,1,0.041133,0.551747,0,6,0.999997,0.999997,1
6,1,1,0.407426,0.407426,1,1,0.999943,0.999943,1
7,6,5,0.060609,0.264785,0,6,1.0,1.0,1
8,3,1,0.230295,0.233738,0,3,1.0,1.0,1
9,1,1,0.523663,0.523663,1,1,0.985912,0.985912,1
