<img src="./IMTA.png" alt="Logo IMT Atlantique" width="300"/>

##  **Frugal AI : Data Scarcity on Prostate MRI**
## TAF Health - UE B - 2025/2026 

Pierre-Henri.Conze@imt-atlantique.fr - Vincent.Jaouen@imt-atlantique.fr

In this lab, we will work with the **Prostate158** dataset (Mid-Axial slices). 
We want to understand the impact of **training data size** and **augmentation** on segmentation performance.

**Dataset**:
*   Images: T2-weighted MRI slices of the prostate.
*   Labels: Prostate segmentation masks.

**Objectives:**
1.  **Setup**: Load `prostate158` dataset and define a fixed validation split.
2.  **Part I (Scarcity)**: Train specific U-Nets on very small subsets (e.g., 5, 20 images) without augmentation.
3.  **Part II (Augmentation)**: Repeat the training on the smallest subsets using extensive Data Augmentation.
4.  **Part III (Semi-Supervised Learning)**: Use unlabeled data to improve performance.

**Student Instructions:**
*   This notebook contains several **Questions** and **Exercises** marked with üìù.
*   Please complete the code cells marked with `TODO`.
*   Answer the questions in the markdown cells provided.

In [None]:
# Setup for Google Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

import os
if IN_COLAB:
    print("Running on Google Colab. Setting up environment...")
    
    repo_name = "HealthLabs-IMTA"
    repo_url = "https://github.com/vhxjaouen/HealthLabs-IMTA.git"

    # Go to root if we appear to be in the repo's notebooks directory
    if os.path.basename(os.getcwd()) == "notebooks" and os.path.exists("../setup.py"):
        os.chdir("..")
        print("Moved to repository root.")

    # Check if we need to clone
    if not os.path.exists("setup.py"):
        if not os.path.exists(repo_name):
            print(f"Cloning {repo_name}...")
            !git clone {repo_url}
        
        # Move into the repo
        if os.path.exists(repo_name):
            os.chdir(repo_name)
            print(f"Changed directory to {os.getcwd()}")
    
    # Install package
    if os.path.exists("setup.py"):
        print("Installing package...")
        !pip install .
    else:
        print(f"Error: Could not find setup.py at {os.getcwd()}")
        print(f"Directory contents: {os.listdir(os.getcwd())}")

    # Move to notebooks directory for running the rest of the notebook
    if os.path.exists("notebooks"):
        os.chdir("notebooks")
        print(f"Working directory set to: {os.getcwd()}")
    
    print("Environment setup complete.")
else:
    print("Running locally.")

import torch
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU detected. Training will be very slow.")

In [None]:
import sys, os
import torch
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import yaml
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRangePercentilesd,
    RandFlipd, RandRotate90d, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
    EnsureTyped, SpatialPadd, CenterSpatialCropd, Resized
)
from monai.utils import set_determinism

# Fix seed for reproducibility
set_determinism(seed=29200)

## 1. Data Loading and Inspection

We will parse the `dataset.json` provided with the dataset to get image/label pairs.

In [None]:
# Configuration
data_dir = "../datasets/prostate158_MidAxial"
json_path = os.path.join(data_dir, "dataset.json")

# Load dataset.json
with open(json_path) as f:
    schema = json.load(f)

# Extract training paths (relative paths in JSON need to be joined with data_dir)
data_dicts = []
for entry in schema["training"]:
    # json entries: "./imagesTr/xxx.nii.gz"
    img_path = os.path.join(data_dir, entry["image"].replace("./", ""))
    lbl_path = os.path.join(data_dir, entry["label"].replace("./", ""))
    data_dicts.append({"image": img_path, "label": lbl_path})

print(f"Total available images: {len(data_dicts)}")

# Define Fixed Split (Last 30 for Validation)
val_size = 30
val_files = data_dicts[-val_size:]
train_pool = data_dicts[:-val_size]

print(f"Validation set size: {len(val_files)}")
print(f"Training pool size: {len(train_pool)}")

### üìù Exercise 1: Visualize a Sample Pair

It is always good practice to check your data before training.
Complete the code below to load and visualize the **first image and its corresponding label** from the `train_pool`.
*   Use `nibabel` (imported as `nib`) or just use the file path with `LoadImaged` transform if you prefer.
*   Display them side-by-side using `matplotlib`.

In [None]:
import nibabel as nib

# Select the first sample
sample = train_pool[0]
print(f"Sample: {sample}")

