# Medical Image Preprocessing Pipeline Explanation

This notebook implements comprehensive preprocessing for volumetric medical imaging data. Below are detailed explanations of each preprocessing step and why they're crucial for medical AI applications.

## 📋 Core Preprocessing Steps

### 1. **EnsureChannelFirst()**
- **Purpose**: Standardizes tensor format to (C, H, W, D) where C=channels
- **Why Important**: Medical images can come in different formats. This ensures consistency for neural networks
- **Example**: Converts (H, W, D) → (1, H, W, D) for single-channel data

### 2. **Orientation(axcodes="RAS")**
- **Purpose**: Standardizes anatomical orientation to Right-Anterior-Superior
- **Why Critical**: Medical scans can have different orientations (RAS, LPS, etc.)
- **Impact**: Ensures consistent spatial relationships across all data
- **Alternative**: LPS (Left-Posterior-Superior) is also common

### 3. **Spacing(pixdim=(1.5, 1.5, 2.0))**
- **Purpose**: Resamples voxels to consistent physical dimensions
- **Why Essential**: Different scanners produce different voxel sizes (e.g., 0.5mm vs 2mm)
- **Parameters**: 
  - `pixdim`: Target spacing in mm (x, y, z)
  - `mode`: Interpolation method ("bilinear" for images, "nearest" for labels)
- **Impact**: Ensures fair comparison and consistent model input

## 🏥 Domain-Specific Intensity Normalization

### For CT Scans: ScaleIntensityRange()
```python
ScaleIntensityRange(a_min=-200, a_max=300, b_min=0.0, b_max=1.0, clip=True)
```
- **Purpose**: Normalizes Hounsfield Units (HU) to [0,1] range
- **HU Values**:
  - Air: -1000 HU
  - Water: 0 HU  
  - Bone: +1000 HU
  - Soft tissue: -200 to +300 HU
- **Why Clipping**: Removes extreme outliers that could skew normalization

### For MRI Scans: NormalizeIntensity()
```python
NormalizeIntensity(nonzero=True, channel_wise=True)
```
- **Purpose**: Z-score normalization (mean=0, std=1)
- **Why Different from CT**: MRI intensities are relative, not absolute like CT
- **Parameters**:
  - `nonzero=True`: Only normalizes non-zero voxels (ignores background)
  - `channel_wise=True`: Normalizes each channel independently

### 4. **Resize(spatial_size=(96, 96, 96))**
- **Purpose**: Standardizes volume dimensions for batch processing
- **Why Necessary**: Medical volumes vary greatly in size
- **Trade-offs**: 
  - Smaller size = faster training, less memory
  - Larger size = more detail, better performance
- **Mode**: "trilinear" for smooth interpolation

## 🔄 Data Augmentation (Training Only)

### Spatial Augmentations

#### RandRotate90(prob=0.3, spatial_axes=(0, 1))
- **Purpose**: Random 90° rotations in axial plane
- **Why Conservative**: Medical anatomy has consistent orientation
- **Probability**: 30% to avoid over-augmentation

#### RandFlip(prob=0.3, spatial_axis=0)
- **Purpose**: Random horizontal flips
- **Anatomical Consideration**: Only flips that preserve medical validity
- **Limitation**: Careful with asymmetric organs (heart, liver)

#### RandAffine()
```python
RandAffine(prob=0.3, rotate_range=0.1, translate_range=5, scale_range=0.1)
```
- **Purpose**: Small geometric transformations
- **Parameters**:
  - `rotate_range=0.1`: ±5.7° rotation
  - `translate_range=5`: ±5 voxel shifts
  - `scale_range=0.1`: ±10% scaling
- **Why Small Values**: Medical images require precise alignment

### Intensity Augmentations

#### RandGaussianNoise(prob=0.2, std=0.05)
- **Purpose**: Simulates scanner noise variations
- **Why Important**: Different scanners have different noise characteristics
- **Conservative std**: 5% noise to avoid corrupting anatomical features

#### RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1))
- **Purpose**: Simulates different contrast settings
- **Gamma Range**: 10% variation is subtle but effective
- **Medical Relevance**: Accounts for different scanning protocols

## 🔧 Final Processing Steps

### ToTensor() & EnsureType(dtype=torch.float32)
- **Purpose**: Converts to PyTorch tensors with consistent data type
- **Why float32**: Balance between precision and memory usage
- **GPU Compatibility**: Ensures tensor can be moved to GPU efficiently

## 🎯 Key Design Principles

### 1. **Domain Awareness**
- Different modalities (CT vs MRI) require different intensity handling
- Anatomical constraints guide augmentation choices

### 2. **Conservative Augmentation**
- Medical images require anatomical validity
- Small, realistic transformations preserve diagnostic quality

### 3. **Consistency First**
- Standardized spacing, orientation, and intensity ranges
- Enables fair comparison across different scanners and protocols

### 4. **Training vs Validation**
- Augmentations only during training
- Deterministic preprocessing for validation/testing

## ⚠️ Important Considerations

### Memory Management
- Volumetric data is memory-intensive
- Consider patch-based processing for very large volumes

### Quality Control
- Always validate preprocessing results visually
- Check for anatomical correctness after transformations

### Dataset-Specific Tuning
- HU ranges might need adjustment for different CT protocols
- MRI sequences may require specialized normalization

