## Tutorial - Evaluate explanations during model training

This tutorial demonstrates how one can use the library to evaluate how explanations changes while a model is training. We use a pre-trained AlexNet model and Tiny Imagenet dataset to showcase the library's functionality.



In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!pip install captum
!pip install opencv-python
!pip install torch==1.8.0 torchvision==0.9.0	
#pip install torch==1.9.0+cu102 torchvision==0.9.1

import torch
import torchvision
from torchvision import transforms
import numpy as np
import pandas as pd
from tqdm import tqdm
from captum.attr import *
import matplotlib.pyplot as plt
from pathlib import Path
import warnings

# Retrieve source code.
from drive.MyDrive.Projects.xai_quantification_toolbox import * #import xaiquantificationtoolbox

# Notebook settings.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning)
%load_ext autoreload
%autoreload 2

import gc



### Load 5 classes of Imagenet dataset.

In [9]:
# TODO. Update to tiny imagenet dataset!

# Load datasets and make loaders.
test_set = torchvision.datasets.ImageFolder(root='/content/drive/My Drive/imagenet_images',
                                            transform=transforms.Compose([transforms.Resize(256),
                                                                          transforms.CenterCrop((224, 224)),
                                                                          transforms.ToTensor(),
                                                                          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=56)

# Load a batch of inputs and outputs to use for evaluation.
x_batch, y_batch = iter(test_loader).next()
x_batch, y_batch = x_batch.to(device), y_batch.to(device)

### During model training/ fine-tuning calculate max-sensitivity scores of Integrated Gradients explanations.

In [15]:
def evaluate_model(model, images, labels, device):
    """Evaluate torch model given images and lables and return predictions and targets."""
    model.eval()
    logits = torch.Tensor().to(device)
    targets = torch.LongTensor().to(device)    
    return torch.nn.functional.softmax(torch.cat([logits, model(images)]), dim=1), torch.cat([targets, labels])

In [None]:
# Load AlexNet model (only constructor, not with weights).
model = torchvision.models.mobilenet_v3_small()

# Set necessary configs/ parameters.
model.to(device)  
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
path_model_weights = "drive/MyDrive/Projects/xai_quantification_toolbox/nbs/resources/imagenet5"
epochs = 5
nr_samples = 56
max_batches = 12
sensitivities = {}
 
for epoch in range(epochs):
    model.train() 
    
    for b, (images, labels) in enumerate(test_loader):
        
        if b >= max_batches:
            break

        images, labels = images.to(device), labels.to(device)
        logits = model(images)

        loss = criterion(logits, labels)
        model.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate model!
    predictions, labels = evaluate_model(model, x_batch.to(device), y_batch.to(device), device)
    test_acc = np.mean(np.argmax(predictions.cpu().detach().numpy(), axis=1) == labels.detach().cpu().numpy())
    
    # Explain model (on a few test samples) and measure sensitivies.
    sensitivities[epoch] = MaxSensitivity()(model=model, 
                                            x_batch=x_batch[:nr_samples].cpu().numpy(), 
                                            y_batch=y_batch[:nr_samples].cpu().numpy(), 
                                            a_batch=None, 
                                            **{"explanation_func": "Saliency", 
                                                "device": device,
                                                "img_size": 224})
    
    print(f"Epoch {epoch+1}/{epochs} - train accuracy: {(100 * test_acc):.2f}% - max sensitivity {np.mean(sensitivities[epoch]):.2f}")

# Save model.
torch.save(model.state_dict(), path_model_weights)
model.to(device)

Epoch 1/5 - train accuracy: 30.36% - max sensitivity 0.03
Epoch 2/5 - train accuracy: 30.36% - max sensitivity 0.05


In [36]:
[np.mean(v) for k, v in sensitivities]

# Summarise in a dataframe.      
df = pd.DataFrame(sensitivities)
df["avg"] = df.mean(axis=0)
df

[0.03605495,
 0.038111627,
 0.035774343,
 0.036996625,
 0.03895698,
 0.035916544,
 0.025491755,
 0.0071537965,
 0.022229657,
 0.0200718,
 0.04412011,
 0.032924026,
 0.043629825,
 0.021562316,
 0.01864284,
 0.027031252,
 0.020951644,
 0.0147696175,
 0.027309624,
 0.034381274,
 0.037629068,
 0.017584994,
 0.025542881,
 0.018253563]