In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import models
from skimage.transform import resize

In [None]:
NUM_CLASSES = 3139
FEATURE_EXTRACTING = False

HEIGHT = 512
WIDTH = 2560

EPOCHS = 5

# Check GPU support on your machine.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

In [None]:
class GradCamModel(nn.Module):
    def __init__(self, model=None):
        super().__init__()
        self.gradients = None
        self.tensorhook = []
        self.layerhook = []
        self.selected_out = None
        
        #PRETRAINED MODEL
        if model:
            self.pretrained = model
            self.layerhook.append(self.pretrained.layer4.register_forward_hook(self.forward_hook()))

        else:
            self.pretrained = models.resnet50(pretrained=True)
            self.layerhook.append(self.pretrained.layer4.register_forward_hook(self.forward_hook()))
        
        for p in self.pretrained.parameters():
            p.requires_grad = True
    
    def activations_hook(self,grad):
        self.gradients = grad

    def get_act_grads(self):
        return self.gradients

    def forward_hook(self):
        def hook(module, inp, out):
            self.selected_out = out
            self.tensorhook.append(out.register_hook(self.activations_hook))
        return hook

    def forward(self,x):
        out = self.pretrained(x)
        return out, self.selected_out


In [None]:
model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs,NUM )

checkpoint = torch.load("pretrainedresnet50_14epoch.tar", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

gcmodel = GradCamModel(model).to(device)

In [None]:
img = dataset[0]

fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(img["image"].permute(1,2,0))

In [None]:
out, acts = gcmodel(img["image"].to(device)[None, :])
acts = acts.detach().cpu()
loss = nn.CrossEntropyLoss()(out,torch.from_numpy(np.array([600])).to(device))
loss.backward()
grads = gcmodel.get_act_grads().detach().cpu()
pooled_grads = torch.mean(grads, dim=[0,2,3]).detach().cpu()

for i in range(acts.shape[1]):
    acts[:,i,:,:] *= pooled_grads[i]

heatmap_j = torch.mean(acts, dim = 1).squeeze()
heatmap_j_max = heatmap_j.max(axis = 0)[0]
heatmap_j /= heatmap_j_max

In [None]:
heatmap_j = resize(heatmap_j,(512,2560),preserve_range=True)


In [None]:
cmap = mpl.cm.get_cmap("jet",256)
heatmap_j2 = cmap(heatmap_j,alpha = 0.2)

In [None]:
fig, axs = plt.subplots(1,1,figsize = (15,10))
axs.imshow(img["image"].permute(1,2,0))
axs.imshow(heatmap_j2)
plt.show()