# TODO: Load the image and label using nib.load() or any other method
# Note: nib.load(path).get_fdata() returns the numpy array
# img_data = ...
# label_data = ...

# TODO: Visualize them using plt.imshow()
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(..., cmap="gray")
# plt.subplot(1, 2, 2)
# plt.imshow(..., cmap="jet")
# plt.show()

### üìù Question 1 (Data)
What are the dimensions (shape) of the loaded image array? Is it 2D or 3D?
Why does the `dataset.json` usually point to NIfTI (.nii.gz) files for medical imaging?

**Answer:**

## 2. Transforms Pipeline (MONAI)

We setup the MONAI transforms.
*   **Preprocessing**: Load, Channel First, Normalize Intensity (1st-99th percentile).
*   **Augmentation**: Flips, Rotation, Zoom, Intensity Shift (activated only if `augment=True`).

In [None]:
def get_transforms(augment=False):
    # Base Transforms
    transforms_list = [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1), 
        Resized(keys=["image", "label"], spatial_size=(256, 256), mode=("bilinear", "nearest")),
        ScaleIntensityRangePercentilesd(
            keys="image", lower=1, upper=99, 
            b_min=0.0, b_max=1.0, clip=True
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
    
    # Augmentation
    if augment:
        transforms_list += [
            # Geometric
            RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), 
            RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, mode=["area", "nearest"], prob=0.3),
            
            # Intensity
            RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
            RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.05),
        ]
        
    return Compose(transforms_list)

### üìù Question 2 (MONAI Transforms)
Look at the `EnsureChannelFirstd` transform.
PyTorch models expect tensors in the format `(Batch, Channel, Height, Width)` for 2D.
Most standard 2D image libraries (like PIL or OpenCV) load images as `(H, W, C)` or `(H, W)`.
Why is `EnsureChannelFirstd` important here, and what happens if an image is loaded as `(H, W)` without a channel dimension?

**Answer:**

### üìù Exercise 2: Visualizing Augmentations

Run the code below to see the effect of data augmentation on a single sample.
Try running the cell multiple times to see different random transformations.

In [None]:
# Create a dataset with augmentation enabled
# Note: CacheDataset allows us to cache the *deterministic* part of the transforms (first 4 items in list)
# But here cache_rate=0.0 ensures we re-apply transforms every time to see randomness
aug_ds = CacheDataset(data=[train_pool[0]], transform=get_transforms(augment=True), cache_rate=0.0)
aug_loader = DataLoader(aug_ds, batch_size=1)

# Get a batch
batch = next(iter(aug_loader))
img_aug = batch["image"][0, 0].numpy()
lbl_aug = batch["label"][0, 0].numpy()

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Augmented Image")
plt.imshow(img_aug, cmap="gray")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Augmented Label")
plt.imshow(lbl_aug, cmap="jet", alpha=0.5) # Label overlay
plt.axis("off")
plt.show()

## 3. Experiment Runner

We reuse the configuration from `segmentation.yaml` but override channel settings.

In [None]:
from healthlabs_imta.utils.training import train_segmentation
from healthlabs_imta.utils.model_utils import model_factory
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric

# Load Config from package
import healthlabs_imta
package_dir = os.path.dirname(healthlabs_imta.__file__)
config_path = os.path.join(package_dir, "configs", "segmentation.yaml")

with open(config_path) as f:
    cfg = yaml.safe_load(f)

# Override config for this dataset
cfg["data"]["data_dir"] = data_dir
cfg["model"]["in_channels"] = 1   # Single channel MRI
cfg["model"]["out_channels"] = 1  # Binary segmentation
cfg["training"]["max_epochs"] = 30  # Set global default duration

def run_experiment(n_train_samples, augment, max_epochs=None):
    # Use config value if max_epochs not provided
    if max_epochs is None:
        max_epochs = cfg["training"]["max_epochs"]

    print(f"\n{'='*40}")
    print(f"Running Experiment: N={n_train_samples}, Augmentation={augment}, Epochs={max_epochs}")
    print(f"{'='*40}")
    
    # 1. Deterministic Subset
    train_subset = train_pool[:n_train_samples]
    
    # 2. Dataloaders - Using CacheDataset to speed up training
    train_ds = CacheDataset(
        data=train_subset, 
        transform=get_transforms(augment=augment), 
        cache_rate=1.0, num_workers=2
    )
    val_ds = CacheDataset(
        data=val_files, 
        transform=get_transforms(augment=False), 
        cache_rate=1.0, num_workers=2
    )
    
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)
    
    # 3. Model Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_factory(cfg["model"]).to(device)
    loss_fn = DiceCELoss(sigmoid=True, to_onehot_y=False)
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # 4. Training
    history = train_segmentation(
        model, train_loader, val_loader,
        loss_fn, dice_metric, optimizer,
        device=device, max_epochs=max_epochs,
        overlay_fn=None # Disable plotting during loop
    )
    
    return history, model

