In [None]:
import os
import sys
import importlib
import time
from contextlib import redirect_stdout
import json

import numpy as np
import pickle
from skimage import transform

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

import importlib
from src.medsam_segmentation import MedSAMTool
from src.data_io import ImageData

In [2]:
def _get_binary_masks(nonbinary_mask):
    """ 
    Given nonbinary mask which encodes N masks, return N binary masks which
    should encode the same information.
    
    Parameters:
        - nonbinary_mask: ndarray of shape (H, W)
    Returns:
        - binary_masks: ndarray of shape (N, H, W)
    """
    binary_masks = []
    for i in np.unique(nonbinary_mask)[1:]:
        binary_mask = (nonbinary_mask == i).astype(np.uint8)
        binary_masks.append(binary_mask.copy())
    binary_masks = np.stack(binary_masks, axis=0)
    return binary_masks

In [10]:
def get_data(modality, exp_type):
    query_name = "X-Ray" if modality == "xray" else "Dermoscopy"
    imgs_2d_and_3d = os.listdir(os.path.join(os.getcwd(), "../data/imgs"))
    imgs_2d = [f for f in imgs_2d_and_3d if f.startswith('2D')] # 2803
    imgs_2d_modality = [f for f in imgs_2d if query_name in f]    # 50
    print(f"{len(imgs_2d_modality)} images in {modality} modality")

    np.random.seed(42)
    split_resulting_len = len(imgs_2d_modality) // 2
    val_filenames_bank  = np.random.choice(imgs_2d_modality, size=split_resulting_len, replace=False)
    test_filenames_bank = [f for f in imgs_2d_modality if f not in val_filenames_bank]

    file_str = f"{modality}_{exp_type}"
    print("Starting experiment", file_str)

    if exp_type.startswith("val"):
        filebank = val_filenames_bank
    else:
        filebank = test_filenames_bank
    sample_size = int(exp_type.split("_")[-1])
    
    # ========================== Unpack ==========================
    print(f"\nUnpacking {len(filebank)} images...")
    val_raw_imgs, val_raw_boxes, val_raw_gts = [], [], []
    for i, img_filename in enumerate(filebank):
        img_data = np.load(os.path.join(os.getcwd(), f"../data/imgs/{img_filename}"))
        mask_data = np.load(os.path.join(os.getcwd(), f"../data/gts/{img_filename}"))
        
        image, boxes, nonbinary_mask = img_data['imgs'], img_data["boxes"], mask_data['gts']
        num_masks = 0
        for box, mask in zip(boxes, _get_binary_masks(nonbinary_mask)):
            x1, y1, x2, y2 = box
            box_string = f"[{x1},{y1},{x2},{y2}]"
            val_raw_imgs.append(image)
            val_raw_boxes.append(box_string)
            val_raw_gts.append(mask)
            num_masks += 1
        
        print(f"Processed idx {i}: {img_filename} -> {num_masks} masks")
    print(f"Finished unpacking images into {len(val_raw_imgs)}.\n")

    # ========================== Resize ==========================
    print("Resizing images...")
    random_5_indices = np.random.choice(len(val_raw_imgs), size=sample_size, replace=False)
    imgs_to_resize = [val_raw_imgs[i] for i in random_5_indices]
    boxes_to_resize = [val_raw_boxes[i] for i in random_5_indices]
    gts_to_use = [val_raw_gts[i] for i in random_5_indices]
    
    resized_imgs, resized_boxes = [], []
    resized_gts = gts_to_use
    for i, (img_np, box_str) in enumerate(zip(imgs_to_resize, boxes_to_resize)):
        if len(img_np.shape) == 2:
            img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
        else:
            img_3c = img_np

        H, W, _ = img_3c.shape

        # Resize image to 1024x1024
        img_1024 = transform.resize(
            img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
        ).astype(np.uint8)

        img_1024 = img_1024 / 255.0
        resized_imgs.append(img_1024)

        # Scale box to 1024x1024
        box_np = np.array([[int(x) for x in box_str[1:-1].split(',')]])
        box_scaled = box_np / np.array([W, H, W, H]) * 1024
        resized_boxes.append(box_scaled)

        print(f"file {i} | og img shape {img_np.shape} | box_str {box_str} -> {box_scaled.shape}")

    resized_filepath = os.path.join(os.getcwd(), f"../data/resized_{file_str}.pkl")
    # if file doesn't exist, create it
    if not os.path.exists(resized_filepath):
        os.makedirs(os.path.dirname(resized_filepath), exist_ok=True)
    with open(resized_filepath, "wb") as f:
        pickle.dump((resized_imgs, resized_boxes, resized_gts), f)
    print(f"Saved resized data to {resized_filepath}\n")

In [None]:
# This cell takes 10 mins to run
from contextlib import redirect_stdout
with open("../data/output.log", "w") as f, redirect_stdout(f):
    # Your code here
    for modality in ["xray", "dermoscopy"]:
        for exp_type in ["val_filenames_5", "test_filenames_25", "val_filenames_25"]:
            get_data(modality, exp_type)

In [None]:
def get_baseline(modality):
    combined_val, combined_val = None, None
    for exp_type in ["val", "test"]:
        np.random.seed(42)
        
        print(f"\n================== Running baseline for {modality}, {exp_type} set ==================")

        resized_filepath = f"../data/resized_{modality}_{exp_type}_filenames_25.pkl"
        
        # ================ Get baseline ================
        baseline_start_time = time.time()
        print(f"Reading from {resized_filepath}")
        with open(resized_filepath, "rb") as f:
            resized_imgs, resized_boxes, resized_gts = pickle.load(f)

        segmenter = MedSAMTool(gpu_id=3, checkpoint_path="../data/medsam_vit_b.pth")

        used_imgs = resized_imgs
        used_boxes = resized_boxes
        used_masks = resized_gts

        images = ImageData(
            raw=used_imgs,
            batch_size=min(8, len(used_imgs)),
            image_ids=[i for i in range(len(used_imgs))],
            masks=used_masks,
            predicted_masks=used_masks,
        )
        is_rgb = modality == "dermoscopy"
        pred_masks = segmenter.predict(images, used_boxes, used_for_baseline=True, is_rgb=is_rgb)
        metrics_dict = segmenter.evaluate(pred_masks, used_masks)

        if exp_type == "val":
            combined_val = metrics_dict['dsc_metric'] + metrics_dict['nsd_metric']
            print(f"Combined metric: {combined_val:.4f}")

        else:
            combined_test = metrics_dict['dsc_metric'] + metrics_dict['nsd_metric']
            print(f"Combined metric: {combined_test:.4f}")

        print(f"Finished running {exp_type} baseline in {time.time() - baseline_start_time:.2f} seconds")
    
    # Save the metrics to a file
    f"../data/resized_{modality}_{exp_type}.pkl"
    metrics_filepath = f"../data/{modality}_baseline_expert.json"
    with open(metrics_filepath, "w") as f:
        json_output = {
            "expert_baseline_val_avg_metric": combined_val,
            "expert_baseline_test_avg_metric": combined_test,
        }
        json.dump(json_output, f)

In [None]:
with open("../data/get_baselines_output.log", "w") as f, redirect_stdout(f):
    for modality in ["dermoscopy", "xray"]:
        get_baseline(modality)