<a href="https://colab.research.google.com/github/zi-bou/zi-bou/blob/main/Al4l_MONAI%2BPytorch_(V6).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Cardiac MRI Model Evolution Summary

| Version | Model                      | Augmentation                         | Loss        | Notes                                                    | Train Acc | Val Acc |
|---------|----------------------------|--------------------------------------|-------------|----------------------------------------------------------|-----------|---------|
| V1      | ResNet18                   | None                                 | CrossEntropy| Baseline model                                            | 38%       | 30%     |
| V2      | ResNet18                   | Basic 3D transforms                  | CrossEntropy| Added transforms to improve robustness                   | 43%       | 20%     |
| V3      | ResNet18 + Dropout         | Basic transforms                     | CrossEntropy| Dropout and oversampling added                           | 37%       | 30%     |
| V4      | ResNet18 + Dropout         | Class-specific (same for all)        | Focal Loss  | Introduced Focal Loss to handle class imbalance          | 43%       | 30%     |
| V5      | DenseNet121 + Dropout      | Class-specific (same for all)        | Focal Loss  | Switched to DenseNet121, achieved better generalization  | 37.5%     | 35%     |


# 🧠 Cardiac MRI Multi-Task Deep Learning – V6 : Multi-Task Deep Learning Pipeline – (Classification + Segmentation)
# BASIC ALGORITHM
──────────────────────────────────────────────────────────────────────────────

                    📁 Dataset Folder (on Google Drive)
                    └── [pXXXX]/
                         ├── frame01.nii          🫀 Diastolic image
                         ├── frame11.nii          🫀 Systolic image
                         ├── frame01_gt.nii       🎯 Segmentation mask(diastole)
                         ├── frame11_gt.nii       🎯 Segmentation mask (systole)
                         └── gt.txt               🏷️ Class label (0–4)
                         


**🔗 Load data from Google Drive**
──────────────────────────────────────────────────────────────────────────────

    │ Use `nibabel` to load .nii.gz images
    │ Extract 3D volumes: shape = [H, W, D]
    │  Read `gt.txt` as classification label
    │
    │  ➕ Stack frame01 & frame11 → [2, H, W, D]
    │  🔄 Permute to PyTorch format → [2, D, H, W]
    │
    │  Normalize intensities to [0, 1]
    │
    │  Optional: Apply transforms (flip, zoom, rotate, noise)
    │
    │  Prepare segmentation masks (diastole & systole):
    │    ➕ Stack → [2, D, H, W]
    │
    └──  - 🎯 Final input:    X = [2, D, H, W]   (2 channels)
         -  🎯 Final masks:   Y_seg = [2, D, H, W]
         -  🎯 Class label:   Y_cls ∈ [0, 4]


**📦 Model Architecture: Multi-Task DenseNet-121**
──────────────────────────────────────────────────────────

                           Input
                      [B, 2, D, H, W]
                             │
                             ▼
                  +----------------------+
                  |  Shared Encoder:     |
                  |  3D DenseNet-121     |
                  +----------------------+
                             │
             ┌───────────────┴───────────────┐
             ▼                               ▼
       +--------------------+         +------------------------+
       |  Segmentation Head |         |  Classification Head   |
       |  Decoder → [B, 2, D, H, W]   |  Linear → [B, 5]       |
       +--------------------+         +------------------------+


**⚙️ Training Loop**

──────────────────────────────────────────────────────────────────────────────

    for epoch in range(N):
  
      for batch in train_loader:

        - Forward pass:
            ▸ X → encoder
            ▸ Shared features → seg_output & class_output

        - Resize ground-truth segmentation masks
        - Compute losses:
            ▸ Classification loss: Focal Loss
            ▸ Segmentation loss: Dice Loss

        - Combine losses:
            total_loss = λ * Dice + (1 - λ) * Focal

        - Backpropagation
        - Optimizer step
        - Validation + Early Stopping

        🔄 Log accuracy & loss per epoch


**📊 Evaluation**

────────────────────────────────────────────
   
    - Plot Accuracy / Loss curves

    - Show Confusion Matrix

    - Display Class Distribution Histogram

──────────────────────────────────────────────────────────────

💡 **Outcome:**

    🔎 Classification: Predicts disease class [0–4]

    🎯 Segmentation: Localizes anatomical heart structures

──────────────────────────────────────────────────────────────


# STEP BY STEP DETAILED

In this project, we build a multi-task learning model that performs two tasks at once:

