<a href="https://colab.research.google.com/github/weirdrazak/MSGDANet/blob/main/MSGDANet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install adan_pytorch albumentations

Collecting adan_pytorch
  Downloading adan_pytorch-0.1.0-py3-none-any.whl.metadata (661 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6->adan_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6->adan_pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6->adan_pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6->adan_pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6->adan_pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torc

In [None]:
import os,sys
import torch
import albumentations as A
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from google.colab import drive
import warnings  # To suppress specific sklearn warnings
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    cohen_kappa_score,
    precision_recall_curve,
    auc
)
from tabulate import tabulate  # For pretty printing metrics table
from torch.optim.lr_scheduler import PolynomialLR
from tqdm import tqdm
from adan_pytorch import Adan
from albumentations.pytorch import ToTensorV2
from collections import defaultdict

# VGG16-BN Encoder: extracts hierarchical features from input images
class VGGEncoder(nn.Module):
    def __init__(self,pretrained=True):
        super(VGGEncoder, self).__init__()


        # Load VGG16 with Batch Normalization (optionally pretrained on ImageNet)
        vgg = models.vgg16_bn(pretrained=pretrained)

        # Convert the feature layers to a list for manual slicing
        features = list(vgg.features.children())

        # Split VGG16 into 5 stages based on spatial downsampling:
        # Each stage ends with a MaxPool (downsampling) layer

        # Stage 1: conv1_1 to relu1_2 — output: (B, 64, H/2, W/2)
        self.stage1 = nn.Sequential(*features[0:6])

        # Stage 2: conv2_1 to relu2_2 — output: (B, 128, H/4, W/4)
        self.stage2 = nn.Sequential(*features[6:13])

        # Stage 3: conv3_1 to relu3_3 — output: (B, 256, H/8, W/8)
        self.stage3 = nn.Sequential(*features[13:23])

        # Stage 4: conv4_1 to relu4_3 — output: (B, 512, H/16, W/16)
        self.stage4 = nn.Sequential(*features[23:33])

        # Stage 5: conv5_1 to relu5_3 — output: (B, 512, H/32, W/32)
        self.stage5 = nn.Sequential(*features[33:43])

    def forward(self, x):
        # Pass input through each stage sequentially, saving intermediate outputs
        f1 = self.stage1(x)  # Low-level texture edges
        f2 = self.stage2(f1) # Slightly more abstract edges
        f3 = self.stage3(f2) # Mid-level patterns
        f4 = self.stage4(f3) # High-level object parts
        f5 = self.stage5(f4) # Deepest semantic features

        # Return all feature maps for use in MSAB
        return f1, f2, f3, f4, f5



#MULTI SCALE ATTENTION BLOCK
# Combines multi-scale convolutions and self-attention for feature refinement
class MSAB(nn.Module):
    def __init__(self, in_channels, out_size=(256, 256)):
        super(MSAB, self).__init__()

        # Multi-scale convolution branches
        self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0)  # Focus on channel-wise info
        self.conv3x3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)  # Capture local spatial patterns

        # Attention layers (query, key, value)
        self.query_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)

       # Learnable scaling factor to control attention contribution
        self.scale = nn.Parameter(torch.tensor(1.0))

        # Final upsample output to a consistent size (e.g., for skip connections)
        self.out_size = out_size

    def forward(self, x):
        B, C, H, W = x.shape  # Batch size, Channels, Height, Width

        f1 = self.conv1x1(x)  # Channel-focused features (fine details)
        f2 = self.conv3x3(x)  # Spatial-focused features (edges, textures)

       # Step 2: Self-attention mechanism
        # Transform input into query, key, value
        Q = self.query_conv(x).view(B, C, -1)              # (B, C, H*W)
        K = self.key_conv(x).view(B, C, -1)                # (B, C, H*W)
        V = self.value_conv(x).view(B, C, -1)              # (B, C, H*W)

         # Compute attention scores: Qᵀ·K → (B, H*W, H*W)
        attn = torch.bmm(Q.transpose(1, 2), K)
        # Normalize the attention map using softmax
        attn = F.softmax(attn, dim=-1)
        # Weighted sum of values: attention-weighted features
        Fmid = torch.bmm(V, attn.transpose(1, 2))  # (B, C, H*W)
        Fmid = Fmid.view(B, C, H, W)               # Reshape back to (B, C, H, W)

        # Step 3: Combine attention with multi-scale features
        sFmid = self.scale * Fmid         # Scale the attention feature map
        f1c = sFmid + f1                  # Fuse with 1x1 conv output
        f2c = sFmid + f2                  # Fuse with 3x3 conv output

        fused = f1c + f2c                 # Final feature fusion


        # Step 4: Upsample to 256x256
        Fmsab = F.interpolate(fused, size=self.out_size, mode='bilinear', align_corners=False)


        return Fmsab


#LESON AWARE RELATION BLOCK
class LARB(nn.Module):
    def __init__(self, in_channels, mid_channels=64):
        super(LARB, self).__init__()

        self.lesion_types = ['MA', 'SE', 'HE', 'EX']

        self.lesion_convs = nn.ModuleDict({
            l: nn.Conv2d(in_channels, mid_channels, kernel_size=1)
            for l in self.lesion_types
        })

        def lesion_attention():
            return nn.Sequential(
                nn.Conv2d(mid_channels, mid_channels // 4, kernel_size=1),  #  channel dim reduction
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels // 4, mid_channels, kernel_size=1)   # ↑ channel dim restored
            )

        self.attention_blocks = nn.ModuleDict({
            l: lesion_attention() for l in self.lesion_types
        })

        self.reduction_convs = nn.ModuleDict({
            l: nn.Conv2d(mid_channels, 1, kernel_size=1)
            for l in self.lesion_types
        })

        self.scale = nn.Parameter(torch.tensor(1.0))

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def apply_lah(self, FL, att_layer):
        avg_pool = F.adaptive_avg_pool2d(FL, 1)
        max_pool = F.adaptive_max_pool2d(FL, 1)
        att = att_layer(avg_pool) + self.scale * att_layer(max_pool)
        return FL * torch.sigmoid(att)

    def forward(self, f3_out, f4_out, f5_out):
        FM = torch.cat([f3_out, f4_out, f5_out], dim=1)  # [B, C_total, H, W]

        outputs = {}
        for l in self.lesion_types:
            FL = self.lesion_convs[l](FM)
            FL_att = self.apply_lah(FL, self.attention_blocks[l])
            FR = self.reduction_convs[l](FL_att)
            outputs[l] = FR

        return outputs['MA'], outputs['SE'], outputs['HE'], outputs['EX']


# =================**SPATIAL FUSION BLOCK**===============

class SpatialFusionBlock(nn.Module):
    def __init__(self, in_channels=4, inter_channels=16):
        super(SpatialFusionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(inter_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, lesion_maps, image):
        """
        lesion_maps: Tensor [B, 4, H, W]  -> MA, HE, EX, SE
        image:       Tensor [B, 3, H, W]  -> original RGB image

        Returns:
        enhanced_image: [B, 3, H, W]
        """
        x = self.conv1(lesion_maps)         # [B, 4, H, W]
        x = self.relu(x)
        x = self.conv2(x)                   # [B, 3, H, W]
        attention_map = self.sigmoid(x)     # [B, 1, H, W]

        # Resize attention_map to match image spatial size
        if attention_map.shape[-2:] != image.shape[-2:]:
            attention_map = F.interpolate(attention_map, size=image.shape[-2:], mode='bilinear', align_corners=False)

        enhanced_image = image * attention_map  # [B, 3, h, w]
        return enhanced_image


# ENHANCED SELF ATTENTION BLOCK
class ESAB(nn.Module):
    def __init__(self, shared_encoder, embed_dim=512):  # VGG16 f5 = 512 channels
        super(ESAB, self).__init__()
        self.encoder = shared_encoder

        self.query_proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        self.key_proj   = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        self.value_proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)

        self.ffd = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        )

    def forward(self, image_sfb,x):
        B, C, H, W = image_sfb.shape
        B, C, H, W = x.shape


        # Step 1: Get global feature from full image
        global_feats = self.encoder(x)[-1]  # [B, 512, H/16, W/16 ] = [B, 512, 32, 32] for 512x512

        # Step 2: Split input image into 4 patches (2x2)
        patch_rows = torch.chunk(image_sfb, 2, dim=2)  # height
        patches = [torch.chunk(r, 2, dim=3) for r in patch_rows]  # width
        patches = [p for row in patches for p in row]  # flatten

        # Step 3: Process each patch
        local_outs = []
        for patch in patches:
            local_feat = self.encoder(patch)[-1]  # local f5

            Q = self.query_proj(local_feat).flatten(2)       # [B, D, hw]
            K = self.key_proj(global_feats).flatten(2)       # [B, D, HW]
            V = self.value_proj(global_feats).flatten(2)     # [B, D, HW]

            attn = torch.bmm(Q.transpose(1, 2), K) / (K.size(1) ** 0.5)  # [B, hw, HW]
            attn = F.softmax(attn, dim=-1)
            out = torch.bmm(V, attn.transpose(1, 2))         # [B, D, hw]
            out = out.view_as(local_feat)
            local_outs.append(self.ffd(out))

        # Step 4: Re-stitch the 4 patches into a full feature map
        top = torch.cat([local_outs[0], local_outs[1]], dim=3)
        bottom = torch.cat([local_outs[2], local_outs[3]], dim=3)
        stitched = torch.cat([top, bottom], dim=2)

        # Step 5: Add residual global features
        return stitched + global_feats  # [B, D, H/32, W/32]



# ================**GRADING HEAD**==================

class GradingHead(nn.Module):
    def __init__(self, in_channels=512, num_classes=5):
        super(GradingHead, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x)       # [B, C, 1, 1]
        x = self.flatten(x)    # [B, C]
        out = self.fc(x)       # [B, num_classes]
        return out


# ==========FULL MSGDANet CLASS WRAPPER=============

