In [None]:
!pip install numpy>=2.0.0
!pip install monai

In [None]:
import os
import numpy as np
import torch
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    RandFlipd,
    RandRotate90d,
    RandRotateD,
    RandZoomd,
    RandScaleIntensityd,
    RandGaussianNoised,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandShiftIntensityd,
    ToTensord,
)

data_dir = "/content/drive/MyDrive/LiQA_training_data/npz_sliced"
batch_size = 8
num_workers = 2
n_splits = 4
seed = 42

# Helper Functions
def get_stage_from_path(path: str) -> str:
    """
    Extract stage name (S1–S4) from a file path.

    Expected: the path contains one of 'S1', 'S2', 'S3', 'S4'.
    """
    parts = path.split(os.sep)
    for p in parts:
        if p in ("S1", "S2", "S3", "S4"):
            return p
    raise ValueError(f"Stage not found in path: {path}")


def stage_to_label(stage: str) -> int:
    """
    Map stage to a binary label.
    - S1/S2/S3 -> 0
    - S4       -> 1
    """
    return 0 if stage in ("S1", "S2", "S3") else 1

# Build File Index (path / stage / label)
file_info = []
for root, _, files in os.walk(data_dir):
    for fname in files:
        if fname.endswith(".npz"):
            path = os.path.join(root, fname)
            stage = get_stage_from_path(path)
            label = stage_to_label(stage)
            file_info.append({"path": path, "stage": stage, "label": label})

print(f"Found {len(file_info)} npz files in total.")
print("Stage distribution:", Counter([f["stage"] for f in file_info]))
print("Class distribution:", Counter([f["label"] for f in file_info]))

# Stage-Aware Augmentation
stage_transforms = {
    # S1: most samples -> mild augmentation
    "S1": Compose(
        [
            RandFlipd(keys=["image"], prob=0.3, spatial_axis=0),
            RandRotate90d(keys=["image"], prob=0.2, max_k=3),
            ToTensord(keys=["image"]),
        ]
    ),
    # S2: fewer samples -> moderate augmentation
    "S2": Compose(
        [
            RandFlipd(keys=["image"], prob=0.5, spatial_axis=(0, 1)),
            RandRotate90d(keys=["image"], prob=0.4, max_k=3),
            RandRotateD(keys=["image"], range_x=0.1745, prob=0.4),
            RandZoomd(keys=["image"], min_zoom=0.95, max_zoom=1.05, prob=0.4),
            RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.3),
            ToTensord(keys=["image"]),
        ]
    ),
    # S3: least samples -> strong augmentation (key design)
    "S3": Compose(
        [
            RandFlipd(keys=["image"], prob=0.8, spatial_axis=(0, 1, 2)),
            RandRotate90d(keys=["image"], prob=0.7, max_k=3),
            RandRotateD(keys=["image"], range_x=0.5236, prob=0.6),
            RandZoomd(keys=["image"], min_zoom=0.8, max_zoom=1.2, prob=0.6),
            RandScaleIntensityd(keys=["image"], factors=0.2, prob=0.5),
            RandGaussianNoised(keys=["image"], prob=0.4, std=0.02),
            RandAdjustContrastd(keys=["image"], prob=0.4, gamma=(0.7, 1.3)),
            RandGaussianSmoothd(keys=["image"], prob=0.3, sigma_x=(0.5, 1.0)),
            RandShiftIntensityd(keys=["image"], prob=0.3, offsets=0.1),
            ToTensord(keys=["image"]),
        ]
    ),
    # S4: moderate amount -> medium augmentation
    "S4": Compose(
        [
            RandFlipd(keys=["image"], prob=0.5, spatial_axis=(0, 1)),
            RandRotate90d(keys=["image"], prob=0.4, max_k=3),
            RandRotateD(keys=["image"], range_x=0.1745, prob=0.4),
            RandZoomd(keys=["image"], min_zoom=0.95, max_zoom=1.05, prob=0.4),
            RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.3),
            ToTensord(keys=["image"]),
        ]
    ),
}

# Validation: no random augmentation, tensor conversion only
val_transform = Compose([ToTensord(keys=["image"])])

# Dataset Definition
class MRIDataset(Dataset):
    """
    Load one .npz sample and apply:
    - stage-aware random augmentation for training
    - minimal transform for validation
    """

    def __init__(self, infos, train: bool = True):
        self.infos = infos
        self.train = train

    def __len__(self):
        return len(self.infos)

    def __getitem__(self, idx: int):
        info = self.infos[idx]

        # Load data from npz
        data = np.load(info["path"])
        img = data["image"].astype(np.float32)
        mask = data["mask"].astype(np.int8)

        stage = info["stage"]
        label = info["label"]

        # Apply transform (Monai dict-style transforms)
        sample = {"image": img}
        if self.train:
            sample = stage_transforms[stage](sample)
        else:
            sample = val_transform(sample)

        # Return tensors used by the model
        return {
            "image": sample["image"],
            "modality_mask": torch.from_numpy(mask),
            "label": torch.tensor(label, dtype=torch.long),
        }

# Stratified K-Fold Split + DataLoaders (Stratify by stage to keep S1–S4 distribution similar across folds)
stage_labels = [int(info["stage"][1]) for info in file_info]

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

dataloaders = []

for fold, (tr_idx, vl_idx) in enumerate(skf.split(file_info, stage_labels), 1):
    tr_infos = [file_info[i] for i in tr_idx]
    vl_infos = [file_info[i] for i in vl_idx]

    tr_ds = MRIDataset(tr_infos, train=True)
    vl_ds = MRIDataset(vl_infos, train=False)

    tr_loader = DataLoader(
        tr_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
    )
    vl_loader = DataLoader(
        vl_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    dataloaders.append((tr_loader, vl_loader))

    print(
        f"Fold {fold}: train={len(tr_ds)}, val={len(vl_ds)}",
        "| Stage_train=", Counter([i["stage"] for i in tr_infos]),
        "Stage_val=", Counter([i["stage"] for i in vl_infos]),)

print(f"\nPipeline ready. Total folds: {len(dataloaders)}")
