In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [17]:
import sys
# sys.path.append('E:/Code/Spiking-Visual-attention-for-Medical-image-segmentation/models/TCSA_SNN')

import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import SemanticSegmentationTarget

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

from simple_utils import (
    save_checkpoint,
    load_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs
    )

import resnet_2_copy

from datetime import datetime
import glob
import random
# from fvcore.nn import FlopCountAnalysis
# from thop import profile
# from ptflops import get_model_complexity_info

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet_2_copy.resnet34().to(device)
summary(model, input_size=(1, 256, 256), batch_size=32)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        Snn_Conv2d-1      [32, 2, 64, 128, 128]           3,136
      BatchNorm3d1-2      [32, 64, 4, 128, 128]             128
     batch_norm_2d-3      [32, 2, 64, 128, 128]               0
        mem_update-4      [32, 2, 64, 128, 128]               0
        Snn_Conv2d-5        [32, 2, 64, 64, 64]          36,864
      BatchNorm3d1-6        [32, 64, 4, 64, 64]             128
     batch_norm_2d-7        [32, 2, 64, 64, 64]               0
        mem_update-8        [32, 2, 64, 64, 64]               0
        Snn_Conv2d-9        [32, 2, 64, 64, 64]          36,864
     BatchNorm3d2-10        [32, 64, 4, 64, 64]             128
   batch_norm_2d1-11        [32, 2, 64, 64, 64]               0
AdaptiveAvgPool3d-12           [32, 4, 1, 1, 1]               0
           Conv3d-13           [32, 1, 1, 1, 1]               4
             ReLU-14           [32, 1, 

In [7]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation.zip"
extract_path = "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted"

# # Extract ZIP
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(extract_path)

# List extracted files
os.listdir(extract_path)

['Modified_3_Brain_Tumor_Segmentation']

In [8]:
IMG_DIR      =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/images"
MASK_DIR     =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/masks"
VAL_IMG_DIR  =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/val_images"
VAL_MASK_DIR =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/val_masks"

In [22]:
Device = "cuda" if torch.cuda.is_available() else "cpu"
Learning_rate = 1e-3
Batch_size  = 32
num_epochs  = 60
num_workers = 4
IMAGE_HEIGHT = 256 # 256 -> 512 originally
IMAGE_WIDTH  = 256 # 256 -> 512 originally
PIN_MEMORY = True
LOAD_MODEL = False
CHECKPOINT_NAME = None

In [10]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    running_loss=0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=Device)
        targets = targets.float().unsqueeze(1).to(device=Device)

        with torch.amp.autocast(device_type=Device):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        running_loss+=loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

    nasar = model.calculate_nasar()

    return running_loss/len(loader), nasar

In [23]:
def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ]
    )
    val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ]
    )
    model = resnet_2_copy.resnet34().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=Learning_rate)
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    # scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    train_loader, val_loader = get_loaders(
        IMG_DIR,
        MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        Batch_size,
        train_transform,
        val_transform,
        num_workers,
        PIN_MEMORY,
    )

    train_losses=[]
    val_dice_scores=[]
    val_accs=[]
    train_nasar=[]

    if LOAD_MODEL:
        load_checkpoint(model=model, optimizer=optimizer, checkpoint_name=CHECKPOINT_NAME)
        val_acc_loaded, val_dice_loaded = check_accuracy(val_loader, model, device=Device)

    scaler = torch.amp.GradScaler()

    for epoch in range(num_epochs):
        train_loss, nasar = train_fn(train_loader, model, optimizer, loss_fn, scaler)

        val_acc, val_dice = check_accuracy(val_loader, model, device=Device)

        train_losses.append(train_loss)
        val_accs.append(val_acc.cpu().item())
        val_dice_scores.append(val_dice.cpu().item())
        train_nasar.append(nasar)

        scheduler.step(val_dice)
        print(f"Epoch {epoch}: LR = {optimizer.param_groups[0]['lr']}")

        checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        }

        checkpoint_filename = save_checkpoint(checkpoint, is_best=len(val_dice_scores) == 1 or val_dice_scores[-1] > max(val_dice_scores[:-1]))

        if epoch != num_epochs - 1:
            save_predictions_as_imgs(
                val_loader,
                model,
                checkpoint_filename=checkpoint_filename,
                train_losses=train_losses,
                val_accs=val_accs,
                val_dice_scores=val_dice_scores,
                train_nasar=train_nasar,
                folder="Att_Res_SNN_saved_images/",
                device=Device,
                show_last_epoch=False,
            )
        else:
            save_predictions_as_imgs(
                val_loader,
                model,
                checkpoint_filename=checkpoint_filename,
                train_losses=train_losses,
                val_accs=val_accs,
                val_dice_scores=val_dice_scores,
                train_nasar=train_nasar,
                folder="Att_Res_SNN_saved_images/",
                device=Device,
                show_last_epoch=True,
            )

