In [None]:
# -*- coding: utf-8 -*-
"""
Inference of the SAMRI on the nifti datasets.
"""

import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from utils.visual import *
from utils.utils import *
from utils.dataloader import NiiDataset
from utils.prompt import *
from tqdm import tqdm
from utils.losses import dice_similarity

In [None]:
file_path = [TEST_IMAGE_PATH[0]]
test_dataset = NiiDataset(file_path, multi_mask= False)
len(test_dataset)

In [None]:
num_points = 1
num_bboxes = 1
jitter = JITTER

## SAMRI

In [None]:
model_type = 'samri'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type]
# checkpoint = SAM_CHECKPOINT[model_type]
checkpoint = "/scratch/project/samri//Model_save/mult/samri_vitb_mult_45.pth"
device = DEVICE

# regist the MRI-SAM model and predictor.
mri_sam_model = sam_model_registry[encoder_tpye](checkpoint)
mri_sam_model = mri_sam_model.to(device)
mri_sam_model.eval()
predictor = SamPredictor(mri_sam_model)

In [None]:
p_record_samri = []
b_record_samri = []

for image, mask in tqdm(test_dataset):
    # Image embedding inference
    predictor.set_image(image)
    
    name = test_dataset.get_name()
    mask = mask[0,:,:]

    # generate prompts
    point = gen_points(mask)
    point_label = np.array([1])
    points = gen_points(mask, num_points=num_points)
    points_label = []
    for i in range(num_points):
        points_label += [1]
    points_label = np.array(points_label)
    bbox = gen_bboxes(mask, jitter=jitter)

    # generate mask
    pre_mask_p, _, _ = predictor.predict(
                        point_coords=point,
                        point_labels=point_label,
                        multimask_output=False,
                    )
    
    pre_mask_b, _, _ = predictor.predict(
                        point_coords=None,
                        point_labels=None,
                        box=bbox[None, :],
                        multimask_output=False,
                    )

    p_record_samri.append(dice_similarity(mask, pre_mask_p[0, :, :]))
    b_record_samri.append(dice_similarity(mask, pre_mask_b[0, :, :]))



In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_samri)
plt.title("Point prompt, SAMRI")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_samri)
plt.title("BBox prompt, SAMRI")
plt.show()

# Vit_b

In [None]:
model_type = 'vit_b'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
mri_sam_model = sam_model_registry[encoder_tpye](checkpoint)
mri_sam_model = mri_sam_model.to(device)
predictor = SamPredictor(mri_sam_model)

In [None]:
p_record_vitb = []
b_record_vitb = []

for image, mask in tqdm(test_dataset):
    # Image embedding inference
    predictor.set_image(image)
    
    name = test_dataset.get_name()
    mask = mask[0,:,:]

    # generate prompts
    point = gen_points(mask)
    point_label = np.array([1])
    points = gen_points(mask, num_points=num_points)
    points_label = []
    for i in range(num_points):
        points_label += [1]
    points_label = np.array(points_label)
    bbox = gen_bboxes(mask, jitter=jitter)

    # generate mask
    pre_mask_p, _, _ = predictor.predict(
                        point_coords=point,
                        point_labels=point_label,
                        multimask_output=False,
                    )
    
    pre_mask_b, _, _ = predictor.predict(
                        point_coords=None,
                        point_labels=None,
                        box=bbox[None, :],
                        multimask_output=False,
                    )

    p_record_vitb.append(dice_similarity(mask, pre_mask_p[0, :, :]))
    b_record_vitb.append(dice_similarity(mask, pre_mask_b[0, :, :]))


In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_vitb)
plt.title("Point prompt, SAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_vitb)
plt.title("BBox prompt, SAM vit-b")
plt.show()

# Vit_h

In [None]:
model_type = 'vit_h'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
mri_sam_model = sam_model_registry[encoder_tpye](checkpoint)
mri_sam_model = mri_sam_model.to(device)
predictor = SamPredictor(mri_sam_model)

In [None]:
p_record_vith = []
b_record_vith = []

for image, mask in tqdm(test_dataset):
    # Image embedding inference
    predictor.set_image(image)
    
    name = test_dataset.get_name()
    mask = mask[0,:,:]

    # generate prompts
    point = gen_points(mask)
    point_label = np.array([1])
    points = gen_points(mask, num_points=num_points)
    points_label = []
    for i in range(num_points):
        points_label += [1]
    points_label = np.array(points_label)
    bbox = gen_bboxes(mask, jitter=jitter)

    # generate mask
    pre_mask_p, _, _ = predictor.predict(
                        point_coords=point,
                        point_labels=point_label,
                        multimask_output=False,
                    )
    
    pre_mask_b, _, _ = predictor.predict(
                        point_coords=None,
                        point_labels=None,
                        box=bbox[None, :],
                        multimask_output=False,
                    )

    p_record_vith.append(dice_similarity(mask, pre_mask_p[0, :, :]))
    b_record_vith.append(dice_similarity(mask, pre_mask_b[0, :, :]))


In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_vith)
plt.title("Point prompt, SAM vit-h")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_vith)
plt.title("Bbox prompt, SAM vit-h")
plt.show()

# MedSAM

In [None]:
model_type = 'med_sam'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
mri_sam_model = sam_model_registry[encoder_tpye](checkpoint)
mri_sam_model = mri_sam_model.to(device)
predictor = SamPredictor(mri_sam_model)

In [None]:
p_record_medsam = []
b_record_medsam = []

for image, mask in tqdm(test_dataset):
    # Image embedding inference
    predictor.set_image(image)
    
    name = test_dataset.get_name()
    mask = mask[0,:,:]

    # generate prompts
    point = gen_points(mask)
    point_label = np.array([1])
    points = gen_points(mask, num_points=num_points)
    points_label = []
    for i in range(num_points):
        points_label += [1]
    points_label = np.array(points_label)
    bbox = gen_bboxes(mask, jitter=jitter)

    # generate mask
    pre_mask_p, _, _ = predictor.predict(
                        point_coords=point,
                        point_labels=point_label,
                        multimask_output=False,
                    )
    
    pre_mask_b, _, _ = predictor.predict(
                        point_coords=None,
                        point_labels=None,
                        box=bbox[None, :],
                        multimask_output=False,
                    )

    p_record_medsam.append(dice_similarity(mask, pre_mask_p[0, :, :]))
    b_record_medsam.append(dice_similarity(mask, pre_mask_b[0, :, :]))


In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_medsam)
plt.title("Point prompt, Med_SAM")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_medsam)
plt.title("Bbox prompt, Med_SAM")
plt.show()