### üìù Question 3 (MONAI Datasets)
In `run_experiment`, we use `CacheDataset`. 
What is the advantage of using `CacheDataset` over a standard `Dataset` when `cache_rate=1.0`?
When might you *not* want to use `CacheDataset` (or use a lower cache rate)?

**Answer:**

## 4. Part I: Scarcity Impact (No Augmentation)

We will verify the hypothesis that **more data = better performance**.
We will train the model with very few images: **5 and 20**.

In [None]:
results_no_aug = {}
# Reduced sample sizes for the lab to save time
sample_sizes = [20, 5] 

for n in sample_sizes:
    # 30 epochs is enough for a quick demonstration
    (train_losses, val_dices, best_dice, weights), _ = run_experiment(n, augment=False, max_epochs=30)
    results_no_aug[n] = val_dices
    print(f"-> Final Best Dice (N={n}, No Aug): {best_dice:.4f}")

### üìù Question 4 (Results Analysis)
Compare the performance (Best Dice) between N=5 and N=20. Is the difference significant? 
What is the risk of training a deep network like U-Net on only 5 images?

**Answer:**

## 5. Part II: Impact of Augmentation

Now repeat the experiment for **N=5 and N=20** but with `augment=True`.

In [None]:
results_aug = {}
aug_sample_sizes = [20, 5]

# TODO: Complete the loop to run the experiment with augmentation enabled
for n in aug_sample_sizes:
    # Use run_experiment function
    # (train_losses, val_dices, best_dice, weights), _ = ...
    pass 

    # results_aug[n] = val_dices
    # print(f"-> Final Best Dice (N={n}, Aug): {best_dice:.4f}")

### üìù Exercise 3: Plotting Results.

Use the provided code to plot the learning curves.
Does augmentation help more when N is small (5) or large (20)?

In [None]:
plt.figure(figsize=(12, 6))

colors = {5: 'r', 20: 'blue'}

# Plot No Aug
for n, dices in results_no_aug.items():
    c = colors.get(n, 'gray')
    plt.plot(dices, label=f'N={n} (No Aug)', color=c, linestyle='-', linewidth=2)

# Plot Aug
# TODO: Uncomment and adapt if you filled results_aug
# for n, dices in results_aug.items():
#     c = colors.get(n, 'gray')
#     plt.plot(dices, label=f'N={n} (Aug)', color=c, linestyle='--', linewidth=2)

plt.title("Prostate Segmentation: Impact of Scarcity & Augmentation")
plt.xlabel("Epochs")
plt.ylabel("Validation Dice")
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.0)
plt.show()

### üìù Question 5 (Data Augmentation)
In theory, data augmentation acts as a regularizer.
Did you observe reduced overfitting (e.g., gap between training loss and validation metric) or improved generalization?

**Answer:**

## 7. Part III: Semi-Supervised Learning (Student-Teacher)

We simulate a scenario where we have **100 images** available, but only **50 are labeled**.
Can we leverage the 50 "unlabeled" images to improve the performance of a model trained on only 50 labeled examples?

**Strategy (Self-Training / Pseudo-Labeling):**
1.  **Train Teacher**: Use the labeled images to train a model (Teacher).
2.  **Generate Pseudo-Labels**: Use the Teacher to predict segmentation masks for the unlabeled images.
3.  **Train Student**: Train a new model (Student) on the **combined dataset** (Labeled + Pseudo-Labeled).

In [None]:
from monai.transforms import Resize
import nibabel as nib
import shutil

n_labeled = 20       # Reduced to 20 for this lab exercise
n_unlabeled = 30     # Use 30 unlabeled images
ssl_epochs = 30      

print(f"Configuration: {n_labeled} Labeled, {n_unlabeled} Unlabeled, {ssl_epochs} Epochs")

### Step 1: Train the Teacher
We train a teacher on the small labeled set (N=20).