In [None]:
if __name__ == "__main__":
    main()

100%|██████████| 82/82 [03:02<00:00,  2.23s/it, loss=0.0646]


NASAR: 0.2411674803971125
Got 28818978/29360128 with acc  98.16
Dice score: 0.0
Epoch 0: LR = 0.001
✅ Checkpoint saved: Att_Res_SNN_checkpoint_7_2025-04-26_07-54-35.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0498]

NASAR: 0.2377642421095584





Got 28820961/29360128 with acc  98.16
Dice score: 0.010084839537739754
Epoch 1: LR = 0.0009999974905567724
✅ Checkpoint saved: Att_Res_SNN_checkpoint_8_2025-04-26_07-58-01.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0386]

NASAR: 0.226367046015923





Got 28853132/29360128 with acc  98.27
Dice score: 0.1394859403371811
Epoch 2: LR = 0.0009995200111414387
✅ Checkpoint saved: Att_Res_SNN_checkpoint_9_2025-04-26_08-01-27.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0399]

NASAR: 0.23205121358235678





Got 28891523/29360128 with acc  98.40
Dice score: 0.42168572545051575
Epoch 3: LR = 0.0009956189076349342
✅ Checkpoint saved: Att_Res_SNN_checkpoint_10_2025-04-26_08-04-53.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.04]

NASAR: 0.22824993939466878





Got 28965722/29360128 with acc  98.66
Dice score: 0.5357621908187866
Epoch 4: LR = 0.0009929342479555985
✅ Checkpoint saved: Att_Res_SNN_checkpoint_11_2025-04-26_08-08-18.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0352]


NASAR: 0.21532982839664944
Got 28931496/29360128 with acc  98.54
Dice score: 0.4165295660495758
Epoch 5: LR = 0.0009957252414749581
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0404]

NASAR: 0.21561234881620453





Got 28919844/29360128 with acc  98.50
Dice score: 0.44129478931427
Epoch 6: LR = 0.000995202646673508
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0367]

NASAR: 0.20712736067077922





Got 28971806/29360128 with acc  98.68
Dice score: 0.48480498790740967
Epoch 7: LR = 0.0009942119231972793
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0207]

NASAR: 0.2043639617346822





Got 28906082/29360128 with acc  98.45
Dice score: 0.3464163839817047
Epoch 8: LR = 0.0009970419337008163
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0288]

NASAR: 0.20855484546070369





Got 28993627/29360128 with acc  98.75
Dice score: 0.5519629716873169
Epoch 9: LR = 0.0009925015554820857
✅ Checkpoint saved: Att_Res_SNN_checkpoint_12_2025-04-26_08-25-25.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0323]

NASAR: 0.19480467066518578





Got 28947204/29360128 with acc  98.59
Dice score: 0.40866032242774963
Epoch 10: LR = 0.0009958850156043097
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0242]


NASAR: 0.18872458050508453
Got 29025991/29360128 with acc  98.86
Dice score: 0.6195486187934875
Epoch 11: LR = 0.0009905589767920447
✅ Checkpoint saved: Att_Res_SNN_checkpoint_13_2025-04-26_08-32-18.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0263]

NASAR: 0.18667477173424663





Got 28941719/29360128 with acc  98.57
Dice score: 0.5544072985649109
Epoch 12: LR = 0.0009924351634374167
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0247]

NASAR: 0.18110333474029397





Got 29021036/29360128 with acc  98.85
Dice score: 0.6567437052726746
Epoch 13: LR = 0.0009893954944816372
✅ Checkpoint saved: Att_Res_SNN_checkpoint_14_2025-04-26_08-39-09.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0311]

NASAR: 0.18566094877574366





Got 28969316/29360128 with acc  98.67
Dice score: 0.5065150260925293
Epoch 14: LR = 0.0009936830446198566
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0326]

NASAR: 0.1810575135996644





Got 28956869/29360128 with acc  98.63
Dice score: 0.46589845418930054
Epoch 15: LR = 0.0009946537791253372
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.021]

NASAR: 0.17138781122198687





Got 29009338/29360128 with acc  98.81
Dice score: 0.5731886029243469
Epoch 16: LR = 0.0009919153538282181
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0215]

NASAR: 0.17634305819659166





