In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: /workspace/iscat


In [2]:
import h5py
import numpy as np
particle_data_path ='dataset/brightfield_particles.hdf5'
with h5py.File(particle_data_path , 'r') as f:
    print(f['data'].shape)
    print(np.unique(f['labels'],return_counts=True))

(41350, 16, 201)
(array([0, 1, 2, 3]), array([32462,  8659,    60,   169]))


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
# from torchvision.models import vit_b_16
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from torchvision.models.vision_transformer import VisionTransformer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, balanced_accuracy_score
import matplotlib.pyplot as plt

In [4]:
from torchvision.transforms import v2    
def compute_normalization_stats(h5_path, classes=None):
    """
    Compute mean and standard deviation for z-score normalization.
    
    Args:
        h5_path (str): Path to HDF5 file
        classes (list, optional): List of classes to include in computation
        
    Returns:
        tuple: (mean, std) computed across all data points
    """
    with h5py.File(h5_path, 'r') as h5_file:
        data = h5_file['data'][:]
        labels = h5_file['labels'][:]
        
        if classes is not None:
            # Filter data for selected classes
            mask = np.isin(labels, classes)
            data = data[mask]
        
        # Compute statistics across all dimensions
        mean = np.mean(data)
        std = np.std(data)
        
        print(f"Computed statistics: mean = {mean:.4f}, std = {std:.4f}")
        
        return mean, std
        
class ParticleDataset(Dataset):
    """Custom Dataset for particle data with flexible class selection and normalization."""
    def __init__(self, h5_path, classes=[0, 1], transform=None, mean=None, std=None,padding=False,indices=None):
        self.h5_file = h5py.File(h5_path, 'r')
        data = self.h5_file['data'][:]
        labels = self.h5_file['labels'][:]
        self.padding = padding
        # Filter data for selected classes
        mask = np.isin(labels, classes)
        if indices is None:
            self.data = data[mask][:]
            self.labels = labels[mask][:] 
        else:
            self.data = data[mask][indices]
            self.labels = labels[mask][indices]
        
        # Create class mapping to handle non-consecutive class indices
        self.class_to_idx = {c: i for i, c in enumerate(classes)}
        self.num_classes = len(classes)
        
        # Map original labels to new consecutive indices
        self.labels = np.array([self.class_to_idx[label] for label in self.labels])
        self.transform = transform
        self.mean = mean
        self.std = std
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Get particle data
        particle = self.data[idx]  # Shape: (16, 201)
        
        # Apply normalization if mean and std are provided
        if self.mean is not None and self.std is not None:
            particle = (particle - self.mean) / self.std
        
        # Convert to torch tensor for better interpolation
        particle_tensor = torch.FloatTensor(particle).unsqueeze(0)  # Add channel dim
        
        # Resize to (16, 16) using bicubic interpolation
        resized = torch.nn.functional.interpolate(
            particle_tensor.unsqueeze(0),  # Add batch dim
            size=(16, 201),
            mode='bicubic',
            align_corners=True
        ).squeeze(0).squeeze(0)  # Remove batch and channel dims
        
        final_tensor = resized.unsqueeze(0).repeat(3, 1, 1)  # Repeat across 3 channels
        
        if self.transform:
            final_tensor = self.transform(final_tensor)
        
        # Create one-hot encoded label
        label_idx = self.labels[idx]
        # label_onehot = torch.zeros(self.num_classes)
        # label_onehot[label_idx] = 1
        
        # return final_tensor, label_onehot
        return final_tensor , label_idx
    def close(self):
        self.h5_file.close()     

In [5]:
import torch 

def distance_matrix(a, b):
    a_expanded = a.view(-1, 1)
    b_expanded = b.view(1, -1)

    return torch.abs(a_expanded - b_expanded)

def knn_divergence(points_x, points_y, k, smoothing_kernel=None):
    xx_distances = distance_matrix(points_x, points_x)
    xy_distances = distance_matrix(points_x, points_y) # one row for every sample in x, one col for every sample in y

    # if the sets have different sizes
    # e.g. y has twice as many points -> the distance to the 3rd closest point in x should be the same as the distance to the 6th point in y
    k_multiplier = points_y.shape[0] / points_x.shape[0]

    k_dist_xx = torch.sort(xx_distances, dim=1)[0][:, k]
    k_dist_xy = torch.sort(xy_distances, dim=1)[0][:, (k * k_multiplier).to(torch.int)]
    # optional: smoothen the distances 
    # (so that it matters less whether a point is the i-th or the (i+1)-th closest neighbor)
    if smoothing_kernel != None:
            # torch conv1d demands a channel dimension, hence the (un)squeezing
            k_dist_xx = torch.nn.functional.conv1d(k_dist_xx.unsqueeze(1), weight=smoothing_kernel.view(1, 1, -1)).flatten(1)
            k_dist_xy = torch.nn.functional.conv1d(k_dist_xy.unsqueeze(1), weight=smoothing_kernel.view(1, 1, -1)).flatten(1)

    # return torch.mean((1 - k_dist_xx / k_dist_xy)**2)
    return torch.mean((k_dist_xx - k_dist_xy)**2)

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

def generate_label_distribution(num_points=10000, mean=76, std=22.5, min_value=10, max_value=None):
    """
    Generate a tensor of points sampled from a normal distribution with specified mean and standard deviation
    while rejecting points outside the optional min and max value constraints.
    
    Args:
        num_points (int): Number of points to generate
        mean (float): Mean of the distribution
        std (float): Standard deviation of the distribution
        min_value (float, optional): Minimum value of the distribution (inclusive)
        max_value (float, optional): Maximum value of the distribution (inclusive)
    
    Returns:
        torch.Tensor: Tensor of generated points within the specified range
    """
    points = torch.empty(0)  # Initialize an empty tensor to store valid points

    while points.numel() < num_points:
        # Generate points from normal distribution
        generated_points = torch.normal(mean=mean, std=std, size=(num_points,))
        
        # Filter points based on the min and max values
        if min_value is not None:
            generated_points = generated_points[generated_points >= min_value]
        if max_value is not None:
            generated_points = generated_points[generated_points <= max_value]
        
        # Add the valid points to the tensor
        points = torch.cat((points, generated_points))
    # Return only the first `num_points` points
    return points[:num_points]

In [7]:
from torchsummary import summary
# resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
# # resnet .fc = nn.Linear(resnet.fc.in_features, 1)
# resnet .fc = nn.Sequential(
#     nn.Linear(resnet.fc.in_features, 32),
#     nn.ReLU(),
#     nn.Linear(32, 1))
# summary(resnet, input_size=(3, 16, 201),device='cpu')

In [8]:
from torch.utils.data import DataLoader
torch.manual_seed(42)
DEVICE = "cuda:11"
# Device configuration
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# classes = [0]
transform = v2.Compose([
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
])

mean, std = compute_normalization_stats('dataset/brightfield_particles.hdf5', classes=[0,1])
batch_size_80 = 10000
num_points_80 = 30000
batch_size_300 = 8192
num_points_300 = 20000

dataset_80 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                            classes=[0],
                            mean=mean,
                            std=std,
                            padding=True,
                            transform = transform,
                            indices=list(range(0, 30000)),
                         )
dataset_300 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                            classes=[1],
                            mean=mean,
                            std=std,
                            padding=True,
                            transform = transform,
                            indices=list(range(0, 8192)),
                         )
# batch_size = 10000
# num_points = 30000
dataloader_80 = DataLoader(dataset_80, batch_size=batch_size_80 , shuffle=True)
label_points_80 = generate_label_distribution(num_points_80, mean=76, std=22.5)

dataloader_300 = DataLoader(dataset_300, batch_size=batch_size_300 , shuffle=True)
label_points_300 = generate_label_distribution(num_points_300, mean=302, std=25)

Using device: cuda:11
Computed statistics: mean = 7537.6143, std = 1312.3260


