### “Zero-Shot” General Purpose Segmentation Models Have Comparable Efficacy to Trained Models: An Analysis of the Meta “Segment Anything Model” on Meningioma MRI

#### Requirements:
```
Pytorch
Segment-Anything
Roboflow
Supervision
Wget
SimpleITK
```

#### Install requirements

In [None]:
!pip install torch
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install -q roboflow supervision
!pip install wget
!wget -q 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
!pip install SimpleITK

In [None]:
import os
import SimpleITK as sitk
import numpy as np
import cv2
import torch
from segment_anything import SamPredictor, sam_model_registry

CHECKPOINT = ""  # checkpoint used for model
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
print(DEVICE)

os.makedirs("./segments/", exist_ok=True)

base_path = "/meningioma-segments-chunked/"  # route to the dataset


def read_dataset(start_path):
    all_files = []
    for root, _, files in os.walk(start_path):
        for f in files:
            if f.endswith(".nii"):
                all_files.append(os.path.join(root, f))
    all_patient_ids = sorted(
        list(set([f.split("/")[-1].split("-")[2] for f in all_files]))
    )
    return all_patient_ids, all_files


all_patient_ids, all_files = read_dataset(base_path)


def get_imgs_for_patient(patient_id):
    files = [f for f in all_files if f.split("-")[-3] == patient_id]
    if len(files) > 5:
        files = files[:5]
    assert len(files) == 5
    imgs = {}
    for file in files:
        modality = file.split("-")[-1].split(".")[0]
        assert modality in ["t1n", "t1c", "t2f", "t2w", "seg"]

        img = sitk.ReadImage(file)
        assert img is not None
        assert img.GetSize() == (240, 240, 155)
        img = sitk.GetArrayFromImage(img)

        if modality == "seg":
            img = img.astype(np.uint8)
        else:
            i_min, i_max = np.min(img), np.max(img)
            img = (img - i_min) / (i_max - i_min) * 255
            img = np.round(img).astype(np.uint8)

        imgs[modality] = img

    return imgs


def download_object(obj, f_name):
    # download the object to the file system for your architecture
    pass


def get_bounding_box(seg):
    tumor_indices = np.where(seg == 3)  # just the enhancing bit
    min_x, min_y, min_z = np.min(tumor_indices, axis=1)
    max_x, max_y, max_z = np.max(tumor_indices, axis=1)

    return min_x, min_y, min_z, max_x, max_y, max_z


def get_set_of_slices_that_need_segmentation_and_corresponding_ground_truth(imgs):
    min_x, min_y, min_z, max_x, max_y, max_z = get_bounding_box(imgs["seg"])
    ranges = [max_x - min_x, max_y - min_y, max_z - min_z]
    min_range_idx = np.argmin(ranges)
    img_slices = {}
    for modality, img in imgs.items():
        if min_range_idx == 0:
            slices = [img[i, :, :] for i in range(min_x, max_x)]
        elif min_range_idx == 1:
            slices = [img[:, i, :] for i in range(min_y, max_y)]
        else:
            slices = [img[:, :, i] for i in range(min_z, max_z)]

        img_slices[modality] = slices

    return img_slices


def get_pos_neg_points_from_ground_truth(seg_slice, img_slice, num_points):
    compartment_nonenh = np.where(seg_slice == 1)
    compartment_edema = np.where(seg_slice == 2)
    compartment_enh = np.where(seg_slice == 3)

    def randomly_pick_n_points(n, compartment):
        try:
            num_points = compartment[0].shape[0]
            random_indices = np.random.choice(num_points, n, replace=False)
            points = np.array(
                [compartment[1][random_indices], compartment[0][random_indices]]
            ).T
        except:
            points = np.array([[0, 0]])

        return points

    pos_points = {
        "nonenh": randomly_pick_n_points(num_points, compartment_nonenh),
        "edema": randomly_pick_n_points(num_points, compartment_edema),
        "enh": randomly_pick_n_points(num_points, compartment_enh),
    }

    neg_points = {
        "nonenh": np.vstack(
            [
                randomly_pick_n_points(num_points, compartment_enh),
                randomly_pick_n_points(num_points, compartment_edema),
            ]
        ),
        "edema": np.vstack(
            [
                randomly_pick_n_points(num_points, compartment_enh),
                randomly_pick_n_points(num_points, compartment_nonenh),
            ]
        ),
        "enh": np.vstack(
            [
                randomly_pick_n_points(num_points, compartment_nonenh),
                randomly_pick_n_points(num_points, compartment_edema),
            ]
        ),
    }

    return pos_points, neg_points


