In Kaggle, add the following to the dependencies:
```
pip install torch
pip install torchvision
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [15]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'monai[einops,itk]>=1.5.0'

In [16]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive

        IN_COLAB = True
        import os

        drive.mount("/content/drive")
        os.makedirs("/content/drive/MyDrive/xai", exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass

In [17]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [18]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["MR", "CT"]
data_path = "data/"
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True
training_epochs = {
    ("CHAOS", "MR"): 30,
    ("CHAOS", "CT"): 10,
    ("MMWHS", "MR"): 30,
    ("MMWHS", "CT"): 20,
}
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

In [19]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
import torch
from monai import transforms
from monai.data.image_reader import ITKReader
from pathlib import Path
from src.datasets.volumetricPNGReader import VolumetricPNGReader


def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def get_preprocessing(domain, is_training=True, dataset_name=None):
    """
    Get comprehensive preprocessing pipeline for volumetric medical data.

    Returns separate transforms for images and segmentations to work with ImageDataset.
    Handles different file formats based on dataset:
    - CHAOS: DICOM images (directories), PNG labels (directories)
    - MMWHS: NIfTI images and labels
    """
    if use_3d:

        # Image-specific transforms (applied to image files)
        image_transforms = [
            transforms.EnsureChannelFirst(
                channel_dim="no_channel"
            ),  # Ensure channel-first format
            transforms.Orientation(axcodes="RAS"),  # Standardize spatial orientation
            transforms.Spacing(
                pixdim=(1.5, 1.5, 2.0), mode="bilinear"
            ),  # Consistent voxel spacing
        ]

        # Domain-specific intensity normalization for images
        if domain == "CT":
            # CT-specific intensity normalization
            image_transforms.extend(
                [
                    transforms.ScaleIntensityRange(
                        a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True
                    ),  # CT-specific clipping
                ]
            )
        else:  # MR
            # MR intensity normalization
            image_transforms.extend(
                [
                    transforms.NormalizeIntensity(nonzero=True, channel_wise=True),
                ]
            )

        # Spatial resizing for images
        image_transforms.append(
            transforms.Resize(spatial_size=96, size_mode="longest", mode="area")
        )

        # Training-specific augmentations for images only
        if is_training:
            # Image-only augmentations (safe for ImageDataset)
            augmentation_transforms = [
                transforms.RandGaussianNoise(prob=0.2, std=0.05),
                transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
            ]
            image_transforms.extend(augmentation_transforms)

        # Final conversion to tensor for images
        image_transforms.extend(
            [
                transforms.ToTensor(),
                transforms.EnsureType(dtype=torch.float32),
            ]
        )

        # Segmentation-specific transforms (applied to segmentation files)
        seg_transforms = [
            transforms.Lambda(lambda x: print(x.shape) or x),
            transforms.EnsureChannelFirst(
                channel_dim="no_channel"
            ),  # Ensure channel-first format
            transforms.Orientation(axcodes="RAS"),  # Standardize spatial orientation
            transforms.Spacing(
                pixdim=(1.5, 1.5, 2.0), mode="nearest"
            ),  # Use nearest for labels
            transforms.Resize(
                spatial_size=96, size_mode="longest", mode="nearest"
            ),  # Resize labels
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]

        # Create separate transform pipelines
        image_transform = transforms.Compose(image_transforms)
        seg_transform = transforms.Compose(seg_transforms)

        return image_transform, seg_transform

    else:
        # 2D preprocessing (if needed)
        return None, None

In [21]:
# Finetuning loop

for (dataset_name, domain), epochs in training_epochs.items():
    download_and_extract_dataset(dataset_name, data_path)

    # Get separate transforms for images and segmentations
    image_transform, seg_transform = get_preprocessing(
        domain, is_training=True, dataset_name=dataset_name
    )

    filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
    filename = checkpoint_path / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        transform=image_transform,  # Use transform instead of preprocess
        seg_transform=seg_transform,  # Pass seg_transform too
        base_path=data_path,
        batch_size=1,
        num_workers=1,
        slice_2d=not use_3d,
    )

    model = dataset.get_model(
        encoder_type="swin_unetr",
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        checkpoint_path
        / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images"
    )
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline",
        model_metrics,
    )

    # Train the segmentation head
    # For the proposal of the paper this part should not be done like this!
    # This requires data - the original task arithmetic paper builds classification heads using no data, only templates
    # if we have time this point should be addressed, otherwise at least mention that the technical problem
    # requires more developement time and the proof of concept should still be somewhat solid.
    model.freeze_body()
    model.finetune(
        epochs=epochs, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )
    # Save the head
    torch.save(
        model.head,
        checkpoint_path
        / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head.pth",
    )

    metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head",
        metrics,
    )

    # Finetune the encoder-decoder
    model.unfreeze()
    model.freeze_head()
    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )
    # Save the finetuned model's state_dict

    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuned model for CHAOS in MR domain with 3d images already exists at checkpoints/CHAOS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in CT domain with 3d images already exists at checkpoints/CHAOS_CT_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 3d images already exists at checkpoints/MMWHS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in CT domain with 3d images already exists at checkpoints/MMWHS_CT_3d_finetuned.pth. Skipping finetuning.


# Domain adaptation

In [22]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images"
        )
        baseline_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            [
                SwinUNETR,
                SwinTransformer,
                PatchEmbed,
                Conv3d,
                Dropout,
                ModuleList,
                BasicLayer,
                SwinTransformerBlock,
                LayerNorm,
                WindowAttention,
                Linear,
                Softmax,
                Identity,
                MLPBlock,
                GELU,
                PatchMerging,
                UnetrBasicBlock,
                UnetResBlock,
                Convolution,
                LeakyReLU,
                InstanceNorm3d,
                UnetrUpBlock,
                ConvTranspose3d,
                UnetOutBlock,
            ]
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the .out layer from the task vector
            out_layer_keys = [k for k in task_vector.keys() if "out." in k]
            for k in out_layer_keys:
                del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in MR domain with 3d images
Building task vector for CHAOS dataset in CT domain with 3d images
Building task vector for MMWHS dataset in MR domain with 3d images
Building task vector for MMWHS dataset in CT domain with 3d images


In [23]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS_CT": task_vectors["MMWHS_MR"]
    + task_vectors["CHAOS_CT"]
    - task_vectors["CHAOS_MR"],
    "MMWHS_MR": task_vectors["MMWHS_CT"]
    + task_vectors["CHAOS_MR"]
    - task_vectors["CHAOS_CT"],
    "CHAOS_CT": task_vectors["CHAOS_MR"]
    + task_vectors["MMWHS_CT"]
    - task_vectors["MMWHS_MR"],
    "CHAOS_MR": task_vectors["CHAOS_CT"]
    + task_vectors["MMWHS_MR"]
    - task_vectors["MMWHS_CT"],
}

In [24]:
# 🔄 Task Vector Cross-Domain Adaptation Experiments
print("🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in dataset_names:
    for target_domain in domains:
        print(f"\n{dataset_name}: {target_domain} adaptation")

        # Get separate transforms for target domain, passing dataset_name for correct readers
        image_transform, seg_transform = get_preprocessing(
            target_domain, is_training=False, dataset_name=dataset_name
        )
        # image_transform = None
        # seg_transform = None

        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": data_path,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": 1,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        # try:
        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)

        composite_key = f"{dataset_name}_{target_domain}"
        if composite_key in composite_task_vectors:
            composite_task_vector = composite_task_vectors[composite_key]

            target_model = target_dataset.get_model()
            target_model.load_task_vector(composite_task_vector)
            # target_model.setup_for_dataset(target_dataset)

            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)
            print(
                f"   ✅ {composite_key}: Dice={metrics.get('train', {}).get('dice', 0):.3f}"
            )
        else:
            print(f"   ⚠️ No composite task vector found for {composite_key}")

        # except Exception as e:
        #     print(f"   ❌ {dataset_name} {target_domain}: {str(e)[:100]}...")
        #     import traceback
        #     traceback.print_exc()
        #     # continue
        #     break

print("=" * 80)

🔄 Task Vector Cross-Domain Adaptation Experiments

CHAOS: MR adaptation
Loaded ['data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/1/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/10/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/13/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/15/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/19/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/2/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/20/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/21/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/22/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/3/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/31/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/32/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/33/T2SPIR/DICOM_anon', 'data/CHAOS/CHAOS_Train_Sets/Train_Sets/MR/34/T2SPIR/DICOM_anon',

Evaluating train:   0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([256, 256, 26])





RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7f224d7726c0>

In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # After Head-training performance
    print("\n🏋️‍♂️ After Head-Training Performance:")
    for key, metrics in all_metrics.items():
        if "head" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Finetuned performance
    print("\n🏆 Finetuned Performance:")
    for key, metrics in all_metrics.items():
        if "finetuned" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "adaptation" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
if IN_COLAB:
    import shutil

    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy(
        "/content/checkpoints.zip", "/content/drive/MyDrive/xai/checkpoints.zip"
    )

    # Copy metrics.json to Google Drive
    shutil.copy(
        "/content/xai/outputs/metrics.json", "/content/drive/MyDrive/xai/metrics.json"
    )

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints