In [None]:
from inspect import getsource
import math
import json
from PIL import Image
from tqdm.notebook import tqdm

import numpy as np

import torch
from torch import nn
from torchvision import models
from torchvision import transforms

import matplotlib.pyplot as plt
plt.style.use('ggplot')


In [None]:
with open('imagenet_classes.json') as f:
    imagenet_classes = json.load(f)
imagenet_classes

In [None]:
image = Image.open('cat-dog.jpg')
image

In [None]:
class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.flatten(x, 1)
    
    
model = models.resnet50(pretrained=True)
model.eval()
print(getsource(model._forward_impl))

layers = [
    model.conv1,
    model.bn1,
    model.relu,
    model.maxpool,
    model.layer1,
    model.layer2,
    model.layer3,
    model.layer4,
    model.avgpool,
    Flatten(),
    model.fc
]

In [None]:
img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
#     transforms.CenterCrop(224),
    transforms.ToTensor(),
])

normalize = transforms.Compose([
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

avg = nn.AdaptiveAvgPool2d(1)
relu = nn.ReLU()
softmax = nn.Softmax(dim=1)

In [None]:
layer_index = 7

transformed_image = img_transform(image)
x = normalize(transformed_image).unsqueeze(0)
act = nn.Sequential(*layers[:layer_index+1])(x)
act = act.detach().requires_grad_(True)
pred = softmax(
    nn.Sequential(*layers[layer_index+1:])(act)
)
argmax = pred.argmax(dim=1)
print(f'pred: {argmax.item()}, {imagenet_classes[str(argmax.item())]}')


plt.figure(figsize=[6, 6], dpi=80)
plt.imshow(transformed_image.permute(1,2,0), )
plt.title(f'pred={imagenet_classes[str(argmax.item())]}')
plt.axis('off')
plt.show()


gradcams = None
for i, class_index in enumerate(tqdm(range(1000))):
    
    if act.grad is not None:
        act.grad.data.fill_(0)
        
    ## the two-liner:
    pred[:,class_index].sum().backward(retain_graph=True)
    gradcam = relu(act * avg(act.grad)).sum(dim=1)
    
    gradcam = gradcam.detach().numpy()
    if gradcams is None:
        gradcams = np.zeros([1000,*gradcam.shape[1:]])
    gradcams[i] = gradcam
#     plt.subplot(201,5,i+2)
#     plt.imshow(gradcam[0], vmax=1.5e-4)
#     plt.title(f'gradcam_on={imagenet_classes[str(class_index)].split(",")[0]}\nmax={gradcam.max():.2e}')
#     # plt.colorbar()
#     plt.axis('off')
    
# plt.show()


In [None]:
nrows = 10
ncols = 10
nfigs = math.ceil(1000/nrows/ncols)
figsize = 4 * np.array([ncols, nrows]) * np.array([gradcams.shape[2],gradcams.shape[1]])/min(gradcams.shape[1:])
figsize[1] += 3

vmax = gradcams.max()
for fig in range(nfigs):
    plt.figure(figsize=figsize, dpi=120)
    for i in range(nrows):
        for j in range(ncols):
            class_index = fig*(nrows*ncols) + i*ncols + j
            gradcam = gradcams[class_index]
            plt.subplot(nrows,ncols,i*ncols+j+1)
            plt.imshow(gradcam, vmin=0, vmax=vmax)
            plt.title(f'{imagenet_classes[str(class_index)].split(",")[0]}\n{gradcam.max().item():.2e}')
            plt.axis('off')
    plt.show()