<a href="https://colab.research.google.com/github/sfbllgrn/DD2412_Class_Contrastive_Explanations/blob/main/experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/sfbllgrn/DD2412_Class_Contrastive_Explanations.git

import os
import shutil

# Define the source and destination paths
source_folder = '/content/DD2412_Class_Contrastive_Explanations'
destination_folder = '/content'

# List the files and subdirectories in the source folder
contents = os.listdir(source_folder)

# Move each item from the source folder to the destination folder
for item in contents:
    source_path = os.path.join(source_folder, item)
    destination_path = os.path.join(destination_folder, item)
    shutil.move(source_path, destination_path)

# Remove the now-empty source folder
os.rmdir(source_folder)


In [4]:
# Mount Google drive that contains all data
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Imports
from torchvision.models import densenet161, DenseNet161_Weights
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import googlenet, GoogLeNet_Weights
from torchvision.models import mnasnet0_5, MNASNet0_5_Weights # Här gissar jag att dom använder 0.5, står inte någonstans
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights

import numpy as np
import torch
from torch.autograd.functional import jacobian as J
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import sys
from torchvision import transforms
from torch import nn
from torch.nn.functional import one_hot


In [5]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [6]:
# Load data

# when data is stored locally
#data_folder = "/Users/sofia/Documents/Skola/KTH/Master/Deep Learning, Advanced Course DD2412/Class Contrastive Explanations/DD2412_Class_Contrastive_Explanations/Data_small"

# for data stored on google drive
data_folder = "/content/drive/MyDrive/Colab Notebooks/Deep learning advanced/ImageNet_Data/val"
data_obj = ImageFolder(root=data_folder, transform=DenseNet161_Weights.DEFAULT.transforms())

BATCH_SIZE = 1
val_dataloader = DataLoader(data_obj, batch_size=BATCH_SIZE, shuffle=False)


In [17]:
# Init Pretrained models

# debug:
alex = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
pretrained_models = {"alexnet":alex}

#densenet = densenet161(weights=DenseNet161_Weights.IMAGENET1K_V1)
#mobilenet_small = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
#alex = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
#google = googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
#mnasnet = mnasnet0_5(weights=MNASNet0_5_Weights.IMAGENET1K_V1)
#resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
#mobilenet_large = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1)  # Denna har även IMAGENET1K_v2
#efficientnet = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1) # Denna har även IMAGENET1K_v2
#pretrained_models = {
 #                    "alexnet":alex, "googlenet":google,
 #                    "mnasnet":mnasnet, "resnet":resnet,
 #                    "mobilenet_large":mobilenet_large,
 #                    "efficientnet":efficientnet,
 #                    "densenet":densenet, "mobilenet_small":mobilenet_small}

In [12]:
# Perform gradient sign pertubations

def calculate_weighted_contrast(x, t, net):
  """Calculated weighted attribute explanation.
  Inputs:
    - X: one tensor input data (image), with shape (1,3,224,224)
    - t: one-dimensional tensor containing target class for the input img
    - net: network model"""

  logits = net(x)
  num_classes = logits.shape[1]
  phi_t = calculate_gradient(x, t, net)

  weighted_explanation = phi_t
  alpha_denominator = torch.sum(torch.exp(logits), dim=1)-torch.exp(logits[0,t])
  for s in range(num_classes):
    if s!=t:
      s = torch.tensor(s).to(device)
      alpha_s = torch.exp(logits[0,s])/alpha_denominator
      phi_s = calculate_gradient(x, s, net)
      weighted_explanation -= alpha_s*phi_s
  return weighted_explanation


def calculate_max_contrast(x, t, net):
  logits = net(x)
  num_classes = logits.shape[1]
  logits[t] = -1000000000  # because we want to take argmax on logits except t. I assume that there will always be a logit larger than this value
  s_star = torch.argmax(logits)
  phi_s_star = calculate_gradient(x, s_star, net)
  phi_t = calculate_gradient(x, t, net)
  return phi_t - phi_s_star


def calculate_mean_contrast(x, t, net):
    logits = net(x)
    num_classes = logits.shape[1]
    phi_t = calculate_gradient(x,t,net)
    mean_contrast = phi_t
    for s in range(num_classes):
      if s!=t:
        s = torch.tensor(s).to(device)
        phi_s = calculate_gradient(x,s,net)
        mean_contrast -= phi_s/(num_classes-1)
    return mean_contrast


def get_attribute_explanation(x, t, net, contrast_type):
  if contrast_type=="original":
    return calculate_gradient(x, t, net)
  elif contrast_type=="weighted":
    return calculate_weighted_contrast(x, t, net)
  elif contrast_type=="mean":
    return calculate_mean_contrast(x, t, net)
  elif contrast_type=="max":
    return calculate_max_contrast(x, t, net)


def calculate_gradient(x, t, net, probs=False):
  logits = net(x)
  pred_probab = nn.Softmax(dim=1)(logits)
  yt_oh = one_hot(t, num_classes=logits.shape[1])
  external_grad = torch.reshape(yt_oh, logits.shape)
  x.retain_grad()   # Enables this Tensor to have their grad populated during backward() even if it is not a leaf node ?
  if probs:
    pred_probab.backward(gradient=external_grad)
    return x.grad

  logits.backward(gradient=external_grad)

  return x.grad

