## Manual data loading

In [None]:
import torch
torch.cuda.get_device_name(0)

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import random
import torch
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize
os.listdir()
os.chdir('..')


In [None]:
Div="TRAIN"

## Preprocess data

We extract the bounding box coordinates which will be used to feed into SAM as prompts.

In [None]:
df=pd.read_csv("FileList.csv").set_index("FileName")

In [None]:
#Divide into Train/Val/Test according to the original file
Train=list(df[df["Split"]=="TRAIN"].index)
Val=list(df[df["Split"]=="VAL"].index)
Test=list(df[df["Split"]=="TEST"].index)

In [None]:
#Select the Train set
Train0=[s + "_ED" for s in Train]
Train1=[s + "_ES" for s in Train]
z=sorted(Train0+Train1)

In [None]:
name_mapper = {'TRAIN': Train, 'VAL': Val, 'TEST': Test}
z=sorted([s + "_ED" for s in name_mapper[Div]]+[s + "_ES" for s in name_mapper[Div]])

In [None]:
#Adapted for EchoNet dataset
import pandas as pd
bbox_df=pd.read_excel("bbox_coords.xlsx").set_index("Frame")
bbox_coords = {}
for k in bbox_df.index:
    bbox_coords[k]=np.array([bbox_df.loc[k,"xmin"], bbox_df.loc[k,"ymin"],bbox_df.loc[k,"xmax"],bbox_df.loc[k,"ymax"]])

In [None]:
print(z)

We extract the ground truth segmentation masks

In [None]:
remove_ls = ['0X234005774F4CB5CD_ED', '0X234005774F4CB5CD_ES', '0X2DC68261CBCC04AE_ED', '0X2DC68261CBCC04AE_ES', '0X35291BE9AB90FB89_ED', '0X35291BE9AB90FB89_ES', '0X5515B0BD077BE68A_ED', '0X5515B0BD077BE68A_ES', '0X6C435C1B417FDE8A_ED', '0X6C435C1B417FDE8A_ES']
for i in remove_ls:
    if i in z:
        z.remove(i)

In [None]:
os.getcwd()

In [None]:
from collections import defaultdict

import torch

from PIL import Image


from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)

all_images = []
all_masks = []
all_boxes = []
orig_masks = []
with torch.no_grad():
    for k in tqdm(z[:100]):#bbox_coords.keys():
        image = cv2.imread(f'Images/{Div}/{k}.jpg')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        all_images.append(image)
        original_image_size = image.shape[:2]

        gt_grayscale = cv2.imread(f'Labels/{Div}/{k}.png', cv2.IMREAD_GRAYSCALE)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        orig_masks.append(mask)
        gt_grayscale = cv2.resize(gt_grayscale, (256, 256), interpolation = cv2.INTER_AREA)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        all_masks.append(mask)

        all_boxes.append(bbox_coords[k])

print(len(all_images), len(all_masks))

In [None]:
torch.cuda.get_device_name(0)

In [None]:
plt.imshow(all_images[0])

In [None]:
all_images = np.stack(all_images, axis=0)
all_masks = torch.stack(all_masks, dim=0)
orig_masks = torch.stack(orig_masks, dim=0)
all_boxes = np.stack(all_boxes, axis=0)
print(all_images.shape, all_masks.shape,all_boxes.shape)

In [None]:
model_type = 'vit_b'
checkpoint = 'sam_vit_b_01ec64.pth'
device = 'cuda:0'

In [None]:
from segment_anything import SamPredictor, sam_model_registry
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
#sam_model.to(device)    # SAM is on CUDA
sam_model.train();    # Tells sam_model that we are running under training mode
#sam_model.eval();    # Tells sam_model that we are running under testing mode

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# def train(images, masks, batch_size, num_epochs, device):
#     # Convert the numpy arrays to PyTorch tensors
# images_tensor = torch.from_numpy(images).float()
# masks_tensor = torch.from_numpy(masks).float()

batch_size = 8

transform = ResizeLongestSide(sam_model.image_encoder.img_size)

In [None]:
import os
import numpy as np
import torch


