In [None]:
# -*- coding: utf-8 -*-
"""
Inference of the MRI-SAM on the nifti datasets.
"""

import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from utils.visual import *
from utils.utils import *
from utils.dataloader import NiiDataset
from utils.prompt import *
from tqdm import tqdm
from utils.losses import dice_similarity

In [None]:
encoder_tpye = ENCODER_TYPE['vit_b'] # Choose one from vit_b, vit_h, and med_sam
checkpoint = SAM_CHECKPOINT[encoder_tpye]
device = DEVICE

# regist the MRI-SAM model and predictor.
mri_sam_model = sam_model_registry[encoder_tpye](checkpoint)
mri_sam_model = mri_sam_model.to(device)
predictor = SamPredictor(mri_sam_model)

# load dataset
file_path = TEST_IMAGE_PATH
test_dataset = NiiDataset(file_path)


In [None]:
# setup essential parameters.
num_points = NUM_POINTS
num_bboxes = NUM_BBOXES
jitter = JITTER
record = {}
save_path = SAVE_PATH


for image, mask in tqdm(test_dataset):
    # Image embedding inference
    predictor.set_image(image)
    
    name = test_dataset.get_name()
    # split the multi-labeled mask into single labeled
    # logit masks.
    masks = MaskSplit(mask)

    sub_record = {"p":[], "b":[], "mp":[]}
    p_rec = []
    mp_rec = []
    b_rec = []

    for each_label_mask in masks: # shape is HW=(255, 255)
        # generate prompts
        point = gen_points(each_label_mask)
        point_label = np.array([1])
        points = gen_points(each_label_mask, num_points=num_points)
        points_label = []
        for i in range(num_points):
            points_label += [1]
        points_label = np.array(points_label)
        bbox = gen_bboxes(each_label_mask, jitter=jitter)

        # generate mask
        pre_mask_p, _, _ = predictor.predict(
                            point_coords=point,
                            point_labels=point_label,
                            multimask_output=False,
                        )
        
        pre_mask_b, _, _ = predictor.predict(
                            point_coords=None,
                            point_labels=None,
                            box=bbox[None, :],
                            multimask_output=False,
                        )
        
        pre_mask_mp, _, _ = predictor.predict(
                            point_coords=points,
                            point_labels=points_label,
                            multimask_output=False,
                        )
        
        p_dice = dice_similarity(each_label_mask, pre_mask_p[0, :, :])
        b_dice = dice_similarity(each_label_mask, pre_mask_b[0, :, :])
        mp_dice = dice_similarity(each_label_mask, pre_mask_mp[0, :, :])

        p_rec.append(p_dice)
        b_rec.append(b_dice)
        mp_rec.append(mp_dice)

    sub_record["p"].append(p_rec)
    sub_record["b"].append(b_rec)
    sub_record["mp"].append(mp_rec)

    record[name] = sub_record


# Vit_b

In [None]:
p_record = [[] for i in range (1, 7)]
mp_record = [[] for i in range (1, 7)]
b_record = [[] for i in range (1, 7)]
for name in record.keys():
    for i in range(6):
        p_record[i].append(record[name]['p'][0][i])
        b_record[i].append(record[name]['b'][0][i])
        mp_record[i].append(record[name]['mp'][0][i])

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record,labels=LABEL_LIST)
plt.title("Point prompt, SAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(mp_record,labels=LABEL_LIST)
plt.title("Multi-points prompt, SAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record,labels=LABEL_LIST)
plt.title("Box prompt, SAM vit-b")
plt.show()

# Vit_h

In [None]:
p_record = [[] for i in range (1, 7)]
mp_record = [[] for i in range (1, 7)]
b_record = [[] for i in range (1, 7)]
for name in record.keys():
    for i in range(6):
        p_record[i].append(record[name]['p'][0][i])
        b_record[i].append(record[name]['b'][0][i])
        mp_record[i].append(record[name]['mp'][0][i])

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record,labels=LABEL_LIST)
plt.title("Point prompt, SAM vit-h")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(mp_record,labels=LABEL_LIST)
plt.title("Multi-points prompt, SAM vit-h")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record,labels=LABEL_LIST)
plt.title("Box prompt, SAM vit-h")
plt.show()

# MedSAM

In [None]:
p_record = [[] for i in range (1, 7)]
mp_record = [[] for i in range (1, 7)]
b_record = [[] for i in range (1, 7)]
for name in record.keys():
    for i in range(6):
        p_record[i].append(record[name]['p'][0][i])
        b_record[i].append(record[name]['b'][0][i])
        mp_record[i].append(record[name]['mp'][0][i])

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record,labels=LABEL_LIST)
plt.title("Point prompt, MedSAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(mp_record,labels=LABEL_LIST)
plt.title("Multi-points prompt, MedSAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record,labels=LABEL_LIST)
plt.title("Box prompt, MedSAM vit-b")
plt.show()