<a href="https://colab.research.google.com/github/vggls/msc_thesis_medical_xai/blob/main/experiments/crc_resnet34/MaxSensitivity_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **1. Imports**

In [None]:
pip install grad-cam

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import shutil
import random
import pickle

import torch
from torchvision  import datasets, transforms
from torch import nn
from torch.utils.data import DataLoader

from pytorch_grad_cam import GradCAM, HiResCAM

In [None]:
# custom written code
from max_sensitivity import sample_eps_Inf,  get_explanation,  get_exp_sens,  MaxSensitivity_Dataset,  plot_scores_frequency

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **2. Load .zip test data file from Google Drive and unzip it**

In [None]:
!unzip "./drive/My Drive/Datasets/CRC/CRC-VAL-HE-7K.zip" # test set 7K

[1;30;43mΗ έξοδος ροής περικόπηκε στις τελευταίες 5000 γραμμές.[0m
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ACMSDEFF.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ACQQYLLS.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ADCHTGEE.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AFELDRPS.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AFFMDFQV.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AFQQTGKI.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AGKPYMDE.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AHDNMNIT.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AHKLPKMS.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AHQCDGMY.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AIIGEWYP.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ALLMHHRT.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ALQTIPLF.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-APHIEAQK.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-AQGAYQML.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ARHEISPN.tif  
  inflating: CRC-VAL-HE-7K/DEB/DEB-TCGA-ARIHITHS.ti

# **3. Test Dataset**

In [None]:
# to be applied to validation and test data
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])

In [None]:
test_path = './CRC-VAL-HE-7K/'
test_dataset = datasets.ImageFolder(root=test_path, transform=test_transforms)

# **4. Load model**

In [None]:
drive_path = './drive/MyDrive/Colab_Notebooks/dataset_models/CRC/Models/'
resnet34 = torch.load(drive_path + 'crc_resnet34.pt')

In [None]:
resnet34 = resnet34.cuda()
resnet34 = resnet34.eval()

# **5. Experiment function**

In [None]:
def run_experiment(radius, labels, samples_per_img, xai):

    assert xai in ['gradcam', 'hirescam']

    if xai == 'gradcam':
      cam_instance = GradCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)
    elif xai == 'hirescam':
      cam_instance = HiResCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)

    print('radius:{} - labels:{}'.format(radius, labels))

    dataset = [(img, label) for (img, label) in test_dataset if (label in labels)]
    _, data_scores = MaxSensitivity_Dataset(dataset, resnet34, cam_instance, radius, samples_per_img)

    drive_path = './drive/MyDrive/Colab_Notebooks/dataset_models/CRC/Max_Sensitivity/exp_ResNet34/'

    if xai == 'gradcam':
      with open(drive_path + 'maxsens_resnet34_gradcam_radius:{}_labels:{}.pickle'.format(radius, labels), 'wb') as f:
        pickle.dump(data_scores, f)
    elif xai == 'hirescam':
      with open(drive_path + 'maxsens_resnet34_hirescam_radius:{}_labels:{}.pickle'.format(radius, labels), 'wb') as f:
        pickle.dump(data_scores, f)

    return data_scores

# **6. GradCAM**

## **radius 0.05**

In [None]:
data_scores = run_experiment(radius=0.05, labels=[0,1,2,3], samples_per_img=20, xai='gradcam')

radius:0.05 - labels:[0, 1, 2, 3]
Total time: 889.7212681770325 secs
Correctly predicted images: 2858/3158
Avg secs per image:  0.31


In [None]:
sum(data_scores)

47279.86483454704

In [None]:
data_scores = run_experiment(radius=0.05, labels=[4,5,6,7,8], samples_per_img=20, xai='gradcam')

radius:0.05 - labels:[4, 5, 6, 7, 8]
Total time: 1145.0395720005035 secs
Correctly predicted images: 3711/4022
Avg secs per image:  0.31


In [None]:
sum(data_scores)

47704.28295946121

## **radius 0.1**

In [None]:
data_scores = run_experiment(radius=0.1, labels=[0,1,2], samples_per_img=20, xai='gradcam')

radius:0.1 - labels:[0, 1, 2]
Total time: 716.2290580272675 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.32


In [None]:
sum(data_scores)

72229.73666381836

In [None]:
data_scores = run_experiment(radius=0.1, labels=[3,4,5], samples_per_img=20, xai='gradcam')

radius:0.1 - labels:[3, 4, 5]
Total time: 735.0154044628143 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.34


In [None]:
sum(data_scores)

64950.87347698212

In [None]:
data_scores = run_experiment(radius=0.1, labels=[6,7,8], samples_per_img=20, xai='gradcam')

radius:0.1 - labels:[6, 7, 8]
Total time: 716.4381680488586 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.33


In [None]:
sum(data_scores)

47300.58795309067

## **radius 0.2**

In [None]:
data_scores = run_experiment(radius=0.2, labels=[0,1,2,3], samples_per_img=30, xai='gradcam')

radius:0.2 - labels:[0, 1, 2, 3]
Total time: 1535.9241313934326 secs
Correctly predicted images: 2858/3158
Avg secs per image:  0.54


In [None]:
sum(data_scores)

131485.10600090027

In [None]:
data_scores = run_experiment(radius=0.2, labels=[4,5,6,7,8], samples_per_img=30, xai='gradcam')

radius:0.2 - labels:[4, 5, 6, 7, 8]
Total time: 1690.1485257148743 secs
Correctly predicted images: 3711/4022
Avg secs per image:  0.46


In [None]:
sum(data_scores)

176814.6002368927

## **radius 0.3**

In [None]:
data_scores = run_experiment(radius=0.3, labels=[0,1,2], samples_per_img=30, xai='gradcam')

radius:0.3 - labels:[0, 1, 2]
Total time: 1044.7618608474731 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.47


In [None]:
sum(data_scores)

120348.47531032562

In [None]:
data_scores = run_experiment(radius=0.3, labels=[3,4,5], samples_per_img=30, xai='gradcam')

radius:0.3 - labels:[3, 4, 5]
Total time: 1050.802269935608 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.49


In [None]:
sum(data_scores)

155210.19903182983

In [None]:
data_scores = run_experiment(radius=0.3, labels=[6,7,8], samples_per_img=30, xai='gradcam')

radius:0.3 - labels:[6, 7, 8]
Total time: 1030.4509558677673 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.47


In [None]:
sum(data_scores)

141123.63207435608

## **radius 0.4**

In [None]:
data_scores = run_experiment(radius=0.4, labels=[0,1,2], samples_per_img=40, xai='gradcam')

radius:0.4 - labels:[0, 1, 2]
Total time: 1735.3760902881622 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.78


In [None]:
sum(data_scores)

137711.39922237396

In [None]:
data_scores = run_experiment(radius=0.4, labels=[3,4,5], samples_per_img=40, xai='gradcam')

radius:0.4 - labels:[3, 4, 5]
Total time: 1646.4227328300476 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.76


In [None]:
sum(data_scores)

181331.37240982056

In [None]:
data_scores = run_experiment(radius=0.4, labels=[6,7,8], samples_per_img=40, xai='gradcam')

radius:0.4 - labels:[6, 7, 8]
Total time: 1696.9106981754303 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.78


In [None]:
sum(data_scores)

206426.14287948608

## **radius 0.5**

In [None]:
data_scores = run_experiment(radius=0.5, labels=[0,1,2], samples_per_img=40, xai='gradcam')

radius:0.5 - labels:[0, 1, 2]
Total time: 1347.7357683181763 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.6


In [None]:
sum(data_scores)

154628.34824180603

In [None]:
data_scores = run_experiment(radius=0.5, labels=[3,4,5], samples_per_img=40, xai='gradcam')

radius:0.5 - labels:[3, 4, 5]
Total time: 1368.789751291275 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.64


In [None]:
sum(data_scores)

202914.8122768402

In [None]:
data_scores = run_experiment(radius=0.5, labels=[6,7,8], samples_per_img=40, xai='gradcam')

radius:0.5 - labels:[6, 7, 8]
Total time: 1343.1255123615265 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.62


In [None]:
sum(data_scores)

255666.9595527649

# **7. HiResCAM**

## **radius 0.05**

In [None]:
data_scores = run_experiment(radius=0.05, labels=[0,1,2,3], samples_per_img=20, xai='hirescam')

radius:0.05 - labels:[0, 1, 2, 3]
Total time: 1035.8348858356476 secs
Correctly predicted images: 2858/3158
Avg secs per image:  0.36


In [None]:
sum(data_scores)

44187.03478384018

In [None]:
data_scores = run_experiment(radius=0.05, labels=[4,5,6,7,8], samples_per_img=20, xai='hirescam')

radius:0.05 - labels:[4, 5, 6, 7, 8]
Total time: 1140.453549861908 secs
Correctly predicted images: 3711/4022
Avg secs per image:  0.31


In [None]:
sum(data_scores)

40621.49029326439

## **radius 0.1**

In [None]:
data_scores = run_experiment(radius=0.1, labels=[0,1,2], samples_per_img=20, xai='hirescam')

radius:0.1 - labels:[0, 1, 2]
Total time: 707.6605799198151 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.32


In [None]:
sum(data_scores)

67762.61979675293

In [None]:
data_scores = run_experiment(radius=0.1, labels=[3,4,5], samples_per_img=20, xai='hirescam')

radius:0.1 - labels:[3, 4, 5]
Total time: 725.2654504776001 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.34


In [None]:
sum(data_scores)

53582.72028017044

In [None]:
data_scores = run_experiment(radius=0.1, labels=[6,7,8], samples_per_img=20, xai='hirescam')

radius:0.1 - labels:[6, 7, 8]
Total time: 711.9879014492035 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.33


In [None]:
sum(data_scores)

41345.95800232887

## **radius 0.2**

In [None]:
data_scores = run_experiment(radius=0.2, labels=[0,1,2,3], samples_per_img=30, xai='hirescam')

radius:0.2 - labels:[0, 1, 2, 3]
Total time: 1480.2862694263458 secs
Correctly predicted images: 2858/3158
Avg secs per image:  0.52


In [None]:
sum(data_scores)

121034.22462368011

In [None]:
data_scores = run_experiment(radius=0.2, labels=[4,5,6,7,8], samples_per_img=30, xai='hirescam')

radius:0.2 - labels:[4, 5, 6, 7, 8]
Total time: 1664.051607131958 secs
Correctly predicted images: 3711/4022
Avg secs per image:  0.45


In [None]:
sum(data_scores)

146647.97851228714

## **radius 0.3**

In [None]:
data_scores = run_experiment(radius=0.3, labels=[0,1,2], samples_per_img=30, xai='hirescam')

radius:0.3 - labels:[0, 1, 2]
Total time: 1027.142365694046 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.46


In [None]:
sum(data_scores)

116618.92820835114

In [None]:
data_scores = run_experiment(radius=0.3, labels=[3,4,5], samples_per_img=30, xai='hirescam')

radius:0.3 - labels:[3, 4, 5]
Total time: 1050.5196452140808 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.49


In [None]:
sum(data_scores)

124246.32281684875

In [None]:
data_scores = run_experiment(radius=0.3, labels=[6,7,8], samples_per_img=30, xai='hirescam')

radius:0.3 - labels:[6, 7, 8]
Total time: 1028.8461337089539 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.47


In [None]:
sum(data_scores)

106743.15491485596

##**radius 0.4**

In [None]:
data_scores = run_experiment(radius=0.4, labels=[0,1,2], samples_per_img=40, xai='hirescam')

radius:0.4 - labels:[0, 1, 2]
Total time: 1388.1642701625824 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.62


In [None]:
sum(data_scores)

130186.51616096497

In [None]:
data_scores = run_experiment(radius=0.4, labels=[3,4,5], samples_per_img=40, xai='hirescam')

radius:0.4 - labels:[3, 4, 5]
Total time: 1415.1888465881348 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.66


In [None]:
sum(data_scores)

148012.09937667847

In [None]:
data_scores = run_experiment(radius=0.4, labels=[6,7,8], samples_per_img=40, xai='hirescam')

radius:0.4 - labels:[6, 7, 8]
Total time: 1404.9269824028015 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.64


In [None]:
sum(data_scores)

145472.20769119263

## **radius 0.5**

In [None]:
data_scores = run_experiment(radius=0.5, labels=[0,1,2], samples_per_img=40, xai='hirescam')

radius:0.5 - labels:[0, 1, 2]
Total time: 1339.7305948734283 secs
Correctly predicted images: 2234/2524
Avg secs per image:  0.6


In [None]:
sum(data_scores)

139386.2993440628

In [None]:
data_scores = run_experiment(radius=0.5, labels=[3,4,5], samples_per_img=40, xai='hirescam')

radius:0.5 - labels:[3, 4, 5]
Total time: 1362.28245139122 secs
Correctly predicted images: 2155/2261
Avg secs per image:  0.63


In [None]:
sum(data_scores)

168930.82308387756

In [None]:
data_scores = run_experiment(radius=0.5, labels=[6,7,8], samples_per_img=40, xai='hirescam')

radius:0.5 - labels:[6, 7, 8]
Total time: 1323.0808136463165 secs
Correctly predicted images: 2180/2395
Avg secs per image:  0.61


In [None]:
sum(data_scores)

175136.8635406494