In [9]:
def train_resnet(model, dataloaders, label_points, device, num_epochs=10, learning_rate=0.1,class_loss=True): #3e-2
    """
    Train ResNet model using KNN divergence loss with early stopping and learning rate scheduling.
    
    Args:
        model (torch.nn.Module): ResNet model
        dataloader (torch.utils.data.DataLoader): Training dataloader
        label_points (torch.Tensor): Pre-generated label points
        device (torch.device): Device to train on
        num_epochs (int): Number of training epochs
        learning_rate (float): Learning rate for optimizer
    
    Returns:
        model: Trained model
        best_loss: Best training loss achieved
    """
    if class_loss:
        loss_weights = len(dataloaders[0].dataset)/len(dataloaders[1].dataset)
        loss_weights = torch.Tensor([loss_weights]).to(device)
        bceloss = torch.nn.BCEWithLogitsLoss(pos_weight=loss_weights)
    # Move label points to the specified device
    labels = [label.to(device, non_blocking=True) for label in label_points] 
    # labels_80, labels_300 = labels[0],labels[1] 
    # Prepare k values for KNN divergence
    ks = [torch.arange(2, 1000//10, dtype=torch.int) for label_point in label_points]    
    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Setup learning rate scheduler with patience of 8
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=8, factor=0.5
    )
    
    model.to(device)
   
    # Early stopping parameters
    best_loss = float('inf')
    patience = 20
    patience_counter = 0
    best_model_state = None
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        batch_count = 0
        accuracy = 0
        for idx, (label, dataloader) in enumerate(zip(labels, dataloaders)):
                for batch_images, _ in dataloader:
                    batch_count += 1
                    gt = torch.clone(label)
                    batch_images = batch_images.to(device)
                    
                    # Zero gradients
                    optimizer.zero_grad()
                    
                    # Forward pass: generate predictions
                    if class_loss:
                        batch_predictions, pred_class= model(batch_images)
                        gt_class = torch.full(pred_class.shape, idx, dtype=torch.float, device=device)
                        loss = knn_divergence(batch_predictions, gt, ks[idx])
                        loss += bceloss(pred_class,pred_class)  
                        logits = torch.sigmoid(pred_class)
                        accuracy += (gt_class == (logits > 0.5).float()).sum().item()

                    else:
                        batch_predictions= model(batch_images)
                        loss = knn_divergence(batch_predictions, gt, ks[idx])

                    # Backward pass and optimize
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
        
        # Calculate average loss for the epoch
        avg_loss = total_loss / batch_count
        current_lr = optimizer.param_groups[0]['lr']
        if class_loss :
            total_samples = sum(len(d.dataset) for d in dataloaders)
            accuracy = accuracy / total_samples
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, LR: {current_lr:.2e}, Accuracy {accuracy:.2e}')
        else:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, LR: {current_lr:.2e}')
        # Learning rate scheduling
        scheduler.step(avg_loss)
        
        # Early stopping check
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch + 1} epochs')
                # Restore best model
                model.load_state_dict(best_model_state)
                break
    
    return model, best_loss

In [10]:
import torch
import torch.nn as nn

class ResNetDualHead(nn.Module):
    def __init__(self, class_head=False):
        super(ResNetDualHead, self).__init__()
        
        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=None)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        # Regression head
        self.regression_head = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.class_head = class_head
        if class_head:
            # Classification head (binary classification)
            self.classification_head = nn.Sequential(
                nn.Linear(in_features, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 1)  # Binary classification (logits)
            )
        
    def forward(self, x):
        features = self.backbone.forward(x)
        regression_output = self.regression_head(features)
        
        if self.class_head:
            classification_output = self.classification_head(features)
            return regression_output, classification_output
        
        return regression_output

In [11]:
# resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=None)
# resnet .fc = nn.Sequential(
#     nn.Linear(resnet.fc.in_features, 128),
#     nn.ReLU(),
#     nn.Linear(128, 64),
#     nn.ReLU(),
#     nn.Linear(64, 1)
# )
resnet =  ResNetDualHead(class_head=True)
dataloaders = (dataloader_80,dataloader_300)
label_points = (label_points_80,label_points_300)
# resnet .fc = nn.Linear(resnet.fc.in_features, 1)
resnet,best_loss = train_resnet(resnet, dataloaders, label_points, device, num_epochs=300)
resnet.eval()

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


TypeError: BCELoss.__init__() got an unexpected keyword argument 'pos_weight'

In [None]:
resnet.eval()
resnet(dataset_80[0][0].unsqueeze(0).to(device))

In [None]:
plot_dataset_80 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=[0],
                          mean=mean,
                          std=std,
                          padding=True,
                        transform = None,
                        indices = None
                         )
plot_dataloader_80 = DataLoader(plot_dataset_80  , batch_size=len(plot_dataset_80 ))
with torch.no_grad():
    out= next(iter(dataloader_80))[0]
    out  = resnet(out.to(device)).cpu().detach().numpy()

In [None]:
out.mean()

In [None]:
# Plot histograms
plt.figure(figsize=(8, 5))
plt.hist(label_points_80 , bins=50, alpha=0.6, label='Ground Truth', color='blue', density=True)
plt.hist(out, bins=50, alpha=0.6, label='Prediction', color='red', density=True)

# Labels and legend
plt.xlabel('Diamerter[nm]')
plt.ylabel('Density[norm.]')
plt.title('Ground Truth vs Prediction Distribution_80nm')
plt.legend()
plt.grid(True)
# Show plot
plt.show()


In [None]:
plot_dataset_300 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=[1],
                          mean=mean,
                          std=std,
                          padding=True,
                        transform = None,
                        indices = None
                         )
plot_dataloader_300 = DataLoader(plot_dataset_300 , batch_size=len(plot_dataset_300))
with torch.no_grad():
    out_2 = next(iter(plot_dataloader_300))[0]
    out_2  = resnet(out_2.to(device)).cpu().detach().numpy()
label_points_300 = generate_label_distribution(len(plot_dataset_300), mean=302, std=25)
print(out_2.mean())

In [None]:
import matplotlib.pyplot as plt
# Example list of values
# values =out_2+(302-out_2.mean())
values =out_2
# values = o
# Set up the plot
plt.figure(figsize=(8, 6))

# Plot the histogram
plt.hist(values, bins=50, color='red',label='Prediction',alpha=0.7, density=True)
plt.hist(label_points_300  , bins=50, alpha=0.6, label='Ground Truth', color='blue', density=True)
# Add labels
# Labels and legend
plt.xlabel('Diamerter[nm]')
plt.ylabel('Density[norm.]')
plt.title('Ground Truth vs Prediction Distribution_300nm')
plt.legend()
plt.grid(True)

# Show plot
plt.show()

In [None]:
plot_dataset_300 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=[3],
                          mean=mean,
                          std=std,
                          padding=True,
                        transform = None,
                        indices=None
                         )
plot_dataloader_300 = DataLoader(plot_dataset_1300, batch_size=len(plot_dataset_300))
with torch.no_grad():
    out_3 = next(iter(plot_dataloader_300))[0]
    out_3  = resnet(out_3.to(device)).cpu().detach().numpy()
import matplotlib.pyplot as plt

# Example list of values
values = out_3
# values = o
# Set up the plot
plt.figure(figsize=(8, 6))

# Plot the histogram
plt.hist(values, bins=50, density=True, color='blue', alpha=0.7)

# Add labels
plt.title('Distribution of Values_1300nm')
plt.xlabel('Value')
plt.ylabel('Density')

# Show plot
plt.show()

In [None]:
plot_dataset_80 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=[0],
                          mean=mean,
                          std=std,
                          padding=True,
                            transform = None,indices=list(range(0, 30000))
                         )
plot_dataloader_80 = DataLoader(plot_dataset_80, batch_size=30000)
with torch.no_grad():
    out_4 = next(iter(plot_dataloader_80))[0]
    out_4  = resnet(out_4.to(device)).cpu().detach().numpy()
import matplotlib.pyplot as plt
# Example list of values
values = out_4
# values = o
# Set up the plot
plt.figure(figsize=(8, 6))

# Plot the histogram
plt.hist(label_points_80 , bins=50, alpha=0.6, label='Ground Truth', color='blue', density=True)
plt.hist(values, bins=50, density=True, color='red', alpha=0.6)

# Add labels
plt.title('Distribution of Values_80nm')
plt.xlabel('Value')
plt.ylabel('Density')

# Show plot
plt.show()

In [None]:
values.std()

In [None]:
with torch.no_grad():
    output_map = resnet.conv1(dataset_80[0][0].unsqueeze(0).to(device)).squeeze(0).cpu()

In [None]:
fig, ax = plt.subplots(figsize=(18, 10))
f=torch.clone(output_map[3])
# f[f<(f.mean())]=0
ax.imshow(f,cmap='gray')
plt.show()