def sample_points_from_mask(mask, num_points):
    indices = np.where(mask > 0.5)
    num_indices = indices[0].shape[0]
    num_points_to_sample = min(num_points, num_indices)
    random_indices = np.random.choice(num_indices, num_points_to_sample, replace=False)
    points = np.array([indices[1][random_indices], indices[0][random_indices]]).T
    return {"enh": points}


def segment_single_slice(
    patient_id, slice_index, img_slice, seg_slice, num_points=1, mask_guess=None
):
    def get_bounding_box_from_mask(mask):
        indices = np.where(mask > 0.5)
        min_x = np.min(indices[1])
        min_y = np.min(indices[0])
        max_x = np.max(indices[1])
        max_y = np.max(indices[0])
        return [(min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)]

    bounding_box = None
    if mask_guess is not None:
        mask_guess = np.where(mask_guess > 0.5, 1, 0)
        bounding_box = get_bounding_box_from_mask(mask_guess)
        pos_points = sample_points_from_mask(mask_guess, num_points)
        neg_points = sample_points_from_mask(1 - mask_guess, num_points)
    else:
        pos_points, neg_points = get_pos_neg_points_from_ground_truth(
            seg_slice, img_slice, num_points
        )
    predictor.set_image(cv2.cvtColor(img_slice, cv2.COLOR_GRAY2BGR))

    def dice_score_components(pred_mask, ground_truth):
        pred_mask = pred_mask > 0.5
        ground_truth = ground_truth > 0
        intersection = np.sum(pred_mask * ground_truth)
        union = np.sum(pred_mask) + np.sum(ground_truth)
        return {
            "dice": 2 * intersection / union if union > 0 else 0,
            "pred": np.sum(pred_mask),
            "truth": np.sum(ground_truth),
            "inter": intersection,
            "union": union,
        }

    def segment_compartment_on_single_slice(compartment_name):
        input_points = np.vstack(
            [pos_points[compartment_name], neg_points[compartment_name]]
        )
        input_labels = np.array(
            [1] * len(pos_points[compartment_name])
            + [0] * len(neg_points[compartment_name])
        )
        input_box = None
        if bounding_box is not None:
            top_left_x = np.min([b[0] for b in bounding_box])
            top_left_y = np.min([b[1] for b in bounding_box])
            bottom_right_x = np.max([b[0] for b in bounding_box])
            bottom_right_y = np.max([b[1] for b in bounding_box])
            input_box = np.array(
                [top_left_x, top_left_y, bottom_right_x, bottom_right_y]
            )
            input_box = input_box[None, :]

        mask, _, _ = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            mask_input=None,
            multimask_output=False,
            box=input_box,
        )

        val = (
            1
            if compartment_name == "nonenh"
            else (2 if compartment_name == "edema" else 3)
        )
        ground_truth = np.where(seg_slice == val, 255, 0).astype(np.uint8)

        score = dice_score_components(mask, ground_truth)
        return {"mask": mask.tolist(), "score": score["dice"]}

    segmentation_results = segment_compartment_on_single_slice("enh")
    return {
        "mask": segmentation_results["mask"],
        "score": segmentation_results["score"],
        "compartment_sizes": {
            "nonenh": np.sum(np.where(seg_slice == 1, 1, 0)),
            "edema": np.sum(np.where(seg_slice == 2, 1, 0)),
            "enh": np.sum(np.where(seg_slice == 3, 1, 0)),
        },
    }


