<a href="https://colab.research.google.com/github/vggls/msc_thesis_medical_xai/blob/main/experiments/crc_resnet34/AOPC_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **1. Imports**

In [None]:
pip install grad-cam

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import shutil
import random
import pickle

import torch
from torchvision  import datasets, transforms
from torch import nn
from torch.utils.data import DataLoader

from pytorch_grad_cam import GradCAM, HiResCAM

In [None]:
# custom written code
from morf import MoRF, AOPC_Dataset, plot_aopc_per_step
from heatmap import Heatmap

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **2. Load .zip test data file from Google Drive and unzip it**

In [None]:
!unzip "./drive/My Drive/Datasets/CRC/CRC-VAL-HE-7K.zip" # test set 7K

# **3. Test dataset**

In [None]:
# to be applied to validation and test data
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])

In [None]:
test_path = './CRC-VAL-HE-7K/'
test_dataset = datasets.ImageFolder(root=test_path, transform=test_transforms)

# **4. Load model**

In [None]:
model_path = './drive/MyDrive/Colab_Notebooks/dataset_models/CRC/Models/'
resnet34 = torch.load(model_path + 'crc_resnet34.pt')

In [None]:
resnet34 = resnet34.cuda()
resnet34 = resnet34.eval()

# **5. Experiment function**