- 🎯 **Classification**: Predict the patient’s cardiac condition  
  Output shape: `[B, 5]`  
- 🧠 **Segmentation**: Segment anatomical structures (systole + diastole)  
  Output shape: `[B, 2, D, H, W]`

We use a **3D DenseNet121** as a shared encoder and branch into:

- A **classification head**
- A **segmentation decoder head**

---


#✅ Step 0: Fixing Binary Compatibility

On Colab, version mismatches (e.g., numpy, tensorflow, numba) can cause errors.

To prevent this:
- Uninstall conflicting packages
- Install specific versions:  
  - `numpy==1.23.5`  
  - `monai`, `nibabel`, `matplotlib`, `scikit-learn`, `torch`, etc.

---

In [2]:
# ✅ Step 0: Safe reinstallation to avoid version conflicts
# ------------------------------------------
# 🧹 Uninstall problematic pre-installed packages in Colab.
# These are known to cause binary conflicts when working with MONAI, PyTorch, etc.
!pip uninstall -y numpy numba tensorflow thinc

# 🧪 Reinstall a compatible version of NumPy for MONAI (1.23.5 works well with PyTorch and MONAI).
!pip install numpy==1.23.5

Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Successfully uninstalled numpy-2.0.2
Found existing installation: numba 0.60.0
Uninstalling numba-0.60.0:
  Successfully uninstalled numba-0.60.0
Found existing installation: tensorflow 2.18.0
Uninstalling tensorflow-2.18.0:
  Successfully uninstalled tensorflow-2.18.0
Found existing installation: thinc 8.3.6
Uninstalling thinc-8.3.6:
  Successfully uninstalled thinc-8.3.6
Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m85.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency c

In [1]:
# Install MONAI (Medical Open Network for AI) and key dependencies for medical imaging and training
!pip install monai nibabel matplotlib scikit-learn torch torchvision torchaudio -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m76.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m48.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

#✅ Step 1: Mount Google Drive

Mount your Google Drive so we can access the dataset.

Each patient folder contains:
- `frame01.nii`: Systolic volume
- `frame11.nii`: Diastolic volume
- `frame01_gt.nii`, `frame11_gt.nii`: Corresponding segmentation masks
- `gt.txt`: Text file with the patient’s classification label (e.g., “class3”)

---

In [2]:
# ==========================================
# ✅ Step 1: Mount Google Drive
# ==========================================
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# ✅ Step 2: Imports

We import:
- `nibabel` to read .nii medical images
- `torch` for modeling, training, loss
- `monai` for medical models and transforms
- `sklearn` for label encoding and class balancing
- `matplotlib` for visualizations

---

In [3]:
# ==========================================
# ✅ Step 2: Import required libraries
# ==========================================
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# We'll use MONAI's DenseNet for the 3D encoder
from monai.networks.nets import densenet121
# Some MONAI transforms and dice loss for segmentation
from monai.transforms import Compose, RandFlip, RandRotate, RandGaussianNoise, RandZoom
from monai.losses import DiceLoss

# ✅ Step 3: Estimate Volume Shape

We scan one sample image per patient to estimate typical dimensions.

- Set target width/height to 224 (for DenseNet compatibility)
- Set depth to the 90th percentile of observed depths (or at least 32)

Final shape: `[2, 32, 224, 224]`

Padding/cropping is applied to match this shape for all volumes.

---

In [4]:
# ==========================================
# ✅ Step 3: Estimate volume shape (90th percentile)
# ==========================================
base_dir = '/content/drive/MyDrive/KAGGLE/Al4l/data/drive-download-20250107T191042Z-001/train'
depths, heights, widths = [], [], []

# Loop over each patient folder and extract shape from frame01.nii
for pid in os.listdir(base_dir):
    pdir = os.path.join(base_dir, pid)
    if not os.path.isdir(pdir):
        continue
    files = os.listdir(pdir)
    for f in files:
        if 'frame01' in f and 'gt' not in f and f.endswith('.nii'):
            # Load volume, extract shape
            vol = nib.load(os.path.join(pdir, f)).get_fdata()
            h, w, d = vol.shape
            heights.append(h)
            widths.append(w)
            depths.append(d)

# We'll fix height & width at 224, depth at 90th percentile or at least 32
target_height = 224
target_width = 224
target_depth = max(int(np.percentile(depths, 90)), 32)
print(f"📏 Target shape (H, W, D): {target_height}, {target_width}, {target_depth}")

