In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
# sys.path.append('E:/Code/Spiking-Visual-attention-for-Medical-image-segmentation/models/TCSA_SNN')

import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import SemanticSegmentationTarget

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

from simple_utils import (
    save_checkpoint,
    load_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs
    )

import resnet_2_copy

from datetime import datetime
import glob
import random
from fvcore.nn import FlopCountAnalysis
from thop import profile
from ptflops import get_model_complexity_info

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet_2_copy.resnet34().to(device)
summary(model, input_size=(1, 256, 256), batch_size=32)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        Snn_Conv2d-1      [32, 2, 64, 128, 128]           3,136
      BatchNorm3d1-2      [32, 64, 4, 128, 128]             128
     batch_norm_2d-3      [32, 2, 64, 128, 128]               0
        mem_update-4      [32, 2, 64, 128, 128]               0
        Snn_Conv2d-5        [32, 2, 64, 64, 64]          36,864
      BatchNorm3d1-6        [32, 64, 4, 64, 64]             128
     batch_norm_2d-7        [32, 2, 64, 64, 64]               0
        mem_update-8        [32, 2, 64, 64, 64]               0
        Snn_Conv2d-9        [32, 2, 64, 64, 64]          36,864
     BatchNorm3d2-10        [32, 64, 4, 64, 64]             128
   batch_norm_2d1-11        [32, 2, 64, 64, 64]               0
AdaptiveAvgPool3d-12           [32, 4, 1, 1, 1]               0
           Conv3d-13           [32, 1, 1, 1, 1]               4
             ReLU-14           [32, 1, 

In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation.zip"
extract_path = "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted"

# # Extract ZIP
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(extract_path)

# List extracted files
os.listdir(extract_path)

In [None]:
IMG_DIR      =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/images"
MASK_DIR     =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/masks"
VAL_IMG_DIR  =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/val_images"
VAL_MASK_DIR =  "/content/drive/MyDrive/05-GP25-Mostafa-SpikingNeuralNetwork/Dataset/Modified_3_Brain_Tumor_Segmentation Extracted/Modified_3_Brain_Tumor_Segmentation/val_masks"

In [None]:
Device = "cuda" if torch.cuda.is_available() else "cpu"
Learning_rate = 1e-3 # 1e-3 -> 1e-4 -> 5e-5 
Batch_size  = 32
num_epochs  = 3
num_workers = 4
IMAGE_HEIGHT = 256 # 256 -> 512 originally
IMAGE_WIDTH  = 256 # 256 -> 512 originally
PIN_MEMORY = True
LOAD_MODEL = False
CHECKPOINT_NAME = None

In [4]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    running_loss=0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=Device)
        targets = targets.float().unsqueeze(1).to(device=Device) 

        with torch.amp.autocast(device_type=Device):
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        
        running_loss+=loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        loop.set_postfix(loss=loss.item())
        
    nasar = model.calculate_nasar()

    return running_loss/len(loader), nasar 

In [None]:
def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ]
    )
    val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ]
    )
    model = resnet_2_copy.resnet34().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=Learning_rate)
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    # scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)  
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    train_loader, val_loader = get_loaders(
        IMG_DIR,
        MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        Batch_size,
        train_transform,
        val_transform,
        num_workers,
        PIN_MEMORY,
    )

    train_losses=[]
    val_dice_scores=[]
    val_accs=[]
    train_nasar=[]

    if LOAD_MODEL:
        load_checkpoint(model=model, optimizer=optimizer, checkpoint_name=CHECKPOINT_NAME)
        val_acc_loaded, val_dice_loaded = check_accuracy(val_loader, model, device=Device)

    scaler = torch.amp.GradScaler()

    for epoch in range(num_epochs):
        train_loss, nasar = train_fn(train_loader, model, optimizer, loss_fn, scaler)

        val_acc, val_dice = check_accuracy(val_loader, model, device=Device)

        train_losses.append(train_loss)
        val_accs.append(val_acc.cpu().item())
        val_dice_scores.append(val_dice.cpu().item())
        train_nasar.append(nasar)

        scheduler.step(val_dice)
        print(f"Epoch {epoch}: LR = {optimizer.param_groups[0]['lr']}")

        checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        }

        checkpoint_filename = save_checkpoint(checkpoint, is_best=len(val_dice_scores) == 1 or val_dice_scores[-1] > max(val_dice_scores[:-1]))
        
        if epoch != num_epochs - 1:
            save_predictions_as_imgs(
                val_loader, 
                model, 
                checkpoint_filename=checkpoint_filename, 
                train_losses=train_losses,
                val_accs=val_accs,
                val_dice_scores=val_dice_scores,
                train_nasar=train_nasar,
                folder="Att_Res_SNN_saved_images/", 
                device=Device, 
                show_last_epoch=False,
            )
        else:
            save_predictions_as_imgs(
                val_loader, 
                model, 
                checkpoint_filename=checkpoint_filename, 
                train_losses=train_losses,
                val_accs=val_accs,
                val_dice_scores=val_dice_scores,
                train_nasar=train_nasar, 
                folder="Att_Res_SNN_saved_images/", 
                device=Device, 
                show_last_epoch=True,
            )

In [6]:
if __name__ == "__main__": 
    main()

100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0715]


NASAR: 0.17817558145299203
Got 28818978/29360128 with acc  98.16
Dice score: 0.0
Epoch 0: LR = 0.001
✅ Checkpoint saved: Att_Res_SNN_checkpoint_24_2025-04-23_22-07-08.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:27<00:00,  1.80s/it, loss=0.0572]


NASAR: 0.16996044284301184
Got 28822350/29360128 with acc  98.17
Dice score: 0.013110661879181862
Epoch 1: LR = 0.0009999957588034576
✅ Checkpoint saved: Att_Res_SNN_checkpoint_25_2025-04-23_22-10-25.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:07<00:00,  1.56s/it, loss=0.0453]


NASAR: 0.16466474756948266
Got 28852617/29360128 with acc  98.27
Dice score: 0.1977100521326065
Epoch 2: LR = 0.0009990358210256326
✅ Checkpoint saved: Att_Res_SNN_checkpoint_26_2025-04-23_22-13-07.pth.tar


100%|████████████████████████████████████████████████████████████████████████| 82/82 [02:15<00:00,  1.65s/it, loss=0.042]


NASAR: 0.1623323915150244
Got 28897474/29360128 with acc  98.42
Dice score: 0.3092058598995209
Epoch 3: LR = 0.0009976428150025285
✅ Checkpoint saved: Att_Res_SNN_checkpoint_27_2025-04-23_22-16-05.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:32<00:00,  1.87s/it, loss=0.0374]


NASAR: 0.1540888396786972
Got 28937078/29360128 with acc  98.56
Dice score: 0.4503011703491211
Epoch 4: LR = 0.0009950051605806265
✅ Checkpoint saved: Att_Res_SNN_checkpoint_28_2025-04-23_22-19-29.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:31<00:00,  1.85s/it, loss=0.0374]


NASAR: 0.1578169451073302
Got 28988804/29360128 with acc  98.74
Dice score: 0.5433835387229919
Epoch 5: LR = 0.0009927322858058294
✅ Checkpoint saved: Att_Res_SNN_checkpoint_29_2025-04-23_22-22-53.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:38<00:00,  1.93s/it, loss=0.0279]


NASAR: 0.16025331882243984
Got 28959912/29360128 with acc  98.64
Dice score: 0.4679892659187317
Epoch 6: LR = 0.000994605774009284
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:36<00:00,  1.91s/it, loss=0.0216]


NASAR: 0.151043842655952
Got 28925617/29360128 with acc  98.52
Dice score: 0.4064474403858185
Epoch 7: LR = 0.0009959294005798221
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:36<00:00,  1.91s/it, loss=0.0251]


NASAR: 0.15393095956721775
Got 28987836/29360128 with acc  98.73
Dice score: 0.5394655466079712
Epoch 8: LR = 0.0009928364641767402
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0224]


NASAR: 0.15220692236098884
Got 28961514/29360128 with acc  98.64
Dice score: 0.5060362219810486
Epoch 9: LR = 0.0009936949558793935
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0273]


NASAR: 0.13394965588206975
Got 28877679/29360128 with acc  98.36
Dice score: 0.30485010147094727
Epoch 10: LR = 0.0009977087078576862
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:44<00:00,  2.01s/it, loss=0.0194]


NASAR: 0.14917868851496022
Got 28917120/29360128 with acc  98.49
Dice score: 0.4313684105873108
Epoch 11: LR = 0.0009954157141027515
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0233]


NASAR: 0.15220760291730853
Got 29005380/29360128 with acc  98.79
Dice score: 0.5640794634819031
Epoch 12: LR = 0.0009921696078539486
✅ Checkpoint saved: Att_Res_SNN_checkpoint_30_2025-04-23_22-44-26.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.95s/it, loss=0.0275]


NASAR: 0.14314699844575265
Got 29009491/29360128 with acc  98.81
Dice score: 0.6353523135185242
Epoch 13: LR = 0.0009900728037297998
✅ Checkpoint saved: Att_Res_SNN_checkpoint_31_2025-04-23_22-48-01.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.95s/it, loss=0.0189]


NASAR: 0.1457181044027839
Got 28996004/29360128 with acc  98.76
Dice score: 0.5775715708732605
Epoch 14: LR = 0.0009917915817692772
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:38<00:00,  1.93s/it, loss=0.0269]


NASAR: 0.12037666750625825
Got 28922736/29360128 with acc  98.51
Dice score: 0.35283035039901733
Epoch 15: LR = 0.0009969314943700362
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.95s/it, loss=0.0193]


NASAR: 0.1259999521461451
Got 29011755/29360128 with acc  98.81
Dice score: 0.5938996076583862
Epoch 16: LR = 0.000991322279638263
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:45<00:00,  2.01s/it, loss=0.0222]


NASAR: 0.1306579571934373
Got 29018141/29360128 with acc  98.84
Dice score: 0.5995096564292908
Epoch 17: LR = 0.0009911580508071638
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.0238]


NASAR: 0.11784682475345236
Got 29013787/29360128 with acc  98.82
Dice score: 0.6395869255065918
Epoch 18: LR = 0.00098994047874294
✅ Checkpoint saved: Att_Res_SNN_checkpoint_32_2025-04-23_23-04-08.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0183]


NASAR: 0.1110584814223885
Got 29043952/29360128 with acc  98.92
Dice score: 0.6806653141975403
Epoch 19: LR = 0.000988611895284042
✅ Checkpoint saved: Att_Res_SNN_checkpoint_33_2025-04-23_23-07-49.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:42<00:00,  1.99s/it, loss=0.0176]


NASAR: 0.1287711327064765
Got 29056441/29360128 with acc  98.97
Dice score: 0.6649622321128845
Epoch 20: LR = 0.000989129393670344
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0212]


NASAR: 0.1283515070525693
Got 29002460/29360128 with acc  98.78
Dice score: 0.5869576334953308
Epoch 21: LR = 0.0009915233868191757
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.94s/it, loss=0.0214]


NASAR: 0.12698346348435666
Got 29046667/29360128 with acc  98.93
Dice score: 0.6858803033828735
Epoch 22: LR = 0.0009884374010469094
✅ Checkpoint saved: Att_Res_SNN_checkpoint_34_2025-04-23_23-17-30.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:38<00:00,  1.93s/it, loss=0.0171]


NASAR: 0.11874572324081206
Got 28973170/29360128 with acc  98.68
Dice score: 0.5237757563591003
Epoch 23: LR = 0.000993246164536279
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0209]


NASAR: 0.12090217339600756
Got 29052051/29360128 with acc  98.95
Dice score: 0.7101943492889404
Epoch 24: LR = 0.0009876065611736623
✅ Checkpoint saved: Att_Res_SNN_checkpoint_35_2025-04-23_23-24-06.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.95s/it, loss=0.0137]