# class MSGDANet(nn.Module):
#     def __init__(self, num_classes=5, img_size=(256, 256)):
#         super(MSGDANet, self).__init__()

#         # Shared encoder
#         self.encoder = VGGEncoder()

#         # Multi-Scale Attention Blocks
#         self.msab1 = MSAB(64, out_size=img_size)
#         self.msab2 = MSAB(128, out_size=img_size)
#         self.msab3 = MSAB(256, out_size=img_size)
#         self.msab4 = MSAB(512, out_size=img_size)
#         self.msab5 = MSAB(512, out_size=img_size)

#         # Lesion-Aware Relation Block
#         self.larb = LARB(in_channels=256 + 512 + 512, mid_channels=64)

#         # Spatial Fusion Block
#         self.sfb = SpatialFusionBlock(in_channels=4, inter_channels=16)

#         # Enhanced Self-Attention Block
#         self.esab = ESAB(shared_encoder=self.encoder, embed_dim=512)

#         # DR Grading Head
#         self.grading_head = GradingHead(in_channels=512, num_classes=num_classes)

#     def forward(self, x):
#         """
#         Input:
#             x: RGB fundus image [B, 3, 512, 512]
#         Output:
#             lesion_maps: [B, 4, 512, 512] - MA, HE, EX, SE
#             logits: [B, num_classes] - DR grades
#         """
#         # Stage-wise feature extraction
#         f1, f2, f3, f4, f5 = self.encoder(x)

#         # Multi-scale attention per stage
#         # fMSAB1 = self.msab1(f1)
#         # fMSAB2 = self.msab2(f2)
#         fMSAB3 = self.msab3(f3)
#         fMSAB4 = self.msab4(f4)
#         fMSAB5 = self.msab5(f5)

#         # print("MSAB3 output:", fMSAB3.min().item(), fMSAB3.max().item())
#         # print("MSAB4 output:", fMSAB4.min().item(), fMSAB4.max().item())
#         # print("MSAB5 output:", fMSAB5.min().item(), fMSAB5.max().item())

#         fMSAB3 = F.interpolate(fMSAB3, size=(256, 256), mode='bilinear', align_corners=False)
#         fMSAB4 = F.interpolate(fMSAB4, size=(256, 256), mode='bilinear', align_corners=False)
#         fMSAB5 = F.interpolate(fMSAB5, size=(256, 256), mode='bilinear', align_corners=False)

#         # LARB: lesion-specific features
#         FR_MA, FR_HE, FR_EX, FR_SE = self.larb(fMSAB3, fMSAB4, fMSAB5)

#         # print("FR_ma:", FR_ma.min().item(), FR_ma.max().item())
#         # print("FR_he:", FR_he.min().item(), FR_he.max().item())
#         # print("FR_ex:", FR_ex.min().item(), FR_ex.max().item())
#         # print("FR_se:", FR_se.min().item(), FR_se.max().item())

#         # Stack lesion maps into single tensor [B, 4, H, W]
#         lesion_maps = torch.cat([FR_MA, FR_HE, FR_EX, FR_SE], dim=1)
#         lesion_maps = (lesion_maps - lesion_maps.mean()) / (lesion_maps.std() + 1e-6)
#         lesion_maps = torch.sigmoid(lesion_maps)

#         # SFB: spatial attention over RGB image
#         isfb = self.sfb(lesion_maps, x)

#         # ESAB: enhanced global-local fusion
#         esab_out = self.esab(isfb)

#         # Grading head
#         logits = self.grading_head(esab_out)

#         return lesion_maps, logits
class MSGDANet(nn.Module):
    def __init__(self, num_classes=5, img_size=(256, 256)):
        super(MSGDANet, self).__init__()

        # ─── Shared Encoder ─────────────────────────────────────────────────────
        self.encoder = VGGEncoder()

        # ─── Segmentation Branch ───────────────────────────────────────────────
        self.segmentation_branch = nn.ModuleDict({
            "msab3": MSAB(256, out_size=img_size),
            "msab4": MSAB(512, out_size=img_size),
            "msab5": MSAB(512, out_size=img_size),
            "larb": LARB(in_channels=256 + 512 + 512, mid_channels=64)
        })

        # ─── Grading Branch ────────────────────────────────────────────────────
        self.grading_branch = nn.ModuleDict({
            "sfb": SpatialFusionBlock(in_channels=4, inter_channels=16),
            "esab": ESAB(shared_encoder=self.encoder, embed_dim=512),
            "head": GradingHead(in_channels=512, num_classes=num_classes)
        })

    def forward(self, x):
        # ─── Encoder Features ───
        f1, f2, f3, f4, f5 = self.encoder(x)

        # ─── MSAB ───
        msab3 = self.segmentation_branch["msab3"](f3)
        msab4 = self.segmentation_branch["msab4"](f4)
        msab5 = self.segmentation_branch["msab5"](f5)

        # ─── Upsample to match output size ───
        msab3 = F.interpolate(msab3, size=(256, 256), mode='bilinear', align_corners=False)
        msab4 = F.interpolate(msab4, size=(256, 256), mode='bilinear', align_corners=False)
        msab5 = F.interpolate(msab5, size=(256, 256), mode='bilinear', align_corners=False)

        # ─── LARB Lesion Maps ───
        FR_MA, FR_HE, FR_EX, FR_SE = self.segmentation_branch["larb"](msab3, msab4, msab5)
        lesion_maps = torch.cat([FR_MA, FR_HE, FR_EX, FR_SE], dim=1)
        lesion_maps = (lesion_maps - lesion_maps.mean()) / (lesion_maps.std() + 1e-6)
        lesion_maps = torch.sigmoid(lesion_maps)

        # ─── SFB + ESAB + Grading ───
        isfb = self.grading_branch["sfb"](lesion_maps, x)
        esab_out = self.grading_branch["esab"](isfb,x)
        logits = self.grading_branch["head"](esab_out)

        return lesion_maps, logits

