In [None]:
%matplotlib inline
import os
import sys
from functools import partial

import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from open_clip import create_model_and_transforms, get_tokenizer

import numpy as np
from torchmetrics import JaccardIndex
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator

from dataset import PartImageNetWithMask, PredictedMask
from utils import TextFeatures, get_masked_pred_sam_c, get_masked_pred_sam_f
from utils import create_colormap, visualize_img, visualize_seg

In [None]:
device = "cuda:0"
torch.cuda.set_device(device)

clip, _, clip_transform = create_model_and_transforms('ViT-B-16', pretrained='openai')
tokenizer = get_tokenizer('ViT-B-16')

clip = clip.to(device)

normalize = clip_transform.transforms[-1]
img_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.CenterCrop([224, 224]),
])
seg_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.NEAREST),
    T.CenterCrop([224, 224]),
])

In [None]:
# SAM_MODEL = "vit_b"
# SAM_CKPT_PATH = os.path.join('../sam_ckpt', 'sam_vit_b_01ec64.pth')
SAM_MODEL = "vit_h"
SAM_CKPT_PATH = os.path.join('../sam_ckpt', 'sam_vit_h_4b8939.pth')

sam = sam_model_registry[SAM_MODEL](checkpoint=SAM_CKPT_PATH)
sam.to(device=device)

predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
DATA_ROOT = '../data/PartImageNet/'
SAVE_ROOT = '../pred_segs/'

img_root = os.path.join(DATA_ROOT, 'images/val')
ano_root = os.path.join(DATA_ROOT, 'annotations/val.json')

# Output: image, seg_c, seg_f
dataset = PartImageNetWithMask(img_root, ano_root, img_transform, seg_transform)  # use base image transform

In [None]:
print(dataset.classname_c)
print(dataset.classname_f)

text_features = TextFeatures(clip, tokenizer, dataset.classname_c, dataset.classname_f)    

names = {}
for c in dataset.classname_c:
    names[c] = [f for f in dataset.classname_f if c in f]
cmap_c, cmap_f = create_colormap(names)

In [None]:
def print_values():
    print("{:d}/{:d}    {:.2f}/{:.2f}".format(
        index + 1, len(dataset),
        np.mean(accs_c) * 100, np.mean(accs_f) * 100,
    ))

jaccard_c = JaccardIndex(task="multiclass", num_classes=11+1)
jaccard_f = JaccardIndex(task="multiclass", num_classes=40+1)

accs_c, accs_f = [], []
for index in range(len(dataset)):
    try:
        img_base, seg_c, seg_f = dataset[index]
        img = clip_transform(img_base)
    
        masks = mask_generator.generate(np.array(img_base))
        masks = sorted(masks, key=lambda m: m["area"], reverse=True)
        masks = [torch.from_numpy(m['segmentation']).unsqueeze(0) for m in masks]
        mask_c = mask_f = masks
    
        pred_c = get_masked_pred_sam_c(clip, text_features, img, mask_c)
        pred_f = get_masked_pred_sam_f(clip, text_features, img, mask_f, pred_c)

        accs_c.append(jaccard_c(pred_c, seg_c).item())
        accs_f.append(jaccard_f(pred_f, seg_f).item())
    except:
        pass
    
    if (index + 1) % 100 == 0:
        print_values()

index = len(dataset) - 1
print_values()