<a href="https://colab.research.google.com/github/sfbllgrn/DD2412_Class_Contrastive_Explanations/blob/main/experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/sfbllgrn/DD2412_Class_Contrastive_Explanations.git

import os
import shutil

# Define the source and destination paths
source_folder = '/content/DD2412_Class_Contrastive_Explanations'
destination_folder = '/content'

# List the files and subdirectories in the source folder
contents = os.listdir(source_folder)

# Move each item from the source folder to the destination folder
for item in contents:
    source_path = os.path.join(source_folder, item)
    destination_path = os.path.join(destination_folder, item)
    shutil.move(source_path, destination_path)

# Remove the now-empty source folder
os.rmdir(source_folder)


In [1]:
# Imports
from torchvision.models import densenet161, DenseNet161_Weights    
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import googlenet, GoogLeNet_Weights
from torchvision.models import mnasnet0_5, MNASNet0_5_Weights # Här gissar jag att dom använder 0.5, står inte någonstans
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights

import numpy as np
import torch
from torch.autograd.functional import jacobian as J
from torchvision.datasets import ImageFolder, ImageNet
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import sys
from torchvision import transforms

In [2]:
# Load data

data_folder = "/Users/sofia/Documents/Skola/KTH/Master/Deep Learning, Advanced Course DD2412/Class Contrastive Explanations/val"
data_obj = ImageFolder(root=data_folder, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()])
    )

BATCH_SIZE = 64
val_dataloader = DataLoader(data_obj, batch_size=BATCH_SIZE, shuffle=False)


In [4]:
# Init Pretrained models

densenet = densenet161(weights=DenseNet161_Weights.IMAGENET1K_V1)
mobilenet_small = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
alexnet = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
googlenet = googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
mnasnet = mnasnet0_5(weights=MNASNet0_5_Weights.IMAGENET1K_V1)
resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
mobilenet_large = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1)  # Denna har även IMAGENET1K_v2
efficientnet = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1) # Denna har även IMAGENET1K_v2
pretrained_models = {"densenet":densenet, "mobilenet_small":mobilenet_small, 
                    "alexnet":alexnet, "googlenet":googlenet, 
                    "mnasnet":mnasnet, "resnet":resnet, 
                    "mobilenet_large":mobilenet_large, 
                    "efficientnet":efficientnet}

In [7]:
# Perform gradient sign pertubations


def attribution_explanation(net, x):
  value_logits = J(lambda x:net.model(x), x)
  value_probs = J(lambda x:net.model(x).softmax(), x)
  return value_logits, value_probs


def gradient_sign_pertube(data, net, n):
  epsilon = 1e-3
  x = data
  alpha = epsilon/n
  outputs = net(x)
  _, predictions = torch.max(outputs, 1)
  for i in range(n):
    # Att tänka ut: vilket x ska vi ha, xn eller original?
    x = x + alpha*np.sign(attribution_explanation(net, data)[predictions])  # blir det rätt index här?
    x = np.clip(x, np.min(x-epsilon, 0), np.max(x+epsilon, 1))
  return x


class PerturbationTransform:
    def __init__(self, network, n):
        self.network = network
        self.n = n

    def __call__(self, x):
        return gradient_sign_pertube(x, self.network, self.n)


def get_perturbed_dataloader(net):
  perturbed_dict = {}
  for n in [2]:#[1,2,10]:
    perturbed_dataset = ImageFolder(root=data_folder, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    PerturbationTransform(net, 1),  # Apply your Pertube function with the network
    ]))
    perturbed_dict[n] = DataLoader(perturbed_dataset, batch_size=BATCH_SIZE, shuffle=False)
  return perturbed_dict

In [3]:
# Evaluate on Data

for name, model in pretrained_models.items():
    # Set model to eval mode
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        subset_size = 15
        # Loop through a subset of the validation data
        for batch_idx, (inputs, labels) in enumerate(val_dataloader):
            if batch_idx < subset_size:
                outputs = model(inputs)
                # Calculate any metrics you need here
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            else:
                break 
        
        accuracy = correct / total
        print('Validation Accuracy for {} with original data: {}\%'.format(name, accuracy*100))
    



NameError: name 'pretrained_models' is not defined

In [None]:
for name, model in pretrained_models.items():
    # Set model to eval mode
    model.eval()

    correct = 0
    total = 0  
    
    perturbed_dict = get_perturbed_dataloader(model)
    with torch.no_grad():
        correct_pert = 0
        total_pert = 0
        for n_iters, perturbed_dataloader in perturbed_dict.items():
            for batch_idx, (inputs, labels) in enumerate(perturbed_dataloader):
                if batch_idx < subset_size:
                    outputs = model(inputs)
                    # Calculate any metrics you need here
                    _, predicted = torch.max(outputs, 1)
                    total_pert += labels.size(0)
                    correct_pert += (predicted == labels).sum().item()
                else:
                    break 
        
            accuracy_pert = correct_pert / total_pert
            print('Validation Accuracy for {} with pertubed data, {} iterations: {}\%'.format(name, n, accuracy_pert*100))
            