# Evaluate MSA finetuning on iqs_dv
A notebook that evaluates a fine tuned MSA model on the EVICAN test sets. Metrics are saved to a .json file and Model predictions are saved as plots to .pngs.

## Imports

In [2]:
# Import
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as transforms

# Prompting imports
from prompter import *

# Metric imports
from torchmetrics import Accuracy, Precision, Recall, JaccardIndex, F1Score
from metrics import MeanIoU, PanopticQuality
from torchmetrics.detection.mean_ap import MeanAveragePrecision
calc_iou_matrix = MeanIoU._calc_iou_matrix

# Augmentation imports
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

# Model
from models.sam import SamPredictor, sam_model_registry

# Data
from dataset import *
from torch.utils.data import DataLoader

from utils import *


## Path setup

In [3]:
# Put the original SAM weights here. 
# They can be downloaded from: 
# https://github.com/facebookresearch/segment-anything#model-checkpoints
SAM_CKPT = '/home/zozchaab/checkpoint/sam/sam_vit_b_01ec64.pth'

# Put your trained SAM model here
TUNED_MODEL_CKPT = '/home/zozchaab/Medical-SAM-Adapter/logs/random_sampling_per_component_2024_01_14_21_06_28/Model/checkpoint_best.pth'

# Data directory needed, where EVICAN should be saved to /loaded from
DATA_DIR = '/home/zozchaab/data'

# Define your save directory here
SAVE_ROOT = '/home/zozchaab/Medical_SAM_Adapter_Training/Evaluation/iqs_dv'
os.makedirs(SAVE_ROOT, exist_ok=True)

## Model Setup

In [4]:
# Make sure GPU is available
device = 0
torch.cuda.get_device_properties(device)

_CudaDeviceProperties(name='GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)

Please choose the checkpoint to be used here !!

In [11]:
# This is needed to circumvent the argparser in the original repo
class Args:
    thd = False
    image_size = 120
    crop_size = 120
    data_path = "/home/zozchaab/data/deepvision/deepvision"
    b = 2
    w=0
args = Args()

#tuned_ckpt = torch.load(TUNED_MODEL_CKPT)
sam = sam_model_registry['vit_b'](args, checkpoint=SAM_CKPT)
#sam.load_state_dict(tuned_ckpt['state_dict'])
sam.eval()
sam = sam.to(device)
predictor = SamPredictor(sam)

## Data Setup

In [12]:

transform_2d = transforms.Compose([

lambda x: x.expand(3, -1, -1),
transforms.Lambda(lambda x: x / 65535.0)

])

test_dataset = iqs_dv(data_path=os.path.join(args.data_path,'iqs_dv_test'),crop_size=args.crop_size,transform_2D=transform_2d)

nice_test_loader = DataLoader(
test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=args.w,
collate_fn=collate_fn
) 


## Prompt Setup

In [9]:
N_POINTS_MAX = 1 # Define the number of intial prompts generated using gaussian sampling
N_MAX_ITER_PROMPTS = 9 # Define the number of iterative prompts

## Test Loop

In [17]:
metrics_out = {}
vis_path = "home/zozchaab/vis"
# Iterate over the three dataset difficulties


folder = os.path.join(SAVE_ROOT, f'{N_POINTS_MAX}_{N_MAX_ITER_PROMPTS}')
os.makedirs(folder, exist_ok=True)

# Create metrics
metrics = [
    Accuracy(task='binary').to(device), 
    Precision(task='binary').to(device), 
    Recall(task='binary').to(device), 
    F1Score(task='binary').to(device), 
    JaccardIndex(task='binary').to(device),
]
miou = MeanIoU('optimal', False, False).to(device)
pq = PanopticQuality().to(device)
ind =0
with torch.no_grad():
    with tqdm(total=len(test_dataset), unit='img') as pbar:
        for pack in nice_test_loader:
            preds = []
            scores = []
            prompts = []
            original_preds = []
            imgs = pack['image'].to(dtype = torch.float32, device = device)
            targets = pack['label'].to(dtype = torch.float32, device = device)
            names = pack['metadata']
            
            for img, mask in zip(imgs, targets):
                img_emb = sam.image_encoder(img.unsqueeze(0))
                # Create Prompts
            
                # Randomly sample number of prompts
                #n_points = np.random.randint(1, N_POINTS_MAX)
                #n_pos = np.random.randint(1, n_points) if n_points > 1 else 1
                #n_neg = np.random.randint(0, n_points-n_pos) if (n_points - n_pos) > 0 else 0
                n_neg = 0
                n_pos = 5
                pts, lbls = sample_from_mask(mask.squeeze(0), mode="random", n_pos=n_pos,n_neg = n_neg) 
    

                user_iter = 0 
                # Randomly add pseudo user input 
                #user_iter = np.random.randint(N_MAX_ITER_PROMPTS)
                for i in range(user_iter):
                    # print(f'User interaction {i+1}/{user_iter}')
                    
                    # Set prompt
                    prompt = (pts.unsqueeze(0).to(device), lbls.unsqueeze(0).to(device))
                    se, de = sam.prompt_encoder(
                        points=prompt,
                        boxes=None,
                        masks=None,
                    ) # type: ignore
                    
                    # Predict Mask
                    pred, _ = sam.mask_decoder(
                        image_embeddings=img_emb,
                        image_pe=sam.prompt_encoder.get_dense_pe(),  # type: ignore
                        sparse_prompt_embeddings=se,
                        dense_prompt_embeddings=de, 
                        multimask_output=False,
                    ) # type: ignore
                    # Compare Prediction to GT
                    pred = F.interpolate(pred, mask.shape[-2:]) # This is a bit cumbersome, but the easiest fix for now
                    pred = pred.squeeze() > 0 #check this!!!
                    clusters = pred.cpu() != mask
                    # Sample point from largest error cluster 
                    new_prompt = find_best_new_prompt(clusters)
                    new_label = mask[new_prompt[0, 1], new_prompt[0, 0]].to(torch.int64)
                    pts = torch.concatenate([pts, new_prompt])
                    lbls = torch.concatenate([lbls, torch.tensor([new_label])])

                # Final Mask inference
                prompts.append([pts,lbls])
                prompt = (pts.unsqueeze(0).to(device), lbls.unsqueeze(0).to(device))

                # Set Prompt
                
                se, de = sam.prompt_encoder(
                    points=prompt,
                    boxes=None,
                    masks=None,
                ) # type: ignore

                # Predict Mask
                pred, score = sam.mask_decoder(
                    image_embeddings=img_emb,
                    image_pe=sam.prompt_encoder.get_dense_pe(),  # type: ignore
                    sparse_prompt_embeddings=se,
                    dense_prompt_embeddings=de, 
                    multimask_output=False,
                ) # type: ignore
                original_preds.append((pred.squeeze(0) > 0).float())
                pred = F.interpolate(pred, mask.shape[-2:]).squeeze(0) # This is a bit cumbersome, but the easiest fix for now
                preds.append((pred > 0).float())
                scores.append(score)
                
                for m in metrics:
                    m.update(pred, mask.to(torch.uint8))
                pq.update(pred, mask)
                miou.update(pred, mask)
                
            scores = torch.stack(scores)
            
           
            vis = 0
            if vis:
                if ind % vis == 0:
                    visualize_batch(imgs=imgs, masks=targets, pred_masks=preds, names=names, prompts=prompts,original_preds=original_preds,save_path=vis_path)
            ind+=1

            pbar.update()

ms = {str(m): m.compute().item() for m in metrics}
metrics_out = {
    **ms,
    'miou': {k:v.item() for k, v in miou.compute().items()},
    'pq': {k:v.item() for k, v in pq.compute().items()},
}

import json
with open(os.path.join(SAVE_ROOT, f'MSA_{N_POINTS_MAX}_{N_MAX_ITER_PROMPTS}_metrics.json'), 'w') as f:
    json.dump(metrics_out, f)

 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                           | 87/173 [06:26<06:21,  4.44s/img]


In [16]:
metrics_out

{'BinaryAccuracy()': 0.8903571367263794,
 'BinaryPrecision()': 0.03192711994051933,
 'BinaryRecall()': 0.3515303134918213,
 'BinaryF1Score()': 0.05853765457868576,
 'BinaryJaccardIndex()': 0.03015132248401642,
 'miou': {'mIoU_micro': 0.010290836072619538,
  'mIoU_macro': 0.010290835984051228,
  'n_instances': 20760,
  'n_images': 20760},
 'pq': {'panoptic_quality': 0.0001517037017038092,
  'recognition_quality': 0.00028901733458042145,
  'segmentation_quality': 0.0001517037017038092}}

In [None]:
def crop(image,crop_size,top_left_x,top_left_y):
    # Crop the image
    return image[top_left_y:top_left_y + crop_size, top_left_x:top_left_x + crop_size]

def crop_image_and_mask(image, mask, crop_size):
    # Pad or crop the image to the target size
    h, w = image.shape[:2]

    if h < crop_size or w < crop_size:
        # Calculate padding needed
        pad_h = max(0, crop_size - h)
        pad_w = max(0, crop_size - w)

        # Calculate padding on each side
        top_pad = pad_h // 2
        bottom_pad = pad_h - top_pad
        left_pad = pad_w // 2
        right_pad = pad_w - left_pad

        # Pad the image
        image = F.pad(image, (left_pad, right_pad, top_pad, bottom_pad), mode='constant', value=0)
        
    elif h > crop_size or w > crop_size:
        # Randomly choose the top-left corner of the crop
        top_left_x = np.random.randint(0, w - crop_size + 1)
        top_left_y = np.random.randint(0, h - crop_size + 1)
  
        cropped_image = crop(image, crop_size,top_left_x,top_left_y)
        cropped_mask = crop(mask, crop_size, top_left_x, top_left_y)

    return cropped_image.unsqueeze(0), cropped_mask.unsqueeze(0)