# Grad-CAM implementation in PyTorch
Source: https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82

## Load VGG19 and setup data

In [None]:
import torch
import torch.nn as nn
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets

import matplotlib.pyplot as plt
import numpy as np

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root="./data/", transform=transform)

# dataloader with single image
dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # get pretrained VGG19 network
        self.vgg = vgg19(pretrained=True)
        
        # disect the network to access last convolutional layer
        self.feature_conv = self.vgg.features[:36]
        
        # get maxpool
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the classifier of vgg19
        self.classifier = self.vgg.classifier
        
        # place holder for gradients
        self.gradients = None
        
    # hook for gradients of the activation
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.feature_conv(x)
        
        # register hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.max_pool(x)
#         x = self.vgg.avgpool(x)
        # x = x.view((1, -1))
        x = x.view((-1, 512 * 7 * 7))
        x = self.classifier(x)
        
        return x
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.feature_conv(x)

## Evaluate model

In [None]:
vgg = VGG()

vgg.eval() # Set to eval to avoid random results

# get an image from dataset
img, _ = next(iter(dataloader))

# get prediction
pred = vgg(img)
labelid = pred.argmax(dim=1)
score = pred.max(dim=1)

print("label id = {}    score = {}".format(labelid.numpy(), score.values.data.numpy()))

## Get map

In [None]:
# get prediction
pred = vgg(img)
labelid = pred.argmax(dim=1)

# get the gradient of the output with respect to the output of the model
pred[:, labelid[0]].backward()

# pull the gradients out of the model
gradients = vgg.get_activations_gradient()

# pool the gradients across channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# get the activations of the last convolutional layer
activations = vgg.get_activations(img).detach()

# weight the channels by corresponding gradients
for i in range(512):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# average the channels of activation
heatmap = torch.mean(activations, dim=1).squeeze()

# ReLU on top of the heatmap
heatmap = np.maximum(heatmap, 0)

# normalize the heatmap
heatmap /= torch.max(heatmap)

# draw the heatmap
plt.matshow(heatmap.squeeze())

In [None]:
import cv2

img = cv2.imread("./data/elephant/elephant.jpeg")
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

canvas = heatmap * 0.4 + img

In [None]:
plt.imshow(canvas[:,:,::-1] / 1.4 /255)

## Analyze muliple images

In [None]:
class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = ImageFolderWithPaths(root="./data/", transform=transform)
dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

In [None]:
vgg = VGG().eval()

In [None]:
for (img, y, path) in dataloader:
    # get prediction
    pred = vgg(img)
    labelid = pred.argmax(dim=1)
    
    print(path, labelid)

    # get the gradient of the output with respect to the output of the model
    pred[:, labelid[0]].backward()

    # pull the gradients out of the model
    gradients = vgg.get_activations_gradient()

    # pool the gradients across channels
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # get the activations of the last convolutional layer
    activations = vgg.get_activations(img).detach()

    # weight the channels by corresponding gradients
    for i in range(512):
        activations[:, i, :, :] *= pooled_gradients[i]

    # average the channels of activation
    heatmap = torch.mean(activations, dim=1).squeeze()

    # ReLU on top of the heatmap
    heatmap = np.maximum(heatmap, 0)

    # normalize the heatmap
    heatmap /= torch.max(heatmap)

    # draw the heatmap
#     plt.matshow(heatmap.squeeze())
    
    img = cv2.imread(path[0])
    heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    canvas = heatmap * 0.4 + img * 0.60
    
    fig = plt.figure()
    plt.imshow(canvas[:,:,::-1] /255)