The original task arithmetic paper uses CLIP and builds classification heads with zero-shot training based on some templates.
In our case, we have some special characteristics to keep in mind:
- We are operating on medical, 3D, data.
- Our objective is semantic labeling, not simple classification.

To deal with this we have a couple of options to try:
- Copy the same flow using CLIP trained on imagenet and slice the 3d images into multiple 2d.
    - Since we need to do segmentation, we need to find a model that can perform segmentation and be finetuned
    - Even better if we access separately the segmentation head and the encoder
- Use 3d resnet pretrained on medicalnet by monai, perform a few short training loops with frozen encoder to train the segmentation head. 
- Use medsam/medsam2?

In Kaggle, add the following to the dependencies:
```
pip install 'torch>=1.12.1'
torchvision
numpy
pydicom
PILlow
matplotlib
transformers>=4.20.0
git+https://github.com/bowang-lab/MedSAM2.git
```
Also remember to enable file persistence and internet access

In [None]:
# Uncomment on Kaggle
# !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
# %cd xai
# !git pull
# %pip install 'napari[pyqt6,optional]==0.6.2a1' 'monai[einops,nibabel]>=1.1.0' open-clip-torch

In [None]:
# Uncomment on Colab
# !git clone https://github.com/parmigggiana/xai /content/xai
# %cd /content/xai
# !git fetch
# !git reset --hard origin/main
# %pip install -r requirements.txt

In [None]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [None]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["MR", "CT"]
data_path = "data/"
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True

In [None]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
import torch


def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def get_preprocessing(domain):
    # if domain == "MR":
    #         def preprocess(x):
    #             return {
    #                 "image": torch.nn.functional.interpolate(
    #                     x["image"].float(), scale_factor=0.5
    #                 ),
    #                 "label": torch.nn.functional.interpolate(
    #                     x["label"].float(), scale_factor=0.5, mode="nearest"
    #                 ).long()  if x["label"] is not None else None
    #             }
    # elif domain == "CT":
    #     def preprocess(x):
    #         if x["image"].shape[-1] > 256:
    #             return {
    #                 "image": torch.nn.functional.interpolate(
    #                     x["image"].float(), scale_factor=0.125
    #                 ),
    #                 "label": torch.nn.functional.interpolate(
    #                     x["label"].float(), scale_factor=0.125, mode="nearest"
    #                 ).long() if x["label"] is not None else None
    #             }
    #         else:
    #             return x

    def preprocess(x):
        # Only scale down to (96, 128, 128) if larger
        target_shape = (96, 128, 128)
        img = x["image"].float()
        lbl = x["label"].float() if x["label"] is not None else None

        if img.shape[-3:] != target_shape:
            img = torch.nn.functional.interpolate(
                img.unsqueeze(0),
                size=target_shape,
                mode="trilinear",
                align_corners=False,
            ).squeeze(0)
            if lbl is not None:
                lbl = (
                    torch.nn.functional.interpolate(
                        lbl.unsqueeze(0), size=target_shape, mode="nearest"
                    )
                    .squeeze(0)
                    .long()
                )
        else:
            if lbl is not None:
                lbl = lbl.long()

        return {"image": img, "label": lbl}

    return preprocess

In [None]:
# Finetuning loop using the new hybrid approach

for dataset_name in dataset_names:
    download_and_extract_dataset(dataset_name, data_path)

    for domain in domains:

        preprocess = get_preprocessing(domain)

        filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        filename = checkpoint_path / filename
        # Check if the finetuned checkpoint already exists
        if filename.exists():
            print(
                f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
            )
            continue

        print(
            f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images (hybrid approach)"
        )
        dataset: BaseDataset = get_dataset(
            dataset_name=dataset_name,
            domain=domain,
            preprocess=preprocess,
            base_path=data_path,
            batch_size=1,
            num_workers=1,
            slice_2d=not use_3d,
        )

        # Use the new hybrid model method
        model = dataset.get_hybrid_model(
            encoder_type="swin_unetr",
            use_semantic_head=False,
        )

        head_filename = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )

        # segmentation_head = model.semantic_head
        # torch.save(segmentation_head, head_filename)

        # Save the baseline model's state_dict before finetuning
        baseline_filename = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        torch.save(model.encoder, baseline_filename)
        print(
            f"Processing {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images (hybrid approach)"
        )
        model_metrics = model.evaluate()
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline_hybrid",
            model_metrics,
        )

        history = model.finetune(
            epochs=15,
            learning_rate=5e-4,
            weight_decay=1e-5,
        )
        # Save the finetuned model's state_dict

        torch.save(model.encoder, filename)
        model_metrics = model.evaluate()
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned_hybrid",
            model_metrics,
        )
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

# Domain adaptation

In [None]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images"
        )
        baseline_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            [
                SwinUNETR,
                SwinTransformer,
                PatchEmbed,
                Conv3d,
                Dropout,
                ModuleList,
                BasicLayer,
                SwinTransformerBlock,
                LayerNorm,
                WindowAttention,
                Linear,
                Softmax,
                Identity,
                MLPBlock,
                GELU,
                PatchMerging,
                UnetrBasicBlock,
                UnetResBlock,
                Convolution,
                LeakyReLU,
                InstanceNorm3d,
                UnetrUpBlock,
                ConvTranspose3d,
                UnetOutBlock,
            ]
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

In [None]:
# Task Vector Cross-Domain Evaluation (merged version, with nested loops for all configs)
print("\n🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in dataset_names:
    for source_domain in domains:
        for target_domain in domains:
            if source_domain == target_domain:
                continue
            key = (dataset_name, source_domain, target_domain)
            metrics_key = f"{dataset_name}_{target_domain}_from_{source_domain}_hybrid"
            # Only add extra_kwarg 'liver_only' for CHAOS dataset
            extra_kwargs = {}
            if dataset_name == "CHAOS":
                extra_kwargs["liver_only"] = True
            try:
                task_vector_key = f"{dataset_name}_{source_domain}"
                if task_vector_key in task_vectors:
                    print(
                        f"\n{dataset_name}: {source_domain} → {target_domain} adaptation"
                    )
                    # Load target domain dataset
                    dataset_kwargs = dict(
                        dataset_name=dataset_name,
                        domain=target_domain,
                        base_path=data_path,
                        batch_size=1,
                        num_workers=1,
                        slice_2d=False,
                    )
                    dataset_kwargs.update(extra_kwargs)
                    target_dataset = get_dataset(**dataset_kwargs)
                    target_model = target_dataset.get_hybrid_model(
                        encoder_type="swin_unetr",
                        use_semantic_head=False,
                    )

                    # Apply task vector from source domain
                    source_task_vector = task_vectors[task_vector_key]
                    target_model.load_task_vector(source_task_vector)

                    # Evaluate cross-domain performance
                    cross_domain_metrics = target_model.evaluate()
                    update_metrics(metrics_key, cross_domain_metrics)

                    print(
                        f"   ✅ {source_domain}→{target_domain}: Dice={cross_domain_metrics.get('dice', 0):.3f}"
                    )
                    print(
                        f"   📊 Hausdorff Distance: {cross_domain_metrics.get('hausdorff', 0):.3f}"
                    )
            except Exception as e:
                print(
                    f"   ❌ {dataset_name} {source_domain}→{target_domain} error: {e}"
                )

In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance (Hybrid Semantic-Guided):")
    for key, metrics in all_metrics.items():
        if "baseline" in key and "hybrid" in key:
            dice = metrics.get("dice", 0)
            hausdorff = metrics.get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "from" in key and "hybrid" in key:
            dice = metrics.get("dice", 0)
            hausdorff = metrics.get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")