In [1]:
import os
import sys
sys.path = ['efficientnet-b0-08094119.pth'] + sys.path

import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from warmup_scheduler import GradualWarmupScheduler
from efficientnet_pytorch import model as enet
import albumentations
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm 
#import enet_model
#import enet_utils 
import warnings
warnings.filterwarnings("ignore") 

In [2]:
df_train = pd.read_csv('train.csv')
image_folder = os.path.join('train_images')

kernel_type = 'how_to_train_effnet_b0_to_get_LB_0.86'

enet_type = 'efficientnet-b0'
fold = 0
height = 256
width = 256
batch_size = 2
num_workers = 4
out_dim = 5
init_lr = 3e-4
warmup_factor = 10

warmup_epo = 1
n_epochs = 20
device = torch.device('cuda') 

In [3]:
skf = StratifiedKFold(4, shuffle=True, random_state=42)
df_train['fold'] = -1
for i, (train_idx, valid_idx) in enumerate(skf.split(df_train, df_train['isup_grade'])):
    df_train.loc[valid_idx, 'fold'] = i
df_train.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,fold
0,0005f7aaab2800f6170c399693a96917,karolinska,0,0+0,3
1,000920ad0b612851f8e01bcc880d9b3d,karolinska,0,0+0,0
2,0018ae58b01bdadc8e347995b69f99aa,radboud,4,4+4,2
3,001c62abd11fa4b57bf7a6c603a11bb9,karolinska,4,4+4,3
4,001d865e65ef5d2579c190a0e0350d8f,karolinska,0,0+0,3


In [4]:
pretrained_model = {'efficientnet-b0': 'efficientnet-b0-08094119.pth'}

class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_name(backbone)
        self.enet.load_state_dict(torch.load(pretrained_model[backbone]))

        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x 

In [5]:
class Dataset(Dataset):
    def __init__(self,
                 df,
                 transform=None,
                ):

        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        
        file_name = self.df['image_id'].values[index]
        file_path = f'train_images/{file_name}.tiff'
        image = skimage.io.MultiImage(file_path)
        image = cv2.resize(image[-1], (height, width))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image'] 
            
        image = image.astype(np.float32)
        image /= 255
        image = image.transpose(2, 0, 1)
        row = self.df.iloc[index] 

        label = np.zeros(5).astype(np.float32)
        label[:row.isup_grade] = 1.

        
        return torch.tensor(image), torch.tensor(label)

In [6]:
transforms_train = albumentations.Compose([
    albumentations.Transpose(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
])
transforms_val = albumentations.Compose([])

In [7]:
criterion = nn.BCEWithLogitsLoss() 

In [8]:
def train_epoch(loader, optimizer):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for (data, target) in bar:
        
        data, target = data.to(device), target.to(device)
        loss_func = criterion
        optimizer.zero_grad()
        logits = model(data)
        loss = loss_func(logits, target)
        loss.backward()
        optimizer.step()

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        bar.set_description('loss: %.5f, smth: %.5f' % (loss_np, smooth_loss))
    return train_loss


def val_epoch(loader, get_output=False):

    model.eval()
    val_loss = []
    LOGITS = []
    PREDS = []
    TARGETS = []

    with torch.no_grad():
        for (data, target) in tqdm(loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)

            loss = criterion(logits, target)

            pred = logits.sigmoid().sum(1).detach().round()
            LOGITS.append(logits)
            PREDS.append(pred)
            TARGETS.append(target.sum(1))

            val_loss.append(loss.detach().cpu().numpy())
        val_loss = np.mean(val_loss)

    LOGITS = torch.cat(LOGITS).cpu().numpy()
    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    acc = (PREDS == TARGETS).mean() * 100.
    
    qwk = cohen_kappa_score(PREDS, TARGETS, weights='quadratic')
    qwk_k = cohen_kappa_score(PREDS[df_valid['data_provider'] == 'karolinska'], df_valid[df_valid['data_provider'] == 'karolinska'].isup_grade.values, weights='quadratic')
    qwk_r = cohen_kappa_score(PREDS[df_valid['data_provider'] == 'radboud'], df_valid[df_valid['data_provider'] == 'radboud'].isup_grade.values, weights='quadratic')
    print('qwk', qwk, 'qwk_k', qwk_k, 'qwk_r', qwk_r)

    if get_output:
        return PREDS
    else:
        return val_loss, acc, qwk

In [9]:
train_idx = np.where((df_train['fold'] != fold))[0]
valid_idx = np.where((df_train['fold'] == fold))[0]

df_this  = df_train.loc[train_idx]
df_valid = df_train.loc[valid_idx]

dataset_train = Dataset(df_this , transform=transforms_train)
dataset_valid = Dataset(df_valid, transform=transforms_val)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, sampler=RandomSampler(dataset_train), num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, sampler=SequentialSampler(dataset_valid), num_workers=num_workers)

model = enetv2(enet_type, out_dim=out_dim)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=init_lr/warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs-warmup_epo)
scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)

print(len(dataset_train), len(dataset_valid)) 

7962 2654


In [10]:
qwk_max = 0.
best_file = f'{kernel_type}_best_fold{fold}.pth'
for epoch in range(1, n_epochs+1):
    print(time.ctime(), 'Epoch:', epoch)
    scheduler.step(epoch-1)

    train_loss = train_epoch(train_loader, optimizer)
    val_loss, acc, qwk = val_epoch(valid_loader)

    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {np.mean(train_loss):.5f}, val loss: {np.mean(val_loss):.5f}, acc: {(acc):.5f}, qwk: {(qwk):.5f}'
    print(content)

    if qwk > qwk_max:
        print('score progress: ({:.6f} --> {:.6f}).  Saving model ...'.format(qwk_max, qwk))
        torch.save(model.state_dict(), best_file)
        qwk_max = qwk

torch.save(model.state_dict(), os.path.join(f'{kernel_type}_final_fold{fold}.pth'))

Sun Sep 13 20:25:25 2020 Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.4819261051067212 qwk_k 0.34680529256466996 qwk_r 0.3656405538566062
Sun Sep 13 20:30:39 2020 Epoch 1, lr: 0.0000300, train loss: 0.52151, val loss: 0.47092, acc: 25.01884, qwk: 0.48193
score progress: (0.000000 --> 0.481926).  Saving model ...
Sun Sep 13 20:30:39 2020 Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5036183085623176 qwk_k 0.35744682837613073 qwk_r 0.45183965606419907
Sun Sep 13 20:35:45 2020 Epoch 2, lr: 0.0003000, train loss: 0.50722, val loss: 0.45388, acc: 23.81311, qwk: 0.50362
score progress: (0.481926 --> 0.503618).  Saving model ...
Sun Sep 13 20:35:45 2020 Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5476479370146539 qwk_k 0.4354038151045354 qwk_r 0.5113711620566725
Sun Sep 13 20:40:51 2020 Epoch 3, lr: 0.0000300, train loss: 0.44145, val loss: 0.41985, acc: 27.39261, qwk: 0.54765
score progress: (0.503618 --> 0.547648).  Saving model ...
Sun Sep 13 20:40:51 2020 Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5570992042884364 qwk_k 0.45668166242873387 qwk_r 0.5644725053165389
Sun Sep 13 20:45:58 2020 Epoch 4, lr: 0.0002919, train loss: 0.46192, val loss: 0.45150, acc: 25.13188, qwk: 0.55710
score progress: (0.547648 --> 0.557099).  Saving model ...
Sun Sep 13 20:45:58 2020 Epoch: 5


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5630261005687617 qwk_k 0.45675425725446606 qwk_r 0.5168237776621609
Sun Sep 13 20:51:05 2020 Epoch 5, lr: 0.0002819, train loss: 0.44371, val loss: 0.42694, acc: 23.85079, qwk: 0.56303
score progress: (0.557099 --> 0.563026).  Saving model ...
Sun Sep 13 20:51:05 2020 Epoch: 6


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6025634267629834 qwk_k 0.4998793340743358 qwk_r 0.5467398017864538
Sun Sep 13 20:56:11 2020 Epoch 6, lr: 0.0002684, train loss: 0.42857, val loss: 0.43184, acc: 32.89375, qwk: 0.60256
score progress: (0.563026 --> 0.602563).  Saving model ...
Sun Sep 13 20:56:11 2020 Epoch: 7


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5720576120815613 qwk_k 0.41925088119357634 qwk_r 0.5260150196128381
Sun Sep 13 21:01:16 2020 Epoch 7, lr: 0.0002516, train loss: 0.42570, val loss: 0.45017, acc: 37.00075, qwk: 0.57206
Sun Sep 13 21:01:16 2020 Epoch: 8


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.5665830445186295 qwk_k 0.47060051039966566 qwk_r 0.5350775482322856
Sun Sep 13 21:06:20 2020 Epoch 8, lr: 0.0002320, train loss: 0.41367, val loss: 0.41228, acc: 27.65637, qwk: 0.56658
Sun Sep 13 21:06:20 2020 Epoch: 9


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.65052497182875 qwk_k 0.5794092921169917 qwk_r 0.5951722507843877
Sun Sep 13 21:11:27 2020 Epoch 9, lr: 0.0002103, train loss: 0.40894, val loss: 0.42182, acc: 38.01809, qwk: 0.65052
score progress: (0.602563 --> 0.650525).  Saving model ...
Sun Sep 13 21:11:27 2020 Epoch: 10


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6238706397308283 qwk_k 0.5370580419458459 qwk_r 0.5893079936053456
Sun Sep 13 21:16:35 2020 Epoch 10, lr: 0.0001868, train loss: 0.39871, val loss: 0.39749, acc: 33.79804, qwk: 0.62387
Sun Sep 13 21:16:35 2020 Epoch: 11


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6493123553758984 qwk_k 0.5780320668288713 qwk_r 0.6064203545859794
Sun Sep 13 21:21:43 2020 Epoch 11, lr: 0.0001624, train loss: 0.38975, val loss: 0.40709, acc: 34.32555, qwk: 0.64931
Sun Sep 13 21:21:43 2020 Epoch: 12


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6484656473839918 qwk_k 0.536598452573857 qwk_r 0.5879432015529804
Sun Sep 13 21:26:50 2020 Epoch 12, lr: 0.0001376, train loss: 0.38160, val loss: 0.40591, acc: 38.99774, qwk: 0.64847
Sun Sep 13 21:26:50 2020 Epoch: 13


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6332652279316031 qwk_k 0.5734165275647372 qwk_r 0.5847939216695304
Sun Sep 13 21:31:58 2020 Epoch 13, lr: 0.0001132, train loss: 0.37411, val loss: 0.39347, acc: 36.88772, qwk: 0.63327
Sun Sep 13 21:31:58 2020 Epoch: 14


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6100712491411224 qwk_k 0.5434175916975361 qwk_r 0.566758249568593
Sun Sep 13 21:37:06 2020 Epoch 14, lr: 0.0000897, train loss: 0.36629, val loss: 0.39760, acc: 35.53127, qwk: 0.61007
Sun Sep 13 21:37:06 2020 Epoch: 15


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6481946146828325 qwk_k 0.5703890643055626 qwk_r 0.6155164335375545
Sun Sep 13 21:42:12 2020 Epoch 15, lr: 0.0000680, train loss: 0.35540, val loss: 0.39283, acc: 36.05878, qwk: 0.64819
Sun Sep 13 21:42:12 2020 Epoch: 16


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6472576960638357 qwk_k 0.5952524808400268 qwk_r 0.5832538704665988
Sun Sep 13 21:47:16 2020 Epoch 16, lr: 0.0000484, train loss: 0.34836, val loss: 0.39480, acc: 36.20950, qwk: 0.64726
Sun Sep 13 21:47:16 2020 Epoch: 17


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.659799133931591 qwk_k 0.5749914854969631 qwk_r 0.6044426445674878
Sun Sep 13 21:52:20 2020 Epoch 17, lr: 0.0000316, train loss: 0.34517, val loss: 0.39295, acc: 37.71665, qwk: 0.65980
score progress: (0.650525 --> 0.659799).  Saving model ...
Sun Sep 13 21:52:20 2020 Epoch: 18


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6589752451024027 qwk_k 0.5814561201272301 qwk_r 0.6157643792750311
Sun Sep 13 21:57:24 2020 Epoch 18, lr: 0.0000181, train loss: 0.34008, val loss: 0.38972, acc: 38.39488, qwk: 0.65898
Sun Sep 13 21:57:24 2020 Epoch: 19


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6654980146099494 qwk_k 0.596727793731045 qwk_r 0.6166372182751798
Sun Sep 13 22:02:31 2020 Epoch 19, lr: 0.0000081, train loss: 0.33568, val loss: 0.38993, acc: 37.64130, qwk: 0.66550
score progress: (0.659799 --> 0.665498).  Saving model ...
Sun Sep 13 22:02:31 2020 Epoch: 20


HBox(children=(FloatProgress(value=0.0, max=3981.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1327.0), HTML(value='')))


qwk 0.6693748764940592 qwk_k 0.5975209661255165 qwk_r 0.62795096666421
Sun Sep 13 22:07:38 2020 Epoch 20, lr: 0.0000020, train loss: 0.33657, val loss: 0.38825, acc: 38.47023, qwk: 0.66937
score progress: (0.665498 --> 0.669375).  Saving model ...


## Visualizing and understanding the network 
### Forward hooks 

In [11]:
df_valid  = df_train.loc[780].to_frame().transpose()
#df_valid.astype({'col1': 'int32'}).dtypes
df_valid 

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,fold
780,13aba34105b637fcf77e4efdc4ccaef4,radboud,4,4+4,0


In [12]:
df_valid['isup_grade'] = df_valid['isup_grade'].astype('int')
df_valid.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 780 to 780
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   image_id       1 non-null      object
 1   data_provider  1 non-null      object
 2   isup_grade     1 non-null      int64 
 3   gleason_score  1 non-null      object
 4   fold           1 non-null      object
dtypes: int64(1), object(4)
memory usage: 48.0+ bytes


In [13]:
sys.path = ['effnet_best_fold.pth'] + sys.path
model_dict = {'efficientnet-b0': 'effnet_best_fold.pth'}
model = enetv2(enet_type, out_dim=out_dim)
model.load_state_dict(torch.load(model_dict['efficientnet-b0']))
model.to(device)
model.eval() 