In [None]:
with torch.no_grad():
    imgs = next(iter(plot_dataloader_80))[0]  # (10000,3,16,201)
    img = imgs[2000].unsqueeze(0)  # (1,3,16,201)
    
    # Get prediction for original image
    size_1 = resnet(img.to(device)).cpu()
    size_1 = size_1.squeeze(0)
    
    # Flip the image horizontally (along the last dimension)
    img_flipped = torch.flip(img, dims=[-1])
    
    # Get prediction for flipped image
    size_2 = resnet(img_flipped.to(device)).cpu()
    size_2 = size_2.squeeze(0)
    
    # Print both predictions
    print(f"Original image size prediction: {size_1.item():.3f}")
    print(f"Flipped image size prediction: {size_2.item():.3f}")
    print(f"Absolute difference: {abs(size_1.item() - size_2.item()):.3f}")

In [None]:
import torch
import matplotlib.pyplot as plt

plot_dataset_80 = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=[0],
                          mean=mean,
                          std=std,
                          padding=True,
                            transform = None
                         )
plot_dataloader = DataLoader(plot_dataset_80, batch_size=batch_size)
with torch.no_grad():
    imgs = next(iter(plot_dataloader))[0]  # (10000,3,16,201)
    sizes = resnet(imgs.to(device)).cpu() 
    
max_size, max_idx = sizes.max(dim=0)
min_size, min_idx = sizes.min(dim=0)

# Compute middle size (median)
mid_size = sizes.median()
mid_idx = (sizes - mid_size).abs().argmin()
mid_idx = torch.tensor([mid_idx], dtype=torch.int64)

# Create 9 intermediate values between min, mid, and max
intermediate_sizes, intermediate_indices = [], []
for fraction in torch.linspace(0, 1, steps=9):
    interp_size = min_size + fraction * (max_size - min_size)
    closest_idx = (sizes - interp_size).abs().argmin()
    intermediate_sizes.append(sizes[closest_idx])
    intermediate_indices.append(closest_idx)

# Convert indices to tensor
intermediate_indices = torch.tensor(intermediate_indices, dtype=torch.int64)

# Get images and resize them
resized_images = [
    torch.nn.functional.interpolate(
        imgs[idx][0:1].unsqueeze(0), size=(16, 32), mode="bilinear", align_corners=False
    ).squeeze(0)[0]
    for idx in intermediate_indices
]

# Create 3x3 subplot
fig, axes = plt.subplots(3, 3, figsize=(12, 9))

# Plot images
for ax, img, size in zip(axes.flat, resized_images, intermediate_sizes):
    ax.imshow(img, cmap="gray")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.title.set_text(f"size: {size.item()}")

# Adjust layout and show
plt.tight_layout()
plt.show()

In [None]:
import sys
import contextlib
import torch.nn.functional as F
model = resnet
def remove_all_hooks(model):
    """Removes all forward and backward hooks from the model."""
    for module in model.modules():  
        if hasattr(module, "_forward_hooks"):
            module._forward_hooks.clear()
        if hasattr(module, "_backward_hooks"):
            module._backward_hooks.clear()
        if hasattr(module, "_full_backward_hooks"):
            module._full_backward_hooks.clear()

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.feature_maps = None
        self.gradient = None
        
        self.target_layer.register_forward_hook(self._save_feature_maps)
        self.target_layer.register_full_backward_hook(self._save_gradient)
    
    def _save_feature_maps(self, module, input, output):
        self.feature_maps = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        self.gradient = grad_output[0].detach()
    
    def generate_heatmap(self, input_image, target_index=None):
        self.model.eval()
        
        # Get model prediction
        output = self.model(input_image)
        
        if target_index is None:
            target = output
        else:
            target = output[:, target_index]

        self.model.zero_grad()
        target.backward()
        
        # Calculate weights
        weights = torch.mean(self.gradient, dim=(2, 3), keepdim=True)
        
        # Generate weighted combination of feature maps
        cam = torch.sum(weights * self.feature_maps, dim=1, keepdim=True)
        
        # Apply ReLU
        cam = F.relu(cam)
        # Interpolate to match the width of the input image
        # Note: We're only interpolating the width (201) since height is different
        cam = F.interpolate(
            cam,
            size=(input_image.shape[2], input_image.shape[3]),  # (16, 201)
            mode='bicubic',
            align_corners=False
        )
        
        heatmap = cam.cpu().numpy()[0, 0]
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        
        return heatmap
    def remove_hooks(self):
        """Remove all registered hooks from the target layer."""
        if hasattr(self.target_layer, "_forward_hooks"):
            self.target_layer._forward_hooks.clear()
        if hasattr(self.target_layer, "_backward_hooks"):
            self.target_layer._backward_hooks.clear()
        if hasattr(self.target_layer, "_full_backward_hooks"):
            self.target_layer._full_backward_hooks.clear()

