In [1]:
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader as TorchDataset
from torch.optim import Adam
import wandb
import monai
from tqdm import tqdm
import numpy as np
import datasets
from transformers import SamProcessor, SamModel
from statistics import mean
import torch
from torch.nn.utils.rnn import pad_sequence 
import torch.nn.functional as F
import cv2
import random
import time
from scipy.ndimage import label
import evaluate
# Model info
# base_models = ["facebook/sam-vit-base", "facebook/sam-vit-huge", "facebook/sam-vit-large", "wanglab/medsam-vit-base"]
base_model = "facebook/sam-vit-base"
checkpoint_path = "/vol/data/models/custom_24-01-24_23.49.46"
dataset_path = "/vol/data/datasets/processed/custom/default_preprocessed_at_24-01-10_13.41.28"
pseudocolor = "grayscale"
prompt_type = "bboxes"


###
def evaluate_metrics(model, dataset, config, processor):
    model.eval()
    metric = evaluate.load("mean_iou")
    segmentations = []
    ground_truths = []
    for i in tqdm(range(3)):#len(dataset))):
        with torch.no_grad():
            if (config["prompt_type"]=="points"):
                image, points, gt_masks, mask_values = dataset[i]
                inputs = processor(image, input_points= [points], return_tensors="pt")
            else:
                image, bboxes, gt_masks, mask_values = dataset[i]
                inputs = processor(image, input_boxes=[bboxes], return_tensors="pt")
            outputs= model(**inputs, multimask_output=False)
            masks = torch.zeros(1, 14, 496, 512)
            masks = F.interpolate(outputs.pred_masks.squeeze(2), (1024,1024), mode="bilinear", align_corners=False)
            masks = masks[..., : inputs["reshaped_input_sizes"][0,0], : inputs["reshaped_input_sizes"][0,1]]
            masks = F.interpolate(masks, (inputs["original_sizes"][0,0],inputs["original_sizes"][0,1]), mode="bilinear", align_corners=False)
            masks = torch.sigmoid(masks).squeeze().numpy()
            binary_masks = (masks > 0.5).astype(np.uint8)
            for c in range(len(mask_values)):
                if mask_values[c] == 0 and c > 0:
                    break
                segmentations.append(binary_masks[c]* (mask_values[c]+1))
                ground_truths.append(gt_masks[c]* (mask_values[c]+1))
        
    metric_output = metric.compute(
        predictions=segmentations,
        references=ground_truths,
        ignore_index=255,
        num_labels=14,
        reduce_labels=True,
    )
    print(metric_output)
    f = open(config["results_path"], "w")
    f.write(str(metric_output))
    f.close()
###

class SAMDataset(TorchDataset):
    def __init__(self, dataset, processor, config):
        self.dataset = dataset
        self.processor = processor
        self.config = config

    def __len__(self):
        return len(self.dataset)

    def get_bboxes_and_gt_masks(self, ground_truth_mask):
        # get bounding boxes from mask
        structure = np.ones((3, 3), dtype=np.int32)
        bboxes, gt_masks = [],[]
        mask_values= np.unique(ground_truth_mask)
        final_mask_values = []
        #Comment for background prediction
        #mask_values = mask_values[1:]
        for v in mask_values: 
            binary_gt_mask = np.where(ground_truth_mask == v, 1.0, 0.0)
            labeled_gt_mask, ncomponents = label(binary_gt_mask, structure)
            for c in range(ncomponents):
                final_mask_values.append(v)
                x_indices, y_indices = np.where(labeled_gt_mask== c+1)
                x_min, x_max = np.min(x_indices), np.max(x_indices)
                y_min, y_max = np.min(y_indices), np.max(y_indices)
                # add perturbation to bounding box coordinates
                H, W = ground_truth_mask.shape
                x_min = max(0, x_min + np.random.randint(-10, 10))
                x_max = min(W, x_max + np.random.randint(-10, 10))
                y_min = max(0, y_min + np.random.randint(-10, 10))
                y_max = min(H, y_max + np.random.randint(-10, 10))
                bbox = [x_min, y_min, x_max, y_max]
                bboxes.append(bbox)
                gt_mask = np.where(labeled_gt_mask== c+1, 1.0, 0.0)
                gt_masks.append(gt_mask)
        return bboxes, gt_masks, final_mask_values

    def get_points_and_gt_masks(self, ground_truth_mask):
        structure = np.ones((3, 3), dtype=np.int32)
        points, gt_masks = [],[]
        mask_values= np.unique(ground_truth_mask)
        final_mask_values = []
        #Comment for background prediction
        #mask_values= mask_values[1:]
        for v in mask_values: 
            binary_gt_mask = np.where(ground_truth_mask == v, 1.0, 0.0)
            labeled_gt_mask, ncomponents = label(binary_gt_mask, structure)
            for c in range(ncomponents):
                final_mask_values.append(v)
                x_indices, y_indices = np.where(labeled_gt_mask== c+1)
                rand_idx = random.randrange(0, len(x_indices))
                points.append([[x_indices[rand_idx], y_indices[rand_idx]]])
                gt_mask = np.where(labeled_gt_mask== c+1, 1.0, 0.0)
                gt_masks.append(gt_mask)
        return points, gt_masks, final_mask_values

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = np.array(item["image"])
        if (self.config["pseudocolor"] != None):
            image = cv2.applyColorMap(image[:, :, 0], self.config["pseudocolor"])
        ground_truth_mask = np.array(item["label"])
        if (self.config["prompt_type"]=="points"):
            points, gt_masks, mask_values= self.get_points_and_gt_masks(ground_truth_mask)
            return [image, points, gt_masks, mask_values]
        else:
            bboxes, gt_masks, mask_values= self.get_bboxes_and_gt_masks(ground_truth_mask)
            return [image, bboxes, gt_masks, mask_values]

#Colormap
OCV_COLORMAPS = {
    "Autumn": cv2.COLORMAP_AUTUMN, 
    "Bone": cv2.COLORMAP_BONE,
    "Cividis": cv2.COLORMAP_CIVIDIS, 
    "Cool": cv2.COLORMAP_COOL, 
    "Deepgreen": cv2.COLORMAP_DEEPGREEN,
    "Hot": cv2.COLORMAP_HOT,
    "HSV": cv2.COLORMAP_HSV,
    "Inferno": cv2.COLORMAP_INFERNO,
    "Jet": cv2.COLORMAP_JET,
    "Magma": cv2.COLORMAP_MAGMA,
    "Ocean": cv2.COLORMAP_OCEAN,
    "Parula": cv2.COLORMAP_PARULA,
    "Pink": cv2.COLORMAP_PINK,
    "Plasma": cv2.COLORMAP_PLASMA,
    "Rainbow": cv2.COLORMAP_RAINBOW,
    "Viridis": cv2.COLORMAP_VIRIDIS,
    "Winter": cv2.COLORMAP_WINTER,
    "Spring": cv2.COLORMAP_SPRING,
    "Summer": cv2.COLORMAP_SUMMER,
    "Twilight shifted": cv2.COLORMAP_TWILIGHT_SHIFTED,
    "Twilight": cv2.COLORMAP_TWILIGHT,
    "Turbo": cv2.COLORMAP_TURBO,
    "grayscale": None
}
# mask_dict
mask_dict = (
    "background",
    "epiretinal membrane",
    "neurosensory retina",
    "intraretinal fluid",
    "subretinal fluid",
    "subretinal hyperreflective material",
    "retinal pigment epithelium",
    "pigment epithelial detachment",
    "posterior hyaloid membrane",
    "choroid border",
    "imaging artifacts",
    "fibrosis",
    "vitreous body",
    "image padding" 
)

processor = SamProcessor.from_pretrained(base_model)
model = SamModel.from_pretrained(base_model)
model.load_state_dict(torch.load(checkpoint_path +".pt"))
dataset = datasets.load_from_disk(dataset_path)["test"]
config ={
    "pseudocolor": OCV_COLORMAPS[pseudocolor],
    "prompt_type": prompt_type,
    "mask_dict": mask_dict,
    "results_path": checkpoint_path + ".txt"
}
dataset = SAMDataset(dataset=dataset, processor=processor, config=config)



In [46]:
model.eval()
metric = evaluate.load("mean_iou")
segmentations = []
ground_truths = []
for i in range(14):
    segmentations.append([])
    ground_truths.append([])
for i in tqdm(range(1)):#len(dataset))):
    with torch.no_grad():
        if (config["prompt_type"]=="points"):
            image, points, gt_masks, mask_values = dataset[i]
            inputs = processor(image, input_points= [points], return_tensors="pt")
        else:
            image, bboxes, gt_masks, mask_values = dataset[i]
            inputs = processor(image, input_boxes=[bboxes], return_tensors="pt")
        outputs= model(**inputs, multimask_output=False)
        masks = torch.zeros(1, 14, 496, 512)
        masks = F.interpolate(outputs.pred_masks.squeeze(2), (1024,1024), mode="bilinear", align_corners=False)
        masks = masks[..., : inputs["reshaped_input_sizes"][0,0], : inputs["reshaped_input_sizes"][0,1]]
        masks = F.interpolate(masks, (inputs["original_sizes"][0,0],inputs["original_sizes"][0,1]), mode="bilinear", align_corners=False)
        masks = torch.sigmoid(masks).squeeze().numpy()
        binary_masks = (masks > 0.5).astype(np.uint8)
        for c in range(len(mask_values)):
            if mask_values[c] == 0 and c > 0:
                break
            segmentations[mask_values[c]].append(binary_masks[c])
            ground_truths[mask_values[c]].append(gt_masks[c])


100%|██████████| 1/1 [00:06<00:00,  6.26s/it]


In [27]:
metric_output = metric.compute(
    predictions=segmentations[2],
    references=ground_truths[2],
    ignore_index=255,
    num_labels=1,
    reduce_labels=True,
)

In [26]:
metric_output

{'mean_iou': 0.8037700037802419,
 'mean_accuracy': 0.8037700037802419,
 'overall_accuracy': 0.8037700037802419,
 'per_category_iou': array([0.80377]),
 'per_category_accuracy': array([0.80377])}

In [47]:
for i in segmentations:
    print(len(i))

1
1
1
0
0
0
1
6
0
0
1
0
1
1


In [50]:
for i in range(14):
    print(metric.compute(
        predictions=segmentations[i],
        references=ground_truths[i],
        ignore_index=255,
        num_labels=1,
        reduce_labels=True,
    ))

{'mean_iou': 0.17971943715306166, 'mean_accuracy': 0.17971943715306166, 'overall_accuracy': 0.17971943715306166, 'per_category_iou': array([0.17971944]), 'per_category_accuracy': array([0.17971944])}
{'mean_iou': 1.0, 'mean_accuracy': 1.0, 'overall_accuracy': 1.0, 'per_category_iou': array([1.]), 'per_category_accuracy': array([1.])}
{'mean_iou': 1.0, 'mean_accuracy': 1.0, 'overall_accuracy': 1.0, 'per_category_iou': array([1.]), 'per_category_accuracy': array([1.])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}


  all_acc = total_area_intersect.sum() / total_area_label.sum()
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  metrics["mean_iou"] = np.nanmean(iou)
  metrics["mean_accuracy"] = np.nanmean(acc)


{'mean_iou': 1.0, 'mean_accuracy': 1.0, 'overall_accuracy': 1.0, 'per_category_iou': array([1.]), 'per_category_accuracy': array([1.])}
{'mean_iou': 1.0, 'mean_accuracy': 1.0, 'overall_accuracy': 1.0, 'per_category_iou': array([1.]), 'per_category_accuracy': array([1.])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}
{'mean_iou': 0.9261353957275732, 'mean_accuracy': 0.9261353957275732, 'overall_accuracy': 0.9261353957275732, 'per_category_iou': array([0.9261354]), 'per_category_accuracy': array([0.9261354])}
{'mean_iou': nan, 'mean_accuracy': nan, 'overall_accuracy': nan, 'per_category_iou': array([nan]), 'per_category_accuracy': array([nan])}
{'mean_iou': 0.4174541164423965, 'mean_accuracy': 0.4174541164423965, 'overall_accuracy': 0.4174541164423965, 'per_cate

In [53]:
(0.63437679+0.00345155+0.01173586+0.01183259+0.02157136+0.13978959+0.00600943+0.14160961+0.01343778+0.02645307+0.07045111+0.0290135+0.30911137+0.39061849) /14

0.12924729285714287

In [54]:
metric_output

{'mean_iou': 0.8037700037802419,
 'mean_accuracy': 0.8037700037802419,
 'overall_accuracy': 0.8037700037802419,
 'per_category_iou': array([0.80377]),
 'per_category_accuracy': array([0.80377])}

In [56]:
metric_output['per_category_accuracy'][0]

0.8037700037802419

In [60]:
x = np.zeros(14)

In [61]:
x[0]

0.0

In [65]:
str(list(x))

'[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]'