#df_valid  = df_train.loc[780].to_frame().transpose()
dataset_valid = PANDADataset(df_valid, image_size, n_tiles, transform=transforms_val)
valid_loader = torch.utils.data.DataLoader(dataset_valid, num_workers=num_workers) 
val_epoch(valid_loader, get_output=True) 

NameError: name 'PANDADataset' is not defined

In [None]:
sys.path = ['how_to_train_effnet_b0_to_get_LB_0.86_final_fold0.pth'] + sys.path
model_dict = {'efficientnet-b0': 'how_to_train_effnet_b0_to_get_LB_0.86_final_fold0.pth'}
model = enetv2(enet_type, out_dim=out_dim)
model.load_state_dict(torch.load(model_dict['efficientnet-b0']))
model.eval()

visualisation = {}

def hook_fn(m, i, o):
    visualisation[m] = o 

def get_all_layers(net):
    for name, layer in net._modules.items():
        if isinstance(layer, nn.Sequential):
            get_all_layers(layer)
        else:
            layer.register_forward_hook(hook_fn)

get_all_layers(model)

# Just to check whether we got all layers
visualisation.keys()

def preprocess(transform=transforms_train):
    tiff_file = os.path.join(image_folder, '005e66f06bce9c2e49142536caf2f6ee.tiff')
    image = skimage.io.MultiImage(tiff_file)[1]
    tiles, OK = get_tiles(image)
    
    idxes = list(range(n_tiles))

    n_row_tiles = int(np.sqrt(n_tiles))
    images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
    for h in range(n_row_tiles):
        for w in range(n_row_tiles):
            i = h * n_row_tiles + w
    
            if len(tiles) > idxes[i]:
                this_img = tiles[idxes[i]]['img']
            else:
                this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
            this_img = 255 - this_img
            if transform is not None:
                this_img = transform(image=this_img)['image']
            h1 = h * image_size
            w1 = w * image_size
            images[h1:h1+image_size, w1:w1+image_size] = this_img

    if transform is not None:
        images = transform(image=images)['image']
    images = images.astype(np.float32)
    images /= 255
    images = images.transpose(2, 0, 1)

    label = np.zeros(5).astype(np.float32)
    label[:row.isup_grade] = 1.
    Tensor = torch.tensor(images) 
    Label= torch.tensor(label)
    return Tensor 