class EchoDataset(Dataset):
    def __init__(self, all_images, all_masks, all_boxes, transform):
        self.all_images = all_images # this is np # N, W, H, 3
        self.all_masks = all_masks # this is tensor # N, W, H
        self.all_boxes = all_boxes
        self.transform = transform

    def __len__(self):
        return self.all_images.shape[0]

    def __getitem__(self, index):
        image =  torch.as_tensor(self.transform.apply_image(self.all_images[index])) # W, H, 3 # resized
        image = image.permute(2, 0, 1).contiguous()

        mask = self.all_masks[index] # W, H

        prompt_box = self.all_boxes[index]
        box = self.transform.apply_boxes(prompt_box, original_image_size)
        box = torch.as_tensor(box, dtype=torch.float)[0].unsqueeze(0)

        return image, mask, box

dataset = EchoDataset(all_images, all_masks, all_boxes, transform)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
batch = next(iter(train_dataloader))
print(type(batch))
batch_images, batch_masks, batch_boxes = batch
print('batch_images: ', batch_images.shape)
print("batch_masks: ", batch_masks.shape)
print("batch_boxes: ", batch_boxes)
plt.imshow(batch_images[0].permute(1, 2, 0))

## Load the model (Note that you will create a new model!)

from transformers import SamModel

model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
pass

## Train the model

In [None]:
from transformers import SamModel

model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.load_state_dict(torch.load('7.6noon_model.pth', map_location=torch.device('cuda')))

In [None]:
# Helper functions provided in https://github.com/facebookresearch/segment-anything/blob/9e8f1309c94f1128a6e5c047a10fdcb02fc8d651/notebooks/predictor_example.ipynb
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    #print(x0, y0, w, h)
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize
torch.cuda.mem_get_info()
torch.cuda.empty_cache()
torch.cuda.mem_get_info()


In [None]:
from torch.optim import Adam
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=8e-6, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
num_epochs = 150
model.train()

for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      batch_images, batch_masks, batch_boxes = batch
      '''
      plt.figure(figsize=(10,10))
      plt.imshow(batch_images[0].permute(1, 2, 0))
      show_box(batch_boxes.squeeze(1)[0], plt.gca())
      #show_mask(batch_masks[0], plt.gca())
      plt.axis('off')
      plt.show()

      plt.figure(figsize=(10,10))
      show_mask(batch_masks[0], plt.gca())
      plt.axis('off')
      plt.show()
      '''
      batch_images = batch_images.float().to(device)  # convert images to float
      batch_boxes = batch_boxes.float().to(device)  # convert boxes to float



      # forward pass
      #print("pixel_values:", batch_images.shape)
      #print('input_boxes: ', batch_boxes.shape)
      #print('ground_truth_mask: ', batch_masks.shape)
      outputs = model(pixel_values=batch_images.to(device),
                      input_boxes=batch_boxes.to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      #print("predicted_masks: ", predicted_masks.shape)

      ground_truth_masks = batch_masks

      #fig, axes = plt.subplots(1)
      #show_box(batch_boxes.squeeze(1).to('cpu').numpy()[0], axes)
      #show_mask(predicted_masks.detach().cpu().numpy()[0], axes)

      #plt.show()

      loss = seg_loss(predicted_masks.to(device), ground_truth_masks.unsqueeze(1).to(device))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

## Inference

Important note here: as we used the Dice loss with `sigmoid=True`, we need to make sure to appropriately apply a sigmoid activation function to the predicted masks. Hence we won't use the processor's `post_process_masks` method here.

In [None]:
import numpy as np
from PIL import Image

# let's take a random training example
idx = 0

# load image
image = all_images[idx]
print(image.shape)
plt.imshow(image)

In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
# get box prompt based on ground truth segmentation map
ground_truth_mask = orig_masks[0].numpy()
print(ground_truth_mask.shape)
prompt = all_boxes[0].tolist()
print(prompt)

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
outputs

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

In [None]:
# A comparison of predicted and ground truth
fig, axes = plt.subplots()

axes.imshow(np.array(image))
medsam_seg = cv2.resize(medsam_seg, (112, 112), interpolation = cv2.INTER_NEAREST)
show_mask(medsam_seg, axes)
show_box(prompt, axes)
axes.title.set_text(f"Predicted mask")
axes.axis("off")

fig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(ground_truth_mask, axes)
show_box(prompt, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

# Evaluation

In [None]:
def dice_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

In [None]:
#Adapted for EchoNet dataset; Load zero-shot masks
from sklearn.metrics import jaccard_score


Div="TEST"
from tqdm import tqdm


#Select the Val set
Val0=[s + "_ED" for s in Val]
Val1=[s + "_ES" for s in Val]
z=sorted(Val0+Val1)

name_mapper = {'TRAIN': Train, 'VAL': Val, 'TEST': Test}
z=sorted([s + "_ED" for s in name_mapper[Div]]+[s + "_ES" for s in name_mapper[Div]])

In [None]:
remove_ls = ['0X234005774F4CB5CD_ED', '0X234005774F4CB5CD_ES', '0X2DC68261CBCC04AE_ED', '0X2DC68261CBCC04AE_ES', '0X35291BE9AB90FB89_ED', '0X35291BE9AB90FB89_ES', '0X5515B0BD077BE68A_ED', '0X5515B0BD077BE68A_ES', '0X6C435C1B417FDE8A_ED', '0X6C435C1B417FDE8A_ES', '0X5DD5283AC43CCDD1_ED', '0X5DD5283AC43CCDD1_ES']
for i in remove_ls:
    if i in Val0:
        Val0.remove(i)
    if i in Val1:
        Val1.remove(i)

In [None]:
from collections import defaultdict

import torch

from PIL import Image

from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)

val_images = []
val_masks = []
val_boxes = []
val_names = []

with torch.no_grad():
    for k in tqdm(z[0:2576]):#bbox_coords.keys():
        image = cv2.imread(f'Images/{Div}/{k}.jpg')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        val_images.append(image)
        original_image_size = image.shape[:2]

        gt_grayscale = cv2.imread(f'Labels/{Div}/{k}.png', cv2.IMREAD_GRAYSCALE)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        val_masks.append(mask)

        val_boxes.append(bbox_coords[k])
        
        val_names.append(k)

print(len(val_images), len(val_masks))

In [None]:
from PIL import Image

In [None]:
print(val_names[0])

In [None]:
# Eval in 112

ground_truth_masks = []
pred_masks = []

print_idx = 0
for i in tqdm(range(2554)):
  image = val_images[i]
  mask = val_masks[i]
  name = val_names[i]
  ground_truth_masks.append(mask)
  prompt = val_boxes[i].tolist()
  inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)

  model.eval()

  with torch.no_grad():
      outputs = model(**inputs, multimask_output=False)

  # apply sigmoid
  medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
  # convert soft mask to hard mask
  medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
  medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
  medsam_seg = cv2.resize(medsam_seg, (112, 112), interpolation = cv2.INTER_NEAREST)
  pred_masks.append(medsam_seg)
  rgb_data = np.zeros((112, 112, 3), dtype=np.uint8)
  rgb_data[medsam_seg == 1] = [255, 255, 255]  # White
  rgb_data[medsam_seg == 0] = [0, 0, 0]        # Black

  img = Image.fromarray(rgb_data)
  img.save(f'masks/test/{name}.png')
  '''
  predictions, label= pred_masks[i],ground_truth_masks[i]
  iou = jaccard_score(predictions, label, average="micro")
  dice = dice_metric(predictions, label).item()
  #print("iou: ", iou)
  #print("dice: ", dice)
  '''

'''
  if dice - 0.958 < 0.005 and 0.958 - dice < 0.005:
    print("i: ", i)
    print("iou: ", iou)
    print("dice: ", dice)
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(medsam_seg, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Predicted mask")
    axes.axis("off")
    
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(mask, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Ground truth mask")
    axes.axis("off")
    
    print_idx += 1
    
    %matplotlib inline 
    _, axs = plt.subplots(1, 2, figsize=(25, 25))


    axs[0].imshow(np.array(image))
    show_mask(medsam_seg, axs[0])
    show_box(prompt, axs[0])
    axs[0].set_title("Predicted mask", fontsize=26)
    axs[0].axis('off')


    axs[1].imshow(np.array(image))
    show_mask(mask, axs[1])
    show_box(prompt, axs[1])
    axs[1].set_title("Ground truth mask", fontsize=26)
    axs[1].axis('off')

    plt.show()  
'''

iou=[]
dice=[]
for i in tqdm(range(2576)):
    predictions, label= pred_masks[i],ground_truth_masks[i]
    iou.append(jaccard_score(predictions, label, average="micro"))
    dice.append(dice_metric(predictions, label))

print("mean IoU:" f'{np.mean(np.array(iou))}')
print("mean Dice:" f'{np.mean(np.array(dice))}')

In [None]:
print(z[123])

In [None]:
plt.close()

In [None]:
print("mean IoU:" f'{np.mean(np.array(iou))}')
print("mean Dice:" f'{np.mean(np.array(dice))}')
print("Iou 2.5%, 97.5:" f'{np.percentile(np.array(iou), 2.5),np.percentile(np.array(iou), 97.5)}')
print("Dice 2.5%, 97.5:" f'{np.percentile(np.array(dice), 2.5),np.percentile(np.array(dice), 97.5)}')


In [None]:
# Eval in 256
from collections import defaultdict

import torch

from PIL import Image

from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)

val_images = []
val_masks = []
val_boxes = []

with torch.no_grad():
    for k in tqdm(z[:1000]):#bbox_coords.keys():
        image = cv2.imread(f'Images/{Div}/{k}.jpg')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        val_images.append(image)
        original_image_size = image.shape[:2]

        gt_grayscale = cv2.imread(f'Labels/{Div}/{k}.png', cv2.IMREAD_GRAYSCALE)
        gt_grayscale = cv2.resize(gt_grayscale, (256, 256), interpolation = cv2.INTER_NEAREST)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        val_masks.append(mask)

        val_boxes.append(bbox_coords[k])

print(len(val_images), len(val_masks))
ground_truth_masks = []
pred_masks = []