def segment_entire_tumor(patient_id, num_points):
    imgs = get_imgs_for_patient(patient_id)
    img_slices = (
        get_set_of_slices_that_need_segmentation_and_corresponding_ground_truth(imgs)
    )
    total_enh_union = 0
    total_enh_intersection = 0
    img_seg = imgs["seg"]
    total_true_size = {
        "enh": np.sum(np.where(img_seg == 3, 1, 0)),
        "nonenh": np.sum(np.where(img_seg == 1, 1, 0)),
        "edema": np.sum(np.where(img_seg == 2, 1, 0)),
    }

    for slice_index in range(len(img_slices["seg"])):
        img_slice = img_slices["t1c"][slice_index]
        seg_slice = img_slices["seg"][slice_index]
        slice_performance = segment_single_slice(
            patient_id, slice_index, img_slice, seg_slice, num_points=num_points
        )

        slice_segment_filename = "/base_segment_path/%s/slice%s_%spoints.pkl" % (
            patient_id,
            slice_index,
            num_points,
        )
        download_object(slice_performance, slice_segment_filename)

        total_enh_intersection += slice_performance["score"]["inter"]
        total_enh_union += slice_performance["score"]["union"]

    volumetric_dice = 2 * total_enh_intersection / total_enh_union

    patient_summary_totals = {
        "volumetric_dice": volumetric_dice,
        "inter": total_enh_intersection,
        "union": total_enh_union,
        "true_sizes": total_true_size,
    }
    patient_summary_filename = "/base_segment_path/%s/%s_%spoints.pkl" % (
        patient_id,
        "total",
        num_points,
    )
    download_object(patient_summary_totals, patient_summary_filename)


def segment_entire_tumor_using_slice_carry_through(patient_id, num_points):
    imgs = get_imgs_for_patient(patient_id)
    seg = imgs["seg"]
    total_true_size = {
        "enh": np.sum(np.where(seg == 3, 1, 0)),
        "nonenh": np.sum(np.where(seg == 1, 1, 0)),
        "edema": np.sum(np.where(seg == 2, 1, 0)),
    }

    min_x, min_y, min_z, max_x, max_y, max_z = get_bounding_box(seg)
    ranges = [max_x - min_x, max_y - min_y, max_z - min_z]
    min_range_idx = np.argmin(ranges)
    img_slices = {}
    for modality, img in imgs.items():
        if min_range_idx == 0:
            slices = [img[i, :, :] for i in range(min_x, max_x)]
        elif min_range_idx == 1:
            slices = [img[:, i, :] for i in range(min_y, max_y)]
        else:
            slices = [img[:, :, i] for i in range(min_z, max_z)]

        img_slices[modality] = slices

    max_seg = None
    max_seg_index = None
    max_seg_size = 0
    for slice_index, seg_slice in enumerate(img_slices["seg"]):
        seg_size = np.sum(np.where(seg_slice == 3, 1, 0))
        if seg_size > max_seg_size:
            max_seg_size = seg_size
            max_seg = seg_slice
            max_seg_index = slice_index

    # march through the tumor first forwards then backwards, segmenting each slice by using the previous slice as a carry through
    total_enh_union = 0
    total_enh_intersection = 0

    previous_mask = np.where(max_seg == 3, 1, 0)

    def segment_slice_for_carry_through(
        img_slice, seg_slice, slice_index, previous_mask
    ):
        seg_perf = segment_single_slice(
            patient_id,
            slice_index,
            img_slice,
            seg_slice,
            num_points=num_points,
            mask_guess=previous_mask,
        )
        slice_segment_filename = "/base_segment_path/%s/slice%s_%spoints_carry.pkl" % (
            patient_id,
            slice_index,
            num_points,
        )
        download_object(seg_perf, slice_segment_filename)
        updated_dice = seg_perf["score"]["dice"]
        return seg_perf, updated_dice

    previous_mask = np.where(max_seg == 3, 1, 0)
    for slice_index in range(max_seg_index, len(img_slices["seg"])):  # forwards
        img_slice = img_slices["t1c"][slice_index]
        seg_slice = img_slices["seg"][slice_index]

        slice_performance = segment_slice_for_carry_through(
            img_slice, seg_slice, slice_index, previous_mask
        )

        previous_mask = np.array(slice_performance["mask"])
        total_enh_intersection += slice_performance["score"]["inter"]
        total_enh_union += slice_performance["score"]["union"]

    previous_mask = np.where(max_seg == 3, 1, 0)
    for slice_index in range(max_seg_index - 1, -1, -1):  # backwards
        img_slice = img_slices["t1c"][slice_index]
        seg_slice = img_slices["seg"][slice_index]

        slice_performance = segment_slice_for_carry_through(
            img_slice, seg_slice, slice_index, previous_mask
        )

        previous_mask = np.array(slice_performance["mask"])
        total_enh_intersection += slice_performance["score"]["inter"]
        total_enh_union += slice_performance["score"]["union"]

    # compute the volumetric dice score
    volumetric_dice = 2 * total_enh_intersection / total_enh_union

    patient_summary_totals = {
        "volumetric_dice": volumetric_dice,
        "inter": total_enh_intersection,
        "union": total_enh_union,
        "true_sizes": total_true_size,
    }
    patient_summary_filename = "/base_segment_path/%s/%s_%spoints_carry.pkl" % (
        patient_id,
        "total",
        num_points,
    )
    download_object(patient_summary_totals, patient_summary_filename)