def calculate_gradient_old(net, x, pred_indx):
  value_logits = J(lambda x:net(x)[np.arange(BATCH_SIZE), pred_indx],x)
  value_logits = torch.diagonal(value_logits)
  value_logits = value_logits.permute(3,0,1,2)
  #value_probs = J(lambda x:torch.nn.functional.softmax(net(x), dim=1)[np.arange(BATCH_SIZE),pred_indx], x)
  #value_probs = torch.diagonal(value_probs)
  #value_probs = value_probs.permute(3,0,1,2)
  return value_logits


def gradient_sign_pertube(x, t, net, n, contrast_type="original"):
  epsilon = 1e-3
  xn = x.clone()
  alpha = epsilon/n
  for i in range(n):
    # Att tänka ut: ska det vara x eller data nedan
    logits = get_attribute_explanation(x, t, net, contrast_type)
    xn = xn + alpha*torch.sign(logits)  # blir det rätt index här?
    xn = torch.clamp(xn, min=torch.minimum(x-epsilon, torch.tensor(0)), max=torch.maximum(x+epsilon, torch.tensor(1)))

  return xn


## Test


In [18]:
torch.manual_seed(1)

n=1

model_names = ["alexnet"]

for name in model_names:
  model = pretrained_models[name].to(device)
  model.eval()

  #eval_size = len(data_obj)  # eval on all data
  eval_size = 2             # eval on small subset, for debug

  perturbation_types = ["weighted"]#, "mean", "original", "max"]
  accuracy_dict = {p_type:0 for p_type in perturbation_types}
  accuracy_dict["unperturbed"] = 0
  perturbation_changes = {key:{"yt":[], "pt":[]} for key in perturbation_types}

  for batch_idx, (input, target) in enumerate(val_dataloader):
      #if batch_idx%eval_size/10 == 0:
      print(batch_idx/eval_size)

      if batch_idx < eval_size:

          input.requires_grad_(True)
          input = input.to(device)
          target = target.to(device)

          y = model(input)
          yt = y[np.arange(BATCH_SIZE), target]
          pt = torch.nn.functional.softmax(y, dim=1)[np.arange(BATCH_SIZE), target]
          _, prediction_unperturbed = torch.max(y, 1)
          if prediction_unperturbed == target:
            accuracy_dict["unperturbed"] += 1/eval_size

          for perturbation_type in perturbation_types:
            # perturbed x with respect to logits
            print(perturbation_type)
            x_perturbed = gradient_sign_pertube(input, target, model, n, perturbation_type)
            y_perturbed = model(x_perturbed)
            yt_perturbed = y_perturbed[np.arange(BATCH_SIZE), target]

            # Save the change in yt before and after perturbation
            perturbation_changes[perturbation_type]['yt'].append(yt_perturbed-yt)

            # Calculate and save change in pt before and after perturbation
            pt_perturbed = torch.nn.functional.softmax(y_perturbed, dim=1)[np.arange(BATCH_SIZE), target]
            perturbation_changes[perturbation_type]['pt'].append(pt_perturbed-pt)

            # Store result of perturbed prediction
            _, prediction_perturbed = torch.max(y_perturbed, 1)
            if prediction_perturbed == target:
              accuracy_dict[perturbation_type] += 1/eval_size
      else:
          break


  for perturbation_type in perturbation_types:
    avg_pt_change = torch.mean(torch.stack(perturbation_changes[perturbation_type]['pt']))
    avg_yt_change = torch.mean(torch.stack(perturbation_changes[perturbation_type]['yt']))
    accuracy_change = accuracy_dict[perturbation_type]-accuracy_dict['unperturbed']
    print("Average changes in yt: {} and pt: {}".format(avg_yt_change, avg_pt_change))
    print('Change in accuracy for model {} using {} iterations: {}\%'.format(name, n, accuracy_change))

0.0
tensor([0])
0.5
tensor([0])
1.0
Average changes in yt: -0.24493789672851562 and pt: 1.2609176337718964e-05
Change in accuracy for model alexnet using 1 iterations: 0\%


In [27]:
# Plots
import matplotlib.pyplot as plt



# rearrange change_dictionary to desired format:
measurements = {"yt":tuple([perturbation_changes[p_type]['yt'] for p_type in perturbation_types]),
                "pt":tuple([perturbation_changes[p_type]['pt'] for p_type in perturbation_types]),
                "acc":tuple([accuracy_change[p_type] for p_type in perturbation_types])
                }


print(measurements)
fig, axs = plt.subplots(len(model_names), layout="constrained")

model_name = "alexnet"
#
#for i,model_name in enumerate(model_names):
#  axs[i].set_title(model_name)
axs.set_title(model_name)
width = 0.25
multiplier = 0
x = np.arange(len(perturbation_types))

for attribute, measurement in measurements.items():
    print(measurement)
    offset = width * multiplier
    rects = axs.bar(x + offset, measurement, width, label=attribute)
    axs.bar_label(rects, padding=3)
    multiplier += 1


x_tickes = (p_type for p_type in perturbation_types)
axs.set_xticks(x + width, x_tickes)
axs.legend(loc='upper left', ncols=3)
axs.set_ylim(0, 250)

plt.show()

TypeError: ignored