out = model(preprocess()) 

def showTensor(Tensor):
    plt.figure()
    plt.imshow(Tensor.numpy())
    plt.colorbar()
    plt.show() 
    
showTensor(out)  

In [None]:
import cv2 as cv

model = model 
print(model)
model_weights = [] # we will save the conv layer weights in this list
conv_layers = [] # we will save the 49 conv layers in this list
# get all the model children as list
model_children = list(model.children())

In [None]:
model_children

In [None]:
# counter to keep count of the conv layers
counter = 0 
# append all the conv layers and their respective weights to the list
for i in range(len(model_children)):
    if type(model_children[i]) == nn.Conv2dStaticSamePadding:
        counter += 1
        model_weights.append(model_children[i].weight)
        conv_layers.append(model_children[i])
    elif type(model_children[i]) == nn.Sequential:
        for j in range(len(model_children[i])):
            for child in model_children[i][j].children():
                if type(child) == nn.Conv2dStaticSamePadding:
                    counter += 1
                    model_weights.append(child.weight)
                    conv_layers.append(child)
print(f"Total convolutional layers: {counter}")

In [None]:
nn.Conv2dStaticSamePadding

## Image occlusions

In [None]:
def occlusion(model, image, label, occ_size = 50, occ_stride = 50, occ_pixel = 0.5):
  
    #get the width and height of the image
    width, height = image.shape[-2], image.shape[-1]
  
    #setting the output image width and height
    output_height = int(np.ceil((height-occ_size)/occ_stride))
    output_width = int(np.ceil((width-occ_size)/occ_stride))
  
    #create a white image of sizes we defined
    heatmap = torch.zeros((output_height, output_width))
    
    #iterate all the pixels in each column
    for h in range(0, height):
        for w in range(0, width):
            
            h_start = h*occ_stride
            w_start = w*occ_stride
            h_end = min(height, h_start + occ_size)
            w_end = min(width, w_start + occ_size)
            
            if (w_end) >= width or (h_end) >= height:
                continue
            
            input_image = image.clone().detach()
            
            #replacing all the pixel information in the image with occ_pixel(grey) 
            #in the specified location
            input_image[:, :, w_start:w_end, h_start:h_end] = occ_pixel
            
            #run inference on modified image
            output = model(input_image)
            output = nn.functional.softmax(output, dim=1)
            prob = output.tolist()[0][label]
            
            #setting the heatmap location to probability value
            heatmap[h, w] = prob 

    return heatmap 

In [None]:
heatmap = occlusion(model, images, pred[0].item(), 32, 14)
imgplot = sns.heatmap(heatmap, xticklabels=False, yticklabels=False, vmax=prob_no_occ)
figure = imgplot.get_figure() 

## Saliency maps

In [None]:
import torch
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
import requests
from PIL import Image

model = model
for param in model.parameters():
    param.requires_grad = False 

In [None]:
def show_img(PIL_IMG):
    plt.imshow(np.asarray(PIL_IMG))

In [None]:
model.eval()
X.requires_grad_()
scores = model(X)
score_max_index = scores.argmax()
score_max = scores[0,score_max_index]
score_max.backward()
saliency, _ = torch.max(X.grad.data.abs(),dim=1)
plt.imshow(saliency[0], cmap=plt.cm.hot)
plt.axis('off')
plt.show() 