def segment_entire_tumor_using_trislice_consensus_with_carrythrough(
    patient_id, num_points
):
    imgs = get_imgs_for_patient(patient_id)
    seg = imgs["seg"]
    total_true_size = {
        "enh": np.sum(np.where(seg == 3, 1, 0)),
        "nonenh": np.sum(np.where(seg == 1, 1, 0)),
        "edema": np.sum(np.where(seg == 2, 1, 0)),
    }

    min_x, min_y, min_z, max_x, max_y, max_z = get_bounding_box(seg)
    img_slices = {}
    axial_slices = []
    coronal_slices = []
    sagittal_slices = []
    for modality, img in imgs.items():
        axial_slices = [img[i, :, :] for i in range(min_x, max_x)]
        coronal_slices = [img[:, i, :] for i in range(min_y, max_y)]
        sagittal_slices = [img[:, :, i] for i in range(min_z, max_z)]

        img_slices[modality] = {
            "axi": axial_slices,
            "cor": coronal_slices,
            "sag": sagittal_slices,
        }

    max_segs = {
        "axi": {"max_seg": None, "max_seg_index": None, "max_seg_size": 0},
        "cor": {"max_seg": None, "max_seg_index": None, "max_seg_size": 0},
        "sag": {"max_seg": None, "max_seg_index": None, "max_seg_size": 0},
    }
    for dir in ["axi", "cor", "sag"]:
        for slice_index, seg_slice in enumerate(img_slices["seg"][dir]):
            seg_size = np.sum(np.where(seg_slice == 3, 1, 0))
            if seg_size > max_segs[dir]["max_seg_size"]:
                max_segs[dir]["max_seg_size"] = seg_size
                max_segs[dir]["max_seg"] = seg_slice
                max_segs[dir]["max_seg_index"] = slice_index

    # reconstruct the final segmentation by taking the consensus of the three segmentations in 3D
    consensus = np.zeros_like(imgs["seg"], dtype="uint8")

    def update_consensus_with_slice(
        current_consensus, slice_segmentation, slice_index, dir
    ):
        adjusted_slice_index = (
            slice_index + min_x
            if dir == "axi"
            else (slice_index + min_y if dir == "cor" else slice_index + min_z)
        )
        if dir == "axi":
            current_consensus[adjusted_slice_index, :, :] = np.add(
                current_consensus[adjusted_slice_index, :, :],
                np.where(slice_segmentation > 0.5, 1, 0),
            )
        elif dir == "cor":
            current_consensus[:, adjusted_slice_index, :] = np.add(
                current_consensus[:, adjusted_slice_index, :],
                np.where(slice_segmentation > 0.5, 1, 0),
            )
        else:
            current_consensus[:, :, adjusted_slice_index] = np.add(
                current_consensus[:, :, adjusted_slice_index],
                np.where(slice_segmentation > 0.5, 1, 0),
            )
        return current_consensus

    def segment_slice_for_trislice_carry_through(
        img_slice, seg_slice, slice_index, dir, previous_mask
    ):
        seg_perf = segment_single_slice(
            patient_id,
            slice_index,
            img_slice,
            seg_slice,
            num_points=num_points,
            mask_guess=previous_mask,
        )

        slice_segment_filename = (
            "/base_segment_path/%s/%s_slice%s_%spoints_trislice_carry.pkl"
            % (patient_id, dir, slice_index, num_points)
        )
        download_object(seg_perf, slice_segment_filename)
        updated_dice = seg_perf["score"]["dice"]
        return seg_perf, updated_dice

    for dir in ["axi", "cor", "sag"]:
        previous_mask = np.where(max_segs[dir]["max_seg"] == 3, 1, 0)
        for slice_index in range(
            max_segs[dir]["max_seg_index"], len(img_slices["seg"][dir])
        ):  # go through the slices forwards first
            img_slice = img_slices["t1c"][dir][slice_index]
            seg_slice = img_slices["seg"][dir][slice_index]

            slice_performance = segment_slice_for_trislice_carry_through(
                img_slice, seg_slice, slice_index, dir, previous_mask
            )

            previous_mask = np.array(slice_performance["mask"])
            consensus_slice = np.where(previous_mask > 0.5, 1, 0)
            consensus = update_consensus_with_slice(
                consensus, consensus_slice, slice_index, dir
            )

        previous_mask = np.where(max_segs[dir]["max_seg"] == 3, 1, 0)
        for slice_index in range(max_segs[dir]["max_seg_index"] - 1, -1, -1):
            img_slice = img_slices["t1c"][dir][slice_index]
            seg_slice = img_slices["seg"][dir][slice_index]

            slice_performance = segment_slice_for_trislice_carry_through(
                img_slice, seg_slice, slice_index, dir, previous_mask
            )
            previous_mask = np.array(slice_performance["mask"])
            consensus_slice = np.where(previous_mask > 0.5, 1, 0)
            consensus = update_consensus_with_slice(
                consensus, consensus_slice, slice_index, dir
            )

    consensus_with_1 = np.where(consensus >= 1, 1, 0)
    consensus_with_2 = np.where(consensus >= 2, 1, 0)
    consensus_with_3 = np.where(consensus >= 3, 1, 0)

    bounding_box = {
        "min_x": min_x,
        "min_y": min_y,
        "min_z": min_z,
        "max_x": max_x,
        "max_y": max_y,
        "max_z": max_z,
    }

    ground_truth = np.where(imgs["seg"] == 3, 1, 0)

    def compute_volumetric_dice(proposed_seg, real_seg):
        intersection = np.sum(proposed_seg * real_seg)
        union = np.sum(proposed_seg) + np.sum(real_seg)
        dice = 2 * intersection / union if union > 0 else 0
        return dice

    volumetric_dice_1 = compute_volumetric_dice(consensus_with_1, ground_truth)
    volumetric_dice_2 = compute_volumetric_dice(consensus_with_2, ground_truth)
    volumetric_dice_3 = compute_volumetric_dice(consensus_with_3, ground_truth)

    patient_summary_totals = {
        "volumetric_dice_1": volumetric_dice_1,
        "volumetric_dice_2": volumetric_dice_2,
        "volumetric_dice_3": volumetric_dice_3,
        "true_sizes": total_true_size,
        "bounding_box": bounding_box,
    }

    patient_summary_filename = (
        "/base_segment_path/%s/%s_%spoints_trislice_carry.pkl"
        % (patient_id, "total", num_points)
    )
    download_object(patient_summary_totals, patient_summary_filename)