<a href="https://colab.research.google.com/github/vggls/xai_and_evaluation_metrics/blob/main/example_xrays/AOPC_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**1. Imports**

In [None]:
pip install grad-cam

In [None]:
import numpy as np
import shutil
import pickle

import torch
from torchvision  import datasets, transforms

from pytorch_grad_cam import GradCAM, HiResCAM

In [None]:
# custom written code 
from xrays import create_datasets
from morf import MoRF, AOPC_Dataset, plot_aopc_per_step
from heatmap import Heatmap

In [None]:
from google.colab import drive  
drive.mount('/content/drive') 

Mounted at /content/drive


**2. Download from Kaggle**

In [None]:
! mkdir ~/.kaggle
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/kaggle.json

In [None]:
! kaggle datasets download tawsifurrahman/covid19-radiography-database

Downloading covid19-radiography-database.zip to /content
 97% 756M/778M [00:03<00:00, 217MB/s]
100% 778M/778M [00:03<00:00, 214MB/s]


In [None]:
! unzip covid19-radiography-database.zip

**3. Test dataset**

In [None]:
#shutil.rmtree('./COVID-19_Radiography_Dataset/test_dataset')

In [None]:
# to be applied to training data
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),          #recommended size
    transforms.RandomHorizontalFlip(),      #data augmentation
    transforms.RandomVerticalFlip(),        #data augmentation
    transforms.RandomRotation(degrees=20),  #data augmentation
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])
# to be applied to validation and test data
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])

In [None]:
train_dataset, validation_dataset, test_dataset = create_datasets(train_transforms, test_transforms)

In [None]:
# do not need them in this notebook
shutil.rmtree('./COVID-19_Radiography_Dataset/training_dataset')
shutil.rmtree('./COVID-19_Radiography_Dataset/validation_dataset')

**4. Model**

In [None]:
model_path = './drive/MyDrive/Colab_Notebooks/dataset_models/Covid-19_Radiography_Dataset/Models/'
resnet34 = torch.load(model_path + 'xrays_resnet34.pt')

In [None]:
resnet34 = resnet34.cuda()
resnet34 = resnet34.eval()

**5. AOPC**