# ==========**LOSS**=========
class MSGDALoss(nn.Module):
    def __init__(self, alpha, beta, lam, task=None):
        super(MSGDALoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.lam = lam
        self.task = task

        self.ce_loss = nn.CrossEntropyLoss()

        # Lesion-specific weights (normalized)
        lesion_weights = {'MA': 10, 'HE': 1, 'EX': 5, 'SE': 5}
        total = sum(lesion_weights.values())
        self.weights = {k: v / total for k, v in lesion_weights.items()}

    def tversky_loss(self, pred, target):
        smooth = 1e-6
        pred = pred.reshape(-1)
        target = target.reshape(-1)

        tp = (pred * target).sum(dim=(0))
        fp = ((1 - target) * pred).sum(dim=(0))
        fn = (target * (1 - pred)).sum(dim=(0))

        tversky = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth)
        return 1 - tversky

    def forward(self, lesion_preds, lesion_targets, logits=None, labels=None):
        device = lesion_preds.device if lesion_preds is not None else (
             logits.device if logits is not None else 'cpu')
        total_seg_loss = torch.tensor(0.0, device=lesion_preds.device if lesion_preds is not None else 'cpu')
        total_cls_loss = torch.tensor(0.0, device=logits.device if logits is not None else 'cpu')

        # Segmentation loss
        if self.task in ['segmentation'] and lesion_preds is not None:
            used= 0
            skipped = 0
            for i, lesion_type in enumerate(['MA', 'HE', 'EX', 'SE']):
                pred = torch.clamp(lesion_preds[:, i, :, :], 1e-4, 1 - 1e-4)
                target = lesion_targets[:, i, :, :]
                assert pred.shape == target.shape, f"Mismatch: pred {pred.shape}, target {target.shape}"


                # Skip dummy masks
                if torch.all(target == 0):
                    skipped += 1
                    continue
                used += 1
                # print(f"Used lesions: {used}, Skipped: {skipped}")


                tversky = self.tversky_loss(pred, target)
                bce = F.binary_cross_entropy(pred, target)
                seg_loss = self.weights[lesion_type] * (tversky + self.lam * bce)
                total_seg_loss += seg_loss

        # Grading loss
        if self.task in ['grading'] and logits is not None and labels is not None:
            valid_mask = labels != -1
            if valid_mask.any():
                total_cls_loss = self.ce_loss(logits[valid_mask], labels[valid_mask])

        total_loss = total_seg_loss + total_cls_loss
        return total_loss, total_seg_loss, total_cls_loss