📏 Target shape (H, W, D): 224, 224, 32


# ✅ Step 4: Define Custom PyTorch Dataset

We define a `CardiacDataset` class that:
- Loads systolic + diastolic volumes and stacks them as channels
- Loads both segmentation masks and stacks them too
- Pads/crops all data to the fixed shape
- Normalizes volume intensities to [0, 1]
- Applies optional transforms
- Returns:  
  - `x`: input volume [2, D, H, W]  
  - `y`: class label  
  - `seg`: segmentation masks [2, D, H, W]

---

In [5]:
# ==========================================
# ✅ Step 4: Multi-Task Dataset
# ==========================================
# We'll load:
# - Systolic volume (frame01.nii)
# - Diastolic volume (another frame e.g. frame11.nii)
# - Segmentation mask (both frame01_gt.nii and frame11_gt.nii)
# - Class label (from gt.txt)

class CardiacDatasetMT(Dataset):
    def __init__(self, data_dir, target_shape, transform=None):
        self.data_dir = data_dir
        self.target_height, self.target_width, self.target_depth = target_shape
        # List of patient folders
        self.patient_ids = [pid for pid in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, pid))]
        self.labels = []  # textual labels (string)
        self.transform = transform

        # For each patient, read the classification label from gt.txt
        for pid in self.patient_ids:
            label_path = os.path.join(data_dir, pid, 'gt.txt')
            with open(label_path, 'r') as f:
                self.labels.append(f.read().strip())

        # Convert textual labels to integer classes
        self.encoder = LabelEncoder()
        self.encoded_labels = self.encoder.fit_transform(self.labels)

    def __len__(self):
        # Return total number of patients
        return len(self.patient_ids)

    def pad_or_crop(self, volume):
        """
        Pad or crop each dimension (H, W, D) to match the target shape
        so final shape is [target_height, target_width, target_depth].
        We'll do zero-padding if it's smaller.
        """
        h, w, d = volume.shape
        pad_h = max(self.target_height - h, 0)
        pad_w = max(self.target_width - w, 0)
        pad_d = max(self.target_depth - d, 0)

        pad = (
            (pad_h//2, pad_h - pad_h//2),
            (pad_w//2, pad_w - pad_w//2),
            (pad_d//2, pad_d - pad_d//2)
        )
        volume = np.pad(volume, pad, mode='constant')
        volume = volume[:self.target_height, :self.target_width, :self.target_depth]
        return volume

    def __getitem__(self, idx):
        """
        Returns a dict:
        {
          'image': 2-channel volume [2, D, H, W],
          'mask': segmentation mask [1, D, H, W],
          'class': integer classification label
        }
        """
        pid = self.patient_ids[idx]
        pdir = os.path.join(self.data_dir, pid)
        files = os.listdir(pdir)

        # Identify files for systolic, diastolic, and segmentation
        systolic, diastolic = None, None
        seg1_file, seg2_file = None, None

        for f in files:
            if 'frame01' in f and 'gt' not in f and f.endswith('.nii'):
                systolic = f
            elif 'frame' in f and 'gt' not in f and 'frame01' not in f and f.endswith('.nii'):
                diastolic = f
            elif 'frame01_gt' in f and f.endswith('.nii'):
                seg1_file = f
            elif 'frame11_gt' in f and f.endswith('.nii'):
                seg2_file = f

        # Load volumes from .nii
        syst_data = nib.load(os.path.join(pdir, systolic)).get_fdata() if systolic else None
        diast_data = nib.load(os.path.join(pdir, diastolic)).get_fdata() if diastolic else None

        # Pad/crop both volumes
        syst_data = self.pad_or_crop(syst_data)
        diast_data = self.pad_or_crop(diast_data)

        # Stack as channels => shape [2, H, W, D]
        combined = np.stack((syst_data, diast_data), axis=0)
        # Normalize intensities to [0,1]
        combined = combined / np.max(combined)

        # Convert to torch tensor => [2, H, W, D] => permute to [2, D, H, W]
        combined = torch.tensor(combined, dtype=torch.float32)
        combined = combined.permute(0, 3, 1, 2)

        # Load segmentation masks and stack them
        if seg1_file and seg2_file:
            seg1 = nib.load(os.path.join(pdir, seg1_file)).get_fdata()
            seg2 = nib.load(os.path.join(pdir, seg2_file)).get_fdata()

            seg1 = self.pad_or_crop(seg1)
            seg2 = self.pad_or_crop(seg2)

            # Convert any label > 0 => 1 (binary mask)
            seg1 = (seg1 > 0).astype(np.float32)
            seg2 = (seg2 > 0).astype(np.float32)

            # Combine the two binary masks by max (union)
            combined_seg = np.maximum(seg1, seg2)
            seg_data = torch.tensor(combined_seg, dtype=torch.float32)
            seg_data = seg_data.permute(2, 0, 1).unsqueeze(0)  # [1, D, H, W]

        else:
            # Fill with zeros if missing masks
            seg_data = torch.zeros((1, self.target_depth, self.target_height, self.target_width), dtype=torch.float32)

        # Apply augmentation only to the image (not mask)
        if self.transform:
            combined = self.transform(combined)

        class_label = self.encoded_labels[idx]

        return {
            "image": combined,     # [2, D, H, W]
            "mask": seg_data,      # [1, D, H, W]
            "class": class_label   # int
        }


# ✅ Step 5: Data Augmentation (Optional)

We define random transformations:
- Flips
- Rotations
- Zoom
- Gaussian noise

These help improve generalization.

---

In [6]:
# ============================
# ✅ Step 5: Class-specific data augmentations
# ============================
base_transform = Compose([
    RandFlip(spatial_axis=0, prob=0.5),
    RandRotate(range_x=0.1, prob=0.5),
    RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    RandGaussianNoise(prob=0.3, mean=0.0, std=0.05)
])

class CardiacDatasetWithClassAugmentMT(CardiacDatasetMT):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # We'll map class -> transform so each class can get specialized transforms
        # For now, all are the same
        self.class_transforms = {i: base_transform for i in range(5)}

    def __getitem__(self, idx):
        data_dict = super().__getitem__(idx)
        label = data_dict["class"]
        transform = self.class_transforms.get(label, None)
        if transform:
            # Apply transform only on "image"
            data_dict["image"] = transform(data_dict["image"])
        return data_dict

# ✅ Step 6: Dataset Preparation and Class Balancing

- Apply weighted sampling to address class imbalance
- Create training and validation loaders
- Use `WeightedRandomSampler` for training

---

In [7]:
# ============================
# ✅ Step 6: Dataset prep, Weighted Sampling
# ============================
target_shape = (target_height, target_width, target_depth)
dataset = CardiacDatasetWithClassAugmentMT(base_dir, target_shape)

# Compute class weights for classification
y_all = [dataset.encoded_labels[i] for i in range(len(dataset))]
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_all), y=y_all)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class_weights_tensor = class_weights_tensor.to(device)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

# Weighted sampler for training
train_indices = train_ds.indices if hasattr(train_ds, 'indices') else list(range(len(train_ds)))
train_labels = [dataset.encoded_labels[i] for i in train_indices]
train_weights = [class_weights[label] for label in train_labels]
sampler = WeightedRandomSampler(train_weights, num_samples=len(train_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=2, sampler=sampler)
val_loader = DataLoader(val_ds, batch_size=2)

# Free up any leftover GPU memory
torch.cuda.empty_cache()


# ✅ Step 7: Define Model

We use a modified `DenseNet121` from MONAI:

- Shared 3D encoder
- Two heads:
  - One for classification (`Linear(1024, 5)`)
  - One for segmentation (`ConvTranspose3d` decoder)

---

In [8]:
# ============================
# ✅ Step 7: Multi-task DenseNet
# ============================
# We'll define a single 3D DenseNet as encoder,
# then produce 2 outputs: class_out (5 classes), seg_out (1 channel).


# ✅ Multi-task model: classification + segmentation
class MultiTaskDenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Shared encoder
        base_model = densenet121(spatial_dims=3, in_channels=2, out_channels=5)
        self.backbone = base_model.features  # Just the feature extractor

        # Classifier head (after global pooling)
        self.class_dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(1024, 5)

        # Segmentation head (upconvolution to full resolution)
        self.seg_head = nn.Sequential(
            nn.ConvTranspose3d(1024, 512, kernel_size=2, stride=2),  # Upsample
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(256, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(64, 1, kernel_size=2, stride=2),      # Output 1 channel mask
            nn.Sigmoid()  # Since we use DiceLoss
        )

    def forward(self, x):
        print("[DEBUG] Input:", x.shape)  # [B, 2, 32, 224, 224]
        x = self.backbone(x)              # [B, 1024, D', H', W'] → e.g., [B, 1024, 2, 14, 14]
        print("[DEBUG] Features:", x.shape)

        # Classification head
        x_class = F.adaptive_avg_pool3d(x, (1, 1, 1))  # [B, 1024, 1, 1, 1]
        x_class = torch.flatten(x_class, 1)            # [B, 1024]
        x_class = self.class_dropout(x_class)
        x_class = self.classifier(x_class)             # [B, 5]

        # Segmentation head
        x_seg = self.seg_head(x)                       # [B, 1, 32, 224, 224] ideally
        print("[DEBUG] Segmentation output:", x_seg.shape)

        return x_class, x_seg

model = MultiTaskDenseNet().to('cuda' if torch.cuda.is_available() else 'cpu')

# ✅ Step 8: Define Loss Function

We combine:
- `FocalLoss` for classification
- `DiceLoss` for segmentation

The total loss = weighted sum of both.

---

In [9]:
# ============================
# ✅ Step 8: Define losses & optimizer
# ============================
# We'll combine:
# 1) Focal Loss for classification
# 2) Dice Loss (sigmoid) for segmentation (binary)

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

focal_loss = FocalLoss(alpha=class_weights_tensor, gamma=2.0)
dice_loss = DiceLoss(sigmoid=True)  # binary segmentation => use sigmoid

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

def multi_task_loss_fn(class_pred, class_label, seg_pred, seg_mask):
    """
    Combine classification + segmentation losses into a single total_loss.
    - class_pred: [B, 5]  => for classification
    - class_label: [B]    => integer class
    - seg_pred: [B, 1, D', H', W'] => raw logits for seg
    - seg_mask: [B, D', H', W']    => binary target
    """
    # classification loss
    c_loss = focal_loss(class_pred, class_label)

    # segmentation: need channels => seg_mask => [B, 1, D', H', W']
    seg_mask = seg_mask.unsqueeze(1)
    s_loss = dice_loss(seg_pred, seg_mask)

    # Weighted sum of both
    total = c_loss + 0.5 * s_loss
    return total, c_loss.item(), s_loss.item()

# ✅ Step 9: Training Loop

- Train for up to 30 epochs
- Track training and validation accuracy/loss
- Use `ReduceLROnPlateau` to reduce LR on stagnating validation loss
- Use early stopping to save the best model

---

In [10]:
# ============================
# ✅ Step 9: Training Loop for Multi-Task Learning
# ============================
train_acc_list, val_acc_list = [], []
train_loss_list, val_loss_list = [], []
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(1, 31):
    model.train()
    running_loss, correct = 0.0, 0
    total_samples = 0

    for batch in train_loader:
        inputs = batch["image"].to(device)       # [B, 2, D, H, W]
        seg_masks = batch["mask"].to(device)     # [B, D, H, W]
        labels = batch["class"].to(device)       # [B]

        optimizer.zero_grad()
        class_out, seg_out = model(inputs)

        # 🛠️ Ensure segmentation masks are in shape [B, 1, D, H, W] before resizing
        if seg_masks.ndim == 4:
            seg_masks = seg_masks.unsqueeze(1)  # [B, D, H, W] → [B, 1, D, H, W]
        elif seg_masks.ndim == 6:
            seg_masks = seg_masks.squeeze(1)    # Remove singleton if shape is [B, 1, 1, D, H, W]

        print("[DEBUG] Input to model:", inputs.shape)
        print("[DEBUG] Classification output:", class_out.shape)
        print("[DEBUG] Raw segmentation output:", seg_out.shape)
        print("[DEBUG] Ground truth seg_masks shape before resize:", seg_masks.shape)

        # Resize seg_masks to match seg_out shape
        seg_masks_resized = F.interpolate(seg_masks, size=seg_out.shape[2:], mode='trilinear', align_corners=False)
        seg_masks_resized = seg_masks_resized.squeeze(2)  # Remove the singleton channel dimension


        # Final shape check before loss
        if seg_masks_resized.shape != seg_out.shape:
            print("[DEBUG] Shape mismatch before loss")
            print(f"  -> seg_out: {seg_out.shape}")
            print(f"  -> seg_masks_resized: {seg_masks_resized.shape}")
            raise ValueError("Mismatch between segmentation prediction and target shapes")

        print("[DEBUG] Resized segmentation mask:", seg_masks_resized.shape)

        # Compute joint loss (classification + segmentation)
        loss, c_l, s_l = multi_task_loss_fn(class_out, labels, seg_out, seg_masks_resized)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = class_out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    train_acc = correct / total_samples
    train_acc_list.append(train_acc)
    train_loss_list.append(running_loss / len(train_loader))

    # =====================
    # 🔍 Validation Phase
    # =====================
    model.eval()
    val_correct, val_loss_epoch = 0, 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch["image"].to(device)
            seg_masks = batch["mask"].to(device)
            labels = batch["class"].to(device)

            class_out, seg_out = model(inputs)

            # Same preprocessing for val seg masks
            if seg_masks.ndim == 4:
                seg_masks = seg_masks.unsqueeze(1)
            elif seg_masks.ndim == 6:
                seg_masks = seg_masks.squeeze(1)

            seg_masks_resized = F.interpolate(seg_masks, size=seg_out.shape[2:], mode='trilinear', align_corners=False)

            # Final shape check
            if seg_masks_resized.shape != seg_out.shape:
                print("[DEBUG] Validation shape mismatch")
                print(f"  -> seg_out: {seg_out.shape}")
                print(f"  -> seg_masks_resized: {seg_masks_resized.shape}")
                raise ValueError("Validation mismatch between segmentation prediction and target shapes")

            loss, c_l, s_l = multi_task_loss_fn(class_out, labels, seg_out, seg_masks_resized)

            val_loss_epoch += loss.item()
            preds = class_out.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_acc = val_correct / len(val_ds)
    val_acc_list.append(val_acc)
    val_loss_epoch = val_loss_epoch / len(val_loader)
    val_loss_list.append(val_loss_epoch)
    scheduler.step(val_loss_epoch)

    print(f"Epoch {epoch}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}, Train Loss = {train_loss_list[-1]:.4f}, Val Loss = {val_loss_epoch:.4f}")

    if val_loss_epoch < best_val_loss:
        best_val_loss = val_loss_epoch
        early_stop_counter = 0
        torch.save(model.state_dict(), "/content/best_model_densenet121_v6.pth")
        print("✅ Saved new best model")
    else:
        early_stop_counter += 1
        if early_stop_counter >= 5:
            print("⏹️ Early stopping triggered.")
            break


[DEBUG] Input: torch.Size([2, 2, 32, 224, 224])
[DEBUG] Features: torch.Size([2, 1024, 1, 7, 7])
[DEBUG] Segmentation output: torch.Size([2, 1, 16, 112, 112])
[DEBUG] Input to model: torch.Size([2, 2, 32, 224, 224])
[DEBUG] Classification output: torch.Size([2, 5])
[DEBUG] Raw segmentation output: torch.Size([2, 1, 16, 112, 112])
[DEBUG] Ground truth seg_masks shape before resize: torch.Size([2, 1, 32, 224, 224])
[DEBUG] Resized segmentation mask: torch.Size([2, 1, 16, 112, 112])


AssertionError: ground truth has different shape (torch.Size([2, 1, 1, 16, 112, 112])) from input (torch.Size([2, 1, 16, 112, 112]))

# ✅ Step 10: Evaluation

We visualize:
- Accuracy and loss curves
- Confusion matrix
- Class distribution histogram

Model is saved to:  
`/content/best_model_densenet121_v6.pth`


In [None]:
# ============================
# ✅ Step 10: Visualization - learning curves + confusion matrix
# ============================
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_acc_list, label='Train Accuracy')
plt.plot(val_acc_list, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Multi-task Accuracy Curve')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_loss_list, label='Train Loss')
plt.plot(val_loss_list, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Multi-task Loss Curve')
plt.legend()
plt.grid(True)
plt.show()

cm = confusion_matrix(all_labels, all_preds)
display = ConfusionMatrixDisplay(confusion_matrix=cm)
display.plot(cmap='Blues')
plt.title("Validation Confusion Matrix")
plt.show()

print("✅ Best model saved at: /content/best_model_densenet121_v6.pth")
print("✅ Multi-task learning completed!")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
📏 Target shape (H, W, D): 224, 224, 32
[DEBUG] Input: torch.Size([2, 2, 32, 224, 224])
[DEBUG] Features: torch.Size([2, 1024, 1, 7, 7])
[DEBUG] Segmentation output: torch.Size([2, 1, 16, 112, 112])


AssertionError: ground truth has different shape (torch.Size([2, 1, 32, 16, 112, 112])) from input (torch.Size([2, 1, 16, 112, 112]))