Got 28892966/29360128 with acc  98.41
Dice score: 0.24139411747455597
Epoch 17: LR = 0.000998562906481722
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.019]

NASAR: 0.16845177484789925





Got 29050320/29360128 with acc  98.94
Dice score: 0.6719690561294556
Epoch 18: LR = 0.0009888999531026843
✅ Checkpoint saved: Att_Res_SNN_checkpoint_15_2025-04-26_08-56-18.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0246]

NASAR: 0.1661270392332838





Got 28869788/29360128 with acc  98.33
Dice score: 0.5801485776901245
Epoch 19: LR = 0.0009917183713520198
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0213]

NASAR: 0.16741540398396237





Got 28901679/29360128 with acc  98.44
Dice score: 0.35947301983833313
Epoch 20: LR = 0.0009968149900018955
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0185]

NASAR: 0.16299460863283544





Got 29059594/29360128 with acc  98.98
Dice score: 0.7204538583755493
Epoch 21: LR = 0.0009872474411706366
✅ Checkpoint saved: Att_Res_SNN_checkpoint_16_2025-04-26_09-06-35.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.023]

NASAR: 0.1562346964374954





Got 29057194/29360128 with acc  98.97
Dice score: 0.7253310084342957
Epoch 22: LR = 0.00098707494934734
✅ Checkpoint saved: Att_Res_SNN_checkpoint_17_2025-04-26_09-10-00.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0217]

NASAR: 0.15112772122235366





Got 29078136/29360128 with acc  99.04
Dice score: 0.7354661822319031
Epoch 23: LR = 0.0009867128383427143
✅ Checkpoint saved: Att_Res_SNN_checkpoint_18_2025-04-26_09-13-27.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0189]

NASAR: 0.1509432725503411





Got 29033519/29360128 with acc  98.89
Dice score: 0.6945561766624451
Epoch 24: LR = 0.0009881442038521965
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0262]

NASAR: 0.14906045080910266





Got 29055953/29360128 with acc  98.96
Dice score: 0.6691561937332153
Epoch 25: LR = 0.0009889923445223867
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0192]

NASAR: 0.1422230536948907





Got 29064847/29360128 with acc  98.99
Dice score: 0.6915758848190308
Epoch 26: LR = 0.0009882453288619995
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0148]

NASAR: 0.14514716242400694





Got 29067157/29360128 with acc  99.00
Dice score: 0.6770237684249878
Epoch 27: LR = 0.0009887329627215901
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0125]

NASAR: 0.15225164207494316





Got 29027424/29360128 with acc  98.87
Dice score: 0.6380624175071716
Epoch 28: LR = 0.0009899882164741338
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0196]

NASAR: 0.15033763115394844





Got 29077111/29360128 with acc  99.04
Dice score: 0.7034572958946228
Epoch 29: LR = 0.0009878396238515936
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0122]

NASAR: 0.14582393986518394





Got 29058638/29360128 with acc  98.97
Dice score: 0.6836495995521545
Epoch 30: LR = 0.0009885121994491993
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0192]

NASAR: 0.1465805841723518





Got 29052262/29360128 with acc  98.95
Dice score: 0.6504748463630676
Epoch 31: LR = 0.000989596272021652
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0138]

NASAR: 0.146952060466641





Got 29047719/29360128 with acc  98.94
Dice score: 0.7116969227790833
Epoch 32: LR = 0.0009875542817244746
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.0165]

NASAR: 0.14871896376632188





Got 29065674/29360128 with acc  99.00
Dice score: 0.7365753054618835
Epoch 33: LR = 0.0009866729108614408
✅ Checkpoint saved: Att_Res_SNN_checkpoint_19_2025-04-26_09-47-43.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0186]

NASAR: 0.14114510621263388





Got 29099935/29360128 with acc  99.11
Dice score: 0.7505325675010681
Epoch 34: LR = 0.0009861654219880652
✅ Checkpoint saved: Att_Res_SNN_checkpoint_20_2025-04-26_09-51-08.pth.tar


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0132]

NASAR: 0.12991536279239566





Got 29043665/29360128 with acc  98.92
Dice score: 0.644734799861908
Epoch 35: LR = 0.0009897784516366204
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:46<00:00,  2.03s/it, loss=0.014]

NASAR: 0.1329233232238483





Got 29057372/29360128 with acc  98.97
Dice score: 0.6826799511909485
Epoch 36: LR = 0.0009885446387025737
 Checkpoint not saved as best model.


100%|██████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0228]

NASAR: 0.13422568415252256





Got 29039173/29360128 with acc  98.91
Dice score: 0.6687406897544861
Epoch 37: LR = 0.0009890059600737885
 Checkpoint not saved as best model.


  0%|          | 0/82 [00:00<?, ?it/s]

## Grad-CAM

In [15]:
def load_and_preprocess_image(image_path):
    transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=(0.0), std=(1.0), max_pixel_value=255.0),
    ToTensorV2(),
    ])
    image = Image.open(image_path).convert("L")
    image_rgb = Image.open(image_path).convert("RGB")
    processed = transform(image=np.array(image))
    input_tensor = processed['image'].unsqueeze(0)
    vis_image = np.array(image_rgb.resize((256, 256))) / 255.0
    return input_tensor, vis_image

def reshape_transform_csa(tensor):
    # If tensor shape is [B, C, T, H, W], convert to [B, C*T, H, W]
    if tensor.ndim == 5:
        B, C, T, H, W = tensor.size()
        return tensor.permute(0, 2, 1, 3, 4).reshape(B, T * C, H, W)
    elif tensor.ndim == 4:
        return tensor  # Already [B, C, H, W]
    else:
        raise ValueError(f"Unexpected tensor shape: {tensor.shape}")

def reshape_transform_spiking(tensor):
    """Special reshape for spiking networks with attention"""
    if tensor.ndim == 5:  # [T,B,C,H,W]
        # For attention weights
        if tensor.shape[2] == 1:  # Attention mask case
            return tensor.mean(dim=0).mean(dim=1)  # [B,H,W]

        # For regular feature maps
        B, C, T, H, W = tensor.permute(1,2,0,3,4).shape
        return tensor.permute(1,0,2,3,4).reshape(B, T*C, H, W)

    elif tensor.ndim == 4:  # [B,C,H,W]
        return tensor

    raise ValueError(f"Unexpected tensor shape: {tensor.shape}")

def generate_gradcam_grid_multi_layers(
    model,
    image_paths,
    mask_paths,
    target_layers_dict,
    device="cuda",
    save_folder="Att_Res_SNN_gradcam_results",
    grid_rows=3,
    grid_cols=6
    ):
    os.makedirs(save_folder, exist_ok=True)
    model.to(device).eval()

    selected = random.sample(list(zip(image_paths, mask_paths)), grid_rows * grid_cols)

    for layer_name, target_layer in target_layers_dict.items():

        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        layer_folder_name = f"{layer_name.replace('.', '_').replace('[', '_').replace(']', '')}_{timestamp}"
        layer_folder = os.path.join(save_folder, layer_folder_name)
        os.makedirs(layer_folder, exist_ok=True)

        cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform_csa)
        # cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform_spiking)

        fig_cam, axes_cam = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols * 4, grid_rows * 4))
        fig_mask, axes_mask = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols * 4, grid_rows * 4))
        fig_idx = 0

        for img_path, mask_path in selected:
            input_tensor, vis_image = load_and_preprocess_image(img_path)
            input_tensor = input_tensor.to(device)

            mask_tensor, _ = load_and_preprocess_image(mask_path)
            input_mask_tensor = mask_tensor.squeeze(0).cpu().numpy()

            targets = [SemanticSegmentationTarget(0, input_mask_tensor)]
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
            binary_seg_mask = (grayscale_cam > 0.8).astype(np.uint8)

            # Save CAM image
            cam_image = show_cam_on_image(vis_image, grayscale_cam, use_rgb=True)
            cam_save_path = os.path.join(layer_folder, f"cam_{os.path.basename(img_path)}")
            plt.imsave(cam_save_path, cam_image)

            # Save binary segmentation mask
            bin_mask_path = os.path.join(layer_folder, f"mask_{os.path.basename(img_path)}")
            plt.imsave(bin_mask_path, binary_seg_mask, cmap='gray')

            row = fig_idx // grid_cols
            col = fig_idx % grid_cols

            if row < grid_rows:
                axes_cam[row, col].imshow(cam_image)
                axes_cam[row, col].axis("off")
                axes_cam[row, col].set_title(f"{os.path.basename(img_path)}")

                axes_mask[row, col].imshow(binary_seg_mask, cmap='gray')
                axes_mask[row, col].axis("off")
                axes_mask[row, col].set_title(f"{os.path.basename(img_path)}")

                fig_idx += 1

        # Save full CAM grid
        grid_cam_path = os.path.join(layer_folder, f"grid_cam_{layer_name}_{timestamp}.png")
        plt.tight_layout()
        fig_cam.savefig(grid_cam_path)
        plt.close(fig_cam)

        # Save full binary segmentation mask grid
        grid_mask_path = os.path.join(layer_folder, f"grid_mask_{layer_name}_{timestamp}.png")
        plt.tight_layout()
        fig_mask.savefig(grid_mask_path)
        plt.close(fig_mask)

        print(f"✅ Saved {grid_cam_path}")
        print(f"✅ Saved {grid_mask_path}")

In [20]:
model = resnet_2_copy.resnet34().to(device)
checkpoint = torch.load("./Att_Res_SNN_checkpoints/Att_Res_SNN_checkpoint_6_2025-04-26_06-49-48.pth.tar", map_location="cuda")
model.load_state_dict(checkpoint["state_dict"])

image_paths = sorted(glob.glob(os.path.join(VAL_IMG_DIR, "*.png")))
mask_paths = sorted(glob.glob(os.path.join(VAL_MASK_DIR, "*.png")))

target_layers = {
    "conv4_attention_sa": model.conv4_x[-1].residual_function[-1].attention.sa.conv,
    "conv4_residual":     model.conv4_x[-1].residual_function[-1],
    "conv4":              model.conv4_x[-1],

    "conv5_attention_sa": model.conv5_x[-1].residual_function[-1].attention.sa.conv,
    "conv5_residual":     model.conv5_x[-1].residual_function[-1],
    "conv5":              model.conv5_x[-1],

    "upsampling_layer[0]": model.upsampling_layer[0],
    "upsampling_layer[1]": model.upsampling_layer[1],
    "upsampling_layer[2]": model.upsampling_layer[2],
    "upsampling_layer[3]": model.upsampling_layer[3],
    "upsampling_layer[4]": model.upsampling_layer[4],
}

generate_gradcam_grid_multi_layers(
    model=model,
    image_paths=image_paths,
    mask_paths=mask_paths,
    target_layers_dict=target_layers,
    device=device,
    save_folder="Att_Res_SNN_gradcam_results",
    grid_rows=3,
    grid_cols=6
)

✅ Saved Att_Res_SNN_gradcam_results/conv4_attention_sa_20250426-071052/grid_cam_conv4_attention_sa_20250426-071052.png
✅ Saved Att_Res_SNN_gradcam_results/conv4_attention_sa_20250426-071052/grid_mask_conv4_attention_sa_20250426-071052.png
✅ Saved Att_Res_SNN_gradcam_results/conv4_residual_20250426-071101/grid_cam_conv4_residual_20250426-071101.png
✅ Saved Att_Res_SNN_gradcam_results/conv4_residual_20250426-071101/grid_mask_conv4_residual_20250426-071101.png
✅ Saved Att_Res_SNN_gradcam_results/conv4_20250426-071110/grid_cam_conv4_20250426-071110.png
✅ Saved Att_Res_SNN_gradcam_results/conv4_20250426-071110/grid_mask_conv4_20250426-071110.png
✅ Saved Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/grid_cam_conv5_attention_sa_20250426-071118.png
✅ Saved Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/grid_mask_conv5_attention_sa_20250426-071118.png
✅ Saved Att_Res_SNN_gradcam_results/conv5_residual_20250426-071126/grid_cam_conv5_residual_20250426-071126.png
✅

In [21]:
from google.colab import files

!zip -r Att_Res_SNN_gradcam_results.zip Att_Res_SNN_gradcam_results
files.download('Att_Res_SNN_gradcam_results.zip')

!zip -r Att_Res_SNN_saved_images.zip Att_Res_SNN_saved_images
files.download('Att_Res_SNN_saved_images.zip')

  adding: Att_Res_SNN_gradcam_results/ (stored 0%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/ (stored 0%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_1952.png (deflated 1%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/mask_1952.png (deflated 81%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/mask_1936.png (deflated 81%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_2676.png (deflated 1%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_1529.png (deflated 1%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_1539.png (deflated 1%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_1936.png (deflated 1%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_20250426-071118/cam_13.png (deflated 0%)
  adding: Att_Res_SNN_gradcam_results/conv5_attention_sa_2025

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

  adding: Att_Res_SNN_saved_images/ (stored 0%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/ (stored 0%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/5.png (deflated 35%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/pred_12-20250426-064255.png (deflated 91%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/original_1.png (deflated 1%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/original_13.png (deflated 1%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/pred_13-20250426-064255.png (deflated 96%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/pred_5-20250426-064255.png (deflated 87%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_checkpoint_5_2025-04-26_06-42-55/pred_11-20250426-064255.png (deflated 80%)
  adding: Att_Res_SNN_saved_images/Att_Res_SNN_c

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>