In [None]:
print(f"--- Step 1: Training Teacher (N={n_labeled}) ---")
(teacher_losses, teacher_dices, teacher_best_dice, teacher_weights), teacher_model = run_experiment(n_labeled, augment=True, max_epochs=ssl_epochs)
print(f"Teacher Best Validation Dice: {teacher_best_dice:.4f}")

### Step 2: Generate Pseudo-Labels
We use the teacher to label the unlabeled images.
Read the code below to understand how we resize the prediction back to the original image spacing.

In [None]:
print(f"--- Step 2: Generating Pseudo-Labels for {n_unlabeled} images ---")

# Prepare Teacher for Inference
teacher_model.load_state_dict(teacher_weights)
teacher_model.eval()

# Select Unlabeled Data (indices just after the labeled set)
unlabeled_data = train_pool[n_labeled : n_labeled + n_unlabeled]

# Setup Output Directory
pseudo_label_dir = os.path.join(data_dir, "pseudo_labels")
if os.path.exists(pseudo_label_dir):
    shutil.rmtree(pseudo_label_dir)
os.makedirs(pseudo_label_dir, exist_ok=True)

pseudo_labeled_data = []
infer_transforms = get_transforms(augment=False)

with torch.no_grad():
    for i, item in enumerate(unlabeled_data):
        # 1. Prediction (256x256)
        temp_item = {"image": item["image"], "label": item["image"]} 
        input_data = infer_transforms(temp_item)
        input_tensor = input_data["image"].unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
        
        outputs = teacher_model(input_tensor)
        outputs = torch.sigmoid(outputs)
        start_mask = (outputs > 0.5).float().cpu().numpy()[0, 0] # Binary mask
        
        # 2. Resizing back to original resolution
        # To ensure we don't cheat, we load the original NIfTI to get its shape/affine
        orig_img = nib.load(item["image"])
        # Simplified manual resize for clarity:
        resizer = Resize(spatial_size=orig_img.shape[:2], mode="nearest")
        mask_tensor_out = resizer(torch.from_numpy(start_mask).unsqueeze(0))
        final_mask = mask_tensor_out.squeeze(0).numpy()
        
        # Correct dimensions (Add channel dim if needed)
        if len(orig_img.shape) == 3 and orig_img.shape[2] == 1:
            final_mask = final_mask[:, :, np.newaxis]
            
        # Save output
        pseudo_filename = f"pseudo_{os.path.basename(item['image'])}"
        pseudo_path = os.path.join(pseudo_label_dir, pseudo_filename)
        nib.save(nib.Nifti1Image(final_mask.astype(np.float32), orig_img.affine), pseudo_path)
        pseudo_labeled_data.append({"image": item["image"], "label": pseudo_path})

print(f"Generated {len(pseudo_labeled_data)} pseudo-labels.")

### üìù Exercise 4: Visualize a Pseudo-Label
Visualize one of the generated pseudo-labels overlaid on its image. 
Does the teacher model make mistakes?

In [None]:
# TODO: Visualize the first element of 'pseudo_labeled_data'
# Load image and pseudo-label using nib.load()
# Display them

### Step 3: Train the Student
Now we train the student on **N=20 (GT) + N=30 (Pseudo)**.

In [None]:
print(f"--- Step 3: Training Student (Total N={len(pseudo_labeled_data) + n_labeled}) ---")

# Combine Data
combined_data = train_pool[:n_labeled] + pseudo_labeled_data

# Student Training Set with Augmentation
student_ds = CacheDataset(data=combined_data, transform=get_transforms(augment=True), cache_rate=1.0, num_workers=2)
# TODO: Create the dataloaders
student_loader = DataLoader(student_ds, batch_size=4, shuffle=True, num_workers=2)

# Validation components reusable
val_ds = CacheDataset(data=val_files, transform=get_transforms(augment=False), cache_rate=1.0, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

# Train Student
student_model = model_factory(cfg["model"]).to("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)

(student_losses, student_dices, student_best_dice, student_weights) = train_segmentation(
    student_model, student_loader, val_loader,
    loss_fn, dice_metric, optimizer,
    device="cuda" if torch.cuda.is_available() else "cpu", max_epochs=ssl_epochs
)

print(f"Student Best Validation Dice: {student_best_dice:.4f}")
print(f"Improvement over Teacher: {student_best_dice - teacher_best_dice:.4f}")

### üìù Question 6 (SSL Analysis)
Did the Student outperform the Teacher? If not, what could be the reasons?
Consider the quality of the pseudo-labels you visualized in Exercise 4. 
What happens if the Teacher generates bad pseudo-labels?