In [12]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from network.resnet import SupConResNet
import torch

In [13]:
PATH = "./best_model.pth"
model = SupConResNet(name='resnet18')
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [14]:
model

SupConResNet(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (cbam): CBAM(
          (ChannelGate): ChannelGate(
            (mlp): Sequential(
              (0): Flatten()
              (1): Linear(in_features=64, out_features=4, bias=True)
              (2): ReLU()
              (3): Linear(in_features=4, out_featu

In [15]:
for name, layer in model.named_modules():
...     if isinstance(layer, torch.nn.Conv2d):
...             print(name, layer)

encoder.conv1 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
encoder.layer1.0.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
encoder.layer1.0.conv2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
encoder.layer1.0.cbam.SpatialGate.spatial.conv Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
encoder.layer1.1.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
encoder.layer1.1.conv2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
encoder.layer1.1.cbam.SpatialGate.spatial.conv Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
encoder.layer2.0.conv1 Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
encoder.layer2.0.conv2 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
encoder.layer2.0.cbam.SpatialGate.spatial.conv Conv2d(

In [16]:
target_layer = model.encoder.layer4[-1]


In [17]:
target_layer

BasicBlock(
  (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cbam): CBAM(
    (ChannelGate): ChannelGate(
      (mlp): Sequential(
        (0): Flatten()
        (1): Linear(in_features=512, out_features=32, bias=True)
        (2): ReLU()
        (3): Linear(in_features=32, out_features=512, bias=True)
      )
    )
    (SpatialGate): SpatialGate(
      (compress): ChannelPool()
      (spatial): BasicConv(
        (conv): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (bn): BatchNorm2d(1, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
  )
  (shortcut): Sequential()
)

In [70]:
from torchvision.transforms import Compose, Normalize, ToTensor
import cv2
import numpy as np

def preprocess_image(img: np.ndarray, mean=None, std=None) -> torch.Tensor:
    if std is None:
        std = [0.5, 0.5, 0.5]
    if mean is None:
        mean = [0.5, 0.5, 0.5]

    preprocessing = Compose([
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])

    return preprocessing(img.copy()).unsqueeze(0)

In [75]:
image_path = "./data/AD_1.png"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
dim = (224, 224)
rgb_img = cv2.resize(rgb_img, dim)
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img, mean=None, 
                                             std=None)

In [76]:
cam = GradCAM(model=model, target_layer=target_layer, use_cuda=True)
target_category = None

In [77]:
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)

# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam)

In [78]:
cv2.imshow('input image',rgb_img)
cv2.imshow('output image',visualization)
cv2.waitKey(0)
cv2.destroyAllWindows()