complete WN

In [2]:
import sys
sys.path.append('/home/sheng/OSZSL/')
sys.path.append('/home/sheng/OSZSL/CLIP/')

import os
import json
import pickle
import math
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import numpy as np
from functools import reduce
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from ipynb_utils import get_hier_datasets, get_classifier
import clip
from data.datasets import get_datasets_oszsl, build_transform, Vocab

from nltk.corpus import wordnet as wn
from wordnet_utils import *

In [3]:
class Config:
    device = 'cuda:3'
    arch = 'ViT-B/16'
    dataset_name = 'make_nonliving26'
    n_sampled_classes = 100
    input_size = 224
    
    batch_size = 1024
    use_def = False
    clip_checkpoint = None
    vocab_fpath = 'wordnet_nouns_complete4:'
    f_classifier = './cache/wordnet_classifier_complete4:_small_def.pth'
    
    map_vocab_name = False
    
args = Config()

In [4]:
with open('../templates_small_def.json', 'rb') as f:
    templates = json.load(f)

def get_vocab(fpath=None):
    """
    Args:
        vocab: {`names`: list, `ids`: synset ids, `parents`: [{synset ids}]}
    """
    with open('/home/sheng/dataset/wordnet_nouns.pkl', 'rb') as f:
        vocab = pickle.load(f)
    if fpath is not None:
        with open(f'/home/sheng/dataset/{fpath}.pkl', 'rb') as f:
            vocab = pickle.load(f)
        
    # with open('/home/sheng/dataset/wordnet_nouns_no_abstract.pkl', 'rb') as f:
    #     vocab = pickle.load(f)
    return vocab

vocab = get_vocab(args.vocab_fpath)
templates = templates['imagenet']
classnames = vocab['names']
parents = vocab['parents']
defs = vocab['def']

In [5]:
vocab = Vocab(vocab=vocab)

#### model

In [6]:
model, preprocess = clip.load(args.arch)
if args.clip_checkpoint:
    model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
model.to(args.device).eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408


#### dataset

In [7]:
transform_val = build_transform(is_train=False, args=args, train_config=None)
dataset_raw = get_datasets_oszsl(args, None, is_train=True, transform=transform_val, 
                                 seed=1, map_vocab_name=args.map_vocab_name)
dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_val, seed=1, map_vocab_name=args.map_vocab_name)

loader_f = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)

### construct classifier

#### wordnet classifier with definition

In [74]:
all_class_names = vocab.vocab['names']
all_class_defs = vocab.vocab['def']

batch_size = 128
with torch.no_grad():
    zeroshot_weights = []
    with tqdm(total=len(vocab.classnames)//batch_size) as pbar:
        for idx_set, classname_set in zip(np.array_split(np.arange(len(all_class_names)), len(all_class_names)//batch_size), 
                                 np.array_split(all_class_names, len(all_class_names)//batch_size)):
            texts = [
                template.format(classname, all_class_defs[idx])[:77] 
                for idx, classname in zip(idx_set, classname_set) for template in templates
            ]
            texts = clip.tokenize(texts).to(args.device) #tokenize
            class_embeddings = model.encode_text(texts).float() #embed with text encoder
            class_embeddings = class_embeddings.view(-1, len(templates), class_embeddings.size(-1))
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=1)
            class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
            zeroshot_weights.append(class_embedding.cpu())

            pbar.update(1)

classifier = torch.cat(zeroshot_weights, dim=0)
torch.save(classifier, './cache/wordnet_classifier_complete4:_small_def.pth')

641it [04:57,  2.15it/s]                         


#### LLM-augmented classifier

In [13]:
all_class_names = vocab.vocab['names']
all_class_defs = vocab.vocab['def']

data_augmented_classifier = torch.load('./cache/all_aug_prompts_embed-wn-gpt3-c-2023_02_26.pth')
all_augmented_classifier = data_augmented_classifier['all_augmented_classifier']
class_name_key_mapping = data_augmented_classifier['class_name_key_mapping']

### Baseline

#### raw evaluation

In [38]:
args.f_classifier = './cache/wordnet_classifier_complete4:_small_def.pth'
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
amp_autocast = torch.cuda.amp.autocast

### collect variables
prob_k = 5
all_instance_voc_topk_ind = []
all_gt_label_voc = []
all_gt_label_clu = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_instance_voc_topk_ind.append(prob_topk_ind.cpu().numpy())
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_features.append(logits.cpu().numpy())
        pbar.update(1)

all_instance_voc_topk_ind = np.concatenate(all_instance_voc_topk_ind)
# all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_voc = np.array(list(reduce(lambda x, y: x+y, all_gt_label_voc)))
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_features = np.concatenate(all_features)

100%|██████████| 130/130 [04:38<00:00,  2.15s/it]


In [40]:
(np.array(classnames)[all_instance_voc_topk_ind[:, 0]]==np.array(all_gt_label_voc)).mean()

0.24331713930629306

In [35]:
all_gt_label_voc = list(reduce(lambda x, y: x+y, all_gt_label_voc))

#### raw evaluation with  LLM-augmented classifier

In [14]:
def get_aug_classifier(all_synset_names, all_augmented_classifier):
    all_aug_classifier = torch.stack(list(map(lambda x: all_augmented_classifier[x], all_synset_names)), dim=0)
    return all_aug_classifier

In [15]:
amp_autocast = torch.cuda.amp.autocast
### get aug classifier in sequential order
all_synset_names = list(map(lambda x: mapping_ids_synset(x).name(), vocab.vocab['ids']))
all_aug_classifier = get_aug_classifier(all_synset_names, all_augmented_classifier).float()
### normalization
all_aug_classifier = F.normalize(F.normalize(all_aug_classifier, dim=-1).mean(dim=1), dim=-1)
all_aug_classifier = all_aug_classifier.to(args.device)

### collect variables
prob_k = 5
all_instance_voc_topk_ind = []
all_gt_label_voc = []
all_gt_label_clu = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits.float() @ all_aug_classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_instance_voc_topk_ind.append(prob_topk_ind.cpu().numpy())
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_features.append(logits.cpu().numpy())
        pbar.update(1)

all_instance_voc_topk_ind = np.concatenate(all_instance_voc_topk_ind)
# all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_voc = np.array(list(reduce(lambda x, y: x+y, all_gt_label_voc)))
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_features = np.concatenate(all_features)

100%|██████████| 130/130 [04:39<00:00,  2.15s/it]


In [16]:
(np.array(classnames)[all_instance_voc_topk_ind[:, 0]]==np.array(all_gt_label_voc)).mean()

0.21282717583700522