class MetricsTracker:
    def __init__(self):
        self.reset()

    def reset(self):
        self.grading_preds = []
        self.grading_labels = []

        self.segmentation_preds = []
        self.segmentation_targets = []

    def update_grading(self, preds, labels):
        self.grading_preds.append(preds.detach().cpu())
        self.grading_labels.append(labels.detach().cpu())

    def update_segmentation(self, preds, targets):
        self.segmentation_preds.append(preds.detach().cpu())
        self.segmentation_targets.append(targets.detach().cpu())

    def compute_grading_metrics(self):
        if not self.grading_preds:
            return 0.0, 0.0, 0.0, 0.0

        preds = torch.cat(self.grading_preds).numpy()
        labels = torch.cat(self.grading_labels).numpy()

        acc = accuracy_score(labels, preds)
        precision = precision_score(labels, preds, average='macro', zero_division=0)
        recall = recall_score(labels, preds, average='macro', zero_division=0)
        kappa = cohen_kappa_score(labels, preds)

        return acc, precision, recall, kappa

    def compute_segmentation_metrics(self):
        if not self.segmentation_preds:
            return 0.0, 0.0, 0.0

        dice_scores, iou_scores, aupr_scores = [], [], []

        for pred, target in zip(self.segmentation_preds, self.segmentation_targets):
            B, C, H, W = pred.shape  # C = number of lesions
            pred_flat = pred.view(B * C, -1)
            target_flat = target.view(B * C, -1)

            for i in range(B * C):
                p = pred_flat[i].float()
                t = target_flat[i].float()
                p = (p > 0.5).float()
                t = (t > 0.5).float()

                TP = (p * t).sum().item()
                FP = (p * (1 - t)).sum().item()
                FN = ((1 - p) * t).sum().item()

                dice = (2 * TP) / (2 * TP + FP + FN + 1e-6)
                iou = TP / (TP + FP + FN + 1e-6)

                # AUPR
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=UserWarning)
                    try:
                        precision, recall, _ = precision_recall_curve(t.numpy(), p.numpy())
                        aupr = auc(recall, precision)
                    except:
                        aupr = 0.0

                dice_scores.append(dice)
                iou_scores.append(iou)
                aupr_scores.append(aupr)

        mean_dice = sum(dice_scores) / len(dice_scores)
        mean_iou = sum(iou_scores) / len(iou_scores)
        mean_aupr = sum(aupr_scores) / len(aupr_scores)

        return mean_dice, mean_iou, mean_aupr

    def summarize(self, train_loss, val_loss, train_acc, task):
        val_acc, prec, rec, kappa = self.compute_grading_metrics()
        dice, iou, aupr = self.compute_segmentation_metrics()

        # Prepare base headers and values
        headers = ["Train Loss", "Val Loss"]
        values = [f"{train_loss:.2f}", f"{val_loss:.2f}"]

        # Show task-specific training metric
        if task == "grading":
            headers.append("Train Acc")
            values.append(f"{train_acc:.2f}")
        elif task == "segmentation":
            headers.append("Train Dice")
            values.append(f"{train_acc:.4f}")  # here train_acc is actually Dice for segmentation

        # Append validation metrics
        if task == "grading":
            headers += ["Val Acc", "Precision", "Recall", "Kappa"]
            values += [f"{val_acc:.2f}", f"{prec:.2f}", f"{rec:.2f}", f"{kappa:.2f}"]
        elif task == "segmentation":
            headers += ["Dice", "IoU", "AUPR"]
            values += [f"{dice:.4f}", f"{iou:.4f}", f"{aupr:.4f}"]

        print("\n" + tabulate([values], headers=headers, tablefmt="pretty"))

        # Return the metric to monitor
        if task == "grading":
            return val_acc
        elif task == "segmentation":
            return dice
        else:
            return 0.0




