In [None]:
%matplotlib inline
import os
import sys
from functools import partial

import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from open_clip import create_model_and_transforms, get_tokenizer

import numpy as np
from torchmetrics import JaccardIndex

from dataset import PartImageNetWithMask, PredictedMask
from utils import TextFeatures, get_masked_pred_c, get_masked_pred_f

In [None]:
device = "cuda:0"
torch.cuda.set_device(device)

clip, _, clip_transform = create_model_and_transforms('ViT-B-16', pretrained='openai')
tokenizer = get_tokenizer('ViT-B-16')

clip = clip.to(device)

normalize = clip_transform.transforms[-1]
img_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.CenterCrop([224, 224]),
])
seg_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.NEAREST),
    T.CenterCrop([224, 224]),
])

In [None]:
DATA_ROOT = '../data/PartImageNet/'
SAVE_ROOT = '../pred_segs/'

model_name = "cast_base"
# model_name = "vit_base"

img_root = os.path.join(DATA_ROOT, 'images/val')
ano_root = os.path.join(DATA_ROOT, 'annotations/val.json')

pred_c_root = os.path.join(SAVE_ROOT, model_name, 'level4')
pred_f_root = os.path.join(SAVE_ROOT, model_name, 'level3')

# Output: image, seg_c, seg_f
dataset = PartImageNetWithMask(img_root, ano_root, clip_transform, seg_transform)

# Predicted segments by CAST or ViT
mask_dataset_c = PredictedMask(pred_c_root, ano_root)
mask_dataset_f = PredictedMask(pred_f_root, ano_root)

In [None]:
print(dataset.classname_c)
print(dataset.classname_f)

text_features = TextFeatures(clip, tokenizer,
                             dataset.classname_c,
                             dataset.classname_f)

# names = {}
# for c in dataset.classname_c:
#     names[c] = [f for f in dataset.classname_f if c in f]
# cmap_c, cmap_f = create_colormap(names)

In [None]:
def print_values():
    print("{:d}/{:d}    {:.2f}/{:.2f}".format(
        index + 1, len(dataset),
        np.mean(accs_c) * 100, np.mean(accs_f) * 100,
    ))

jaccard_c = JaccardIndex(task="multiclass", num_classes=11+1)
jaccard_f = JaccardIndex(task="multiclass", num_classes=40+1)

accs_c, accs_f = [], []
for index in range(len(dataset)):
    try:
        img, seg_c, seg_f = dataset[index]

        mask_c = mask_dataset_c[index]
        mask_f = mask_dataset_f[index]
    
        pred_c = get_masked_pred_c(clip, text_features, img, mask_c)
        pred_f = get_masked_pred_f(clip, text_features, img, mask_f, pred_c)
        
        accs_c.append(jaccard_c(pred_c, seg_c).item())
        accs_f.append(jaccard_f(pred_f, seg_f).item())
    except:
        pass
    
    if (index + 1) % 100 == 0:
        print_values()

index = len(dataset) - 1
print_values()