# Dataset

In [3]:
import sys
sys.path.append('..')
import torch
import matplotlib.pyplot as plt
import numpy as np
from utils.datasets import COCOSegmentation

In [4]:
dataDir='../Datasets/coco-2017/'
dataType='val2017'
annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)

In [None]:
dataset = COCOSegmentation(dataDir, 'val', crop_size=0)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, worker_init_fn=None)

In [None]:
i, l, n = dataset[0]
i.shape, l.shape, n

In [None]:
for (i, l, n) in dataloader:
    print(i.shape, l.shape, n)
    break

In [None]:
plt.imshow(i[0])
plt.imshow(l[0], alpha=0.5)
plt.axis('off')
plt.show()
l.unique()

# Model

In [5]:
import requests

import matplotlib.pyplot as plt
from PIL import Image

import numpy as np
import torch
from utils.mobile_sam import sam_model_registry, SamPredictor

from utils import *

### Load model

In [6]:
torch.manual_seed(0)
np.random.seed(0)

GPU = 3

device = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu")

In [11]:
model_type = "vit_t"
sam_checkpoint = "../bin/mobile_sam.pt"

model = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device).eval()
predictor = SamPredictor(model)

RuntimeError: Error(s) in loading state_dict for Sam:
	Missing key(s) in state_dict: "prompt_encoder.point_embeddings.4.weight". 

In [10]:
model

Sam(
  (image_encoder): TinyViT(
    (patch_embed): PatchEmbed(
      (seq): Sequential(
        (0): Conv2d_BN(
          (c): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): GELU(approximate='none')
        (2): Conv2d_BN(
          (c): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layers): ModuleList(
      (0): ConvLayer(
        (blocks): ModuleList(
          (0-1): 2 x MBConv(
            (conv1): Conv2d_BN(
              (c): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): Conv2d_BN(
 

In [None]:
torch.save(model.state_dict(), 'bin/distilled_mobile_sam_online.pt')

### Get Input Image

In [None]:
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = np.array(Image.open(requests.get(img_url, stream=True).raw).convert("RGB"))
# plt.imshow(raw_image)
# plt.show()

In [None]:
raw_image.shape

### Run Inference with Prompt

In [None]:
input_points = np.array([[450, 600]])
input_label = np.array([1])

In [None]:
input_points = np.array([[450, 600]])
input_label = np.array([1])



with torch.no_grad():
    predictor.set_image(raw_image)
    output = predictor.predict(input_points, input_label, return_logits=True)

In [None]:
masks, scores, lowres = output

In [None]:
masks.shape, scores.shape, lowres.shape

In [None]:
plt.imshow(raw_image)
plt.imshow(masks[0], alpha=0.5)
plt.imshow(masks[1], alpha=0.5)
plt.imshow(masks[2], alpha=0.5)
plt.axis('off')
plt.show()

In [None]:
print(masks[0].min(), masks[0].max())

### Decode Output

In [None]:
plt.imshow(masks[np.argmax(scores)])
plt.show()

# Outputs

In [None]:
import random
import torch

In [None]:
CENTER = True

In [None]:
def get_output_masks(processor, model, i, input_points, device):
    i = i[0].detach().cpu().numpy().astype(np.uint8)
    predictor.set_image(i)
    masks, scores, _ = predictor.predict(np.array(input_points[0]), np.array([1]))
    return masks, scores
    
def get_prompt(name, label):

    # Load_prompts missing

    C = np.unique(label)[1:]
    c = np.random.choice(C)

    if CENTER:
        x, y = torch.sum(torch.argwhere(label==c),0)/torch.sum(label==c).detach().cpu().numpy()
        x, y = int(x), int(y)
    else:
        x_v, y_v = np.where(label == c)
        r = random.randint(0,len(x_v))
        x, y = x_v[r], y_v[r]
    return [[[y,x]]], c # inverted to compensate different indexing

In [None]:
def get_masks():

    name_list, mask_list, score_list, prompt_list, p_class_list = [], [], [], [], []
    for j, (i, l, n) in enumerate(dataloader):

        prompt, p_class = get_prompt(n, l[0])
        # show_points_on_image(i[0], input_points[0])

        masks, scores = get_output_masks(None, predictor, i, prompt, device)
        # show_masks_on_image(i[0], masks, scores)  
        name_list.append(int(n[0]))
        mask_list.append(masks.squeeze()[scores.argmax()])
        score_list.append(float(scores.max()))
        prompt_list.append(prompt[0][0])
        p_class_list.append(int(p_class))

        if j > 1:
            break

    return name_list, prompt_list, p_class_list, mask_list, score_list

In [None]:
name, prompt, p_class, mask, score = get_masks()

In [None]:
i = 1
name[i], prompt[i], p_class[i], mask[i].shape, score[i]

In [None]:
im = Image.open('../Datasets/coco-2017/val2017/' + str(name[i]).zfill(12) + '.jpg')
plt.imshow(im)
plt.imshow(mask[i], alpha=0.5)
print(dataset.classes[p_class[i]])
print(name[i])
plt.scatter(*prompt[i])
plt.show()

### Save DataFrame

In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame({'name': name, 'prompt': prompt, 'class': p_class, 'mask': mask, 'score': score})

In [None]:
df.head()

In [None]:
df.info()

In [None]:
df.hist(column='class')

In [None]:
df[['name', 'point', 'class']].to_pickle("results/coco_prompts.pkl")

In [None]:
df[['name', 'point', 'class']]

In [None]:
df = pd.read_pickle("results/cityscapes_prompts.pkl")

In [None]:
df.head()

In [None]:
df[df['name']==632][['point', 'class']].values[0][1]

### Predicted Classes

In [None]:
N_CLASSES = 92

In [None]:
def get_instance(label, c=None):
    if c is None:
        C = np.unique(label)[1:]
        c = np.random.choice(C)
        return label == c, c
    else:
        return label == c, c

def get_pred_classes(inst, label, n_classes, threshold=0.01):
    im = torch.logical_not(inst).to(torch.uint8)
    im[im==1] = n_classes
    m = im + label
    h, _ = np.histogram(m, bins=256, range=(0,255))
    clean_h = h[:n_classes]
    mask_tot = np.sum(clean_h)
    classes = np.where(clean_h > threshold * mask_tot)[0]
    return list(classes)

### Test class threshold

In [None]:
# Label
l = torch.zeros((224,224), dtype=torch.uint8)
l[100:150, 50:100] = 35
l[145:150, 95:100] = 91
l[100:140, 160:200] = 60
l[100:115, 50:65] = 0

plt.imshow(l)
plt.show()

In [None]:
# Predicted instance
i = torch.zeros((224,224), dtype=bool)
i[100:150, 50:100] = True
plt.imshow(i)
plt.show()

In [None]:
get_pred_classes(i, l, N_CLASSES, 0.01)

In [None]:
# Modified instance
im = torch.logical_not(i).to(torch.uint8)
im[im==1] = N_CLASSES

plt.imshow(im)
plt.show()

In [None]:
im.min(), im.max()

In [None]:
# Mask (intersection)
m = im + l

plt.imshow(m)
plt.show()

In [None]:
m.unique()

In [None]:
h, _ = np.histogram(m, bins=256, range=(0,255))
h

In [None]:
clean_h = h[:N_CLASSES]
clean_h

In [None]:
mask_tot = np.sum(clean_h)
mask_tot

In [None]:
np.where(clean_h > 0.01 * mask_tot)

# Metrics

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import PIL.Image as Image
import matplotlib.pyplot as plt

from utils import show_points_and_masks_on_image

%matplotlib inline
%load_ext autoreload

In [None]:
EXPERIMENT = ''
DATASET = 'coco'
MODEL = 'FastSAM'
ROOT = Path("../Datasets/coco-2017/val2017/") if DATASET == 'coco' else Path("../Datasets/Cityscapes/leftImg8bit/val/")
SPARSITY = 50
CLASSES = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
           'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 
           'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports', 'kite', 
           'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 
           'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
           'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
           'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
           'hair drier', 'toothbrush', 'hair brush']

In [None]:
def calculate_metrics(target, pred, eps=1e-5, verbose=False):

    if verbose:
        plt.subplot(1, 2, 1)
        plt.imshow(target)
        plt.subplot(1, 2, 2)
        plt.imshow(pred)
        plt.show()

    output = np.reshape(pred, -1)
    target = np.reshape(target, -1)

    tp = np.sum(output * target)  # TP (Intersection)
    un = np.sum(output + target)  # Union
    fp = np.sum(output * (~target))  # FP
    fn = np.sum((~output) * target)  # FN
    tn = np.sum((~output) * (~target))  # TN

    iou = (tp + eps) / (un + eps)
    pixel_acc = (tp + tn + eps) / (tp + tn + fp + fn + eps)
    dice = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    precision = (tp + eps) / (tp + fp + eps)
    recall = (tp + eps) / (tp + fn + eps)
    specificity = (tn + eps) / (tn + fp + eps)

    if verbose:
        print(f"IoU: {iou:.4f}, Pixel Acc: {pixel_acc:.4f}, Dice: {dice:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, Specificity: {specificity:.4f}")

    return iou, pixel_acc, dice, precision, specificity, recall

def get_analytics(target_df, pred_df):
    metrics = {k: [] for k in ['name', 'prompt', 'class', 't_class', 's_class', 'score', 'score_diff', 'mask_size', 
                               'mask_size_diff', 'iou', 'pixel_acc', 'dice', 'precision', 'recall', 'specificity']}
    for i in range(len(target_df)):
        target = target_df.loc[i]
        pred = pred_df.loc[i]

        iou, pixel_acc, dice, precision, specificity, recall = calculate_metrics(target['mask'], pred['mask'])
        
        metrics['name'].append(target['name'])
        metrics['prompt'].append(target['prompt'])
        metrics['class'].append(target['class'])
        metrics['t_class'].append(target['s_class'])
        metrics['s_class'].append(pred['s_class'])
        metrics['score'].append(pred['score'])
        metrics['score_diff'].append((pred['score'] - target['score']) / (target['score'] + 1e-5))
        p_size = np.mean(pred['mask'].astype('float'))
        t_size = np.mean(target['mask'].astype('float'))
        metrics['mask_size'].append(p_size)
        metrics['mask_size_diff'].append((p_size - t_size) / (t_size + 1e-3))
        metrics['iou'].append(iou)
        metrics['pixel_acc'].append(pixel_acc)
        metrics['dice'].append(dice)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['specificity'].append(specificity)
    
    return pd.DataFrame(metrics)

def get_labels(name):
    if isinstance(name, list):
        return [get_labels(n) for n in name]
    else: 
        return CLASSES[name].title()

def get_image(name):
    if DATASET == 'coco':
        image_path = ROOT.joinpath(f'{str(name).zfill(12)}.jpg')
    else:
        image_path = ROOT.joinpath(f"{name.split('_')[0]}/{name}")
    return np.array(Image.open(image_path).convert("RGB"))

def show_entry(row, target_df, pred_df):
    image = get_image(row['name'])
    target_mask = target_df[target_df['name']==row['name']]['mask'].values[0]
    pred_mask = pred_df[pred_df['name']==row['name']]['mask'].values[0]
    show_points_and_masks_on_image(image, [pred_mask, target_mask], [row['prompt']])
    print(f'ID: {row["name"]}, PromptClass: {get_labels(row["class"])}, TargetClass: {get_labels(row["t_class"])}, PredClass: {get_labels(row["s_class"])},') 
    print(f'ScoreDiff: {row["score_diff"]:.4f}, MaskSizeDiff: {row["mask_size_diff"]:.4f}, IoU: {row["iou"]:.4f}')
    
def show_samples(pie_df, target_df, pred_df, n=5):
    print('Legend: Target -> Orange, Prediction -> Blue')
    pie_df.iloc[:n].apply(lambda x: show_entry(x, target_df, pred_df), axis=1)

In [None]:
df_p = pd.read_pickle(f"results/{EXPERIMENT}{DATASET}_prompts.pkl")
df_0 = pd.read_pickle(f"results/{EXPERIMENT}{DATASET}_SAM_0.pkl")
df_s = pd.read_pickle(f"results/{EXPERIMENT}{DATASET}_{MODEL}_0.pkl")
df_0.head()

In [None]:
df_0s = get_analytics(df_0, df_s)
df_0s.head()

In [None]:
min_size = df_0s.nsmallest(25, ['mask_size_diff'])
max_size = df_0s.nlargest(25, ['mask_size_diff'])
min_score = df_0s.nsmallest(25, ['score_diff']) # not very useful
max_score = df_0s.nlargest(25, ['score_diff']) # not very useful
min_iou = df_0s.nsmallest(25, ['iou'])
max_iou = df_0s.nlargest(25, ['iou'])
min_size.head()

In [None]:
max_iou.head()

In [None]:
show_samples(min_size, df_0, df_s, 20)

In [None]:
df_s[df_s['score']<=0.1]

In [None]:
df_s.hist(column='score')

# Test

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch

In [None]:
m = torch.zeros((224,224), dtype=bool)
o = torch.ones((50,50), dtype=bool)
m[100:150, :50] = o

In [None]:
plt.imshow(m)
x, y = torch.argwhere(m==1).sum(0)/torch.sum(m)
x, y = int(x), int(y)
print(x, y)
plt.scatter(y, x, color='red')
plt.show()