In [1]:
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchmetrics import Accuracy

import torch.optim as optim
from cleverhans.torch.attacks.projected_gradient_descent import (projected_gradient_descent)

import quantus
import captum
from captum.attr import Saliency, IntegratedGradients, NoiseTunnel

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import random
import copy
import gc
import math

import warnings
warnings.filterwarnings('ignore')

from pathlib import Path

import matplotlib.pyplot as plt
%matplotlib inline

from resnet_18 import *

In [2]:
print(torch.cuda.is_available())

True


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
%run utils.ipynb
%run metrics.ipynb

In [8]:
batch_size = 64
train_path = '../datasets/imagenette2/train'
val_path = '../datasets/imagenette2/val'

In [9]:
train_dataloader = DataLoader(datasets.ImageFolder(train_path, transform = transforms.Compose([
                                                                    transforms.RandomResizedCrop(224),
                                                                    transforms.RandomHorizontalFlip(),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                                        std=[0.229, 0.224, 0.225])
                                                            ])), batch_size = batch_size, shuffle=True, num_workers=4, pin_memory=True)

test_dataloader = DataLoader(datasets.ImageFolder(val_path,
                                                               transform=transforms.Compose([
                                                                   transforms.ToTensor(),
                                                                   transforms.Resize([224, 224]),
                                                                   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                                        std=[0.229, 0.224, 0.225])
                                                               ])),batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)


In [10]:
classes = ['tench', 'springer', 'casette_player', 'chain_saw','church', 'French_horn', 'garbage_truck', 'gas_pump', 'golf_ball', 'parachute']


In [11]:
criterion = nn.CrossEntropyLoss(reduction="mean").cuda()

In [12]:
def explainer_wrapper(**kwargs):
    """
    A wrapper function to call the appropriate explanation method.

    """
    if kwargs["method"] == "SmoothGrad":
        return smoothgrad_explainer(**kwargs)
    else:
        return ValueError("Explanation function doesnt exist")



def smoothgrad_explainer(model, inputs, targets, abs=True, normalise=True, stdevs=0.15, nt_samples=10, *args, **kwargs):
    """
    Generate explanations for a model's predictions using the SmoothGrad method.

    Args:
        model: The model to explain.
        inputs : Input samples.
        targets: Target labels corresponding to the inputs.
        abs : Whether to use the absolute value of gradients.
        normalise: Whether to normalize the explanation
        stdevs: Standard deviation of the noise added to inputs
        nt_samples: Number of noisy samples to generate for SmoothGrad.

    Returns:
        numpy.ndarray: Explanation maps for the input samples.
    """
    std = kwargs.get("std", 0.15)  # Standard deviation for input noise
    n = kwargs.get("n", 10)  # Number of noisy samples
    clip = kwargs.get("clip", False)

    model.to(kwargs.get("device", None))
    model.eval()

    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs).reshape(-1,kwargs.get("nr_channels", 3),kwargs.get("img_size", 224),kwargs.get("img_size", 224),).to(kwargs.get("device", None))
        )
    
    if not isinstance(targets, torch.Tensor):
        targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))

    assert (len(np.shape(inputs)) == 4), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size"

    if inputs.shape[0] > 1:
        explanation = torch.zeros(
            (
                n,
                inputs.shape[0],
                kwargs.get("img_size", 224),
                kwargs.get("img_size", 224),
            )
        )
    else:
        explanation = torch.zeros(
            (n, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
        )
    saliency = Saliency(model)
    
    explanation = (
        NoiseTunnel(saliency)
        .attribute(inputs=inputs, target=targets, nt_type="smoothgrad", stdevs = stdevs, nt_samples= 10)
        .sum(axis=1)
        .reshape(-1, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
        .cpu()
        .data
    )

    # explanation = explanation.mean(axis=0)
    gc.collect()
    torch.cuda.empty_cache()

    explanation = explanation.numpy()
    # Normalization (if required)
    if normalise:
        explanation = quantus.normalise_func.normalise_by_negative(explanation)

    # Convert the result to NumPy if it is still a PyTorch tensor
    if isinstance(explanation, torch.Tensor):
        if explanation.requires_grad:
            return explanation.cpu().detach().numpy()
        return explanation.cpu().numpy()

    return explanation

In [13]:
xai_method=["SmoothGrad"]
gc.collect()
torch.cuda.empty_cache()

## Vanilla Gradient

In [15]:
MODEL_PATH = "saves/resnet/imagenette/0_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)

Test Accuracy: 84.15%
The road score is:  {1: np.float64(0.9854187562209861), 11: np.float64(0.9302623147929386), 21: np.float64(0.8855254694918366), 31: np.float64(0.8275752431472999), 41: np.float64(0.7686531744591676), 51: np.float64(0.7003263311118203), 61: np.float64(0.6167151750011796), 71: np.float64(0.5189097090335428), 81: np.float64(0.41946263606951006), 91: np.float64(0.30242259808422944)}


In [17]:
MODEL_PATH = "saves/resnet/imagenette/1_5_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 84.10%
The road score is:  {1: np.float64(0.9837310536786187), 11: np.float64(0.929972869715964), 21: np.float64(0.866104104853331), 31: np.float64(0.8072800443816678), 41: np.float64(0.7171702227608943), 51: np.float64(0.6231640628883724), 61: np.float64(0.5228790319156869), 71: np.float64(0.4260239943928879), 81: np.float64(0.3296042879105286), 91: np.float64(0.2543334457371799)}


In [18]:
MODEL_PATH = "saves/resnet/imagenette/1_10_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 83.13%
The road score is:  {1: np.float64(0.9799208963778208), 11: np.float64(0.9136983714162727), 21: np.float64(0.8637160553598301), 31: np.float64(0.8006494027445334), 41: np.float64(0.7237594767473243), 51: np.float64(0.6276571377706955), 61: np.float64(0.5362060711547837), 71: np.float64(0.43228985890760535), 81: np.float64(0.3291858131864658), 91: np.float64(0.25164405727791406)}


In [19]:
MODEL_PATH = "saves/resnet/imagenette/1_15_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 84.18%
The road score is:  {1: np.float64(0.9800060568550581), 11: np.float64(0.9197179038239468), 21: np.float64(0.8670742404008076), 31: np.float64(0.8091293294324844), 41: np.float64(0.7264244267648965), 51: np.float64(0.6474040927489494), 61: np.float64(0.5544084061680202), 71: np.float64(0.4606913360362517), 81: np.float64(0.3678827795588879), 91: np.float64(0.2800112830488548)}


In [16]:
MODEL_PATH = "saves/resnet/imagenette/1_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)

Test Accuracy: 84.31%
The road score is:  {1: np.float64(0.9815317446919453), 11: np.float64(0.916264309570018), 21: np.float64(0.8550035656800294), 31: np.float64(0.78959464373224), 41: np.float64(0.7143713719788255), 51: np.float64(0.6170017447664645), 61: np.float64(0.535769578517576), 71: np.float64(0.440189077713297), 81: np.float64(0.34097142622121834), 91: np.float64(0.2647278476711539)}


In [21]:
MODEL_PATH = "saves/resnet/imagenette/2_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 85.55%
The road score is:  {1: np.float64(0.9820257658315142), 11: np.float64(0.9274665111306769), 21: np.float64(0.8728924914795306), 31: np.float64(0.8038201046585481), 41: np.float64(0.7230428491972497), 51: np.float64(0.6394828207378183), 61: np.float64(0.5341273658014597), 71: np.float64(0.42564235011338775), 81: np.float64(0.33250923822599016), 91: np.float64(0.24350088953506518)}


In [22]:
MODEL_PATH = "saves/resnet/imagenette/6_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 85.63%
The road score is:  {1: np.float64(0.9815523567285375), 11: np.float64(0.9356981658686606), 21: np.float64(0.8902092786831333), 31: np.float64(0.8415169642088424), 41: np.float64(0.7749564670168705), 51: np.float64(0.6978569255157439), 61: np.float64(0.60888594593442), 71: np.float64(0.5116983727967653), 81: np.float64(0.4039617397451793), 91: np.float64(0.3028486530926366)}


In [23]:
MODEL_PATH = "saves/resnet/imagenette/12_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 84.43%
The road score is:  {1: np.float64(0.9835630670525478), 11: np.float64(0.9342213103810895), 21: np.float64(0.8821054857545466), 31: np.float64(0.8255040473880926), 41: np.float64(0.7546564625140637), 51: np.float64(0.6619870188715149), 61: np.float64(0.5630498173438654), 71: np.float64(0.46577110960962625), 81: np.float64(0.3588333885456528), 91: np.float64(0.25681642034993374)}


In [24]:
MODEL_PATH = "saves/resnet/imagenette/18_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 82.70%
The road score is:  {1: np.float64(0.9847261184524104), 11: np.float64(0.9267461343208111), 21: np.float64(0.8685318018749423), 31: np.float64(0.807615101667683), 41: np.float64(0.7248794158224089), 51: np.float64(0.6496159754595252), 61: np.float64(0.5688044644775477), 71: np.float64(0.475189554305934), 81: np.float64(0.3775650117210429), 91: np.float64(0.28272291065287086)}


In [25]:
MODEL_PATH = "saves/resnet/imagenette/29_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
model.eval()
accuracy = test(model, test_dataloader, criterion)
print(f"Test Accuracy: {accuracy:.2f}%")

result_normal = filter_and_compute_road(model, test_dataloader, "Saliency", device, resnet = True)
print("The road score is: ", result_normal)
del model

Test Accuracy: 83.46%
The road score is:  {1: np.float64(0.9860395021377726), 11: np.float64(0.9380568182055771), 21: np.float64(0.8859509869049661), 31: np.float64(0.8254678674765173), 41: np.float64(0.7475342621319393), 51: np.float64(0.651121103479216), 61: np.float64(0.5477292968705794), 71: np.float64(0.446885111033178), 81: np.float64(0.34203169419053664), 91: np.float64(0.24484063846912646)}


## Integrated Gradients

In [27]:
print(device)

cuda


In [None]:
MODEL_PATH = "saves/resnet/imagenette/0_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
result_normal = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_normal)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/1_5_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_5 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_5)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/1_10_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_10 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_10)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/1_15_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_15 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_15)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/1_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)
result_prune_20 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_20)

In [None]:
MODEL_PATH = "saves/resnet/imagenette/2_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_30 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_30)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/6_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_40 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_40)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/12_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_50 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_50)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/18_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_60 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_60)
del model

In [None]:
MODEL_PATH = "saves/resnet/imagenette/29_model_lt.pth.tar"
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.to(device)

result_prune_70 = filter_and_compute_road(model, test_dataloader, "IntegratedGradients", device, resnet = True)
print("The road score is: ", result_prune_70)
del model