In [1]:
import os
import copy
import cv2
cv2.setNumThreads(0) 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import timm
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder


import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
import json
from PIL import Image

from resnet50_model import defect_resnet50
from resnet18_model import defect_resnet18
from resnet101_model import defect_resnet101
from vit_model_base import defect_vit

from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
from sklearn.decomposition import PCA
from scipy.stats import pearsonr

  check_for_updates()


In [2]:
# Define the image format conversion functions
def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)
    elif x.dim() == 3:
        x = x if x.shape[0] == 3 else x.permute(2, 0, 1)
    return x

def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)
    elif x.dim() == 3:
        x = x if x.shape[2] == 3 else x.permute(1, 2, 0)
    return x

# Define the transformation pipeline
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define inverse transformation pipeline (for visualization)
inv_transform = transforms.Compose([
    transforms.Normalize(mean=[-m / s for m, s in zip([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])], 
                         std=[1 / s for s in [0.229, 0.224, 0.225]]),
])


# Select device
#device = "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the model
ckpt_path1="results/resnet50/output_1_00/last.ckpt" #path to last lightning checkpoint
checkpoint1 = defect_resnet50.load_from_checkpoint(ckpt_path1)
model1 = checkpoint1.model
model1.to(device)
model1.eval()

ckpt_path2="results/resnet18/output_1_00/last.ckpt" #path to last lightning checkpoint
checkpoint2 = defect_resnet18.load_from_checkpoint(ckpt_path2)
model2 = checkpoint2.model
model2.to(device)
model2.eval()

ckpt_path3="results/resnet101/output_1_00/last.ckpt" #path to last lightning checkpoint
checkpoint3 = defect_resnet101.load_from_checkpoint(ckpt_path3)
model3 = checkpoint3.model
model3.to(device)
model3.eval()

ckpt_path4="results/deit/output_1_00/last.ckpt" #path to last lightning checkpoint
checkpoint4 = defect_vit.load_from_checkpoint(ckpt_path4)
model4 = checkpoint4.model
model4.to(device)
model4.eval()

print('models are loaded')


models are loaded


In [3]:
# Load the dataset
dataset = ImageFolder(root='imagenet/test/', transform=transform)
class_names = dataset.classes  # Get class names

# Define DataLoader with the desired batch size
batch_size = len(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=24)

# Fetch the first batch
for X, y in dataloader:
    break  # Exit the loop after the first batch

# Print the shape of the images and labels
print(X.dtype, X.shape)  # Should print torch.Size([50, 3, 224, 224])
print(y.shape)  # Should print torch.Size([50])

torch.float32 torch.Size([15404, 3, 224, 224])
torch.Size([15404])