NASAR: 0.1322679474879878
Got 29044528/29360128 with acc  98.93
Dice score: 0.6835374236106873
Epoch 25: LR = 0.000988515954584896
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.95s/it, loss=0.0135]


NASAR: 0.11851171036841164
Got 29084656/29360128 with acc  99.06
Dice score: 0.7071134448051453
Epoch 26: LR = 0.0009877134124857737
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.95s/it, loss=0.0126]


NASAR: 0.10499401719357486
Got 29080317/29360128 with acc  99.05
Dice score: 0.6957746148109436
Epoch 27: LR = 0.0009881027370423183
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0125]


NASAR: 0.10637917317135233
Got 29048183/29360128 with acc  98.94
Dice score: 0.676809549331665
Epoch 28: LR = 0.0009887400626600637
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0189]


NASAR: 0.11816528929231312
Got 28935983/29360128 with acc  98.56
Dice score: 0.502980649471283
Epoch 29: LR = 0.000993770710530915
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.96s/it, loss=0.0147]


NASAR: 0.11666747437956188
Got 29092735/29360128 with acc  99.09
Dice score: 0.7477772831916809
Epoch 30: LR = 0.0009862663463985739
✅ Checkpoint saved: Att_Res_SNN_checkpoint_36_2025-04-23_23-42-44.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:37<00:00,  1.93s/it, loss=0.0166]


NASAR: 0.10383528051242022
Got 29078732/29360128 with acc  99.04
Dice score: 0.7252413034439087
Epoch 31: LR = 0.0009870781319062177
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:43<00:00,  1.99s/it, loss=0.0233]


NASAR: 0.09593905417572164
Got 29043453/29360128 with acc  98.92
Dice score: 0.6448948979377747
Epoch 32: LR = 0.0009897733910241562
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:48<00:00,  2.05s/it, loss=0.0158]


NASAR: 0.10684590944102113
Got 28940120/29360128 with acc  98.57
Dice score: 0.5383761525154114
Epoch 33: LR = 0.0009928652986044735
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:45<00:00,  2.01s/it, loss=0.0194]


NASAR: 0.11352419069675213
Got 29042829/29360128 with acc  98.92
Dice score: 0.6327813863754272
Epoch 34: LR = 0.0009901527161735504
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:47<00:00,  2.04s/it, loss=0.0127]


NASAR: 0.10860988455758967
Got 29078563/29360128 with acc  99.04
Dice score: 0.71254962682724
Epoch 35: LR = 0.0009875245661907083
 Checkpoint not saved as best model.


100%|████████████████████████████████████████████████████████████████████████| 82/82 [02:45<00:00,  2.02s/it, loss=0.011]


NASAR: 0.11719514730390808
Got 28908448/29360128 with acc  98.46
Dice score: 0.4969998002052307
Epoch 36: LR = 0.0009939176737009353
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:44<00:00,  2.01s/it, loss=0.0164]


NASAR: 0.10264275779186839
Got 29091700/29360128 with acc  99.09
Dice score: 0.745006799697876
Epoch 37: LR = 0.0009863674594057714
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0175]


NASAR: 0.110974799859132
Got 29088664/29360128 with acc  99.08
Dice score: 0.7562161684036255
Epoch 38: LR = 0.0009859560898222578
✅ Checkpoint saved: Att_Res_SNN_checkpoint_37_2025-04-24_00-08-03.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:29<00:00,  1.82s/it, loss=0.0197]


NASAR: 0.11504786674965156
Got 29025890/29360128 with acc  98.86
Dice score: 0.6268124580383301
Epoch 39: LR = 0.0009903370176485235
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:24<00:00,  1.76s/it, loss=0.0161]


NASAR: 0.10006689689528774
Got 29038460/29360128 with acc  98.90
Dice score: 0.6416898965835571
Epoch 40: LR = 0.0009898744438383328
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:37<00:00,  1.92s/it, loss=0.0125]


NASAR: 0.10326835918874248
Got 28999670/29360128 with acc  98.77
Dice score: 0.5843667984008789
Epoch 41: LR = 0.0009915978431609316
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:36<00:00,  1.90s/it, loss=0.0214]


NASAR: 0.11117604752661477
Got 29074405/29360128 with acc  99.03
Dice score: 0.7449815273284912
Epoch 42: LR = 0.000986368378652923
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:36<00:00,  1.91s/it, loss=0.0166]


NASAR: 0.1041858476092558
Got 29074920/29360128 with acc  99.03
Dice score: 0.7110733389854431
Epoch 43: LR = 0.0009875759895888853
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.94s/it, loss=0.0135]


NASAR: 0.10461080130277105
Got 29104295/29360128 with acc  99.13
Dice score: 0.7480543255805969
Epoch 44: LR = 0.0009862562165708995
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.95s/it, loss=0.0134]


NASAR: 0.09874134332361356
Got 29104525/29360128 with acc  99.13
Dice score: 0.763681173324585
Epoch 45: LR = 0.0009856787890214088
✅ Checkpoint saved: Att_Res_SNN_checkpoint_38_2025-04-24_00-28-49.pth.tar


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:38<00:00,  1.93s/it, loss=0.0127]


NASAR: 0.10631259506297223
Got 29071246/29360128 with acc  99.02
Dice score: 0.7131072282791138
Epoch 46: LR = 0.0009875051132405225
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:34<00:00,  1.89s/it, loss=0.0138]


NASAR: 0.11108578426737181
Got 29099849/29360128 with acc  99.11
Dice score: 0.7594100832939148
Epoch 47: LR = 0.000985837771512344
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:29<00:00,  1.82s/it, loss=0.0182]


NASAR: 0.10171541025940802
Got 28999542/29360128 with acc  98.77
Dice score: 0.6217542886734009
Epoch 48: LR = 0.000990491847296509
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:28<00:00,  1.81s/it, loss=0.0168]


NASAR: 0.10912358816800542
Got 29099063/29360128 with acc  99.11
Dice score: 0.745825469493866
Epoch 49: LR = 0.0009863376191287898
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:29<00:00,  1.82s/it, loss=0.0108]


NASAR: 0.11016732874050947
Got 29065427/29360128 with acc  99.00
Dice score: 0.7227364182472229
Epoch 50: LR = 0.0009871668550770877
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:37<00:00,  1.92s/it, loss=0.0137]


NASAR: 0.10596106981447605
Got 29097223/29360128 with acc  99.10
Dice score: 0.7410470247268677
Epoch 51: LR = 0.0009865113363936432
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0122]


NASAR: 0.10264109221982284
Got 29100462/29360128 with acc  99.12
Dice score: 0.7510010004043579
Epoch 52: LR = 0.0009861482291317158
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:41<00:00,  1.97s/it, loss=0.0134]


NASAR: 0.11104176301911403
Got 29097141/29360128 with acc  99.10
Dice score: 0.752902626991272
Epoch 53: LR = 0.000986078317139996
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.94s/it, loss=0.0112]


NASAR: 0.1015220068989785
Got 29084123/29360128 with acc  99.06
Dice score: 0.7203543782234192
Epoch 54: LR = 0.0009872509487523112
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:40<00:00,  1.96s/it, loss=0.0104]


NASAR: 0.10810480431211947
Got 29086252/29360128 with acc  99.07
Dice score: 0.7383591532707214
Epoch 55: LR = 0.0009866085723903857
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:38<00:00,  1.94s/it, loss=0.0231]


NASAR: 0.09463636192357597
Got 29101866/29360128 with acc  99.12
Dice score: 0.7444708943367004
Epoch 56: LR = 0.000986386976471047
 Checkpoint not saved as best model.


100%|██████████████████████████████████████████████████████████████████████| 82/82 [02:39<00:00,  1.95s/it, loss=0.00948]


NASAR: 0.1063768807711176
Got 29106396/29360128 with acc  99.14
Dice score: 0.758739709854126
Epoch 57: LR = 0.0009858626448990087
 Checkpoint not saved as best model.


100%|████████████████████████████████████████████████████████████████████████| 82/82 [02:43<00:00,  1.99s/it, loss=0.012]