class MultiTaskDRDataset(Dataset):
    def __init__(self, root_dir, dataset_name, split='train', transform=None, image_size=(256, 256), task=None):
        self.root_dir = root_dir
        self.dataset_name = dataset_name
        self.split = split
        self.transform = transform
        self.image_size = image_size
        self.task = task
        self.lesion_names = ['MA', 'HE', 'EX', 'SE']
        self.samples = self._load_metadata()

    def _load_metadata(self):
        samples = []

        if self.dataset_name == 'DDR':
            if self.task == 'grading':
                grading_csv = os.path.join(self.root_dir, 'DDR/grading', f'{self.split}.csv')
                if os.path.exists(grading_csv):
                    df = pd.read_csv(grading_csv)
                    for _, row in df.iterrows():
                        if row['Retinopathy grade'] == 5:
                            continue  # skip optic disc
                        img_name = row['Image Name']
                        img_path = os.path.join(self.root_dir, f'DDR/grading/{self.split}', img_name)
                        samples.append({
                            'img_name': img_name,
                            'img_path': img_path,
                            'label': int(row['Retinopathy grade']),
                            'masks': None,
                        })

            elif self.task == 'segmentation':
                seg_img_dir = os.path.join(self.root_dir, f'DDR/segmentation/{self.split}/image')
                for img_name in os.listdir(seg_img_dir):
                    img_path = os.path.join(seg_img_dir, img_name)
                    masks = {
                        lesion: os.path.join(
                            self.root_dir,
                            f'DDR/segmentation/{self.split}/groundtruth/{lesion}',
                            img_name.replace('.jpg', '.tif')
                        )
                        for lesion in self.lesion_names
                    }
                    samples.append({
                        'img_name': img_name,
                        'img_path': img_path,
                        'label': -1,
                        'masks': masks
                    })

        elif self.dataset_name == 'IDRID':
            if self.task == 'grading':
                if self.split == 'train':
                    grading_csv = os.path.join(self.root_dir, 'IDRID/Grading/Groundtruths/IDRiD_Disease_Grading_Training_Labels.csv')
                    grading_dir = os.path.join(self.root_dir, 'IDRID/Grading/Fundus Images/Training Set')
                else:
                    grading_csv = os.path.join(self.root_dir, 'IDRID/Grading/Groundtruths/IDRiD_Disease_Grading_Testing_Labels.csv')
                    grading_dir = os.path.join(self.root_dir, 'IDRID/Grading/Fundus Images/Testing Set')

                df = pd.read_csv(grading_csv)
                for _, row in df.iterrows():
                    if row['Retinopathy grade'] == 5:
                            continue  # skip optic disc
                    img_name = row['Image name']
                    img_path = os.path.join(grading_dir, f"{img_name}.jpg")
                    samples.append({
                        'img_name': img_name,
                        'img_path': img_path,
                        'label': int(row['Retinopathy grade']),
                        'masks': None,
                    })

            elif self.task == 'segmentation':
                seg_dir = os.path.join(self.root_dir, f'IDRID/Segmentation/Fundus Images/{self.split.capitalize()} Set')
                mask_dir = os.path.join(self.root_dir, f'IDRID/Segmentation/Masks/{self.split.capitalize()} Set')
                for img_name in os.listdir(seg_dir):
                    name = img_name.replace('.jpg', '')
                    img_path = os.path.join(seg_dir, img_name)
                    masks = {
                        lesion: os.path.join(mask_dir, lesion, f"{name}_{lesion}.tif") for lesion in self.lesion_names
                    }
                    samples.append({
                        'img_name': name,
                        'img_path': img_path,
                        'label': -1,
                        'masks': masks,
                    })

        elif self.dataset_name == 'APTOS':
            csv_path = os.path.join(self.root_dir, 'APTOS', f'{self.split}.csv')
            img_dir = os.path.join(self.root_dir, 'APTOS', f'{self.split}_images')
            df = pd.read_csv(csv_path)
            for _, row in df.iterrows():
                img_id = row['id_code']
                img_path = os.path.join(img_dir, f"{img_id}.png")
                samples.append({
                    'img_name': img_id,
                    'img_path': img_path,
                    'label': int(row['diagnosis']),
                    'masks': None,
                })


        # Filter samples depending on task
        if self.task == 'grading':
            samples = [s for s in samples if s['label'] != -1]
        elif self.task == 'segmentation':
            samples = [s for s in samples if s['masks'] is not None]

        # Print debug info
        num_grading = sum(1 for s in samples if s['label'] != -1)
        num_segmentation = sum(1 for s in samples if s['masks'] is not None)
        print(f"Loaded {len(samples)} samples from {self.dataset_name} {self.split} for {self.task} task.")
        print(f" - Grading samples: {num_grading}")
        print(f" - Segmentation samples: {num_segmentation}")
        return samples

    def _load_mask(self, path):
        if path and os.path.exists(path):
            return np.array(Image.open(path).convert('L').resize(self.image_size))
        else:
            return np.zeros(self.image_size, dtype=np.uint8)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        try:
            image = Image.open(sample['img_path']).convert('RGB')
            image = np.array(image)
            image = cv2.resize(image, self.image_size)
        except Exception as e:
            print(f"Error loading image {sample['img_path']}: {e}")
            return self.__getitem__((idx + 1) % len(self.samples))

        if self.task == 'segmentation':
            try:
                mask_channels = [self._load_mask(sample['masks'].get(lesion)) for lesion in self.lesion_names]
                multi_mask = np.stack(mask_channels, axis=-1).astype(np.float32) / 255.0
            except Exception as e:
                print(f"Error loading masks for {sample['img_name']}: {e}")
                multi_mask = np.zeros((*self.image_size, len(self.lesion_names)), dtype=np.float32)
        else:
            multi_mask = np.zeros((*self.image_size, len(self.lesion_names)), dtype=np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=multi_mask)
            image = augmented['image']
            multi_mask = augmented['mask']
        else:
            image = transforms.ToTensor()(image)
            multi_mask = torch.from_numpy(multi_mask).permute(2, 0, 1).float()

        label = sample['label'] if sample['label'] is not None else -1

        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long),
            'masks': (multi_mask > 0.5).float(),
        }


# ===========**TRANSFORMS**==========
train_transforms = A.Compose([
  A.Resize(256, 256),
  A.HorizontalFlip(p=0.5),
  A.VerticalFlip(p=0.5),
  A.RandomRotate90(p=1.0),
  A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=30, p=0.5),
  A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),              # Enhances local contrast using Adaptive Histogram Equalization
  A.Normalize(mean=(0.5), std=(0.5),max_pixel_value=255.0),
  ToTensorV2()])

val_transforms = A.Compose([
  A.Resize(256, 256),
  A.Normalize(mean=(0.5), std=(0.5),max_pixel_value=255.0),
  ToTensorV2()])



from collections import defaultdict

class RunningAverage:
    def __init__(self):
        self.reset()
    def reset(self):
        self.sum = defaultdict(float)
        self.count = 0
    def update(self, **kwargs):
        self.count += 1
        for k, v in kwargs.items():
            self.sum[k] += v
    def avg(self, k):
        return self.sum[k] / self.count if self.count > 0 else 0.0



# **TRAINING** **SCRIPT**

# ===Config ===
#Setting working directory
drive.mount('/content/drive')
project_root = '/content/drive/MyDrive/vgg'
#project_root = os.path.expanduser("~/MSGDANet")
sys.path.append(project_root)
ROOT_DIR = os.path.join(project_root, "data")

#Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_NAME = 'DDR' #'IDRID' or 'APTOS'
BATCH_SIZE = 2
NUM_CLASSES = 5
EARLY_STOP_PATIENCE = 20
STAGE = 'stage1'  # or 'stage2'
if STAGE == 'stage1':
    TASK = 'segmentation'