In [4]:
def predict_arm1(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model1(img)  # Forward pass through the model
    return output

def predict_arm2(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model2(img)  # Forward pass through the model
    return output


def predict_arm3(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model3(img)  # Forward pass through the model
    return output

def predict_arm4(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model4(img)  # Forward pass through the model
    return output

def predict_arm5(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model5(img)  # Forward pass through the model
    return output

# Function to generate a Grad-CAM heatmap for Transformer models
def gen_cam(image, mask):
    # Create a heatmap from the Grad-CAM mask
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = 255 - heatmap
    heatmap = np.float32(heatmap) / 255
    

    # Superimpose the heatmap on the original image
    cam = (1 - 0.5) * heatmap + 0.5 * image
    cam = cam / np.max(cam)  # Normalize the result
    return np.uint8(255 * cam)  # Convert to 8-bit image

### GradCAM Class

In [5]:
class Hooks():
    def __init__(self, name, module, backward=False):
        
        self.name = name 
        
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    
    def close(self):
        self.hook.remove()

class GradCam:
    def __init__(self, model, target):
        self.model = model.eval()  # Set the model to evaluation mode
        self.feature = None  # To store the features from the target layer
        self.gradient = None  # To store the gradients from the target layer
        self.handlers = []  # List to keep track of hooks
        self.target = target  # Target layer for Grad-CAM
        self._get_hook()  # Register hooks to the target layer

    # Hook to get features from the forward pass
    def _get_features_hook(self, module, input, output):
        self.feature = self.reshape_transform(output)  # Store and reshape the output features

    # Hook to get gradients from the backward pass
    def _get_grads_hook(self, module, input_grad, output_grad):
        self.gradient = self.reshape_transform(output_grad)  # Store and reshape the output gradients

        def _store_grad(grad):
            self.gradient = self.reshape_transform(grad)  # Store gradients for later use

        output_grad.register_hook(_store_grad)  # Register hook to store gradients

    # Register forward hooks to the target layer
    def _get_hook(self):
        self.target.register_forward_hook(self._get_features_hook)
        self.target.register_forward_hook(self._get_grads_hook)

    # Function to reshape the tensor for visualization
    def reshape_transform(self, tensor, height=14, width=14):
        result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
        result = result.transpose(2, 3).transpose(1, 2)  # Rearrange dimensions to (C, H, W)
        return result

    # Function to compute the Grad-CAM heatmap
    def __call__(self, inputs):
        self.model.zero_grad()  # Zero the gradients
        output = self.model(inputs)  # Forward pass

        # Get the index of the highest score in the output
        index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]  # Get the target score
        target.backward()  # Backward pass to compute gradients

        # Get the gradients and features
        gradient = self.gradient[0].cpu().data.numpy()
        weight = np.mean(gradient, axis=(1, 2))  # Average the gradients
        feature = self.feature[0].cpu().data.numpy()

        # Compute the weighted sum of the features
        cam = feature * weight[:, np.newaxis, np.newaxis]
        cam = np.sum(cam, axis=0)  # Sum over the channels
        cam = np.maximum(cam, 0)  # Apply ReLU to remove negative values

        # Normalize the heatmap
        cam -= np.min(cam)
        cam /= np.max(cam)
        cam = cv2.resize(cam, (224, 224))  # Resize to match the input image size
        return cam  # Return the Grad-CAM heatmap


# Here we are saving all hooks for all layers into two lists, one for forward hooks and one for backwards hooks.
# Every hook is capable to register both the input and the output (see the above class).
# Note the different directions of the flow.
# For Forward hooks the feature maps flow rightwise:  input ----> layer ----> output
# For Backward hooks the gradients flow leftwise: output <----- layer ----- input

hookF1 = [Hooks(name, layer) for name, layer in model1.named_modules()]
hookB1 = [Hooks(name, layer,backward=True) for name, layer in model1.named_modules()]

hookF2 = [Hooks(name, layer) for name, layer in model2.named_modules()]
hookB2 = [Hooks(name, layer,backward=True) for name, layer in model2.named_modules()]

hookF3 = [Hooks(name, layer) for name, layer in model3.named_modules()]
hookB3 = [Hooks(name, layer,backward=True) for name, layer in model3.named_modules()]


# Layers in Models:

# for name, layer in model1.named_modules():
#     print(name)

# Generating GradCAM and Ensembling

In [10]:
results = []
flag = 0
# For looping and ensembling
for i in range(X.shape[0]):
    
    sampleX = X[i:i+1]
    true_label = [y[i]]
    true_label_value = true_label[0].item()
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    H, W = img_jpeg.shape[0], img_jpeg.shape[1]

    ######### RESNET 50 #########
    act, grad = None, None
    mask1=0
    heatmap1=0  
    out1 = predict_arm1(sampleX)
    classes1 = torch.argmax(out1, axis=1).cpu().numpy()
    out1[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF1:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()
    
    for hook in hookB1:
        if hook.name == 'layer4.2.conv3':
            grad = hook.output[0].detach()

        
    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad1 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad1[j]   
    # Sum across all k-features (Eq. 2)
    heatmap1 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap1 = np.maximum(heatmap1, 0)
    # normalize the heatmap between 0 and 1
    heatmap1 /= torch.max(heatmap1)
    heatmap1 = cv2.resize(np.array(heatmap1), (W, H))
  

    ######### RESNET 18 #########
    act, grad = None, None
    mask2=0
    heatmap2=0
    out2 = predict_arm2(sampleX)
    classes2 = torch.argmax(out2, axis=1).cpu().numpy()
    out2[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF2:
        if hook.name == 'layer4.1.conv2':
            act = hook.output.detach()
    
    for hook in hookB2:
        if hook.name == 'layer4.1.conv2':
            grad = hook.output[0].detach()

    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad2 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad2[j]   
    # Sum across all k-features (Eq. 2)
    heatmap2 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap2 = np.maximum(heatmap2, 0)
    # normalize the heatmap between 0 and 1
    heatmap2 /= torch.max(heatmap2)
    heatmap2 = cv2.resize(np.array(heatmap2), (W, H))


    ######### RESNET 101 #########
    act, grad = None, None
    mask3=0
    heatmap3=0
    out3 = predict_arm3(sampleX)
    classes3 = torch.argmax(out3, axis=1).cpu().numpy()
    out3[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF3:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()

    for hook in hookB3:
        if hook.name == 'layer4.2.conv3':
            grad = hook.output[0].detach()

    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad3 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad3[j]   
    # Sum across all k-features (Eq. 2)
    heatmap3 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap3 = np.maximum(heatmap3, 0)
    # normalize the heatmap between 0 and 1
    heatmap3 /= torch.max(heatmap3)
    heatmap3 = cv2.resize(np.array(heatmap3), (W, H))


    ###### DEiT Model #########
    img = sampleX
    out4 = predict_arm4(sampleX)
    classes4 = torch.argmax(out4, axis=1).cpu().numpy()
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    # Load the Vision Transformer and target layer
    target_layer = model4.blocks[-1].norm1  # Specify the target layer for Grad-CAM
    
    # Initialize Grad-CAM with the model and target layer
    grad_cam = GradCam(model4, target_layer)
    heatmap4 = grad_cam(img)  # Compute the Grad-CAM mask

    # Check if any of the heatmaps contain NaN values
    if np.isnan(heatmap1).any() or np.isnan(heatmap2).any() or np.isnan(heatmap3).any() or np.isnan(heatmap4).any():
        flag = flag + 1
        continue 
    
    # Ensemble Calculations
    # # Calculate differences
    # d1_2= heatmap1 - heatmap2
    # d1_3= heatmap1 - heatmap3
    # d2_3= heatmap3 - heatmap2
    
    # # Mean Absolute Error
    # mae1_2 = np.mean(np.abs(d1_2))
    # mae1_3 = np.mean(np.abs(d1_3))
    # mae2_3 = np.mean(np.abs(d2_3))
    
    # # Root Mean Squared Error
    # rmse1_2 = np.sqrt(np.mean(d1_2**2))
    # rmse1_3 = np.sqrt(np.mean(d1_3**2))
    # rmse2_3 = np.sqrt(np.mean(d2_3**2))
    
    smoothed1 = gaussian_filter(heatmap1, sigma=3)
    smoothed2 = gaussian_filter(heatmap2, sigma=3)
    smoothed3 = gaussian_filter(heatmap3, sigma=3)
    smoothed4 = gaussian_filter(heatmap4, sigma=3)
    
    peaks1 = peak_local_max(smoothed1, min_distance=3)
    peaks2 = peak_local_max(smoothed2, min_distance=3)
    peaks3 = peak_local_max(smoothed3, min_distance=3)
    peaks4 = peak_local_max(smoothed4, min_distance=3)

    num_peaks_RESNET50 = peaks1.shape[0]
    num_peaks_RESNET18 = peaks2.shape[0]
    num_peaks_RESNET101 = peaks3.shape[0]
    num_peaks_DEIT = peaks4.shape[0]
    
    # Flatten the matrices
    flat1 = heatmap1.flatten()
    flat2 = heatmap2.flatten()
    flat3 = heatmap3.flatten()
    flat4 = heatmap4.flatten()
    
    # Calculate Pearson correlations
    corr_1_2, _ = pearsonr(flat1, flat2)
    corr_1_3, _ = pearsonr(flat1, flat3)
    corr_1_4, _ = pearsonr(flat1, flat4)
    corr_2_3, _ = pearsonr(flat2, flat3)
    corr_2_4, _ = pearsonr(flat2, flat4)
    corr_3_4, _ = pearsonr(flat3, flat4)
    

    # Collect results for this iteration
    results.append([
        true_label_value,  # true label
        classes1[0],    # classes from RESNET50
        classes2[0],    # classes from RESNET18
        classes3[0],    # classes from RESNET101
        classes4[0],    # classes from DEiT
        num_peaks_RESNET50, num_peaks_RESNET18, num_peaks_RESNET101, num_peaks_DEIT,  # Number of peaks
        corr_1_2, corr_1_3, corr_1_4, corr_2_3, corr_2_4, corr_3_4 #PearsonR correlation
    ])
    #print(f"Count: {i}")

# Convert results to a DataFrame
columns = [
    'True Label', 'Classes1', 'Classes2', 'Classes3', 'Classes4',
    'Num Peaks RESNET50', 'Num Peaks RESNET18', 'Num Peaks RESNET101', 'Num Peaks DEIT',
    'corr_1_2', 'corr_1_3', 'corr_1_4', 'corr_2_3', 'corr_2_4', 'corr_3_4'
]

df = pd.DataFrame(results, columns=columns)

# Print the table
#print(df)

# Save to CSV
df.to_csv('results.csv', index=False)

print("The run is complete")
print(flag)


Count: 1
Count: 2
Count: 3
Count: 4
Count: 5
Count: 6
Count: 7
Count: 8
Count: 13
Count: 14
Count: 16
Count: 17
Count: 18
Count: 19
Count: 20
Count: 21
Count: 22
Count: 23
Count: 25
Count: 26
Count: 27
Count: 29
Count: 31
Count: 32
Count: 33
Count: 34
Count: 35
Count: 36
Count: 37
Count: 38
Count: 39
Count: 42
Count: 43
Count: 44
Count: 46
Count: 47
Count: 48
Count: 49
Count: 50
Count: 51
Count: 52
Count: 53
Count: 55
Count: 56
Count: 57
Count: 58
Count: 60
Count: 62
Count: 63
Count: 65
Count: 68
Count: 70
Count: 71
Count: 73
Count: 76
Count: 78
Count: 79
Count: 81
Count: 87
Count: 88
Count: 89
Count: 90
Count: 91
Count: 92
Count: 93
Count: 94
Count: 95
Count: 97
Count: 98
Count: 99
Count: 100
Count: 101
Count: 105
Count: 106
Count: 109
Count: 110
Count: 111
Count: 112
Count: 113
Count: 116
Count: 117
Count: 118
Count: 119
Count: 120
Count: 123
Count: 124
Count: 125
Count: 126
Count: 127
Count: 128
Count: 129
Count: 130
Count: 131
Count: 132
Count: 133
Count: 134
Count: 136
Count: 137


  cam /= np.max(cam)


Count: 14367
Count: 14369
Count: 14372
Count: 14374
Count: 14376
Count: 14379
Count: 14380
Count: 14381
Count: 14382
Count: 14385
Count: 14388
Count: 14389
Count: 14390
Count: 14392
Count: 14393
Count: 14397
Count: 14401
Count: 14409
Count: 14410
Count: 14412
Count: 14414
Count: 14415
Count: 14418
Count: 14420
Count: 14425
Count: 14426
Count: 14427
Count: 14428
Count: 14429
Count: 14434
Count: 14436
Count: 14437
Count: 14439
Count: 14440
Count: 14442
Count: 14443
Count: 14444
Count: 14445
Count: 14447
Count: 14449
Count: 14460
Count: 14461
Count: 14462
Count: 14469
Count: 14471
Count: 14472
Count: 14473
Count: 14475
Count: 14476
Count: 14477
Count: 14482
Count: 14485
Count: 14487
Count: 14488
Count: 14489
Count: 14490
Count: 14493
Count: 14500
Count: 14503
Count: 14504
Count: 14506
Count: 14507
Count: 14510
Count: 14511
Count: 14513
Count: 14514
Count: 14515
Count: 14518
Count: 14519
Count: 14520
Count: 14536
Count: 14538
Count: 14539
Count: 14540
Count: 14541
Count: 14545
Count: 14550

In [11]:
X.shape

torch.Size([15404, 3, 224, 224])

# Anecdotal Example Generation

In [13]:
import random
random_indices = random.sample(range(X.shape[0]), 200)
label_dirs = {
    0: 'mul_pick',
    1: 'nom',
    2: 'pkg_def'
}

In [16]:
# For looping and ensembling
img_dir = 'results/gradCAM/anecdotes'
for i in random_indices: #range(X.shape[0]):

    
    sampleX = X[i:i+1]
    true_label = [y[i]]
    true_label_value = true_label[0].item()
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    H, W = img_jpeg.shape[0], img_jpeg.shape[1]

    
    ######### RESNET 50 #########
    act, grad = None, None
    mask1=0
    heatmap1=0  
    out1 = predict_arm1(sampleX)
    classes1 = torch.argmax(out1, axis=1).cpu().numpy()
    out1[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF1:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()
    
    for hook in hookB1:
        if hook.name == 'layer4.2.conv3':
            grad = hook.output[0].detach()

        
    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad1 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad1[j]   
    # Sum across all k-features (Eq. 2)
    heatmap1 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap1 = np.maximum(heatmap1, 0)
    # normalize the heatmap between 0 and 1
    heatmap1 /= torch.max(heatmap1)
    heatmap1 = cv2.resize(np.array(heatmap1), (W, H))
    mask1 = heatmap1>0.5 # Threshold
  

    ######### RESNET 18 #########
    act, grad = None, None
    mask2=0
    heatmap2=0
    out2 = predict_arm2(sampleX)
    classes2 = torch.argmax(out2, axis=1).cpu().numpy()
    out2[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF2:
        if hook.name == 'layer4.1.conv2':
            act = hook.output.detach()
    
    for hook in hookB2:
        if hook.name == 'layer4.1.conv2':
            grad = hook.output[0].detach()

    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad2 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad2[j]   
    # Sum across all k-features (Eq. 2)
    heatmap2 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap2 = np.maximum(heatmap2, 0)
    # normalize the heatmap between 0 and 1
    heatmap2 /= torch.max(heatmap2)
    heatmap2 = cv2.resize(np.array(heatmap2), (W, H))
    mask2 = heatmap2>0.5 # Threshold



    ######### RESNET 101 #########
    act, grad = None, None
    mask3=0
    heatmap3=0
    out3 = predict_arm3(sampleX)
    classes3 = torch.argmax(out3, axis=1).cpu().numpy()
    out3[:, true_label[0].item()].backward()
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    for hook in hookF3:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()

    for hook in hookB3:
        if hook.name == 'layer4.2.conv3':
            grad = hook.output[0].detach()

    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad3 = torch.mean(grad, dim=[0, 2, 3])   
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad3[j]   
    # Sum across all k-features (Eq. 2)
    heatmap3 = torch.sum(act, dim=1).squeeze()
    # Relu (Eq. 2)
    heatmap3 = np.maximum(heatmap3, 0)
    # normalize the heatmap between 0 and 1
    heatmap3 /= torch.max(heatmap3)
    heatmap3 = cv2.resize(np.array(heatmap3), (W, H))
    mask3 = heatmap3>0.5 # Threshold

    ###### DEiT Small Model #########
    img = sampleX
    out4 = predict_arm4(sampleX)
    classes4 = torch.argmax(out4, axis=1).cpu().numpy()
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    # Load the Vision Transformer and target layer
    target_layer = model4.blocks[-1].norm1  # Specify the target layer for Grad-CAM
    
    # Initialize Grad-CAM with the model and target layer
    grad_cam = GradCam(model4, target_layer)
    heatmap4 = grad_cam(img)  # Compute the Grad-CAM mask


    
    ###### DEiT Base Model #########
    # out5 = predict_arm5(sampleX)
    # classes5 = torch.argmax(out5, axis=1).cpu().numpy()
    # img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    # # Load the Vision Transformer and target layer
    # target_layer = model5.blocks[-1].norm1  # Specify the target layer for Grad-CAM
    
    # # Initialize Grad-CAM with the model and target layer
    # grad_cam = GradCam(model5, target_layer)
    # heatmap5 = grad_cam(img)  # Compute the Grad-CAM mask

    
    # all superimposed images
    result1 = gen_cam(img_jpeg, heatmap1)
    result2 = gen_cam(img_jpeg, heatmap2)
    result3 = gen_cam(img_jpeg, heatmap3)
    result4 = gen_cam(img_jpeg, heatmap4)
    #result5 = gen_cam(img_jpeg, heatmap5)
    
    ## Image Generation and Saving
    fig, axs = plt.subplots(1,5, figsize=(18,3))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_xticks([])  # Turn off x-axis ticks
    axs[0].set_yticks([])  # Turn off y-axis ticks
    axs[0].set_xlabel(f'Original Image: {class_names[y[i]]}')
    
    img2 = axs[1].imshow(result1)
    axs[1].set_xticks([])  # Turn off x-axis ticks
    axs[1].set_yticks([])  # Turn off y-axis ticks
    axs[1].set_xlabel(f'RESNET50: {class_names[classes1[0]]}')
    
    img3 = axs[2].imshow(result2)
    axs[2].set_xticks([])  # Turn off x-axis ticks
    axs[2].set_yticks([])  # Turn off y-axis ticks
    axs[2].set_xlabel(f'RESNET18: {class_names[classes2[0]]}')
    
    img4 = axs[3].imshow(result3)
    axs[3].set_xticks([])  # Turn off x-axis ticks
    axs[3].set_yticks([])  # Turn off y-axis ticks
    axs[3].set_xlabel(f'RESNET101: {class_names[classes3[0]]}')
    
    img5 = axs[4].imshow(result4)
    axs[4].set_xticks([])  # Turn off x-axis ticks
    axs[4].set_yticks([])  # Turn off y-axis ticks
    axs[4].set_xlabel(f'DEiT Small: {class_names[classes4[0]]}')

    # img6 = axs[5].imshow(result5)
    # axs[5].set_xticks([])  # Turn off x-axis ticks
    # axs[5].set_yticks([])  # Turn off y-axis ticks
    # axs[5].set_xlabel(f'DEiT Base: {class_names[classes5[0]]}')
    
    save_subdir = label_dirs.get(true_label_value,'')

    if save_subdir:
        save_path = os.path.join(img_dir, save_subdir)
        os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
        plt.savefig(os.path.join(save_path, f'ex_{i}.png'), dpi=300)
    #plt.savefig(os.path.join(img_dir, f'ex_{i}.png'), dpi=300)  # Using f-string to format the filename
    plt.close()  # Close the figure to free up memory
    #plt.show()   
    
    # # Flatten the matrices
    # flat1 = heatmap1.flatten()
    # flat2 = heatmap2.flatten()
    # flat3 = heatmap3.flatten()
    # flat4 = heatmap4.flatten()
    

    # # Calculate Pearson correlations
    # corr_1_2, _ = pearsonr(flat1, flat2)
    # corr_1_3, _ = pearsonr(flat1, flat3)
    # corr_1_4, _ = pearsonr(flat1, flat4)
    # corr_2_3, _ = pearsonr(flat2, flat3)
    # corr_2_4, _ = pearsonr(flat2, flat4)
    # corr_3_4, _ = pearsonr(flat3, flat4)

    #     # Calculate Spearman correlations
    # scorr_1_2, _ = spearmanr(flat1, flat2)
    # scorr_1_3, _ = spearmanr(flat1, flat3)
    # scorr_1_4, _ = spearmanr(flat1, flat4)
    # scorr_2_3, _ = spearmanr(flat2, flat3)
    # scorr_2_4, _ = spearmanr(flat2, flat4)
    # scorr_3_4, _ = spearmanr(flat3, flat4)
    
    # print(f'Pearson correlation between Heatmap 1 and 2: {corr_1_2}')
    # print(f'Pearson correlation between Heatmap 1 and 3: {corr_1_3}')
    # print(f'Pearson correlation between Heatmap 1 and 4: {corr_1_4}')
    # print(f'Pearson correlation between Heatmap 2 and 3: {corr_2_3}')
    # print(f'Pearson correlation between Heatmap 2 and 4: {corr_2_4}')
    # print(f'Pearson correlation between Heatmap 3 and 4: {corr_3_4}')
    
    # print(f'Spearman correlation between Heatmap 1 and 2: {scorr_1_2}')
    # print(f'Spearman correlation between Heatmap 1 and 3: {scorr_1_3}')
    # print(f'Spearman correlation between Heatmap 1 and 4: {scorr_1_4}')
    # print(f'Spearman correlation between Heatmap 2 and 3: {scorr_2_3}')
    # print(f'Spearman correlation between Heatmap 2 and 4: {scorr_2_4}')
    # print(f'Spearman correlation between Heatmap 3 and 4: {scorr_3_4}')

  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)


# Individual Calcs

In [None]:
# For individual Visualization
for i in range(X.shape[0]):
    sampleX = X[i:i+1]
    out = predict_arm1(sampleX)
    classes = torch.argmax(out, axis=1).cpu().numpy()
    #true_label = []
    true_label = [y[i]]
    #print(true_label[0].item())
    # Backprop from the correct label
    out[:, true_label[0].item()].backward()
    print(f"{'Predicted class':<{20}} True class")
    print(f"{class_names[classes[0]]:<{20}} {class_names[y[i]]}")

    act, grad = None, None
    mask=0
    heatmap=0
    
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    #print('***'*3+'  Forward and Backward Hooks extraction  '+'***'*3)
    for hook in hookF1:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()
            #print('Foward hooks:')
            #print('output of', hook.name)
            #print(act.shape)
            #print('-'*20)
    
    for hook in hookB1:
        if hook.name == 'layer4.2.conv3':
            #print('Backward hooks:')
            grad = hook.output[0].detach()
            #print('gradient of', hook.name)
            #print(grad.shape)
            #print('-'*20)
    
    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad = torch.mean(grad, dim=[0, 2, 3])
    #print('GAP gradients shape:', pooled_grad.shape)
    
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad[j]
    
    # Sum across all k-features (Eq. 2)
    heatmap = torch.sum(act, dim=1).squeeze()
    
    # Relu (Eq. 2)
    heatmap = np.maximum(heatmap, 0)
    
    # normalize the heatmap between 0 and 1
    heatmap /= torch.max(heatmap)
    #print('Heatmap shape:', heatmap.shape)
    
    # Visualize the raw CAM
    #plt.imshow(heatmap.numpy(), cmap='jet'); plt.axis('off'); plt.tight_layout(); plt.colorbar(); plt.show()
    
    # PLOTTING THE IMAGE, HEATMAP and SUPERIMPOSITION
    
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    #img_jpeg = cv2.cvtColor(img_jpeg, cv2.COLOR_BGR2RGB)
    H, W = img_jpeg.shape[0], img_jpeg.shape[1]
    
    heatmap1 = cv2.resize(np.array(heatmap), (W, H))
    #print(f'Image shape {img_jpeg.shape}. Heatmap shape {heatmap.shape}')
    
    superimpose1 = gen_cam(img_jpeg, heatmap1)  # Generate the Grad-CAM heatmap
    
    fig, axs = plt.subplots(1,3, figsize=(15,5))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_title('Original Image')
    img2 = axs[1].imshow(heatmap1, cmap='jet')
    axs[1].set_title('Heatmap')
    img3 = axs[2].imshow(superimpose1)
    axs[2].set_title('SuperImposed')
    #plt.savefig('resnet50_ex_1.png', dpi=300)
    plt.show()

In [None]:
for i in range(X.shape[0]):
    sampleX = X[i:i+1]
    out = predict_arm2(sampleX)
    classes = torch.argmax(out, axis=1).cpu().numpy()
    #true_label = []
    true_label = [y[i]]
    #print(true_label[0].item())
    # Backprop from the correct label
    out[:, true_label[0].item()].backward()
    print(f"{'Predicted class':<{20}} True class")
    print(f"{class_names[classes[0]]:<{20}} {class_names[y[i]]}")

    act, grad = None, None
    mask=0
    heatmap=0
    
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    #print('***'*3+'  Forward and Backward Hooks extraction  '+'***'*3)
    for hook in hookF2:
        if hook.name == 'layer4.1.conv2':
            act = hook.output.detach()
            #print('Foward hooks:')
            #print('output of', hook.name)
            #print(act.shape)
            #print('-'*20)
    
    for hook in hookB2:
        if hook.name == 'layer4.1.conv2':
            #print('Backward hooks:')
            grad = hook.output[0].detach()
            #print('gradient of', hook.name)
            #print(grad.shape)
            #print('-'*20)
    
    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad = torch.mean(grad, dim=[0, 2, 3])
    #print('GAP gradients shape:', pooled_grad.shape)
    
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad[j]
    
    # Sum across all k-features (Eq. 2)
    heatmap = torch.sum(act, dim=1).squeeze()
    
    # Relu (Eq. 2)
    heatmap = np.maximum(heatmap, 0)
    
    # normalize the heatmap between 0 and 1
    heatmap /= torch.max(heatmap)
    #print('Heatmap shape:', heatmap.shape)
    
    # Visualize the raw CAM
    #plt.imshow(heatmap.numpy(), cmap='jet'); plt.axis('off'); plt.tight_layout(); plt.colorbar(); plt.show()
    
    # PLOTTING THE IMAGE, HEATMAP and SUPERIMPOSITION
    
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    #img_jpeg = cv2.cvtColor(img_jpeg, cv2.COLOR_BGR2RGB)
    H, W = img_jpeg.shape[0], img_jpeg.shape[1]
    
    heatmap2 = cv2.resize(np.array(heatmap), (W, H))
    #print(f'Image shape {img_jpeg.shape}. Heatmap shape {heatmap.shape}')
    
    mask = heatmap2>0.6 # let us see where heatmap is above some threshold
    
    superimposed_img = img_jpeg.copy()
    superimposed_img[:,:,0] = img_jpeg[:,:,0]*mask
    superimposed_img[:,:,1] = img_jpeg[:,:,1]*mask
    superimposed_img[:,:,2] = img_jpeg[:,:,2]*mask
    
    fig, axs = plt.subplots(1,3, figsize=(15,5))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_title('Original Image')
    img2 = axs[1].imshow(heatmap2, cmap='jet')
    axs[1].set_title('Heatmap')
    img3 = axs[2].imshow(superimposed_img)
    axs[2].set_title('SuperImposed')

    plt.savefig('resnet18_ex_1.png', dpi=300)
    plt.show()

In [None]:
for i in range(X.shape[0]):
    sampleX = X[i:i+1]
    out = predict_arm3(sampleX)
    classes = torch.argmax(out, axis=1).cpu().numpy()
    #true_label = []
    true_label = [y[i]]
    #print(true_label[0].item())
    # Backprop from the correct label
    out[:, true_label[0].item()].backward()
    print(f"{'Predicted class':<{20}} True class")
    print(f"{class_names[classes[0]]:<{20}} {class_names[y[i]]}")

    act, grad = None, None
    mask=0
    heatmap=0
    
    # Loop over hook lists hookF, hookB and extract the activation and gradient of the resnet.layer4.1.conv2 layer
    #print('***'*3+'  Forward and Backward Hooks extraction  '+'***'*3)
    for hook in hookF3:
        if hook.name == 'layer4.2.conv3':
            act = hook.output.detach()
            #print('Foward hooks:')
            #print('output of', hook.name)
            #print(act.shape)
            #print('-'*20)
    
    for hook in hookB3:
        if hook.name == 'layer4.2.conv3':
            #print('Backward hooks:')
            grad = hook.output[0].detach()
            #print('gradient of', hook.name)
            #print(grad.shape)
            #print('-'*20)
    
    # Global Average Pool the gradients of every feature map (Eq. 1)
    pooled_grad = torch.mean(grad, dim=[0, 2, 3])
    #print('GAP gradients shape:', pooled_grad.shape)
    
    # Multiply every channel k of the activations by GAP gradients, that are the 'weights'
    for j in range(act.shape[1]):
        act[:, j, :, :] *= pooled_grad[j]
    
    # Sum across all k-features (Eq. 2)
    heatmap = torch.sum(act, dim=1).squeeze()
    
    # Relu (Eq. 2)
    heatmap = np.maximum(heatmap, 0)
    
    # normalize the heatmap between 0 and 1
    heatmap /= torch.max(heatmap)
    #print('Heatmap shape:', heatmap.shape)
    
    # Visualize the raw CAM
    #plt.imshow(heatmap.numpy(), cmap='jet'); plt.axis('off'); plt.tight_layout(); plt.colorbar(); plt.show()
    
    # PLOTTING THE IMAGE, HEATMAP and SUPERIMPOSITION
    
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    #img_jpeg = cv2.cvtColor(img_jpeg, cv2.COLOR_BGR2RGB)
    H, W = img_jpeg.shape[0], img_jpeg.shape[1]
    
    heatmap3 = cv2.resize(np.array(heatmap), (W, H))
    #print(f'Image shape {img_jpeg.shape}. Heatmap shape {heatmap.shape}')
    
    mask = heatmap3>0.6 # let us see where heatmap is above some threshold
    
    superimposed_img = img_jpeg.copy()
    superimposed_img[:,:,0] = img_jpeg[:,:,0]*mask
    superimposed_img[:,:,1] = img_jpeg[:,:,1]*mask
    superimposed_img[:,:,2] = img_jpeg[:,:,2]*mask
    
    fig, axs = plt.subplots(1,3, figsize=(15,5))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_title('Original Image')
    img2 = axs[1].imshow(heatmap3, cmap='jet')
    axs[1].set_title('Heatmap')
    img3 = axs[2].imshow(superimposed_img)
    axs[2].set_title('SuperImposed')

    plt.savefig('resnet101_ex_1.png', dpi=300)
    
    plt.show()

In [None]:
for i in range(X.shape[0]):
    sampleX = X[i:i+1]
    img = sampleX
    out = predict_arm4(sampleX)
    classes = torch.argmax(out, axis=1).cpu().numpy()
    print(f"{'Predicted class':<{20}} True class")
    print(f"{class_names[classes[0]]:<{20}} {class_names[y[i]]}")
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    # Load the Vision Transformer and target layer
    target_layer = model4.blocks[-1].norm1  # Specify the target layer for Grad-CAM
    
    # Initialize Grad-CAM with the model and target layer
    grad_cam = GradCam(model4, target_layer)
    mask = grad_cam(img)  # Compute the Grad-CAM mask
    result = gen_cam(img_jpeg, mask)  # Generate the Grad-CAM heatmap
    
    fig, axs = plt.subplots(1,3, figsize=(15,5))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_title('Original Image')
    img2 = axs[1].imshow(mask, cmap='jet')
    axs[1].set_title('Heatmap')
    img3 = axs[2].imshow(result)
    axs[2].set_title('SuperImposed')
    plt.show()

# Statistical Analysis

In [None]:
# Calculate differences
d1_2= heatmap1 - heatmap2
d1_3= heatmap1 - heatmap3
d2_3= heatmap3 - heatmap2

# Mean Absolute Error
mae1_2 = np.mean(np.abs(d1_2))
mae1_3 = np.mean(np.abs(d1_3))
mae2_3 = np.mean(np.abs(d2_3))

# Root Mean Squared Error
rmse1_2 = np.sqrt(np.mean(d1_2**2))
rmse1_3 = np.sqrt(np.mean(d1_3**2))
rmse2_3 = np.sqrt(np.mean(d2_3**2))

print(f"mae1_2: {mae1_2}")
print(f"mae1_3: {mae1_3}")
print(f"mae2_3: {mae2_3}")

print(f"rmse1_2: {rmse1_2}")
print(f"rmse1_3: {rmse1_3}")
print(f"rmse2_3: {rmse2_3}")

smoothed1 = gaussian_filter(heatmap1, sigma=3)
smoothed2 = gaussian_filter(heatmap2, sigma=3)
smoothed3 = gaussian_filter(heatmap3, sigma=3)
smoothed4 = gaussian_filter(heatmap4, sigma=3)

peaks1 = peak_local_max(smoothed1, min_distance=3)
peaks2 = peak_local_max(smoothed2, min_distance=3)
peaks3 = peak_local_max(smoothed3, min_distance=3)
peaks4 = peak_local_max(smoothed4, min_distance=3)

fig, axs = plt.subplots(1,4,figsize=(20,5))

axs[0].imshow(smoothed1, cmap='hot', interpolation='nearest')
axs[0].scatter(peaks1[:,1],peaks1[:,0],marker='x',color='blue',label='peaks')
axs[0].set_title('RESNET 50 Heatmap')


axs[1].imshow(smoothed2, cmap='hot', interpolation='nearest')
axs[1].scatter(peaks2[:,1],peaks2[:,0],marker='x',color='blue',label='peaks')
axs[1].set_title('RESNET 18 Heatmap')

axs[2].imshow(smoothed3, cmap='hot', interpolation='nearest')
axs[2].scatter(peaks3[:,1],peaks3[:,0],marker='x',color='blue',label='peaks')
axs[2].set_title('RESNET 101 Heatmap')

axs[3].imshow(smoothed4, cmap='hot', interpolation='nearest')
axs[3].scatter(peaks4[:,1],peaks4[:,0],marker='x',color='blue',label='peaks')
axs[3].set_title('DEiT Heatmap')

#plt.savefig('example_1', dpi=300)

plt.show()

# Peaks
num_peaks_RESNET50 = peaks1.shape[0]
num_peaks_RESNET18 = peaks2.shape[0]
num_peaks_RESNET101 = peaks3.shape[0]
num_peaks_DEIT = peaks4.shape[0]

print(f"RESNET50 Peaks: {num_peaks_RESNET50}")
print(f"RESNET18 Peaks: {num_peaks_RESNET18}")
print(f"RESNET101 Peaks: {num_peaks_RESNET101}")
print(f"DEiT Peaks: {num_peaks_DEIT}")

# Flatten the matrices
flat1 = heatmap1.flatten()
flat2 = heatmap2.flatten()
flat3 = heatmap3.flatten()
flat4 = heatmap4.flatten()

# Calculate Pearson correlations
corr_1_2, _ = pearsonr(flat1, flat2)
corr_1_3, _ = pearsonr(flat1, flat3)
corr_1_4, _ = pearsonr(flat1, flat4)
corr_2_3, _ = pearsonr(flat2, flat3)
corr_2_4, _ = pearsonr(flat2, flat4)
corr_3_4, _ = pearsonr(flat3, flat4)

print(f'Pearson correlation between Heatmap 1 and 2: {corr_1_2}')
print(f'Pearson correlation between Heatmap 1 and 3: {corr_1_3}')
print(f'Pearson correlation between Heatmap 1 and 4: {corr_1_4}')
print(f'Pearson correlation between Heatmap 2 and 3: {corr_2_3}')
print(f'Pearson correlation between Heatmap 2 and 4: {corr_2_4}')
print(f'Pearson correlation between Heatmap 3 and 4: {corr_3_4}')

# Stack the datasets into one matrix
data_combined = np.vstack([flat1, flat2, flat3, flat4])

# Perform PCA
pca = PCA(n_components=1)
pca_result = pca.fit_transform(data_combined.T)  # Transpose to have samples in rows
variance_ratio = pca.explained_variance_ratio_
print(variance_ratio)
# Plot the PCA results
# plt.figure(figsize=(8, 6))
# plt.scatter(pca_result[:, 0], pca_result[:, 1], label=['Heatmap 1', 'Heatmap 2', 'Heatmap 3', 'Heatmap 4'])
# plt.title('PCA of Heatmap Data')
# plt.xlabel('Principal Component 1')
# plt.ylabel('Principal Component 2')
# plt.legend()
# plt.grid()
# plt.show()


plt.figure(figsize=(8, 6))
plt.scatter(np.arange(len(pca_result)), pca_result, color='blue', marker='o')
plt.title('PCA Result with 1 Component')
plt.xlabel('Sample Index')
plt.ylabel('Principal Component 1')
plt.axhline(0, color='red', linestyle='--')  # Optional: horizontal line at y=0 for reference
plt.grid()
plt.show()

In [None]:
from vit_model_base import defect_vit


ckpt_path5="output/last.ckpt"
checkpoint5 = defect_vit.load_from_checkpoint(ckpt_path5)
model5 = checkpoint5.model
model5.to(device)
model5.eval()

def predict_arm5(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert from NHWC to NCHW
    img = img.to(device)
    #model.to(device)
    #img = img.unsqueeze(0).to(device)  # Add batch dimension and move to device
    output = model5(img)  # Forward pass through the model
    return output




In [None]:
for i in range(X.shape[0]):
    sampleX = X[i:i+1]
    img = sampleX
    out = predict_arm5(sampleX)
    classes = torch.argmax(out, axis=1).cpu().numpy()
    print(f"{'Predicted class':<{20}} True class")
    print(f"{class_names[classes[0]]:<{20}} {class_names[y[i]]}")
    img_jpeg = inv_transform(sampleX).squeeze().permute(1, 2, 0).numpy()
    # Load the Vision Transformer and target layer
    target_layer = model5.blocks[-1].norm1  # Specify the target layer for Grad-CAM
    
    # Initialize Grad-CAM with the model and target layer
    grad_cam = GradCam(model5, target_layer)
    mask = grad_cam(img)  # Compute the Grad-CAM mask
    result = gen_cam(img_jpeg, mask)  # Generate the Grad-CAM heatmap
    
    fig, axs = plt.subplots(1,3, figsize=(15,5))
    img1 = axs[0].imshow(img_jpeg)
    axs[0].set_title('Original Image')
    img2 = axs[1].imshow(mask, cmap='jet')
    axs[1].set_title('Heatmap')
    img3 = axs[2].imshow(result)
    axs[2].set_title('SuperImposed')
    plt.show()

In [None]:
from scipy.stats import entropy

def calculate_entropy(heatmap):
    # Flatten the heatmap and compute the histogram
    flat_heatmap = heatmap.flatten()
    hist, _ = np.histogram(flat_heatmap, bins=256, density=True)
    
    # Use only the non-zero entries to calculate entropy
    hist = hist[hist > 0]  # Remove zero entries
    return -np.sum(hist * np.log2(hist))

# Example usage with 4 heatmaps
heatmap1_entropy = calculate_entropy(heatmap1)
heatmap2_entropy = calculate_entropy(heatmap2)
heatmap3_entropy = calculate_entropy(heatmap3)
heatmap4_entropy = calculate_entropy(heatmap4)

print(f'Entropy of Heatmap 1: {heatmap1_entropy}')
print(f'Entropy of Heatmap 2: {heatmap2_entropy}')
print(f'Entropy of Heatmap 3: {heatmap3_entropy}')
print(f'Entropy of Heatmap 4: {heatmap4_entropy}')


In [8]:
import time

In [None]:
from scipy.stats import spearmanr