In [13]:
import torch
import numpy as np
from torchray.attribution.grad_cam import gradient_to_grad_cam_saliency
from torchray.attribution.guided_backprop import GuidedBackpropReLU
from torchray.attribution.common import Probe, get_module
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms
from pathlib import Path

In [27]:
class GuidedBackpropReLUModel:
    def __init__(self, model):
        self.model = model
#         self.model.eval()

        def recursive_relu_apply(module_top):
            for idx, module in module_top._modules.items():
                recursive_relu_apply(module)
                if module.__class__.__name__ == 'ReLU':
                    module_top._modules[idx] = GuidedBackpropReLU.apply

        # replace ReLU with GuidedBackpropReLU
        recursive_relu_apply(self.model)

    def forward(self, input_img):
        return self.model(input_img)

    def __call__(self, input_img, target_category=None):

        input_img = input_img.requires_grad_(True)

        output = self.forward(input_img)

        if target_category == None:
            target_category = np.argmax(output.cpu().data.numpy())

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0][target_category] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)

        one_hot = torch.sum(one_hot * output)
        one_hot.backward(retain_graph=True)
        output = input_img.grad.cpu().data.numpy()
        output = output[0, :, :, :]

        return output

In [28]:
model = torch.load("../models/label_final/resnet50_d_22_t_12_17.pth", map_location=torch.device(0))
model.eval();

In [29]:
a = GuidedBackpropReLUModel(model=model)