print_idx = 0
for i in tqdm(range(1000)):
  image = val_images[i]
  mask = val_masks[i]
  ground_truth_masks.append(mask)
  prompt = val_boxes[i].tolist()
  inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)

  model.eval()

  with torch.no_grad():
      outputs = model(**inputs, multimask_output=False)

  # apply sigmoid
  medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
  # convert soft mask to hard mask
  medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
  medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
  #medsam_seg = cv2.resize(medsam_seg, (112, 112), interpolation = cv2.INTER_NEAREST)
  pred_masks.append(medsam_seg)

  if print_idx < 10:
    image = cv2.resize(image, (256, 256), interpolation = cv2.INTER_NEAREST)
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(medsam_seg, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Predicted mask")
    axes.axis("off")
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(mask, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Ground truth mask")
    axes.axis("off")
    print_idx += 1

iou=[]
dice=[]
for i in tqdm(range(1000)):
    predictions, label= pred_masks[i],ground_truth_masks[i]
    print(predictions.shape)
    print(label.shape)
    iou.append(jaccard_score(predictions, label, average="micro"))
    dice.append(dice_metric(predictions, label))

print("mean IoU:" f'{np.mean(np.array(iou))}')
print("mean Dice:" f'{np.mean(np.array(dice))}')

# Separate Eval

In [None]:
#Adapted for EchoNet dataset; Load zero-shot masks
from sklearn.metrics import jaccard_score


Div="TRAIN"
from tqdm import tqdm
ground_truth_masks = []
pred_masks = []

#Select the Val set
Val0=[s + "_ED" for s in Train]
Val1=[s + "_ES" for s in Train]
z=sorted(Val0+Val1)

z0 = sorted(Val0)
z1 = sorted(Val1)

In [None]:
remove_ls = ['0X234005774F4CB5CD_ED', '0X234005774F4CB5CD_ES', '0X2DC68261CBCC04AE_ED', '0X2DC68261CBCC04AE_ES', '0X35291BE9AB90FB89_ED', '0X35291BE9AB90FB89_ES', '0X5515B0BD077BE68A_ED', '0X5515B0BD077BE68A_ES', '0X6C435C1B417FDE8A_ED', '0X6C435C1B417FDE8A_ES', '0X5DD5283AC43CCDD1_ED', '0X5DD5283AC43CCDD1_ES']
for i in remove_ls:
    if i in z0:
        z0.remove(i)
    if i in z1:
        z1.remove(i)

In [None]:
from collections import defaultdict

import torch

from PIL import Image

from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)

val_images0 = []
val_masks0 = []
val_boxes0 = []

with torch.no_grad():
    for k in tqdm(z0):#bbox_coords.keys():
        image = cv2.imread(f'Images/{Div}/{k}.jpg')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        val_images0.append(image)
        original_image_size = image.shape[:2]

        gt_grayscale = cv2.imread(f'Labels/{Div}/{k}.png', cv2.IMREAD_GRAYSCALE)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        val_masks0.append(mask)

        val_boxes0.append(bbox_coords[k])

print(len(val_images0), len(val_masks0))

In [None]:
transformed_data = defaultdict(dict)

val_images1 = []
val_masks1 = []
val_boxes1 = []

with torch.no_grad():
    for k in tqdm(z1):#bbox_coords.keys():
        image = cv2.imread(f'Images/{Div}/{k}.jpg')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        val_images1.append(image)
        original_image_size = image.shape[:2]

        gt_grayscale = cv2.imread(f'Labels/{Div}/{k}.png', cv2.IMREAD_GRAYSCALE)
        mask = (gt_grayscale == gt_grayscale.max())
        mask = torch.tensor(mask).cpu()
        mask = (mask > 0).float()
        val_masks1.append(mask)

        val_boxes1.append(bbox_coords[k])

print(len(val_images1), len(val_masks1))

In [None]:
# 0 - ED, 1-ES
val_images = val_images1
val_masks = val_masks1
val_boxes = val_boxes1

# Eval in 112

ground_truth_masks = []
pred_masks = []

print_idx = 0
for i in tqdm(range(7460)):
  image = val_images[i]
  mask = val_masks[i]
  ground_truth_masks.append(mask)
  prompt = val_boxes[i].tolist()
  inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)

  model.eval()

  with torch.no_grad():
      outputs = model(**inputs, multimask_output=False)

  # apply sigmoid
  medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
  # convert soft mask to hard mask
  medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
  medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
  medsam_seg = cv2.resize(medsam_seg, (112, 112), interpolation = cv2.INTER_NEAREST)
  pred_masks.append(medsam_seg)

  if print_idx < 10:
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(medsam_seg, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Predicted mask")
    axes.axis("off")
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    show_mask(mask, axes)
    show_box(prompt, axes)
    axes.title.set_text(f"Ground truth mask")
    axes.axis("off")
    print_idx += 1

iou=[]
dice=[]
for i in tqdm(range(7460)):
    predictions, label= pred_masks[i],ground_truth_masks[i]
    iou.append(jaccard_score(predictions, label, average="micro"))
    dice.append(dice_metric(predictions, label))

print("mean IoU:" f'{np.mean(np.array(iou))}')
print("mean Dice:" f'{np.mean(np.array(dice))}')
print("Iou 2.5%, 97.5:" f'{np.percentile(np.array(iou), 2.5),np.percentile(np.array(iou), 97.5)}')
print("Dice 2.5%, 97.5:" f'{np.percentile(np.array(dice), 2.5),np.percentile(np.array(dice), 97.5)}')


In [None]:
print("mean IoU:" f'{np.mean(np.array(iou))}')
print("mean Dice:" f'{np.mean(np.array(dice))}')
print("Iou 2.5%, 97.5:" f'{np.percentile(np.array(iou), 2.5),np.percentile(np.array(iou), 97.5)}')
print("Dice 2.5%, 97.5:" f'{np.percentile(np.array(dice), 2.5),np.percentile(np.array(dice), 97.5)}')

## Legacy

The code below was used during the creation of this notebook, but was eventually not used anymore.

In [None]:
import torch.nn.functional as F
from typing import Tuple
from torch.nn import MSELoss

loss_fn = MSELoss()

def postprocess_masks(masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], image_size=1024) -> torch.Tensor:
    """
    Remove padding and upscale masks to the original image size.

    Args:
      masks (torch.Tensor):
        Batched masks from the mask_decoder, in BxCxHxW format.
      input_size (tuple(int, int)):
        The size of the image input to the model, in (H, W) format. Used to remove padding.
      original_size (tuple(int, int)):
        The original size of the image before resizing for input to the model, in (H, W) format.

    Returns:
      (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
        is given by original_size.
    """
    masks = F.interpolate(
        masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    masks = masks[..., : input_size[0], : input_size[1]]
    masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
    return masks

In [None]:
# upscaled_masks = postprocess_masks(low_res_masks.squeeze(1), batch["reshaped_input_sizes"][0].tolist(), batch["original_sizes"][0].tolist()).to(device)
# predicted_masks = normalize(threshold(upscaled_masks, 0.0, 0)).squeeze(1)
# loss = loss_fn(predicted_masks, ground_truth_masks)