In [13]:
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import torch
import sys
sys.path.insert(0, '../scripts/')
# from models import model_dict
from dataset import Galaxy10DECals
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from e2cnn import nn as e2cnn_nn
from cnn import ConvBlock as CNN_ConvBlock
import cnn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import warnings

def grad_cam(model, image, target_class=None):
    # Store the gradients of the last convolutional layer
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    # Register hooks
    last_conv_layer = list(model.children())[-2]  # assuming the last layer is a fully connected layer
    handle_forward = last_conv_layer.register_forward_hook(forward_hook)
    handle_backward = last_conv_layer.register_backward_hook(backward_hook)

    # Forward pass
    logits = model(image)
    model.zero_grad()

    # If no specific class is targeted, use the max prediction
    if target_class is None:
        target_class = torch.argmax(logits, dim=1).item()

    # Backward pass
    logits[0, target_class].backward()

    # Unregister hooks
    handle_forward.remove()
    handle_backward.remove()

    gradient = gradients[0]  # [C, H, W]
    alpha = torch.mean(gradient, dim=(1, 2), keepdim=True)  # [C, 1, 1]
    weights = alpha / (torch.sum(alpha, dim=0) + 1e-5)

    activation = activations[0]  # [C, H, W]
    grad_cam_map = torch.sum(activation * weights, dim=0)  # [H, W]
    grad_cam_map = F.relu(grad_cam_map).detach()

    return grad_cam_map

# For visualization
# def show_cam_on_image(img, mask):
#     heatmap = cv2.applyColorMap(np.uint8(255 * mask.cpu().numpy()), cv2.COLORMAP_JET)
#     heatmap = np.float32(heatmap) / 255
#     cam = heatmap + np.float32(img.permute(1,2,0).cpu().numpy())
#     cam = cam / np.max(cam)
#     return np.uint8(255 * cam)

def show_cam_on_image(img, mask):
    # Ensure mask is a 2D tensor and convert it to 0-255 uint8 numpy array
    assert len(mask.shape) == 2, "Mask should be 2D"
    mask_np = np.uint8(255 * mask.cpu().numpy())
    
    # Convert tensor to numpy array with values in range [0, 255]
    img_np = img.squeeze().permute(1, 2, 0).cpu().numpy()
    assert img_np.max() <= 255.0 and img_np.min() >= 0.0, "Image should have values between 0 and 255"
    
    # Apply color map
    heatmap = cv2.applyColorMap(mask_np, cv2.COLORMAP_JET)
    
    # Add CAM heatmap to the original image
    combined = np.float32(heatmap) + img_np
    combined = combined / combined.max() * 255  # Normalize to [0, 255] range
    
    return np.uint8(combined)

# # Example usage:
# model = ...  # Your trained model
# model.eval()

# image = ...  # Your input image of shape [3, 255, 255] and preferably normalized
# image_tensor = image.unsqueeze(0)  # make it [1, 3, 255, 255]

# cam = grad_cam(model, image_tensor)

# # For visualization
# result = show_cam_on_image(image, cam)

In [3]:
data = Galaxy10DECals('../../data/Galaxy10_DECals.h5') ## removed transform
test_loader = DataLoader(data, batch_size=1, shuffle=True)

warnings.filterwarnings('ignore')
device = ('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.ToTensor()
    # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
])

images, labels = next(iter(test_loader))
images = images.permute(0,3,1,2).to(torch.float32)
images = F.interpolate(images, size=(255, 255), mode='bilinear', align_corners=True)


In [7]:
images.shape, images.dtype

(torch.Size([1, 3, 255, 255]), torch.float32)

In [4]:
min(images.flatten())
# converted_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# converted_img = cv2.cvtColor(np.array(images), cv2.COLOR_BGR2BGR555)

tensor(0.)

In [9]:
model_str = 'CNN'
model_3 = cnn.load_CNN()
model_3.load_state_dict(torch.load(f'../../data/new_icml/new_results/CNN.pt', map_location=device))
model_3.eval()

GeneralCNN(
  (block1): ConvBlock(
    (conv): Conv2d(3, 12, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2), bias=False)
    (bn): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
  )
  (block2): ConvBlock(
    (conv): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): ConvBlock(
    (conv): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
  )
  (block4): ConvBlock(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [14]:
cam = grad_cam(model_3, images)
print(cam.shape)

# For visualization
result = show_cam_on_image(images, cam)

torch.Size([192, 4, 4])


AssertionError: Mask should be 2D