In [None]:
! pip install monai

In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from monai.losses import DiceLoss
from scipy.ndimage import zoom
from sklearn.model_selection import train_test_split
import copy

In [None]:
# Dataset Directory
DATASET_DIR = "/content/drive/MyDrive/NTU/ACDC_Dataset"

def parse_cfg(cfg_path):
    with open(cfg_path, "r") as file:
        cfg_data = {}
        for line in file:
            key, value = line.strip().split(": ")
            cfg_data[key] = value
    return int(cfg_data["ED"]), int(cfg_data["ES"]), cfg_data["Group"]

class ACDCDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        patient_dir = self.file_list[idx]
        cfg_path = os.path.join(patient_dir, "Info.cfg")
        ed_frame, _, _ = parse_cfg(cfg_path)

        patient_id = os.path.basename(patient_dir)[-3:]
        ed_image_path = os.path.join(patient_dir, f"patient{patient_id}_frame{ed_frame:02d}.nii.gz")
        ed_label_path = os.path.join(patient_dir, f"patient{patient_id}_frame{ed_frame:02d}_gt.nii.gz")

        ed_image = nib.load(ed_image_path).get_fdata()
        ed_label = nib.load(ed_label_path).get_fdata()

        def preprocess(image, label):
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
            label = label.astype(np.int32)
            image = zoom(image, (128 / image.shape[0], 128 / image.shape[1], 64 / image.shape[2]), order=1)
            label = zoom(label, (128 / label.shape[0], 128 / label.shape[1], 64 / label.shape[2]), order=0)
            return image, label

        ed_image, ed_label = preprocess(ed_image, ed_label)
        ed_label = np.expand_dims(ed_label, axis=0)

        sample = {"image": torch.tensor(ed_image[None, ...], dtype=torch.float32),
                  "label": torch.tensor(ed_label, dtype=torch.long)}
        return sample

# Prepare the dataset and data loaders.
patients = [os.path.join(DATASET_DIR, "training", d)
            for d in os.listdir(os.path.join(DATASET_DIR, "training")) if d.startswith("patient")]
train_patients, val_patients = train_test_split(patients, test_size=0.2, random_state=42)

train_dataset = ACDCDataset(train_patients)
val_dataset   = ACDCDataset(val_patients)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)