def apply_heatmap(image: np.ndarray, heatmap: np.ndarray, alpha: float = 0.5) -> np.ndarray:
    """
    Apply heatmap overlay to original image.
    
    Args:
        image: Original image (C, H, W) normalized
        heatmap: Heatmap array (H, W)
        alpha: Transparency factor for overlay
    """
    # Denormalize image to [0,1] range
    image = image.copy()
    for c in range(image.shape[0]):
        channel = image[c]
        channel_min, channel_max = channel.min(), channel.max()
        image[c] = (channel - channel_min) / (channel_max - channel_min)
    
    # Convert to (H, W, C)
    image = np.transpose(image, (1, 2, 0))
    
    # Convert heatmap to RGB (H, W, 3)
    heatmap_rgb = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
    
    # Create overlay
    overlay = image * (1 - alpha) + heatmap_rgb * alpha
    
    return overlay
remove_all_hooks(model)
model = resnet.to(device)
# target_layer =  model.layer4[1].conv2
target_layer =  model.layer1[1].conv2
images = [plot_dataset_80[idx][0] for idx in intermediate_indices]
def generate_RAM(model,image,traget_layer):    
    model.eval()
    grad_cam = GradCAM(model, target_layer) 
    with torch.no_grad():
        image_tensor = image.to(device).unsqueeze(0)  # Add batch dimension
    heatmap = grad_cam.generate_heatmap(image_tensor)
    grad_cam.remove_hooks()
    return heatmap
heatmaps = [generate_RAM(model,image,target_layer) for image in images]

In [None]:
def plot_heatmap(image_tensor,heatmap):
    # Visualize
    plt.figure(figsize=(15, 4))
    # Original image - denormalize for visualization
    img = image_tensor.cpu().numpy()
    img_norm = np.zeros_like(img)
    for c in range(img.shape[0]):
        channel = img[c]
        channel_min, channel_max = channel.min(), channel.max()
        img_norm[c] = (channel - channel_min) / (channel_max - channel_min)
    plt.subplot(3, 1, 1)
    plt.imshow(np.transpose(img_norm, (1,2,0)))
    plt.title('Original Image')
    plt.axis('off')
    plt.subplot(3, 1, 2)
    plt.imshow(heatmap, cmap='jet')
    plt.title('Heatmap')
    plt.axis('off')
    plt.subplot(3, 1, 3)
    overlay = apply_heatmap(image_tensor.cpu().numpy(), heatmap)
    plt.imshow(overlay)
    plt.title('Overlay')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
plot_heatmap(images[0],heatmaps[0])

In [None]:
plot_heatmap(images[2],heatmaps[2])

In [None]:
plot_heatmap(images[3],heatmaps[3])

In [None]:
plot_heatmap(images[-2],heatmaps[-2])

In [None]:
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import numpy as np

def test_contrast_sensitivity(model, image, device="cpu"):
    """
    Applies different contrast values to an image, runs the model, and plots size prediction vs contrast.
    
    Args:
        model: The trained model for size prediction.
        image: A PIL Image or a torch tensor (C, H, W) in range [0, 1].
        device: The device where the model is running ('cpu' or 'cuda').
    """
    contrast_values = np.linspace(0, 1, 50)  # Contrast from 0 (gray) to 2 (high contrast)
    predictions = []
    
    model = model.to(device)
    model.eval()
    
    image = image.to(device) if isinstance(image, torch.Tensor) else TF.to_tensor(image).to(device)
    
    with torch.no_grad():
        for contrast in contrast_values:
            adjusted_image = image*contrast 
            adjusted_image = adjusted_image.unsqueeze(0)  # Add batch dimension
            pred = model(adjusted_image).item()
            predictions.append(pred)
    
    # Plot
    plt.figure(figsize=(8, 5))
    plt.plot(contrast_values, predictions, marker='o', linestyle='-')
    plt.xlabel("Contrast Factor")
    plt.ylabel("Predicted Size")
    plt.title("Effect of global Contrast change on Size Prediction")
    plt.grid()
    plt.show()   
    return contrast_values, predictions
    
contrast_values,predictions = test_contrast_sensitivity(model, images[2], device=device)