In [None]:
import os
import sys
import importlib
import time
import json

import numpy as np
import pickle
from skimage import transform

project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) # scratch folder
if project_root not in sys.path:
    sys.path.append(project_root)

import importlib
from src.medsam_segmentation import MedSAMTool
from src.data_io import ImageData

In [None]:
def get_baseline(modality):
    combined_val, combined_val = None, None
    for exp_type in ["val", "test"]:
        np.random.seed(42)
        
        print(f"\n================== Running baseline for {modality}, {exp_type} set ==================")

        resized_filepath = f"../data/resized_{modality}_{exp_type}_filenames_25.pkl"
        
        # ================ Get baseline ================
        baseline_start_time = time.time()
        print(f"Reading from {resized_filepath}")
        with open(resized_filepath, "rb") as f:
            resized_imgs, resized_boxes, resized_gts = pickle.load(f)

        segmenter = MedSAMTool(gpu_id=3, checkpoint_path="../data/medsam_vit_b.pth")

        used_imgs = resized_imgs
        used_boxes = resized_boxes
        used_masks = resized_gts

        images = ImageData(
            raw=used_imgs,
            batch_size=min(8, len(used_imgs)),
            image_ids=[i for i in range(len(used_imgs))],
            masks=used_masks,
            predicted_masks=used_masks,
        )
        is_rgb = modality == "dermoscopy"
        pred_masks = segmenter.predict(images, used_boxes, used_for_baseline=True, is_rgb=is_rgb)
        metrics_dict = segmenter.evaluate(pred_masks, used_masks)

        if exp_type == "val":
            combined_val = metrics_dict['dsc_metric'] + metrics_dict['nsd_metric']
            print(f"Combined metric: {combined_val:.4f}")

        else:
            combined_test = metrics_dict['dsc_metric'] + metrics_dict['nsd_metric']
            print(f"Combined metric: {combined_test:.4f}")

        print(f"Finished running {exp_type} baseline in {time.time() - baseline_start_time:.2f} seconds")
    
    # Save the metrics to a file
    f"../data/resized_{modality}_{exp_type}.pkl"
    metrics_filepath = f"../data/{modality}_baseline_expert.json"
    with open(metrics_filepath, "w") as f:
        json_output = {
            "expert_baseline_val_avg_metric": combined_val,
            "expert_baseline_test_avg_metric": combined_test,
        }
        json.dump(json_output, f)

In [None]:
from contextlib import redirect_stdout
with open("../data/get_baselines_output.log", "w") as f, redirect_stdout(f):
    for modality in ["dermoscopy", "xray"]:
        get_baseline(modality)