In [None]:
def window_partition(x, window_size):
    """
    Partition a 5D tensor (B, H, W, D, C) into non-overlapping windows.
    """
    B, H, W, D, C = x.shape
    w1, w2, w3 = window_size
    x = x.view(B, H // w1, w1,
                  W // w2, w2,
                  D // w3, w3, C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
    windows = windows.view(-1, w1 * w2 * w3, C)
    return windows

def window_reverse(windows, window_size, B, H, W, D):
    """
    Reverse the window partition to reconstruct the original tensor.
    """
    w1, w2, w3 = window_size
    x = windows.view(B, H // w1, W // w2, D // w3, w1, w2, w3, -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
    x = x.view(B, H, W, D, -1)
    return x


In [None]:
class SwinTransformerBlock3D(nn.Module):
    def __init__(self, embed_dim, window_size, num_heads, dropout=0.1, mlp_ratio=4.0):
        super(SwinTransformerBlock3D, self).__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, H, W, D):
        B, N, C = x.shape
        # Reshape token sequence into 5D tensor.
        x = x.view(B, H, W, D, C)
        windows = window_partition(x, self.window_size)  # (B*num_windows, window_volume, C)
        windows = self.norm1(windows)
        attn_windows, _ = self.attn(windows, windows, windows)
        windows = windows + attn_windows
        windows = windows + self.mlp(self.norm2(windows))
        x = window_reverse(windows, self.window_size, B, H, W, D)
        x = x.view(B, N, C)
        return x

class SwinViTSegmentation(nn.Module):
    def __init__(self, in_channels=1, num_classes=4, embed_dim=128,
                 patch_size=(16, 16, 16), window_size=(2, 2, 2),
                 num_layers=2, num_heads=4, dropout=0.1, img_size=(128, 128, 64)):
        super(SwinViTSegmentation, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        # Patch embedding
        self.patch_embed = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
        num_patches = grid_size[0] * grid_size[1] * grid_size[2]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.layers = nn.ModuleList([
            SwinTransformerBlock3D(embed_dim, window_size, num_heads, dropout)
            for _ in range(num_layers)
        ])
        # Segmentation head (1x1 convolution)
        self.seg_head = nn.Conv3d(embed_dim, num_classes, kernel_size=1)

    def forward(self, x):
        # x: (B, in_channels, H, W, D)
        x = self.patch_embed(x)  # -> (B, embed_dim, H_patch, W_patch, D_patch)
        B, C, H_patch, W_patch, D_patch = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        x = x + self.pos_embed
        for layer in self.layers:
            x = layer(x, H_patch, W_patch, D_patch)
        x = x.transpose(1, 2).view(B, C, H_patch, W_patch, D_patch)
        logits = self.seg_head(x)  # (B, num_classes, H_patch, W_patch, D_patch)
        logits = F.interpolate(logits, scale_factor=self.patch_size, mode='trilinear', align_corners=False)
        return logits


In [None]:
# Loss Functions

def dice_loss(output, target):
    return DiceLoss(to_onehot_y=True, softmax=True)(output, target)

def cross_entropy_loss(output, target):
    target = target.squeeze(1)
    return F.cross_entropy(output, target)

def normal_consistency_loss(output, target):
    output_grad = torch.gradient(output, dim=(2, 3, 4))
    target_grad = torch.gradient(target.float(), dim=(2, 3, 4))
    loss = 0.0
    for og, tg in zip(output_grad, target_grad):
        loss += F.l1_loss(og, tg)
    return loss / len(output_grad)

def wasserstein_distance_loss(output, target):
    output = F.softmax(output, dim=1)
    output_flat = output.view(output.size(0), output.size(1), -1)
    target_one_hot = F.one_hot(target.squeeze(1), num_classes=output.size(1)).permute(0, 4, 1, 2, 3)
    target_flat = target_one_hot.reshape(output.size(0), output.size(1), -1).float()
    cdf_output = torch.cumsum(output_flat, dim=-1)
    cdf_target = torch.cumsum(target_flat, dim=-1)
    return torch.abs(cdf_output - cdf_target).mean()

def hausdorff_loss(output, target):
    output_soft = F.softmax(output, dim=1)
    preds = torch.argmax(output_soft, dim=1)
    loss = 0.0
    B, C, H, W, D = output.shape[0], output.shape[1], output.shape[2], output.shape[3], output.shape[4]
    for b in range(B):
        for c in range(C):
            gt_mask = (target[b, 0] == c).float().unsqueeze(0).unsqueeze(0)
            pred_mask = (preds[b] == c).float().unsqueeze(0).unsqueeze(0)
            kernel_size = 3
            gt_pool = F.max_pool3d(gt_mask, kernel_size=kernel_size, stride=1, padding=1)
            pred_pool = F.max_pool3d(pred_mask, kernel_size=kernel_size, stride=1, padding=1)
            gt_boundary = gt_mask - (gt_mask * (gt_mask == gt_pool).float())
            pred_boundary = pred_mask - (pred_mask * (pred_mask == pred_pool).float())
            gt_coords = (gt_boundary[0, 0] > 0).nonzero(as_tuple=False).float()
            pred_coords = (pred_boundary[0, 0] > 0).nonzero(as_tuple=False).float()
            if gt_coords.numel() == 0 or pred_coords.numel() == 0:
                continue
            dists = torch.cdist(pred_coords, gt_coords, p=2)
            hd_pred_to_gt = dists.min(dim=1)[0].max()
            hd_gt_to_pred = dists.min(dim=0)[0].max()
            hd = torch.max(hd_pred_to_gt, hd_gt_to_pred)
            loss += hd
    return loss / (B * C)

# Compound loss variants
def compound_loss_variant1(output, target, weights):
    """
    Variant 1: CE + NC + Hausdorff + Dice.
    Weights is a tuple of four values.
    """
    ce = cross_entropy_loss(output, target)
    nc = normal_consistency_loss(output, target)
    hd = hausdorff_loss(output, target)
    d  = dice_loss(output, target)
    return weights[0]*ce + weights[1]*nc + weights[2]*hd + weights[3]*d

def compound_loss_variant2(output, target, weights):
    """
    Variant 2: CE + NC + Wasserstein + Dice.
    """
    ce = cross_entropy_loss(output, target)
    nc = normal_consistency_loss(output, target)
    wd = wasserstein_distance_loss(output, target)
    d  = dice_loss(output, target)
    return weights[0]*ce + weights[1]*nc + weights[2]*wd + weights[3]*d


In [None]:
def compute_iou(preds, targets, num_classes):
    preds = torch.argmax(preds, dim=1)
    iou_scores = []
    for cls in range(num_classes):
        pred_mask = (preds == cls).float()
        target_mask = (targets.squeeze(1) == cls).float()
        intersection = (pred_mask * target_mask).sum()
        union = (pred_mask + target_mask).clamp(0, 1).sum()
        if union == 0:
            iou_scores.append(torch.tensor(1.0, device=preds.device))
        else:
            iou_scores.append(intersection / union)
    return torch.mean(torch.stack(iou_scores))

def pixel_accuracy(preds, targets):
    preds = torch.argmax(preds, dim=1)
    correct = (preds == targets.squeeze(1)).sum().item()
    total = torch.numel(targets)
    return correct / total

def precision_recall(preds, targets, num_classes):
    preds = torch.argmax(preds, dim=1)
    precision, recall = [], []
    for cls in range(num_classes):
        tp = ((preds == cls) & (targets.squeeze(1) == cls)).sum().item()
        fp = ((preds == cls) & (targets.squeeze(1) != cls)).sum().item()
        fn = ((preds != cls) & (targets.squeeze(1) == cls)).sum().item()
        precision.append(tp / (tp + fp + 1e-7))
        recall.append(tp / (tp + fn + 1e-7))
    return sum(precision) / len(precision), sum(recall) / len(recall)

def dice_coefficient(preds, targets):
    preds = torch.argmax(preds, dim=1)
    dice_scores = []
    for cls in range(preds.max().item() + 1):
        pred_mask = (preds == cls).float()
        target_mask = (targets.squeeze(1) == cls).float()
        intersection = (pred_mask * target_mask).sum()
        dice = (2 * intersection) / (pred_mask.sum() + target_mask.sum() + 1e-7)
        dice_scores.append(dice)
    return sum(dice_scores) / len(dice_scores)

In [None]:
# Grid Search Setup

# Define candidate weight sets (for the 4 loss components)
candidate_weights = [
    (0.25, 0.25, 0.25, 0.25),
    (0.20, 0.30, 0.30, 0.20),
    (0.30, 0.20, 0.20, 0.30),
    (0.10, 0.40, 0.40, 0.10)
]

# grid search parameters
NUM_EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# For each variant, we will store the best candidate (according validation Dice coefficient)
best_config = {
    "Variant1": {"weights": None, "dice": 0.0},
    "Variant2": {"weights": None, "dice": 0.0}
}

# to train and evaluate for a given compound loss function
def train_and_evaluate(compound_loss_fn, weights, variant_name):
    model = SwinViTSegmentation(
        in_channels=1,
        num_classes=4,
        embed_dim=128,
        patch_size=(16, 16, 16),
        window_size=(2, 2, 2),
        num_layers=2,
        num_heads=4,
        dropout=0.1,
        img_size=(128, 128, 64)
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            images = batch["image"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = compound_loss_fn(outputs, labels, weights)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(train_loader)
        print(f"[{variant_name} | Weights {weights}] Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f}")

    # Evaluation on validation set
    model.eval()
    val_iou = 0.0
    val_dice = 0.0
    val_pix_acc = 0.0
    val_prec = 0.0
    val_rec = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            outputs = model(images)
            val_iou += compute_iou(outputs, labels, num_classes=4).item()
            val_dice += dice_coefficient(outputs, labels).item()
            val_pix_acc += pixel_accuracy(outputs, labels)
            p, r = precision_recall(outputs, labels, num_classes=4)
            val_prec += p
            val_rec  += r
    num_batches = len(val_loader)
    avg_iou = val_iou / num_batches
    avg_dice = val_dice / num_batches
    avg_pix_acc = val_pix_acc / num_batches
    avg_prec = val_prec / num_batches
    avg_rec = val_rec / num_batches
    print(f"[{variant_name} | Weights {weights}] Validation Metrics -- IoU: {avg_iou:.4f}, Dice: {avg_dice:.4f}, Pixel Acc: {avg_pix_acc:.4f}, Prec: {avg_prec:.4f}, Rec: {avg_rec:.4f}")

    return avg_dice # For example, we use Dice coefficient as the selection metric

In [None]:
# Variant 1: CE + NC + Hausdorff + Dice
print("\n=== Grid Search for Compound Loss Variant 1 (CE + NC + Hausdorff + Dice) ===\n")
for weights in candidate_weights:
    avg_dice = train_and_evaluate(compound_loss_variant1, weights, "Variant1")
    if avg_dice > best_config["Variant1"]["dice"]:
        best_config["Variant1"]["dice"] = avg_dice
        best_config["Variant1"]["weights"] = weights

# Variant 2: CE + NC + Wasserstein + Dice
print("\n=== Grid Search for Compound Loss Variant 2 (CE + NC + Wasserstein + Dice) ===\n")
for weights in candidate_weights:
    avg_dice = train_and_evaluate(compound_loss_variant2, weights, "Variant2")
    if avg_dice > best_config["Variant2"]["dice"]:
        best_config["Variant2"]["dice"] = avg_dice
        best_config["Variant2"]["weights"] = weights

print("\n=== Grid Search Results ===")
print("Best configuration for Variant 1 (CE + NC + Hausdorff + Dice):")
print(best_config["Variant1"])
print("Best configuration for Variant 2 (CE + NC + Wasserstein + Dice):")
print(best_config["Variant2"])