The original task arithmetic paper uses CLIP and builds classification heads with zero-shot training based on some templates.
In our case, we have some special characteristics to keep in mind:
- We are operating on medical, 3D, data.
- Our objective is semantic labeling, not simple classification.

To deal with this we have a couple of options to try:
- Copy the same flow using CLIP trained on imagenet and slice the 3d images into multiple 2d.
    - Since we need to do segmentation, we need to find a model that can perform segmentation and be finetuned
    - Even better if we access separately the segmentation head and the encoder
- Use 3d resnet pretrained on medicalnet by monai, perform a few short training loops with frozen encoder to train the segmentation head. 
- Use medsam/medsam2?

In [None]:
from src.datasets.registry import get_dataset, split_train_into_train_val
from src.modeling import Classifier, get_encoder
from src.head import get_classification_head, build_classification_head
from pathlib import Path
import torch
from src.evaluation import evaluate_segmentation_performance
import json
from src.task_vector import TaskVector

In [None]:
dataset_names = ["CHAOS", "MM-WHS"]
domains = ["CT", "MR"]
save_path = "checkpoints/"
outputs_path = "out/"
use_3d = True

In [None]:
datasets_3d = {(name, domain): get_dataset(dataset_name=name, domain=domain, location=f'data/{name}/', slice_2d=False, batch_size=1, num_workers=0)  for domain in domains for name in dataset_names}
save_path = Path(save_path)
outputs_path = Path(outputs_path)
save_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
def update_metrics(name, new_metrics):
    with open(outputs_path / "metrics.json", "r") as f:
            metrics = json.load(f)
    metrics[name] = new_metrics
    with open(outputs_path / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=4)

In [None]:
# finetuning loop

for dataset_name in dataset_names:
    for domain in domains:
        filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        filename = save_path / filename
        # Check if the finetuned checkpoint already exists
        if filename.exists():
            print(f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning.")
            continue

        print(f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images")
        dataset = get_dataset(dataset_name, domain, is_train=True, base_path = "data/", batch_size=1, num_workers=0, slice_2d=not use_3d)

        model = dataset.get_model()

        model_metrics = model.evaluate()
        update_metrics(f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline", model_metrics)

        tuned_model = model.finetune(
            epochs=100,
            save_path=filename,
        )
        model_metrics = model.evaluate()
        update_metrics(f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned", model_metrics)


NameError: name 'dataset_names' is not defined

# Domain adaptation

In [None]:
# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images")
        baseline_checkpoint = save_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        finetuned_checkpoint = save_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

## CHAOS 
CHAOS MRI has labels for multiple organs, but CHAOS CT has labels for only one organ (liver).
For testing, we will use the liver label from CHAOS CT and the liver label from CHAOS MRI.

In [None]:
for source_domain, target_domain in [("CT", "MR"), ("MR", "CT")]:
    task_vector: TaskVector = task_vectors[f"CHAOS_{source_domain}"]
    chaos_target = get_dataset("CHAOS", domain=target_domain, base_path="data/", batch_size=1, num_workers=0, slice_2d=not use_3d, liver_only=True)
    model = chaos_target.get_model()
    model.load_task_vector(task_vector)
    model_metrics = model.evaluate()
    update_metrics(f"CHAOS_{target_domain}_from_{source_domain}", model_metrics)

## MM-WHS
Labels are the same in both domains

In [None]:
for source_domain, target_domain in [("CT", "MR"), ("MR", "CT")]:
    task_vector: TaskVector = task_vectors[f"MM-WHS_{source_domain}"]
    mmwhs_target = get_dataset("MM-WHS", domain=target_domain, base_path="data/", batch_size=1, num_workers=0, slice_2d=not use_3d)
    model = mmwhs_target.get_model()
    model.load_task_vector(task_vector)
    model_metrics = model.evaluate()
    update_metrics(f"MM-WHS_{target_domain}_from_{source_domain}", model_metrics)