NASAR: 0.10324746790066572
Got 29092540/29360128 with acc  99.09
Dice score: 0.7274970412254333
Epoch 58: LR = 0.0009869979758940271
 Checkpoint not saved as best model.


100%|███████████████████████████████████████████████████████████████████████| 82/82 [02:42<00:00,  1.98s/it, loss=0.0116]


NASAR: 0.10616154737875495
Got 29083432/29360128 with acc  99.06
Dice score: 0.7266072630882263
Epoch 59: LR = 0.0009870296228820228
 Checkpoint not saved as best model.


## Grad-CAM

In [9]:
def load_and_preprocess_image(image_path):
    transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=(0.0), std=(1.0), max_pixel_value=255.0),
    ToTensorV2(),
    ])
    image = Image.open(image_path).convert("L")
    image_rgb = Image.open(image_path).convert("RGB")
    processed = transform(image=np.array(image))
    input_tensor = processed['image'].unsqueeze(0)
    vis_image = np.array(image_rgb.resize((256, 256))) / 255.0
    return input_tensor, vis_image

def reshape_transform_csa(tensor):
    # If tensor shape is [B, C, T, H, W], convert to [B, C*T, H, W]
    if tensor.ndim == 5:
        B, C, T, H, W = tensor.size()
        return tensor.permute(0, 2, 1, 3, 4).reshape(B, T * C, H, W)
    elif tensor.ndim == 4:
        return tensor  # Already [B, C, H, W]
    else:
        raise ValueError(f"Unexpected tensor shape: {tensor.shape}")
    
def reshape_transform_spiking(tensor):
    """Special reshape for spiking networks with attention"""
    if tensor.ndim == 5:  # [T,B,C,H,W]
        # For attention weights
        if tensor.shape[2] == 1:  # Attention mask case
            return tensor.mean(dim=0).mean(dim=1)  # [B,H,W]
        
        # For regular feature maps
        B, C, T, H, W = tensor.permute(1,2,0,3,4).shape
        return tensor.permute(1,0,2,3,4).reshape(B, T*C, H, W)
    
    elif tensor.ndim == 4:  # [B,C,H,W]
        return tensor
    
    raise ValueError(f"Unexpected tensor shape: {tensor.shape}")

def generate_gradcam_grid_multi_layers(
    model,
    image_paths,
    mask_paths,
    target_layers_dict,
    device="cuda",
    save_folder="Att_Res_SNN_gradcam_results",
    grid_rows=3,
    grid_cols=6
    ):
    os.makedirs(save_folder, exist_ok=True)
    model.to(device).eval()

    selected = random.sample(list(zip(image_paths, mask_paths)), grid_rows * grid_cols)

    for layer_name, target_layer in target_layers_dict.items():

        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        layer_folder_name = f"{layer_name.replace('.', '_').replace('[', '_').replace(']', '')}_{timestamp}"
        layer_folder = os.path.join(save_folder, layer_folder_name)
        os.makedirs(layer_folder, exist_ok=True)

        cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform_csa)
        # cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform_spiking)

        fig_cam, axes_cam = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols * 4, grid_rows * 4))
        fig_mask, axes_mask = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols * 4, grid_rows * 4))
        fig_idx = 0

        for img_path, mask_path in selected:
            input_tensor, vis_image = load_and_preprocess_image(img_path)
            input_tensor = input_tensor.to(device)

            mask_tensor, _ = load_and_preprocess_image(mask_path)
            input_mask_tensor = mask_tensor.squeeze(0).cpu().numpy()

            targets = [SemanticSegmentationTarget(0, input_mask_tensor)]
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
            binary_seg_mask = (grayscale_cam > 0.8).astype(np.uint8)

            # Save CAM image
            cam_image = show_cam_on_image(vis_image, grayscale_cam, use_rgb=True)
            cam_save_path = os.path.join(layer_folder, f"cam_{os.path.basename(img_path)}")
            plt.imsave(cam_save_path, cam_image)

            # Save binary segmentation mask
            bin_mask_path = os.path.join(layer_folder, f"mask_{os.path.basename(img_path)}")
            plt.imsave(bin_mask_path, binary_seg_mask, cmap='gray')

            row = fig_idx // grid_cols
            col = fig_idx % grid_cols

            if row < grid_rows:
                axes_cam[row, col].imshow(cam_image)
                axes_cam[row, col].axis("off")
                axes_cam[row, col].set_title(f"{os.path.basename(img_path)}")

                axes_mask[row, col].imshow(binary_seg_mask, cmap='gray')
                axes_mask[row, col].axis("off")
                axes_mask[row, col].set_title(f"{os.path.basename(img_path)}")

                fig_idx += 1

        # Save full CAM grid
        grid_cam_path = os.path.join(layer_folder, f"grid_cam_{layer_name}_{timestamp}.png")
        plt.tight_layout()
        fig_cam.savefig(grid_cam_path)
        plt.close(fig_cam)

        # Save full binary segmentation mask grid
        grid_mask_path = os.path.join(layer_folder, f"grid_mask_{layer_name}_{timestamp}.png")
        plt.tight_layout()
        fig_mask.savefig(grid_mask_path)
        plt.close(fig_mask)

        print(f"✅ Saved {grid_cam_path}")
        print(f"✅ Saved {grid_mask_path}")

In [None]:
model = resnet_2_copy.resnet50().to(device)
checkpoint = torch.load("E:/Code/Attention-SNN/Att_Res_SNN_checkpoints/", map_location="cuda")
model.load_state_dict(checkpoint["state_dict"])

image_paths = sorted(glob.glob(os.path.join(VAL_IMG_DIR, "*.png")))
mask_paths = sorted(glob.glob(os.path.join(VAL_MASK_DIR, "*.png")))

target_layers = {
    "conv4_attention_sa": model.conv4_x[-1].residual_function[-1].attention.sa.conv,
    "conv4_residual":     model.conv4_x[-1].residual_function[-1],
    "conv4":              model.conv4_x[-1],

    "conv5_attention_sa": model.conv5_x[-1].residual_function[-1].attention.sa.conv,
    "conv5_residual":     model.conv5_x[-1].residual_function[-1],
    "conv5":              model.conv5_x[-1],

    "segmentation_head[0]": model.segmentation_head[0],
    "segmentation_head[1]": model.segmentation_head[1],
    "segmentation_head[2]": model.segmentation_head[2],
    "segmentation_head[3]": model.segmentation_head[3],
    "segmentation_head[4]": model.segmentation_head[4],
}

generate_gradcam_grid_multi_layers(
    model=model,
    image_paths=image_paths,
    mask_paths=mask_paths,
    target_layers_dict=target_layers,
    device=device,
    save_folder="Att_Res_SNN_gradcam_results",
    grid_rows=3,
    grid_cols=6
)

✅ Saved Att_Res_SNN_gradcam_results\conv4_attention_sa_20250424-021353\grid_cam_conv4_attention_sa_20250424-021353.png
✅ Saved Att_Res_SNN_gradcam_results\conv4_attention_sa_20250424-021353\grid_mask_conv4_attention_sa_20250424-021353.png
✅ Saved Att_Res_SNN_gradcam_results\conv4_residual_20250424-021405\grid_cam_conv4_residual_20250424-021405.png
✅ Saved Att_Res_SNN_gradcam_results\conv4_residual_20250424-021405\grid_mask_conv4_residual_20250424-021405.png
✅ Saved Att_Res_SNN_gradcam_results\conv4_20250424-021417\grid_cam_conv4_20250424-021417.png
✅ Saved Att_Res_SNN_gradcam_results\conv4_20250424-021417\grid_mask_conv4_20250424-021417.png
✅ Saved Att_Res_SNN_gradcam_results\conv5_attention_sa_20250424-021427\grid_cam_conv5_attention_sa_20250424-021427.png
✅ Saved Att_Res_SNN_gradcam_results\conv5_attention_sa_20250424-021427\grid_mask_conv5_attention_sa_20250424-021427.png
✅ Saved Att_Res_SNN_gradcam_results\conv5_residual_20250424-021437\grid_cam_conv5_residual_20250424-021437.png
✅