In Kaggle, add the following to the dependencies:
```
pip install torch
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [None]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'monai[einops,itk,nibabel]>=1.5.0' git+https://github.com/timojl/clipseg.git

In [None]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive

        IN_COLAB = True
        import os

        drive.mount("/content/drive")
        os.makedirs("/content/drive/MyDrive/xai", exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass

In [None]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [None]:
DATASET_NAMES = ["MMWHS", "CHAOS"]
DOMAINS = ["MR", "CT"]
DATA_PATH = "data/"
CHECKPOINT_PATH = "checkpoints/"
OUTPUTS_PATH = "outputs/"
USE_3D = False
TRAINING_EPOCHS = {
    ("CHAOS", "MR"): 30,
    ("CHAOS", "CT"): 10,
    ("MMWHS", "MR"): 30,
    ("MMWHS", "CT"): 20,
}
SPATIAL_SIZE = 64
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

In [None]:
CHECKPOINT_PATH = Path(CHECKPOINT_PATH)
OUTPUTS_PATH = Path(OUTPUTS_PATH)
DATA_PATH = Path(DATA_PATH)
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUTS_PATH.mkdir(parents=True, exist_ok=True)

if USE_3D:
    encoder_type = "swin_unetr"
else:
    encoder_type = "clipseg"

In [None]:
from typing import Callable
import torch
from monai import transforms

from src.semantic_segmentation import MedicalSegmenter


def update_metrics(name, new_metrics):
    metrics_file = OUTPUTS_PATH / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def debug_metadata(data):
    """Debug transform to print metadata information"""
    print(f"🔍 DEBUG - Data type: {type(data)}")
    if hasattr(data, "meta"):
        print(
            f"🔍 DEBUG - Metadata keys: {list(data.meta.keys()) if data.meta else 'No meta'}"
        )
        print(f"🔍 DEBUG - Full metadata: {data.meta}")
    if hasattr(data, "shape"):
        print(f"🔍 DEBUG - Shape: {data.shape}")
    if hasattr(data, "dtype"):
        print(f"🔍 DEBUG - Dtype: {data.dtype}")
    print("🔍 DEBUG - " + "=" * 50)
    return data


class MetadataAwareTransform:
    """Wrapper to make transforms work with (data, metadata) tuples"""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, data_tuple):
        if isinstance(data_tuple, tuple) and len(data_tuple) == 2:
            data, metadata = data_tuple
            try:
                transformed_data = self.transform(data)
                return transformed_data, metadata
            except Exception as e:
                print(f"Error in transform {self.transform}: {e}")
                # Return original data if transform fails
                return data, metadata
        else:
            # Fallback for non-tuple input
            return self.transform(data_tuple)


class MetadataCompose:
    """Custom Compose that handles (data, metadata) tuples properly"""

    def __init__(self, transforms_list):
        self.transforms_list = transforms_list

    def __call__(self, *args, **kwargs):
        """Apply transforms sequentially, handling both tuple and separate arguments"""
        # Handle case where apply_transform unpacks tuple as separate arguments
        if len(args) == 2 and not kwargs:
            data, metadata = args
            data_tuple = (data, metadata)
        elif len(args) == 1:
            data_tuple = args[0]
        else:
            raise ValueError(f"Unexpected arguments: args={args}, kwargs={kwargs}")

        if isinstance(data_tuple, tuple) and len(data_tuple) == 2:
            data, metadata = data_tuple

            # Apply each transform sequentially
            for transform in self.transforms_list:
                if isinstance(transform, MetadataAwareTransform):
                    data, metadata = transform((data, metadata))
                else:
                    # For regular transforms, just transform the data
                    data = transform(data)

            return data, metadata
        else:
            # Fallback for non-tuple input - apply transforms normally
            result = data_tuple
            for transform in self.transforms_list:
                if isinstance(transform, MetadataAwareTransform):
                    result = transform.transform(result)
                else:
                    result = transform(result)
            return result

    def set_random_state(self, seed=None):
        """Set random state for randomizable transforms"""
        for transform in self.transforms_list:
            if hasattr(transform, "set_random_state"):
                transform.set_random_state(seed=seed)
            elif hasattr(transform, "transform") and hasattr(
                transform.transform, "set_random_state"
            ):
                transform.transform.set_random_state(seed=seed)


def get_preprocessing(dataset_name: str, domain: str, is_training=True):
    """
    Get comprehensive preprocessing pipeline for volumetric medical data.

    Returns separate transforms for images and segmentations to work with ImageDataset.
    Handles different file formats based on dataset:
    - CHAOS: DICOM images (directories), PNG labels (directories)
    - MMWHS: NIfTI images and labels

    Note: Spatial transforms (Spacing, Resize) are handled separately to ensure
    synchronized dimensions between images and labels.
    """
    # Image-specific transforms (applied to image files)

    decode_func = get_decode_func(dataset_name, domain)

    if USE_3D:
        image_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]
    else:
        image_transforms = [
            transforms.Lambda(lambda x: print(f"Image: {x.shape}") or x),
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RA"),
        ]

    # Domain-specific intensity normalization for images
    if domain == "CT":
        image_transforms.append(
            transforms.ScaleIntensityRange(
                a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True
            ),
        )
    else:  # MR
        image_transforms.append(
            transforms.NormalizeIntensity(nonzero=True, channel_wise=True),
        )

    # Training-specific augmentations for images only
    if is_training:
        # Image-only augmentations (safe for ImageDataset)
        augmentation_transforms = [
            transforms.RandGaussianNoise(prob=0.2, std=0.05),
            transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
        ]
        image_transforms.extend(augmentation_transforms)

    if not USE_3D:
        image_transforms.append(transforms.RepeatChannel(repeats=3))

    # Final conversion to tensor for images
    image_transforms.extend(
        [
            transforms.Resize(spatial_size=SPATIAL_SIZE, size_mode="longest", mode="area", anti_aliasing=True),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    # Segmentation-specific transforms (applied to segmentation files)
    # TODO Add decode in preprocessing of segmentation labels

    if not USE_3D:
        seg_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(
                channel_dim="no_channel"
            ),  # Ensure channel-first format
            transforms.Orientation(axcodes="RA"),
        ]
    else:
        seg_transforms = [
            transforms.Lambda(lambda x: print(f": {x.shape}") or x),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]

    seg_transforms.extend(
        [
            transforms.Lambda(lambda x: decode_func(x)),
            transforms.Resize(spatial_size=SPATIAL_SIZE, size_mode="longest", mode="nearest"),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    # NOTE: Claude's suggestion, probably there's better approaches here than creating custom classes but I need to read the MONAI docs better
    # Wrap transforms to handle metadata properly
    metadata_aware_image_transforms = [
        MetadataAwareTransform(t) for t in image_transforms
    ]
    metadata_aware_seg_transforms = [MetadataAwareTransform(t) for t in seg_transforms]

    # Create separate transform pipelines that handle metadata
    image_transform = MetadataCompose(metadata_aware_image_transforms)
    seg_transform = MetadataCompose(metadata_aware_seg_transforms)

    return image_transform, seg_transform

def get_decode_func(dataset_name, domain):
    from src.datasets.mmwhs import mmwhs_labels
    decode = None
    if dataset_name == "CHAOS":
        if domain in ["MR", "MRI"]:
            def decode(labels):
                # Convert intensity values to class indices (keep as float32)
                return labels // 63
        elif domain == "CT":
            def decode(labels):
                return torch.where(labels > 0, 1.0, 0.0)
    elif dataset_name == "MMWHS":
        def decode(labels):
            decoded_labels = torch.zeros_like(labels, dtype=torch.float32)
            for i, label_val in enumerate(mmwhs_labels.keys()):
                decoded_labels[labels == label_val] = i
            return decoded_labels

    if decode is None:
        print(f"Warning: No decode function defined for {dataset_name} in {domain}. Returning labels unchanged.")
        def decode(labels):
            return labels

    return decode


In [None]:
# Finetuning loop

for (dataset_name, domain), epochs in TRAINING_EPOCHS.items():
    download_and_extract_dataset(dataset_name, DATA_PATH)

    image_transform, seg_transform = get_preprocessing(dataset_name, domain,  is_training=True)

    filename = f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
    filename = CHECKPOINT_PATH / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if USE_3D else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        transform=image_transform,  # Use transform instead of preprocess
        seg_transform=seg_transform,  # Pass seg_transform too
        base_path=DATA_PATH,
        batch_size=1,
        num_workers=0,
        slice_2d=not USE_3D,
    )

    #  Ensure the dataset is loaded correctly
    if not isinstance(dataset, BaseDataset):
        raise TypeError(
            f"Expected dataset to be an instance of BaseDataset, got {type(dataset)}"
        )
    # Print dataset information
    print()
    print(f"Dataset: {dataset_name}, Domain: {domain}")
    print(f"Number of training samples: {len(dataset.train_dataset)}")
    print(f"Number of validation samples: {len(dataset.val_dataset)}")
    print(f"Number of test samples: {len(dataset.test_dataset)}")
    print(f"Image shape: {dataset.train_dataset[0][0].shape}")
    print(f"Segmentation shape: {dataset.train_dataset[0][1].shape}")
    print(f"Number of classes: {dataset.num_classes}")
    print()


    model = dataset.get_model(
        encoder_type=encoder_type,
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        CHECKPOINT_PATH
        / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if USE_3D else '2d'} images"
    )
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline",
        model_metrics,
    )

    # Train the segmentation head
    # For the proposal of the paper this part should not be done like this!
    # This requires data - the original task arithmetic paper builds classification heads using no data, only templates
    # if we have time this point should be addressed, otherwise at least mention that the technical problem
    # requires more developement time and the proof of concept should still be somewhat solid.
    model.freeze_body()
    model.finetune(
        epochs=epochs, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )
    # Save the head
    torch.save(
        model.head,
        CHECKPOINT_PATH
        / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_head.pth",
    )

    metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_head",
        metrics,
    )

    # Finetune the encoder-decoder
    model.unfreeze()
    model.freeze_head()
    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )

    # Save the finetuned model's state_dict
    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

# Domain adaptation

In [None]:
# SWIN UNETR Task Vectors
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

safe_globals = [
    SwinUNETR,
    SwinTransformer,
    PatchEmbed,
    Conv3d,
    Dropout,
    ModuleList,
    BasicLayer,
    SwinTransformerBlock,
    LayerNorm,
    WindowAttention,
    Linear,
    Softmax,
    Identity,
    MLPBlock,
    GELU,
    PatchMerging,
    UnetrBasicBlock,
    UnetResBlock,
    Convolution,
    LeakyReLU,
    InstanceNorm3d,
    UnetrUpBlock,
    ConvTranspose3d,
    UnetOutBlock,
]
##

## CLIPSeg Task Vectors
from src.CLIPSeg import CLIPSeg
from clipseg.clipseg import CLIPDensePredT
from clip.model import (
    CLIP,
    VisionTransformer,
    LayerNorm,
    Transformer,
    ResidualAttentionBlock,
    QuickGELU,
)
from torch.nn.modules.conv import Conv2d, ConvTranspose2d
from torch.nn.modules.container import Sequential
from torch.nn.modules.activation import MultiheadAttention, ReLU
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.transformer import (
    TransformerEncoderLayer,
    TransformerEncoder,
    TransformerDecoderLayer,
    TransformerDecoder,
)
from torch.nn.functional import relu
from torch.nn.modules.container import ModuleDict

safe_globals.extend(
    [
        CLIPSeg,
        CLIPDensePredT,
        CLIP,
        VisionTransformer,
        Conv2d,
        LayerNorm,
        Transformer,
        Sequential,
        ResidualAttentionBlock,
        MultiheadAttention,
        NonDynamicallyQuantizableLinear,
        QuickGELU,
        Embedding,
        ReLU,
        ConvTranspose2d,
        TransformerEncoderLayer,
        TransformerEncoder,
        TransformerDecoderLayer,
        TransformerDecoder,
        relu,
        ModuleDict,
    ]
)
##

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images"
        )
        baseline_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            safe_globals=safe_globals,
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the .out layer from the task vector
            out_layer_keys = [k for k in task_vector.keys() if "out." in k]
            for k in out_layer_keys:
                del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

In [None]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS_CT": task_vectors["MMWHS_MR"]
    + task_vectors["CHAOS_CT"]
    - task_vectors["CHAOS_MR"],
    "MMWHS_MR": task_vectors["MMWHS_CT"]
    + task_vectors["CHAOS_MR"]
    - task_vectors["CHAOS_CT"],
    "CHAOS_CT": task_vectors["CHAOS_MR"]
    + task_vectors["MMWHS_CT"]
    - task_vectors["MMWHS_MR"],
    "CHAOS_MR": task_vectors["CHAOS_CT"]
    + task_vectors["MMWHS_MR"]
    - task_vectors["MMWHS_CT"],
}

In [None]:
# 🔄 Task Vector Cross-Domain Adaptation Experiments
print("🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in DATASET_NAMES:
    for target_domain in DOMAINS:
        print(f"\n{dataset_name}: {target_domain} adaptation")

        image_transform, seg_transform = get_preprocessing(
            dataset_name, domain, is_training=False
        )


        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": DATA_PATH,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": 1,
            "num_workers": 0,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        # try:
        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)

        composite_key = f"{dataset_name}_{target_domain}"
        if composite_key in composite_task_vectors:
            composite_task_vector = composite_task_vectors[composite_key]

            target_model = target_dataset.get_model(encoder_type=encoder_type)
            target_model.load_task_vector(composite_task_vector)

            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)
            print(
                f"   ✅ {composite_key}: Dice={metrics.get('train', {}).get('dice', 0):.3f}"
            )
        else:
            print(f"   ⚠️ No composite task vector found for {composite_key}")

        # except Exception as e:
        #     print(f"   ❌ {dataset_name} {target_domain}: {str(e)[:100]}...")
        #     import traceback
        #     traceback.print_exc()
        #     # continue
        #     break

print("=" * 80)

In [None]:
# Load and display all metrics
metrics_file = OUTPUTS_PATH / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # After Head-training performance
    print("\n🏋️‍♂️ After Head-Training Performance:")
    for key, metrics in all_metrics.items():
        if "head" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Finetuned performance
    print("\n🏆 Finetuned Performance:")
    for key, metrics in all_metrics.items():
        if "finetuned" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "adaptation" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
if IN_COLAB:
    import shutil

    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy(
        "/content/checkpoints.zip", "/content/drive/MyDrive/xai/checkpoints.zip"
    )

    # Copy metrics.json to Google Drive
    shutil.copy(
        "/content/xai/outputs/metrics.json", "/content/drive/MyDrive/xai/metrics.json"
    )

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints