In [3]:
import torch
import torch.nn as nn


class gradCAM(nn.Module):
    def __init__(self, model, target_module, target_layer, device):
        super(gradCAM, self).__init__()
        self.target_activations = list()
        self.target_gradients = list()
        self._register_hook(model, target_module, target_layer)

        self.device = device
        self.model = model.to(self.device)
        self.model.eval()

    def _register_hook(self, model, target_module, target_layer):
        if target_module not in list(model._modules.keys()):
            raise TypeError('target module must be in list of modules of model')
        if target_layer not in list(model._modules.get(target_module)._modules.keys()):
            raise TypeError('target layer must be in list of layers of module')

        def register_forward_hook(module, input, output):
            self.target_activations.append(output)

        def register_backward_hook(module, grad_input, grad_output):
            self.target_gradients.append(grad_output[0])

        model._modules[target_module]._modules[target_layer].register_forward_hook(register_forward_hook)
        model._modules[target_module]._modules[target_layer].register_backward_hook(register_backward_hook)

    def forward(self, sample):  # sample with batch_size = 1
        sample = sample.to(self.device)  # [1, C, H, W]
        B, C, H, W = sample.shape
        preds = self.model(sample)  # [1, num_classes]

        categories = torch.argmax(preds, dim=1, keepdims=True)   # [1, num_classes]
        batch_onehot = torch.zeros(size=preds.shape, dtype=torch.float, device=self.device)  # [1, num_classes]
        batch_onehot.scatter_(dim=1, index=categories, value=1)  # [1, num_classes]
        batch_onehot = batch_onehot.requires_grad_(requires_grad=True)  # [1, num_classes]

        categories_score = torch.sum(batch_onehot * preds)  # scalar
        self.model.zero_grad()
        categories_score.backward(retain_graph=True)

        target_gradient = self.target_gradients[0]   # [1, Cf, Hf, Wf]
        target_activation = self.target_activations[0]   # [1, Cf, Hf, Wf]

        weights = torch.mean(target_gradient, dim=(0, 2, 3), keepdims=True)  # [1, Cf, 1, 1]
        saliency_map = torch.sum(target_activation * weights, dim=(0, 1), keepdims=True)  # [1, 1, Hf, Wf]
        saliency_map = nn.functional.relu(saliency_map)  # [1, 1, Hf, Wf]
        saliency_map = nn.functional.interpolate(input=saliency_map, size=(H, W), mode='bilinear', align_corners=False)  # [1, 1, H, W]
        saliency_map = saliency_map.squeeze(dim=0).squeeze(dim=0)  # [H, W]
        saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())
        saliency_map = saliency_map.cpu().data.numpy()

        return saliency_map

In [4]:
# model definition

import torch
import torchvision
import torch.nn as nn


class Resnet18(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(Resnet18, self).__init__()
        resnet18_model = torchvision.models.resnet18(pretrained=pretrained)
        self.resnet18_conv = nn.Sequential(*list(resnet18_model.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(in_features=resnet18_model.fc.in_features, out_features=num_classes, bias=True)
        self.linear.bias.data.fill_(0)

    def forward(self, x: torch.Tensor):
        x = self.resnet18_conv(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)
        return x


In [6]:
# preprocessing definition

def preprocess(image_path, image_size=(224, 224), device='cpu'):
    image = cv2.imread(image_path)
    sample = cv2.resize(image, dsize=image_size)
    sample = torch.from_numpy(sample).to(device).to(torch.float)
    sample = sample.unsqueeze(dim=0)
    sample = sample.permute(0, 3, 1, 2).contiguous()
    sample = (sample - sample.mean(dim=(1, 2, 3), keepdim=True)) / sample.std(dim=(1, 2, 3), keepdim=True)

    return image, sample

In [7]:
# initialization model

weight_path = '../../../../VTCC/phungpx/id_info_extraction/models/weights/hole_detection/2103310956/best_model_66_loss=-0.1619.pt'
num_classes = 2
device = 'cpu'
model = Resnet18(num_classes=num_classes)
model.load_state_dict(torch.load(f=weight_path, map_location='cpu'))
model = model.to(device)

In [8]:
import cv2
import numpy as np
from PIL import Image
from torchvision import models, transforms

def preprocess(image_path, image_size=(224, 224), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    image = cv2.imread(image_path)
    sample = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    sample = Image.fromarray(sample)
    transform = transforms.Compose([
        transforms.Resize(size=image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    sample = transform(sample).unsqueeze(dim=0)
    return image, sample


image_path = '/home/phungpx/Downloads/test/bird.png'
image, sample = preprocess(image_path)
print(sample.shape)

In [35]:
model = models.resnet50(pretrained=True)
target_module = 'layer4'
target_layer = '2'
visualizer = gradCAM(model=model, target_module=target_module, target_layer=target_layer, device='cpu')

In [36]:
import cv2
cam = visualizer(sample)
cam = cv2.resize(cam, dsize=(image.shape[1], image.shape[0]))
print(cam.shape)

(1500, 2000)


In [37]:
def show_cam_on_image(image, cam):
    image = np.float32(image) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    heatmap = heatmap + np.float32(image)
    heatmap = heatmap / np.max(heatmap)
    heatmap = np.uint8(255 * heatmap)
    return heatmap

In [38]:
cam = show_cam_on_image(image, cam)

In [39]:
cv2.imshow('class activation map', cam)
cv2.waitKey()
cv2.destroyAllWindows()

In [71]:
categories = torch.argmax(preds, dim=1, keepdims=True)
batch_onehot = torch.zeros(size=preds.shape, dtype=torch.float, device='cpu')
batch_onehot.scatter_(dim=1, index=categories, value=1)
batch_onehot = batch_onehot.requires_grad_(requires_grad=True)
score = torch.sum(batch_onehot * preds)

In [72]:
score.backward(retain_graph=True)

Bottleneck(
  (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

grad_input:  <class 'tuple'>
grad_input[0]:  <class 'torch.Tensor'>
grad_output:  <class 'tuple'>
grad_output[0]:  <class 'torch.Tensor'>

grad_input size: torch.Size([1, 2048, 7, 7])
grad_output size: torch.Size([1, 2048, 7, 7])


## RUN

In [1]:
import os
print(f'current dir: {os.getcwd()}')
os.chdir('/media/phungpx/WORKSPACE/PHUNGPX/CAMs/')
print(f'changed dir: {os.getcwd()}')

current dir: /media/phungpx/WORKSPACE/PHUNGPX/CAMs/debugs
changed dir: /media/phungpx/WORKSPACE/PHUNGPX/CAMs


In [2]:
import cv2
import utils
import random
from pathlib import Path

### Test

In [3]:
config = utils.load_yaml('./modules/gradCAM/config.yaml')
visualizer = utils.create_instance(config['gradCAM_pretrained'])

In [4]:
image_path = '/home/phungpx/Downloads/test/bird.png'
image = cv2.imread(image_path)
grad_cam, heatmap = visualizer(image)

In [5]:
def _resize(image, max_dim=800):
    h, w = image.shape[:2]
    if (h > w) and (h > max_dim):
        image = cv2.resize(image, dsize=(int(max_dim * w / h), max_dim))
    elif (w > h) and (w > max_dim):
        image = cv2.resize(image, dsize=(max_dim, int(max_dim * h / w)))
    return image

cv2.imshow('grad CAM', _resize(heatmap))
cv2.waitKey()
cv2.destroyAllWindows()

### Run

In [10]:
from ..models.definitions.resnet18 import Resnet18
model = Resnet18(num_classes=2)
print('modules:', model._modules.keys())
print('layers in resnet18_conv:', model._modules['resnet18_conv']._modules.keys())

ImportError: attempted relative import with no known parent package

In [3]:
config = utils.load_yaml('./modules/gradCAM/config.yaml')
visualizer = utils.create_instance(config['gradCAM'])

In [None]:
def _resize(image, max_dim=800):
    h, w = image.shape[:2]
    if (h > w) and (h > max_dim):
        image = cv2.resize(image, dsize=(int(max_dim * w / h), max_dim))
    elif (w > h) and (w > max_dim):
        image = cv2.resize(image, dsize=(max_dim, int(max_dim * h / w)))
    return image

image_dir = Path('../../ID_CARD/hole_classification/dataset/test/hole/CMQD_A/')
image_paths = list(image_dir.glob('**/*.*'))
print(len(image_paths))
for i, image_path in enumerate(image_paths[4:]):
    if i == 10:
        break
    image = cv2.imread(str(image_path))
    gradCAM, heatmap = visualizer(image)
    cv2.imshow('grad CAM', _resize(heatmap))
    cv2.waitKey()
    cv2.destroyAllWindows()

108