In Kaggle, add the following to the dependencies:
```
pip install torch
pip install torchvision
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [1]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'napari[pyqt6,optional]==0.6.2a1' 'monai[einops,nibabel]>=1.1.0'

In [2]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive
        IN_COLAB = True
        import os
        drive.mount('/content/drive')
        os.makedirs('/content/drive/MyDrive/xai', exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass




In [3]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

## 🔍 Practical Example: Before vs After Preprocessing

Here's what happens to your medical image data through each step:

### Original Data Issues:
```
CT Scan A: Shape=(512, 512, 200), Spacing=(0.5, 0.5, 1.0)mm, HU=[-1024, 3071]
CT Scan B: Shape=(256, 256, 150), Spacing=(1.0, 1.0, 2.0)mm, HU=[-500, 2000]
MRI Scan A: Shape=(320, 320, 180), Spacing=(0.8, 0.8, 1.5)mm, Intensity=[0, 4095]
MRI Scan B: Shape=(384, 384, 160), Spacing=(0.6, 0.6, 1.2)mm, Intensity=[0, 2048]
```

### After Preprocessing Pipeline:
```
All Scans: Shape=(96, 96, 96), Spacing=(1.5, 1.5, 2.0)mm, Values=[0.0, 1.0]
```

### Step-by-Step Transformation:

1. **EnsureChannelFirst**: (H,W,D) → (1,H,W,D)
2. **Orientation**: All aligned to RAS coordinate system
3. **Spacing**: Resampled to consistent 1.5×1.5×2.0mm voxels
4. **Intensity**: 
   - CT: HU [-200,300] → [0,1]
   - MRI: Z-score normalized → mean≈0, std≈1
5. **Resize**: All volumes → 96³ voxels
6. **Augmentation**: Random but anatomically valid transformations

This ensures your model sees consistent, comparable data regardless of the original scanner or protocol!

In [4]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["MR", "CT"]
data_path = "data/"
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True
training_epochs = {
    ("CHAOS", "MR"): 30,
    ("CHAOS", "CT"): 10,
    ("MMWHS", "MR"): 30,
    ("MMWHS", "CT"): 20,
}
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

In [5]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
import torch
from monai import transforms

def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def get_preprocessing(domain, is_training=True):
    """
    Get comprehensive preprocessing pipeline for volumetric medical data.

    Args:
        domain: 'MR' or 'CT'
        is_training: Whether this is for training (includes augmentations)
    """
    if use_3d:
        # Base preprocessing steps (applied to all data)
        base_transforms = [
            transforms.EnsureChannelFirst(),
            transforms.Orientation(axcodes="RAS"),  # Standardize orientation
            transforms.Spacing(pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),  # Consistent voxel spacing
        ]

        # Domain-specific intensity normalization
        if domain.upper() in ['CT']:
            # CT: Clip HU values and normalize
            base_transforms.extend([
                transforms.ScaleIntensityRange(
                    a_min=-200, a_max=300, b_min=0.0, b_max=1.0, clip=True
                ),
            ])
        else:  # MR/MRI
            # MR: Z-score normalization (handles varying intensity ranges)
            base_transforms.extend([
                transforms.NormalizeIntensity(nonzero=True, channel_wise=True),
            ])

        # Spatial resizing
        base_transforms.append(
            transforms.Resize(spatial_size=(96, 96, 96), mode="trilinear")
        )

        # Training augmentations
        if is_training:
            augmentation_transforms = [
                transforms.RandRotate90(prob=0.3, spatial_axes=(0, 1)),
                transforms.RandFlip(prob=0.3, spatial_axis=0),
                transforms.RandAffine(
                    prob=0.3,
                    rotate_range=0.1,
                    translate_range=5,
                    scale_range=0.1,
                    mode=("bilinear", "nearest")
                ),
                transforms.RandGaussianNoise(prob=0.2, std=0.05),
                transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
            ]
            base_transforms.extend(augmentation_transforms)

        # Final conversion to tensor
        base_transforms.extend([
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ])

        return transforms.Compose(base_transforms)
    else:
        # 2D preprocessing (if needed)
        return None

In [None]:
# Finetuning loop

for (dataset_name, domain), epochs in training_epochs.items():
    download_and_extract_dataset(dataset_name, data_path)
    preprocess = get_preprocessing(domain, is_training=True)

    filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
    filename = checkpoint_path / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        preprocess=preprocess,
        base_path=data_path,
        batch_size=1,
        num_workers=1,
        slice_2d=not use_3d,
    )

    model = dataset.get_model(
        encoder_type="swin_unetr",
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        checkpoint_path
        / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images"
    )
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline",
        model_metrics,
    )

    # Train the segmentation head
    # For the proposal of the paper this part should not be done like this!
    # This requires data - the original task arithmetic paper builds classification heads using no data, only templates
    # if we have time this point should be addressed, otherwise at least mention that the technical problem
    # requires more developement time and the proof of concept should still be somewhat solid.
    model.freeze_body()
    model.finetune(epochs=epochs, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    # Save the head
    torch.save(model.head, checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head.pth")

    metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head",
        metrics,
    )

    # Finetune the encoder-decoder
    model.unfreeze()
    model.freeze_head()
    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )
    # Save the finetuned model's state_dict

    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuned model for CHAOS in MR domain with 3d images already exists at checkpoints/CHAOS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in CT domain with 3d images already exists at checkpoints/CHAOS_CT_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 3d images already exists at checkpoints/MMWHS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in CT domain with 3d images already exists at checkpoints/MMWHS_CT_3d_finetuned.pth. Skipping finetuning.


# Domain adaptation

In [8]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images"
        )
        baseline_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            [
                SwinUNETR,
                SwinTransformer,
                PatchEmbed,
                Conv3d,
                Dropout,
                ModuleList,
                BasicLayer,
                SwinTransformerBlock,
                LayerNorm,
                WindowAttention,
                Linear,
                Softmax,
                Identity,
                MLPBlock,
                GELU,
                PatchMerging,
                UnetrBasicBlock,
                UnetResBlock,
                Convolution,
                LeakyReLU,
                InstanceNorm3d,
                UnetrUpBlock,
                ConvTranspose3d,
                UnetOutBlock,
            ]
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the .out layer from the task vector
            out_layer_keys = [k for k in task_vector.keys() if "out." in k]
            for k in out_layer_keys:
                del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in MR domain with 3d images
Building task vector for CHAOS dataset in CT domain with 3d images
Building task vector for MMWHS dataset in MR domain with 3d images
Building task vector for MMWHS dataset in CT domain with 3d images


In [9]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS_CT": task_vectors["MMWHS_MR"] + task_vectors["CHAOS_CT"] - task_vectors["CHAOS_MR"],
    "MMWHS_MR": task_vectors["MMWHS_CT"] + task_vectors["CHAOS_MR"] - task_vectors["CHAOS_CT"],
    "CHAOS_CT": task_vectors["CHAOS_MR"] + task_vectors["MMWHS_CT"] - task_vectors["MMWHS_MR"],
    "CHAOS_MR": task_vectors["CHAOS_CT"] + task_vectors["MMWHS_MR"] - task_vectors["MMWHS_CT"],
}

In [None]:
# Task Vector Cross-Domain Evaluation (merged version, with nested loops for all configs)
print("\n🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)
preprocess = get_preprocessing(domain, is_training=False)
for dataset_name in dataset_names:
    for target_domain in domains:

        # Only add extra_kwarg 'liver_only' for CHAOS dataset
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        try:
            composite_key = f"{dataset_name}_{target_domain}"
            if composite_key not in composite_task_vectors:
                # If the composite task vector does not exist, skip this iteration
                print(
                    f"   ❗ Warning: Composite task vector for {composite_key} does not exist. Skipping evaluation."
                )
                continue

            print(
                f"\n{dataset_name}: {target_domain} adaptation"
            )
            # Load target domain dataset
            dataset_kwargs = dict(
                dataset_name=dataset_name,
                domain=target_domain,
                base_path=data_path,
                preprocess=preprocess,
                batch_size=1,
                num_workers=1,
                slice_2d=False,
            )
            dataset_kwargs.update(extra_kwargs)
            target_dataset = get_dataset(**dataset_kwargs)
            target_model = target_dataset.get_model(encoder_type="swin_unetr")

            # Apply composite task vector for target domain
            composite_task_vector = composite_task_vectors[composite_key]
            target_model.load_task_vector(composite_task_vector)

            # Overwrite the model's head with the saved one from the same task
            head_filename = checkpoint_path / f"{dataset_name}_{target_domain}_{'3d' if use_3d else '2d'}_head.pth"
            if head_filename.exists():
                with torch.serialization.safe_globals([UnetOutBlock, Convolution, Conv3d]):
                    target_model.head.load_state_dict(torch.load(head_filename, map_location="cuda" if torch.cuda.is_available() else "cpu").state_dict())
            else:
                print(
                    f"   ❗ Warning: Head file {head_filename} does not exist. Skipping evaluation."
                )
                continue

            # Evaluate cross-domain performance
            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)

            print(
                f"   ✅ {dataset_name} {target_domain}: Dice={metrics.get('dice', 0):.3f}, Hausdorff={metrics.get('hausdorff', 0):.3f}"
            )
        except Exception as e:
            print(
                f"   ❌ {dataset_name} {target_domain} error: {e}"
            )


🔄 Task Vector Cross-Domain Adaptation Experiments

CHAOS: MR adaptation
2025-07-27 17:26:43,305 - INFO - Expected md5 is None, skip md5 check for file data/ssl_pretrained_weights.pth.
2025-07-27 17:26:43,306 - INFO - File exists: data/ssl_pretrained_weights.pth, skipped downloading.
Total updated layers 159 / 159
Pretrained Weights Succesfully Loaded !
⚠️ Task vector missing key: out.conv.conv.weight, skipping update.
⚠️ Task vector missing key: out.conv.conv.bias, skipping update.
🔍 Evaluating train split...


Evaluating train:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # After Head-training performance
    print("\n🏋️‍♂️ After Head-Training Performance:")
    for key, metrics in all_metrics.items():
        if "head" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Finetuned performance
    print("\n🏆 Finetuned Performance:")
    for key, metrics in all_metrics.items():
        if "finetuned" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "adaptation" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
if IN_COLAB:
    import shutil


    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy('/content/checkpoints.zip', '/content/drive/MyDrive/xai/checkpoints.zip')

    # Copy metrics.json to Google Drive
    shutil.copy('/content/xai/outputs/metrics.json', '/content/drive/MyDrive/xai/metrics.json')

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints