The original task arithmetic paper uses CLIP and builds classification heads with zero-shot training based on some templates.
In our case, we have some special characteristics to keep in mind:
- We are operating on medical, 3D, data.
- Our objective is semantic labeling, not simple classification.

To deal with this we have a couple of options to try:
- Copy the same flow using CLIP trained on imagenet and slice the 3d images into multiple 2d.
    - Since we need to do segmentation, we need to find a model that can perform segmentation and be finetuned
    - Even better if we access separately the segmentation head and the encoder
- Use 3d resnet pretrained on medicalnet by monai, perform a few short training loops with frozen encoder to train the segmentation head. 
- Use medsam/medsam2?

In [1]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector

In [2]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["CT", "MR"]
data_path = 'data/'
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True

In [3]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [4]:
import torch

def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
                metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)

In [None]:
# Finetuning loop using the new hybrid approach

for dataset_name in dataset_names:
    for domain in domains:
        filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        filename = checkpoint_path / filename
        # Check if the finetuned checkpoint already exists
        if filename.exists():
            print(
                f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
            )
            continue

        print(
            f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images (hybrid approach)"
        )
        dataset: BaseDataset = get_dataset(
            dataset_name=dataset_name,
            domain=domain,
            base_path=data_path,
            batch_size=1,
            num_workers=0,
            slice_2d=not use_3d,
        )

        # Use the new hybrid model method
        model = dataset.get_hybrid_model(
            encoder_type="resnet",
            use_semantic_head=True,
        )

        head_filename =  checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baselined.pth"
        segmentation_head = model.semantic_head
        torch.save(segmentation_head, head_filename)

        # Save the baseline model's state_dict before finetuning
        baseline_filename =  checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baselined.pth"
        torch.save(model.encoder, baseline_filename)
        model_metrics = model.evaluate()
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline_hybrid",
            model_metrics,
        )

        model.finetune(
            epochs=5,
        )
        # Save the finetuned model's state_dict

        torch.save(model.encoder, filename)
        model_metrics = model.evaluate()
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned_hybrid",
            model_metrics,
        )

Finetuning on CHAOS dataset in CT domain with 3d images (hybrid approach)
Batch size: torch.Size([1, 1, 257, 512, 512]) - RAM: 2538.2MB
Batch size: torch.Size([1, 1, 257, 512, 512]) - RAM: 2538.2MB


# Domain adaptation

In [None]:
# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images")
        baseline_checkpoint = checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        finetuned_checkpoint = checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in CT domain with 3d images


FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/CHAOS_CT_3d_baseline.pth'

In [None]:
# Task Vector Cross-Domain Evaluation (merged version, with nested loops for all configs)
print("\n🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in dataset_names:
    for source_domain in domains:
        for target_domain in domains:
            if source_domain == target_domain:
                continue
            key = (dataset_name, source_domain, target_domain)
            metrics_key = f"{dataset_name}_{target_domain}_from_{source_domain}_hybrid"
            # Only add extra_kwarg 'liver_only' for CHAOS dataset
            extra_kwargs = {}
            if dataset_name == "CHAOS":
                extra_kwargs["liver_only"] = True
            try:
                task_vector_key = f"{dataset_name}_{source_domain}"
                if task_vector_key in task_vectors:
                    print(f"\n{dataset_name}: {source_domain} → {target_domain} adaptation")
                    # Load target domain dataset
                    dataset_kwargs = dict(
                        dataset_name=dataset_name,
                        domain=target_domain,
                        base_path=data_path,
                        batch_size=1,
                        num_workers=0,
                        slice_2d=False,
                    )
                    dataset_kwargs.update(extra_kwargs)
                    target_dataset = get_dataset(**dataset_kwargs)
                    target_model = target_dataset.get_model()

                    # Apply task vector from source domain
                    source_task_vector = task_vectors[task_vector_key]
                    target_model.load_task_vector(source_task_vector)

                    # Evaluate cross-domain performance
                    cross_domain_metrics = target_model.evaluate()
                    update_metrics(metrics_key, cross_domain_metrics)

                    print(f"   ✅ {source_domain}→{target_domain}: Dice={cross_domain_metrics.get('dice', 0):.3f}")
                    print(f"   📊 Hausdorff Distance: {cross_domain_metrics.get('hausdorff', 0):.3f}")
            except Exception as e:
                print(f"   ❌ {dataset_name} {source_domain}→{target_domain} error: {e}")


In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, 'r') as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance (Hybrid Semantic-Guided):")
    for key, metrics in all_metrics.items():
        if 'baseline' in key and 'hybrid' in key:
            dice = metrics.get('dice', 0)
            hausdorff = metrics.get('hausdorff', 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if 'from' in key and 'hybrid' in key:
            dice = metrics.get('dice', 0)
            hausdorff = metrics.get('hausdorff', 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
# Test the evaluation method with fixed tensor handling
import torch
from pathlib import Path

print("Testing evaluation method with hybrid model...")

try:
    # Get a small dataset for testing
    base_path = Path("data")
    test_dataset = get_dataset('CHAOS', base_path, use_3d=True, train=True, domain='CT')

    # Create hybrid model without semantic head to avoid transformer dependencies
    hybrid_model = test_dataset.get_hybrid_model(encoder_type="swin_unetr", use_semantic_head=False)

    print(f"Model created: {type(hybrid_model)}")
    print(f"Model device: {next(hybrid_model.parameters()).device}")

    # Create a minimal test by getting one batch from the loader
    test_loader = test_dataset.train_loader
    batch = next(iter(test_loader))
    print(f"Batch image shape: {batch['image'].shape}")
    print(f"Batch label shape: {batch.get('label', torch.empty(0)).shape}")

    # Test just the forward pass without evaluation metrics
    with torch.no_grad():
        images = batch['image']
        # Check if we need to slice for SwinUNETR dimension requirements
        if images.shape[2] >= 32:  # Depth dimension
            sliced_images = images[:, :, :32, :, :].contiguous()
            print(f"Sliced images shape: {sliced_images.shape}")
            output = hybrid_model(sliced_images)
            print(f"✓ Forward pass successful! Output shape: {output.shape}")
        else:
            print(f"Image depth {images.shape[2]} < 32, cannot test with SwinUNETR")

except Exception as e:
    print(f"Error during test: {e}")
    import traceback
    traceback.print_exc()