In [None]:
def run_experiment(region_size, labels, xai):

    assert xai in ['gradcam', 'hirescam']

    if xai == 'gradcam':
      cam_instance = GradCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)
    elif xai == 'hirescam':
      cam_instance = HiResCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)

    print('region_size:{}*{} - labels:{}'.format(region_size, region_size, labels))

    dataset = [(img, label) for (img, label) in test_dataset if (label in labels)]
    diffs, img_scores, aopc_score = AOPC_Dataset(dataset = dataset,
                                                 model = resnet34,
                                                 region_size = region_size,
                                                 cam_instance = cam_instance)

    drive_path = './drive/MyDrive/Colab_Notebooks/dataset_models/CRC/AOPC/exp_ResNet34/'

    if xai == 'gradcam':
      with open(drive_path + 'aopc_resnet34_gradcam_diffs_region-size:{}_labels:{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(diffs, f)
      with open(drive_path + 'aopc_resnet34_gradcam_scores_region-size:{}_labels:{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(img_scores, f)
    elif xai == 'hirescam':
      with open(drive_path + 'aopc_resnet34_hirescam_diffs_region-size:{}_labels:{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(diffs, f)
      with open(drive_path + 'aopc_resnet34_hirescam_scores_region-size:{}_labels:{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(img_scores, f)

    return aopc_score

# **6. GradCAM**

## **region size 16*16** (heatmap 14*14)

In [None]:
score = run_experiment(region_size=16, labels=[0,1,2], xai='gradcam')

region_size:16*16 - labels:[0, 1, 2]
Total time: 3063.832291126251 secs
No of correctly classified images: 2234/2524
Avg secs per image:  1.37


In [None]:
score

0.492

In [None]:
score = run_experiment(region_size=16, labels=[3,4,5], xai='gradcam')

region_size:16*16 - labels:[3, 4, 5]
Total time: 2989.4520213603973 secs
No of correctly classified images: 2155/2261
Avg secs per image:  1.39


In [None]:
score

0.57

In [None]:
score = run_experiment(region_size=16, labels=[6,7,8], xai='gradcam')

region_size:16*16 - labels:[6, 7, 8]
Total time: 3095.3397443294525 secs
No of correctly classified images: 2180/2395
Avg secs per image:  1.42


In [None]:
score

0.673

## **region size 21*21** (heatmap 11*11)

In [None]:
score = run_experiment(region_size=21, labels=[0,1,2,3], xai='gradcam')

region_size:21*21 - labels:[0, 1, 2, 3]
Total time: 2521.904171705246 secs
No of correctly classified images: 2858/3158
Avg secs per image:  0.88


In [None]:
score

0.553

In [None]:
score = run_experiment(region_size=21, labels=[4,5,6,7,8], xai='gradcam')

region_size:21*21 - labels:[4, 5, 6, 7, 8]
Total time: 3267.3921921253204 secs
No of correctly classified images: 3711/4022
Avg secs per image:  0.88


In [None]:
score

0.633

## **region size 28*28** (heatmap 8*8)

In [None]:
score = run_experiment(region_size=28, labels=[0,1,2,3], xai='gradcam')

region_size:28*28 - labels:[0, 1, 2, 3]
Total time: 1496.1799278259277 secs
No of correctly classified images: 2858/3158
Avg secs per image:  0.52


In [None]:
score

0.543

In [None]:
score = run_experiment(region_size=28, labels=[4,5,6,7,8], xai='gradcam')

region_size:28*28 - labels:[4, 5, 6, 7, 8]
Total time: 1928.724651813507 secs
No of correctly classified images: 3711/4022
Avg secs per image:  0.52


In [None]:
score

0.622

## **region size 56*56** (heatmap 4*4)

In [None]:
score = run_experiment(region_size=56, labels=[0,1,2,3,4,5,6,7,8], xai='gradcam')

region_size:56*56 - labels:[0, 1, 2, 3, 4, 5, 6, 7, 8]
Total time: 1439.0757279396057 secs
No of correctly classified images: 6569/7180
Avg secs per image:  0.22


In [None]:
score

0.573

# **7. HiResCAM**

## **region size 16*16** (heatmap 14*14)

In [None]:
score = run_experiment(region_size=16, labels=[0,1,2], xai='hirescam')

region_size:16*16 - labels:[0, 1, 2]
Total time: 3041.5002291202545 secs
No of correctly classified images: 2234/2524
Avg secs per image:  1.36


In [None]:
score

0.49

In [None]:
score = run_experiment(region_size=16, labels=[3,4,5], xai='hirescam')

region_size:16*16 - labels:[3, 4, 5]
Total time: 2947.0925047397614 secs
No of correctly classified images: 2155/2261
Avg secs per image:  1.37


In [None]:
score

0.623

In [None]:
score = run_experiment(region_size=16, labels=[6,7,8], xai='hirescam')

region_size:16*16 - labels:[6, 7, 8]
Total time: 2968.037647008896 secs
No of correctly classified images: 2180/2395
Avg secs per image:  1.36


In [None]:
score

0.686

## **region size 21*21** (heatmap 11*11)

In [None]:
score = run_experiment(region_size=21, labels=[0,1,2,3], xai='hirescam')

region_size:21*21 - labels:[0, 1, 2, 3]
Total time: 2504.9410943984985 secs
No of correctly classified images: 2858/3158
Avg secs per image:  0.88


In [None]:
score

0.549

In [None]:
score = run_experiment(region_size=21, labels=[4,5,6,7,8], xai='hirescam')

region_size:21*21 - labels:[4, 5, 6, 7, 8]
Total time: 3257.324284553528 secs
No of correctly classified images: 3711/4022
Avg secs per image:  0.88


In [None]:
score

0.639

## **region size 28*28** (heatmap 8*8)

In [None]:
score = run_experiment(region_size=28, labels=[0,1,2,3], xai='hirescam')

region_size:28*28 - labels:[0, 1, 2, 3]
Total time: 1483.9456629753113 secs
No of correctly classified images: 2858/3158
Avg secs per image:  0.52


In [None]:
score

0.538

In [None]:
score = run_experiment(region_size=28, labels=[4,5,6,7,8], xai='hirescam')

region_size:28*28 - labels:[4, 5, 6, 7, 8]
Total time: 1922.5265364646912 secs
No of correctly classified images: 3711/4022
Avg secs per image:  0.52


In [None]:
score

0.629

## **region size 56*56** (heatmap 4*4)

In [None]:
score = run_experiment(region_size=56, labels=[0,1,2,3,4,5,6,7,8], xai='hirescam')

region_size:56*56 - labels:[0, 1, 2, 3, 4, 5, 6, 7, 8]
Total time: 1424.644147157669 secs
No of correctly classified images: 6569/7180
Avg secs per image:  0.22


In [None]:
score

0.572