In [None]:
# -*- coding: utf-8 -*-
"""
Inference of the SAMRI on the nifti datasets.
"""

import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry
from utils.visual import get_dice_from_ds
from utils.utils import *
from utils.dataloader import NiiDataset

from tqdm import tqdm
from utils.losses import dice_similarity

In [None]:
file_path = [TEST_IMAGE_PATH[15]]
test_dataset = NiiDataset(file_path, multi_mask= True)
len(test_dataset)

## SAMRI

In [None]:
model_type = 'samri'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type]
# checkpoint = SAM_CHECKPOINT[model_type]
checkpoint = "/scratch/project/samri//Model_save/mult/samri_vitb_mult_45.pth"
device = DEVICE

# regist the SAMRI model and predictor.
samri_model = sam_model_registry[encoder_tpye](checkpoint)
samri_model = samri_model.to(device)
samri_model.eval()

In [None]:
p_record_samri, b_record_samri = get_dice_from_ds(model=samri_model, test_dataset=test_dataset)

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_samri)
plt.title("Point prompt, SAMRI")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_samri)
plt.title("BBox prompt, SAMRI")
plt.show()

# Vit_b

In [None]:
model_type = 'vit_b'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
sam_vitb = sam_model_registry[encoder_tpye](checkpoint)
sam_vitb = sam_vitb.to(device)

In [None]:
p_record_vitb, b_record_vitb = get_dice_from_ds(model=sam_vitb, test_dataset=test_dataset)

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_vitb)
plt.title("Point prompt, SAM vit-b")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_vitb)
plt.title("BBox prompt, SAM vit-b")
plt.show()

# Vit_h

In [None]:
model_type = 'vit_h'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
sam_vith = sam_model_registry[encoder_tpye](checkpoint)
sam_vith = sam_vith.to(device)

In [None]:
p_record_vith, b_record_vith = get_dice_from_ds(model=sam_vith, test_dataset=test_dataset)

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_vith)
plt.title("Point prompt, SAM vit-h")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_vith)
plt.title("Bbox prompt, SAM vit-h")
plt.show()

# MedSAM

In [None]:
model_type = 'med_sam'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
medsam_model = sam_model_registry[encoder_tpye](checkpoint)
medsam_model = medsam_model.to(device)

In [None]:
p_record_medsam, b_record_medsam = get_dice_from_ds(model=sam_vith, test_dataset=test_dataset)

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(p_record_medsam)
plt.title("Point prompt, Med_SAM")
plt.show()

In [None]:
plt.figure(figsize = (8,8))
plt.boxplot(b_record_medsam)
plt.title("Bbox prompt, Med_SAM")
plt.show()

# Visualization

### 1. Visualize different performance by models

In [None]:
p_record = [p_record_vitb, p_record_vith, p_record_medsam, p_record_samri]
b_record = [b_record_vitb, b_record_vith, b_record_medsam, b_record_samri]
model_color = ["#DCD7C1", "#BFB1D0", "#A7C0DE", "#6C91C2"]
labels = ["SAM Vit-b", "SAM Vit-h", "MedSAM", "SAMRI"]

In [None]:
plt.figure(figsize = (8,10))
bp = plt.boxplot(p_record, labels=labels)
for box, c in zip(bp["boxes"], model_color):
    box.set(color=c)
plt.title("Point prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Models",size=20)
plt.legend(labels, loc="upper center", labelcolor=model_color)
plt.show()

In [None]:
plt.figure(figsize = (8,10))
bp = plt.boxplot(b_record, labels=labels)
for box, c in zip(bp["boxes"], model_color):
    box.set(color=c)
plt.title("Box prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Models",size=20)
plt.legend(labels, loc="lower center", labelcolor=model_color)
plt.show()

### 2. Visualize different performance by epochs

In [None]:
model_no_list = [1, 10, 20, 30, 40, 48]
cp_list = [f"/scratch/project/samri//Model_save/mult/samri_vitb_mult_{no}.pth" for no in model_no_list]
label_list = [f"epoch{i}" for i in model_no_list]

In [None]:
p_record = []
b_record = []
for i,ckpt in enumerate(cp_list):
    print(f"Testing the {i/len(model_no_list)} Model")
    model_type = 'samri'# Choose one from vit_b, vit_h, samri, and med_sam
    encoder_tpye = ENCODER_TYPE[model_type]
    # checkpoint = SAM_CHECKPOINT[model_type]
    checkpoint = ckpt
    device = DEVICE

    # regist the SAMRI model and predictor.
    samri_model = sam_model_registry[encoder_tpye](checkpoint)
    samri_model = samri_model.to(device)
    samri_model.eval()
    p_record_samri, b_record_samri = get_dice_from_ds(model=samri_model, test_dataset=test_dataset)
    p_record.append(p_record_samri)
    b_record.append(b_record_samri)

In [None]:
plt.figure(figsize = (10,8))
bp = plt.boxplot(p_record, labels=label_list)
plt.title("Point prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Epochs",size=20)
plt.show()

In [None]:
plt.figure(figsize = (10,8))
bp = plt.boxplot(b_record, labels=label_list)
plt.title("Box prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Epochs",size=20)
plt.show()

### 3. Visualize the SAM vit-b performance by different datasets.

In [None]:
file_paths = TEST_IMAGE_PATH
label_list = [file_path.split("/")[-3] for file_path in file_paths]

model_type = 'vit_b'# Choose one from vit_b, vit_h, samri, and med_sam
encoder_tpye = ENCODER_TYPE[model_type] 
checkpoint = SAM_CHECKPOINT[model_type]
device = DEVICE

# regist the MRI-SAM model and predictor.
sam_model = sam_model_registry[encoder_tpye](checkpoint)
sam_model = sam_model.to(device)


In [None]:
p_record = []
b_record = []
for file_path in file_paths:
    print("Processing the dataset: ",file_path)
    test_dataset = NiiDataset([file_path], multi_mask= True)    
    p_record_vitb, b_record_vitb = get_dice_from_ds(model=sam_model, test_dataset=test_dataset)
    p_record.append(p_record_vitb)
    b_record.append(b_record_vitb)

In [None]:
plt.figure(figsize = (10,8))
bp = plt.boxplot(p_record, labels=label_list)
plt.title("Point prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Dataset",size=20)
plt.xticks(rotation=45)
plt.show()

In [None]:
plt.figure(figsize = (10,8))
bp = plt.boxplot(b_record, labels=label_list)
plt.title("Box prompt", size=20)
plt.ylabel("DSC", size=20)
plt.xlabel("Dataset",size=20)
plt.xticks(rotation=45)
plt.show()

In [None]:
import pickle
final_record = {"p":p_record,"b":b_record}
with open("/scratch/project/samri/Eval_results/vitb_test_ds", "wb") as f:
    pickle.dump(final_record, f)