In [None]:
def run_experiment(region_size, labels, xai):

    assert xai in ['gradcam', 'hirescam']

    if xai == 'gradcam':
      cam_instance = GradCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)
    elif xai == 'hirescam':
      cam_instance = HiResCAM(model=resnet34, target_layers=[resnet34.layer4[2].conv2], use_cuda=True)
    
    print('region_size:{}*{} - labels:{}'.format(region_size, region_size, labels))

    dataset = [(img, label) for (img, label) in test_dataset if (label in labels)]
    diffs, img_scores, aopc_score = AOPC_Dataset(dataset = dataset, 
                                                 model = resnet34, 
                                                 region_size = region_size, 
                                                 cam_instance = cam_instance)

    drive_path = './drive/MyDrive/Colab_Notebooks/dataset_models/Covid-19_Radiography_Dataset/AOPC/exp_ResNet34/'

    if xai == 'gradcam':
      with open(drive_path + 'aopc_resnet34_gradcam_diffs_region-size_{}_labels_{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(diffs, f)
      with open(drive_path + 'aopc_resnet34_gradcam_scores_region-size_{}_labels_{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(img_scores, f)
    elif xai == 'hirescam':
      with open(drive_path + 'aopc_resnet34_hirescam_diffs_region-size_{}_labels_{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(diffs, f)
      with open(drive_path + 'aopc_resnet34_hirescam_scores_region-size_{}_labels_{}.pickle'.format(region_size, labels), 'wb') as f: pickle.dump(img_scores, f)
    
    return aopc_score

>**GradCAM**

**region size 16*16** (heatmap 14*14)

In [None]:
score = run_experiment(region_size=16, labels=[0,1], xai='gradcam')

region_size:16*16 - labels:[0, 1]
Total time: 1356.2481248378754 secs
No of correctly classified images: 929/964
Avg secs per image:  1.46


In [None]:
score

0.535

In [None]:
score = run_experiment(region_size=16, labels=[2,3], xai='gradcam')

region_size:16*16 - labels:[2, 3]


In [None]:
score

0.892

**region size 21*21** (heatmap 11*11)

In [None]:
score = run_experiment(region_size=21, labels=[0,1], xai='gradcam')

region_size:21*21 - labels:[0, 1]
Total time: 885.3204908370972 secs
No of correctly classified images: 929/964
Avg secs per image:  0.95


In [None]:
score

0.532

In [None]:
score = run_experiment(region_size=21, labels=[2,3], xai='gradcam')

region_size:21*21 - labels:[2, 3]
Total time: 1036.3437812328339 secs
No of correctly classified images: 1088/1155
Avg secs per image:  0.95


In [None]:
score

0.895

**region size 28*28** (heatmap 8*8)

In [None]:
score = run_experiment(region_size=28, labels=[0,1,2,3], xai='gradcam')

region_size:28*28 - labels:[0, 1, 2, 3]
Total time: 1146.345870733261 secs
No of correctly classified images: 2017/2119
Avg secs per image:  0.57


In [None]:
score

0.724

**region size 38*38** (heatmap 6*6)

In [None]:
score = run_experiment(region_size=38, labels=[0,1,2,3], xai='gradcam')

In [None]:
score

**region size 56*56** (heatmap 4*4)



In [None]:
score = run_experiment(region_size=56, labels=[0,1,2,3], xai='gradcam')

region_size:56*56 - labels:[0, 1, 2, 3]
Total time: 465.1636760234833 secs
No of correctly classified images: 2017/2119
Avg secs per image:  0.23


In [None]:
score

0.712

>**HiResCAM**

**region size 16*16** (heatmap 14*14)

In [None]:
score = run_experiment(region_size=16, labels=[0,1], xai='hirescam')

region_size:16*16 - labels:[0, 1]
Total time: 1342.0818548202515 secs
No of correctly classified images: 929/964
Avg secs per image:  1.44


In [None]:
score

0.558

In [None]:
score = run_experiment(region_size=16, labels=[2,3], xai='hirescam')

region_size:16*16 - labels:[2, 3]
Total time: 1578.2365884780884 secs
No of correctly classified images: 1088/1155
Avg secs per image:  1.45


In [None]:
score

0.89

**region size 21*21** (heatmap 11*11)

In [None]:
score = run_experiment(region_size=21, labels=[0,1], xai='hirescam')

region_size:21*21 - labels:[0, 1]
Total time: 885.4581038951874 secs
No of correctly classified images: 929/964
Avg secs per image:  0.95


In [None]:
score

0.564

In [None]:
score = run_experiment(region_size=21, labels=[2,3], xai='hirescam')

region_size:21*21 - labels:[2, 3]
Total time: 1036.7783591747284 secs
No of correctly classified images: 1088/1155
Avg secs per image:  0.95


In [None]:
score

0.892

**region size 28*28** (heatmap 8*8)

In [None]:
score = run_experiment(region_size=28, labels=[0,1,2,3], xai='hirescam')

region_size:28*28 - labels:[0, 1, 2, 3]
Total time: 1133.3510348796844 secs
No of correctly classified images: 2017/2119
Avg secs per image:  0.56


In [None]:
score

0.733

**region size 38*38** (heatmap 6*6)

In [None]:
score = run_experiment(region_size=38, labels=[0,1,2,3], xai='hirescam')

In [None]:
score

**region size 56*56** (heatmap 4*4)

In [None]:
score = run_experiment(region_size=56, labels=[0,1,2,3], xai='hirescam')

region_size:56*56 - labels:[0, 1, 2, 3]
Total time: 462.3086519241333 secs
No of correctly classified images: 2017/2119
Avg secs per image:  0.23


In [None]:
score

0.726