elif STAGE == 'stage2':
    TASK = 'grading'


# Dataset-specific hyperparameters
CONFIG = {
    'DDR':   {'epochs': 380, 'lr': 1e-4, 'weight_decay': 2e-4},
    'IDRID': {'epochs': 500, 'lr': 2e-3, 'weight_decay': 5e-4},
    'APTOS': {'epochs': 330, 'lr': 1e-3, 'weight_decay': 1e-4},
}
params = CONFIG[DATASET_NAME]
NUM_EPOCHS = params['epochs']
LR = params['lr']
WEIGHT_DECAY = params['weight_decay']

# ============DATA  LOADERS===========
train_dataset = MultiTaskDRDataset(dataset_name=DATASET_NAME, root_dir=ROOT_DIR, split='train',task=TASK)
val_dataset = MultiTaskDRDataset(dataset_name=DATASET_NAME, root_dir=ROOT_DIR, split='val',task= TASK)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

#Checkpoint path setup
checkpoint_dir = os.path.join(project_root, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

#Visualizations path setup
visualization_dir = os.path.join(project_root, "visualizations")
os.makedirs(visualization_dir, exist_ok=True)


import matplotlib.pyplot as plt
def visualize_predictions(images, masks, preds, num_samples=3, save_path=None, show_inline=True):
    import matplotlib.pyplot as plt
    import os

    images = images.cpu()
    masks = masks.cpu()
    preds = preds.cpu()

    for i in range(min(num_samples, images.size(0))):
        fig, axs = plt.subplots(3, masks.shape[1], figsize=(4 * masks.shape[1], 10))

        for c in range(masks.shape[1]):
            axs[0, c].imshow(images[i].permute(1, 2, 0))
            axs[0, c].set_title("Image")
            axs[0, c].axis("off")

            axs[1, c].imshow(masks[i, c], cmap='gray')
            axs[1, c].set_title(f"GT Mask (Lesion {c})")
            axs[1, c].axis("off")

            axs[2, c].imshow(preds[i, c], cmap='gray')
            axs[2, c].set_title(f"Pred Mask (Lesion {c})")
            axs[2, c].axis("off")

        plt.tight_layout()

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            out_path = f"{save_path}_sample_{i}.png"
            plt.savefig(out_path)
            print(f"Saved visualization: {out_path}")

        # if show_inline:
        #     plt.show()
        # else:
        #     plt.close()

# === Unified Training Loop for Both Stages ===
for STAGE in ['segmentation', 'grading']:
    print(f"\n============================")
    print(f"Starting Stage: {STAGE.upper()}")
    print(f"============================\n")

    TASK = STAGE
    model = MSGDANet(num_classes=NUM_CLASSES).to(DEVICE)
    loss_fn = MSGDALoss(alpha=0.0, beta=0.0, lam=1.0, task=TASK).to(DEVICE)

    # Load best segmentation checkpoint for grading stage
    if TASK == 'grading':
        seg_ckpt_path = os.path.join(checkpoint_dir, f'msgdanet_best_{DATASET_NAME.lower()}_segmentation.pth')
        if os.path.exists(seg_ckpt_path):
            model.load_state_dict(torch.load(seg_ckpt_path, map_location=DEVICE))
            print(f"Loaded segmentation checkpoint for grading stage: {seg_ckpt_path}")
        else:
            print(f"Segmentation checkpoint not found at {seg_ckpt_path}. Training grading stage from scratch.")


        # Freeze encoder, MSABs, LARB
        for param in model.encoder.parameters(): param.requires_grad = False
        for msab in [model.segmentation_branch["msab3"], model.segmentation_branch["msab4"], model.segmentation_branch["msab5"]]:
            for param in msab.parameters(): param.requires_grad = False
        for param in model.segmentation_branch["larb"].parameters(): param.requires_grad = False

    optimizer = Adan(filter(lambda p: p.requires_grad, model.parameters()),
                     lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999, 0.999))
    scheduler = PolynomialLR(optimizer, total_iters=NUM_EPOCHS, power=0.9)


    # Checkpoint paths
    ckpt_file = os.path.join(checkpoint_dir, f'checkpoint_{DATASET_NAME.lower()}_{TASK}.pth')
    best_model_path = os.path.join(checkpoint_dir, f'msgdanet_best_{DATASET_NAME.lower()}_{TASK}.pth')

    start_epoch = 0
    best_score = -float("inf")
    early_stop_counter = 0

    if os.path.exists(ckpt_file):
        checkpoint = torch.load(ckpt_file, map_location=DEVICE)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        scheduler.load_state_dict(checkpoint['scheduler_state'])
        best_score = checkpoint['best_score']
        early_stop_counter = checkpoint['early_stop_counter']
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from checkpoint: {ckpt_file}")

    # ========== Train Loop ==========
    for epoch in range(start_epoch, NUM_EPOCHS):
        model.train()
        running = RunningAverage()


        pbar = tqdm(train_loader, desc=f"[{TASK.upper()}] Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)
        for batch in pbar:
            images = batch['image'].to(DEVICE)
            masks = batch['masks'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            if epoch == 0 and pbar.n < 3:
               lesion_mask_stats = (masks.sum(dim=(2, 3)) > 0).float().mean(dim=0)
               print("Lesion presence ratio:", lesion_mask_stats.tolist())

            if epoch == 0 and pbar.n < 4 and TASK == 'segmentation': # Only print for segmentation stage
                with torch.no_grad(): # avoid calculating gradients for this print statement
                    seg_preds, _ = model(images) # get predictions without affecting gradient calculation
                    # print("Seg preds stats:", seg_preds.min().item(), seg_preds.max().item())
                    # print("Masks stats:", masks.min().item(), masks.max().item())
                    # print("Seg preds shape:", seg_preds.shape)
                    # print("Masks shape:", masks.shape)
                    # print("Pred mean:", seg_preds.mean().item())




            optimizer.zero_grad()
            seg_preds, grade_logits = model(images)

            if TASK == 'grading':
                seg_preds = seg_preds.detach()
            if TASK == 'segmentation':
                grade_logits = grade_logits.detach()

            total_loss, seg_loss, cls_loss = loss_fn(seg_preds, masks, grade_logits, labels)
            #print(f"Epoch {epoch} | Seg Loss: {seg_loss.item():.4f} | Cls Loss: {cls_loss.item():.4f}")
            total_loss.backward()
            optimizer.step()

            # Metrics
            with torch.no_grad():
                batch_acc = 0.0
                batch_dice = 0.0
                valid_channels =0.0
                if TASK == 'grading':
                    preds = torch.argmax(grade_logits, dim=1)
                    valid = (labels != -1)
                    if valid.any():
                        batch_acc = (preds[valid] == labels[valid]).float().mean().item()
                if TASK == 'segmentation':
                    probs = torch.sigmoid(seg_preds)
                    pred_bin = (probs > 0.5).float()

                    for i in range(4):  # MA, SE, HE, EX
                        p = pred_bin[:, i].float()
                        t = masks[:, i].float()

                        if t.sum() > 0:
                            inter = (p * t).sum((1, 2))
                            union = p.sum((1, 2)) + t.sum((1, 2))
                            dice = (2 * inter / (union + 1e-6)).mean().item()
                            batch_dice += dice
                            valid_channels += 1

                    if valid_channels > 0:
                        batch_dice /= valid_channels
            # Track
            update_kwargs = {'loss': total_loss.item()}
            postfix = {'L': f"{running.avg('loss'):.2f}"}
            if TASK == 'grading':
                update_kwargs['acc'] = batch_acc
                postfix['A'] = f"{running.avg('acc'):.2f}"
            if TASK == 'segmentation':
                update_kwargs['dice'] = batch_dice
                postfix['D'] = f"{running.avg('dice'):.4f}" # Changed to 4 decimal places for dice
            running.update(**update_kwargs)
            pbar.set_postfix(postfix)

        avg_train_loss = running.avg('loss')
        avg_train_acc = running.avg('acc') if TASK == 'grading' else 0.0 # Only get acc for grading
        avg_train_dice = running.avg('dice') if TASK == 'segmentation' else 0.0 # Only get dice for segmentation
        scheduler.step()

        # Validation
        # ===== Validation =====
        model.eval()
        val_loss = 0.0
        val_tracker = MetricsTracker()

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(val_loader, desc="Validation")):
                images = batch['image'].to(DEVICE)
                masks = batch['masks'].to(DEVICE)
                labels = batch['label'].to(DEVICE)

                seg_preds, logits = model(images)

                if TASK == 'grading':
                    seg_preds = seg_preds.detach()
                if TASK == 'segmentation':
                    logits = logits.detach()

                loss, _, _ = loss_fn(seg_preds, masks, logits, labels)
                val_loss += loss.item()

                if TASK == 'grading':
                    preds = torch.argmax(logits, dim=1)
                    val_tracker.update_grading(preds, labels)

                if TASK == 'segmentation':
                    bin_preds = (torch.sigmoid(seg_preds) > 0.5).float()
                    val_tracker.update_segmentation(bin_preds, masks)

                    if batch_idx == 0:
                      vis_base = os.path.join(visualization_dir, f"{TASK}_epoch_{epoch+1:03d}")
                      visualize_predictions(images, masks, bin_preds, num_samples=1, save_path=vis_base, show_inline=True)


        avg_val_loss = val_loss / len(val_loader)

        score = val_tracker.summarize(
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            train_acc=avg_train_acc if TASK == 'grading' else avg_train_dice,
            task=TASK
        )

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'best_score': best_score,
            'early_stop_counter': early_stop_counter
        }, ckpt_file)

        if score > best_score:
            best_score = score
            early_stop_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print("✅ Best model saved.")
        else:
            early_stop_counter += 1
            if early_stop_counter >= EARLY_STOP_PATIENCE:
                print("⏹️ Early stopping triggered.")
                break

print("\n🏁 Both stages complete! Best models saved.\n")