In [2]:
%matplotlib inline
import os
import sys
from functools import partial

import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from open_clip import create_model_and_transforms, get_tokenizer

import numpy as np
from torchmetrics import JaccardIndex

from dataset import PartImageNetWithMask, PredictedMask
from utils import TextFeatures, get_masked_pred_c, get_masked_pred_f

In [3]:
device = "cuda:0"
torch.cuda.set_device(device)

clip, _, clip_transform = create_model_and_transforms('ViT-B-16', pretrained='openai')
tokenizer = get_tokenizer('ViT-B-16')

clip = clip.to(device)

normalize = clip_transform.transforms[-1]
img_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.CenterCrop([224, 224]),
])
seg_transform = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.NEAREST),
    T.CenterCrop([224, 224]),
])

In [4]:
DATA_ROOT = '../data/PartImageNet/'
SAVE_ROOT = '../pred_segs/'

model_name = "cast_base"
# model_name = "vit_base"

img_root = os.path.join(DATA_ROOT, 'images/val')
ano_root = os.path.join(DATA_ROOT, 'annotations/val.json')

pred_c_root = os.path.join(SAVE_ROOT, model_name, 'level4')
pred_f_root = os.path.join(SAVE_ROOT, model_name, 'level3')

# Output: image, seg_c, seg_f
dataset = PartImageNetWithMask(img_root, ano_root, clip_transform, seg_transform)

# Predicted segments by CAST or ViT
mask_dataset_c = PredictedMask(pred_c_root, ano_root)
mask_dataset_f = PredictedMask(pred_f_root, ano_root)

loading annotations into memory...
Done (t=0.31s)
creating index...
index created!
loading annotations into memory...
Done (t=0.13s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!


In [5]:
print(dataset.classname_c)
print(dataset.classname_f)

text_features = TextFeatures(clip, tokenizer,
                             dataset.classname_c,
                             dataset.classname_f)

# names = {}
# for c in dataset.classname_c:
#     names[c] = [f for f in dataset.classname_f if c in f]
# cmap_c, cmap_f = create_colormap(names)

['Quadruped', 'Biped', 'Fish', 'Bird', 'Snake', 'Reptile', 'Car', 'Bicycle', 'Boat', 'Aeroplane', 'Bottle']
['Quadruped Head', 'Quadruped Body', 'Quadruped Foot', 'Quadruped Tail', 'Biped Head', 'Biped Body', 'Biped Hand', 'Biped Foot', 'Biped Tail', 'Fish Head', 'Fish Body', 'Fish Fin', 'Fish Tail', 'Bird Head', 'Bird Body', 'Bird Wing', 'Bird Foot', 'Bird Tail', 'Snake Head', 'Snake Body', 'Reptile Head', 'Reptile Body', 'Reptile Foot', 'Reptile Tail', 'Car Body', 'Car Tier', 'Car Side Mirror', 'Bicycle Body', 'Bicycle Head', 'Bicycle Seat', 'Bicycle Tier', 'Boat Body', 'Boat Sail', 'Aeroplane Head', 'Aeroplane Body', 'Aeroplane Engine', 'Aeroplane Wing', 'Aeroplane Tail', 'Bottle Mouth', 'Bottle Body']


In [6]:
def print_values():
    print("{:d}/{:d}    {:.2f}/{:.2f}".format(
        index + 1, len(dataset),
        np.mean(accs_c) * 100, np.mean(accs_f) * 100,
    ))

jaccard_c = JaccardIndex(task="multiclass", num_classes=11+1)
jaccard_f = JaccardIndex(task="multiclass", num_classes=40+1)

accs_c, accs_f = [], []
for index in range(len(dataset)):
    try:
        img, seg_c, seg_f = dataset[index]

        mask_c = mask_dataset_c[index]
        mask_f = mask_dataset_f[index]
    
        pred_c = get_masked_pred_c(clip, text_features, img, mask_c)
        pred_f = get_masked_pred_f(clip, text_features, img, mask_f, pred_c)
        
        accs_c.append(jaccard_c(pred_c, seg_c).item())
        accs_f.append(jaccard_f(pred_f, seg_f).item())
    except:
        pass
    
    if (index + 1) % 100 == 0:
        print_values()

index = len(dataset) - 1
print_values()

100/2957    35.98/16.58
200/2957    34.37/15.45
300/2957    29.89/13.23
400/2957    27.48/12.28
500/2957    26.26/11.86
600/2957    25.83/12.38
700/2957    25.73/12.83
800/2957    25.83/13.35
900/2957    25.42/13.66
1000/2957    25.21/13.46
1100/2957    24.49/13.00
1200/2957    24.17/12.77
1300/2957    23.80/12.52
1400/2957    24.17/12.59
1500/2957    24.47/12.64
1600/2957    24.83/12.69
1700/2957    25.41/12.78
1800/2957    26.01/12.85
1900/2957    26.60/12.91
2000/2957    27.19/13.02
2100/2957    27.68/13.08
2200/2957    28.24/13.16
2300/2957    28.54/13.16
2400/2957    29.03/13.25
2500/2957    29.37/13.28
2600/2957    29.62/13.25
2700/2957    29.71/13.23
2800/2957    30.03/13.30
2900/2957    29.83/13.23
2957/2957    29.67/13.21
