In [None]:
from inspect import getsource
import math
import json
from PIL import Image
from tqdm.notebook import tqdm

import numpy as np

import torch
from torch import nn
from torchvision import models
from torchvision import transforms

import matplotlib.pyplot as plt
plt.style.use('ggplot')



class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.flatten(x, 1)
    
    
model = models.resnet34(pretrained=True)
model.eval()
layers = [
    model.conv1,
    model.bn1,
    model.relu,
    model.maxpool,
    model.layer1,
    model.layer2,
    model.layer3,
    model.layer4,
    model.avgpool,
    Flatten(),
    model.fc
]




normalize = transforms.Compose([
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

upsample = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.ToTensor(),
])


avg = nn.AdaptiveAvgPool2d(1)
relu = nn.ReLU()
softmax = nn.Softmax(dim=1)


def compute_gradcam(image, class_index=0, layers=layers, layer_index=7, shouldRelu=False, normalize_gradcam=False):
    x = normalize(image).unsqueeze(0)
    act = nn.Sequential(*layers[:layer_index+1])(x)
    act = act.detach().requires_grad_(True)
    pred = softmax(
        nn.Sequential(*layers[layer_index+1:])(act)
    )
    # argmax = pred.argmax(dim=1)
    if act.grad is not None:
        act.grad.data.fill_(0)

        ## the two-liner:
    pred[:,class_index].sum().backward(retain_graph=True)
    if shouldRelu:
#         gradcam = relu(act * avg(act.grad)).sum(dim=1)
        gradcam = relu((act * avg(act.grad)).sum(dim=1))
    else:
        gradcam = (act * avg(act.grad)).sum(dim=1)

    gradcam = gradcam.detach().numpy()
    if normalize_gradcam:
        vmax = gradcam.max()
        vmin = 0
        gradcam = (gradcam-vmin)/(vmax-vmin)
    return gradcam



img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomAffine(0, translate=[10/256, 10/256]),
#     transforms.RandomAffine(0, scale=[0.8,1.5]),
#     transforms.RandomCrop(224),
    transforms.CenterCrop(224),
#     transforms.RandomAffine(0, translate=[5/224, 5/224]),
    transforms.ToTensor(),
])

image = Image.open('cat-dog.jpg')
transformed_image = img_transform(image)

In [None]:
if __name__ == '__main__':
    
    for _ in range(10):
        transformed_image = img_transform(image)
        gradcam = compute_gradcam(transformed_image, 281, shouldRelu=False,)
        
        plt.figure(figsize=[10,4])
#         plt.subplot(121)
        plt.imshow(transformed_image.permute(1,2,0))
#         plt.subplot(122)
#         plt.imshow(gradcam[0], vmin=-0.02, vmax=0.02)
        plt.imshow(gradcam[0], alpha=0.8, extent=[0, 224, 224, 0])
        plt.colorbar()
        plt.axis('off')
        plt.show()

In [None]:
gradcams = np.concatenate([compute_gradcam(transformed_image, i) for i in tqdm(range(1000))])

In [None]:
gradcams.min(), gradcams.max()

In [None]:
gradcams.astype(np.float32).tofile('gradcam_test_1000x7x7.bin')

In [None]:
from torchvision.utils import save_image

In [None]:
save_image(transformed_image, 'image_test.png')