In [2]:
import sys
sys.path.append('/home/sheng/sssa/')
sys.path.append('/home/sheng/sssa/')
# sys.path.append('/home/sheng/sssa/CLIP/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce, partial
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import matplotlib.pyplot as plt
import heapq
from wordnet_utils import *

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode

from ipynb_utils import get_hier_datasets, get_classifier, MCMF_assign_labels
# import clip
import model as clip
from data.datasets import build_transform, get_hier_datasets, Vocab
from data.imagenet_datasets import get_datasets_oszsl


In [3]:
class Config:
    device = 'cuda:0'
    arch = 'ViT-B/16'
    dataset = 'imagenet21k_1'
    n_sampled_classes = 100
    input_size = 224
    seed = 0
    
    batch_size = 512
    use_def = False
    clip_checkpoint = None
    f_classifier = './cache/wordnet_classifier_in21k_word.pth'
    templates_name = 'templates_small'
    
args = Config()

In [4]:
def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

def get_vocab():
    """
    Args:
        vocab: {`names`: list, `ids`: synset ids, `parents`: [{synset ids}]}
    """
    with open('/home/sheng/dataset/wordnet_nouns_with_synset_4.pkl', 'rb') as f:
        vocab = pickle.load(f)
    return vocab

def get_subsample_vocab(sample_synset_id: set):
    vocab = get_vocab()
    index = np.array([ i for i in range(len(vocab['synsets'])) if vocab['synsets'][i] in sample_synset_id ]).astype(np.int32)
    for k in vocab.keys():
        vocab[k] = np.array(vocab[k])[index].tolist()
    return vocab

def read_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
    return data

def read_lvis_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/lvis_imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
        # data = list(map(lambda x: x.split('.')[0], data))
    return data

def load_clip2(args):
    model = clip.load(args.arch)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model

templates = load_templates(args)
vocab = get_vocab()
nouns = [ wn.synset(s) for s in vocab['synsets'] ]
classnames = vocab['names']
parents = vocab['parents']
defs = vocab['def']


In [5]:
""" build entire wn-graph """
from nxgraph_model import *

with open('/home/sheng/dataset/wordnet_nouns_with_synset.pkl', 'rb') as f:
    entire_vocab = pickle.load(f)
    
G = create_graph([wn.synset(x) for x in entire_vocab['synsets']], entire_vocab['ids'], entire_vocab['names'], entire_vocab['def'])

In [6]:
""" prepare dataset and load CLIP """
classes = read_imagenet21k_classes() + os.listdir('/home/sheng/dataset/imagenet-img/')
classes = [wn.synset_from_pos_and_offset('n', int(x[1:])).name() for x in classes]
classes = set(classes)
if args.dataset == 'lvis':
    classes = read_lvis_imagenet21k_classes()
    classes = set(classes)
vocab = get_subsample_vocab(classes)
vocab = Vocab(vocab=vocab)

transform_val = build_transform(is_train=False, args=args, train_config=None)
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(args.input_size, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=torch.tensor(mean),
        std=torch.tensor(std))
])
dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=0)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=256, shuffle=False)
print('dataset size', len(dataset))

# model, preprocess = load_clip(args)
model = load_clip2(args)

dataset size 102006
missing keys:
[]
Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408


In [7]:
amp_autocast = torch.cuda.amp.autocast

all_vfeatures = []
all_clu_label = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual.extract_features(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                all_vfeatures.append(deepcopy(logits.cpu().numpy()))
                all_clu_label.append(deepcopy(label_clu.numpy()))
        pbar.update(1)

all_vfeatures = np.concatenate(all_vfeatures)
all_clu_label = np.concatenate(all_clu_label)

100%|██████████| 399/399 [02:03<00:00,  3.22it/s]


In [8]:
np.save(f'./cache/features/vfeatures-{args.dataset}.npy', all_vfeatures)

In [9]:
from sklearn.cluster import KMeans, MiniBatchKMeans
from my_util_package_oszsl.evaluation import cluster_acc
K = dataset.num_classes
print(f'K={K}')
print(np.unique(all_clu_label).shape)
# kmeans = MiniBatchKMeans(n_clusters=10*K, batch_size=2048, random_state=0, n_init=10, max_iter=500, verbose=2).fit(all_vfeatures)
# preds = kmeans.labels_

K=100
(100,)


In [None]:
cluster_acc(all_clu_label, preds)

In [10]:
kmeans = KMeans(n_clusters=K, random_state=0, n_init=20, max_iter=1000, verbose=1).fit(all_vfeatures)
preds = kmeans.labels_

Initialization complete
Iteration 0, inertia 46416.8671875
Iteration 1, inertia 28931.96484375
Iteration 2, inertia 28316.2734375
Iteration 3, inertia 28164.236328125
Iteration 4, inertia 28106.234375
Iteration 5, inertia 28070.68359375
Iteration 6, inertia 28047.736328125
Iteration 7, inertia 28035.853515625
Iteration 8, inertia 28028.47265625
Iteration 9, inertia 28023.251953125
Iteration 10, inertia 28020.169921875
Iteration 11, inertia 28018.150390625
Iteration 12, inertia 28016.9921875
Iteration 13, inertia 28016.537109375
Iteration 14, inertia 28015.81640625
Iteration 15, inertia 28015.486328125
Iteration 16, inertia 28015.080078125
Iteration 17, inertia 28014.76171875
Iteration 18, inertia 28014.544921875
Iteration 19, inertia 28014.580078125
Iteration 20, inertia 28014.484375
Iteration 21, inertia 28014.451171875
Iteration 22, inertia 28014.45703125
Iteration 23, inertia 28014.431640625
Iteration 24, inertia 28014.359375
Iteration 25, inertia 28014.38671875
Iteration 26, inerti

In [11]:
cluster_acc(all_clu_label, preds)

0.7017822481030528

In [12]:
np.save(f'./cache/cluster/kmeans-{args.dataset}.npy', preds)

In [None]:
for dataset_name in ['make_entity13', 'imagenet', 'make_entity30', 'make_living17']:
    print(f'cluster {dataset_name}')
    args.dataset_name = dataset_name
    dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_val, seed=0)
    loader_val = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)
    print('dataset size', len(dataset))
    
    ### inference
    amp_autocast = torch.cuda.amp.autocast
    all_vfeatures = []
    all_clu_label = []
    with tqdm(total=len(loader_val)) as pbar:
        model.eval()
        for idx_batch, batch in enumerate(loader_val):
            images, label_voc, label_clu, idx_img = batch
            images = images.to(args.device)
            with amp_autocast():
                with torch.no_grad():
                    logits = model.visual.extract_features(images)
                    logits = logits/logits.norm(dim=-1, keepdim=True)
                    all_vfeatures.append(logits.cpu().numpy())
                    all_clu_label.append(label_clu.numpy())
            pbar.update(1)

    all_vfeatures = np.concatenate(all_vfeatures)
    all_clu_label = np.concatenate(all_clu_label)
    
    
    # K = dataset.num_classes
    K = 2000
    kmeans = MiniBatchKMeans(n_clusters=K, batch_size=2048, 
                             random_state=0, n_init=10, max_iter=500, verbose=2).fit(all_vfeatures)
    preds = kmeans.labels_
    print(cluster_acc(all_clu_label, preds))
    
    kmeans = KMeans(n_clusters=K, random_state=0, n_init=3, max_iter=500, verbose=2).fit(all_vfeatures)
    preds = kmeans.labels_
    print(cluster_acc(all_clu_label, preds))
    np.save(f'./cache/cluster/kmeans-{args.dataset_name}-2k.npy', preds)

In [None]:
sns.distplot(vec_count.topk(k=voc_beta*dataset.num_classes).values.cpu().numpy(), bins=100)

In [None]:
print(f'acc={(all_pred_voc == all_gt_voc).float().mean()}')

upperbound

In [None]:
classifier = get_classifier(args)
use_norm = True
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True) if use_norm else classifier

mask = torch.zeros(classifier.size(0), device=args.device)
mapping_classifier = torch.tensor(sorted(set(dataset.labels)), device=args.device)
mask = torch.scatter(mask, 0, mapping_classifier, 1)
classifier = classifier[mask.bool()]

all_pred_voc = []
all_gt_voc = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True) if use_norm else logits
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                all_pred_voc.append(prob.argmax(dim=-1).cpu())
                all_gt_voc.append(label_voc)
        pbar.update(1)

all_pred_voc = torch.cat(all_pred_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)

all_pred_voc = torch.gather(mapping_classifier.cpu(), 0, all_pred_voc)

print(f'acc={(all_pred_voc == all_gt_voc).float().mean()}')

hierarchy accuracy

In [None]:
classifier = get_classifier(args)
use_norm = True
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True) if use_norm else classifier


all_pred_voc = []
all_gt_voc = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True) if use_norm else logits
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                all_pred_voc.append(prob.argmax(dim=-1).cpu())
                # label_voc = torch.tensor(list(map(lambda x: vocab.mapping_names_idx[x], label_voc)))
                all_gt_voc.append(label_voc)
        pbar.update(1)

all_pred_voc = torch.cat(all_pred_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)

In [None]:
mapping_i2gi = vocab.mapping_idx_global_idx
isin = lambda x, y: np.array([xx in y for xx in x])
all_pred_hier = []
for i in range(len(all_gt_voc)):
    cond1 = isin(mapping_i2gi[all_gt_voc[i].item()], reduce(lambda x,y: x|y, [parents[p] for p in mapping_i2gi[all_pred_voc[i].item()]])).any()
    cond2 = isin(mapping_i2gi[all_pred_voc[i].item()], reduce(lambda x,y: x|y, [parents[p] for p in mapping_i2gi[all_gt_voc[i].item()]])).any()
    pred_hier = cond1 | cond2
    all_pred_hier.append(pred_hier)

In [None]:
np.array(all_pred_hier).mean()

In [None]:
(all_pred_voc==all_gt_voc).float().mean()

KNN performance investigation

In [None]:
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
# classifier = F.normalize(classifier, dim=-1)


In [None]:
similarity = classifier@classifier.T

In [None]:
K = 5
topk_ind = similarity.topk(k=K+1).indices

In [None]:
list(map(lambda x: list(map(lambda y: classnames[y], x)), topk_ind.cpu().numpy().tolist()))

In [None]:
len(classnames)

similarity inspection

In [None]:
classifier = get_classifier(args)
use_norm = True
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

all_sim_topk = []
all_sim_topk_val = []
all_gt_label_voc = []
all_features = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = logits @ classifier.t()
                sim_topk = similarity.topk(k=10)
                all_sim_topk.append(sim_topk.indices.cpu())
                all_sim_topk_val.append(sim_topk.values.cpu())
                all_gt_label_voc.append(label_voc)
                all_features.append(logits.cpu())
        pbar.update(1)

all_sim_topk = torch.cat(all_sim_topk, dim=0)
all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_sim_topk_val = torch.cat(all_sim_topk_val, dim=0)
all_features = np.concatenate(all_features)

In [None]:
np.save(f'./pred_clu-{args.dataset_name}-train-{arch}.npy', pred_clu)
# pred_clu = np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy')

In [None]:
# all_features = all_features.to(args.device)
sim = all_features@all_features.t()
np.save(f'./knn_ind-{args.dataset_name}-train-{arch}.npy', sim.topk(k=300).indices.cpu().numpy())

In [None]:
from my_util_package.evaluation import cluster_acc

cluster_acc(pred_clu, all_labels.numpy())

In [None]:
from my_util_package.evaluation import cluster_acc

cluster_acc(pred_clu, all_labels.numpy())

CLIP clustering

In [None]:
pred_clu = np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy')

In [None]:
""" CLIP clustering acc """

from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc

amp_autocast = torch.cuda.amp.autocast
all_features = []
all_labels = []
all_idx_img = []
model.eval()
with tqdm(total=len(loader_f)) as pbar:
    for batch in loader_f:
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                features = model.visual(images)
                features = F.normalize(features, dim=-1).float()
                features = features/features.norm(dim=-1, keepdim=True)
        all_features.append(features.detach().cpu())
        all_labels.append(label_clu)
        all_idx_img.append(idx_img)
        
        pbar.update(1)
        
all_features = torch.cat(all_features, dim=0)
all_labels = torch.cat(all_labels, dim=0)
all_idx_img = torch.cat(all_idx_img, dim=0)

# cluster_acc(pred_clu, all_labels.numpy())

kmeans = KMeans(n_clusters=len(all_labels.unique()), n_init=100, max_iter=1000, random_state=43)
pred_clu = kmeans.fit_predict(all_features.numpy())

In [None]:
np.save(f'./pred_clu-{args.dataset_name}-train-clip.npy', pred_clu)

Cluster navigation (depends on clustering)

In [None]:
""" topk prediction from CLIP """
classifier = get_classifier(args)
use_norm = True
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True) if use_norm else classifier

prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_labels = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True) if use_norm else logits
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                pred_topk_scattered = torch.scatter(torch.zeros([images.size(0), classifier.size(0)], 
                                                                device=args.device), 1, prob_topk_ind, 1)
                all_topk_voc.append(pred_topk_scattered.cpu())
                all_gt_voc.append(label_voc)
        pbar.update(1)

all_topk_voc = torch.cat(all_topk_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)

In [None]:
from scipy.optimize import linear_sum_assignment as linear_assignment

### per predicted-cluster voting
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy'))

pred_kmeans_t = pred_kmeans
# for it in range(10):
#     print(f'iteration {it}')

# cluster agg
all_clu_pred = []
for i in range(len(all_gt_voc.unique())):
    selected = (pred_kmeans==i)
    clu_pred = F.normalize(all_topk_voc[selected].sum(dim=0), dim=-1, p=1)
    all_clu_pred.append(clu_pred)
all_clu_pred = torch.stack(all_clu_pred, dim=0)

# linear assignment
print('is mutex assignment::', all_clu_pred.argmax(dim=-1).size(0)==all_clu_pred.argmax(dim=-1).unique().size(0))
print('assignment collision num::', len(list(filter(lambda x: x>1, Counter(all_clu_pred.argmax(dim=-1).numpy()).values()))))

cost_mat = all_clu_pred.cpu().numpy()
res_ass = linear_assignment(cost_mat.max() - cost_mat)
label_kmeans_voc = torch.tensor([res_ass[1][x.item()] for x in pred_kmeans])

print('instance label acc::', (label_kmeans_voc==all_gt_voc).float().mean().item())

In [None]:
prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_labels = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True) if use_norm else logits
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                pred_topk_scattered = torch.scatter(torch.zeros([images.size(0), classifier.size(0)], 
                                                                device=args.device), 1, prob_topk_ind, 1)
                all_topk_voc.append(pred_topk_scattered.cpu())
                all_gt_voc.append(label_voc)
        pbar.update(1)

all_topk_voc = torch.cat(all_topk_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)

In [None]:
### subset vocab
col_subset = all_clu_pred.nonzero()[:, 1]
col_subset = col_subset.unique().sort().values

KNN investigation

In [None]:
neighborhood_size = np.arange(5, 500, 5)
similarity = all_features@all_features.T
label_match = all_labels.view(-1, 1)==all_labels.view(1, -1)
for K in neighborhood_size:
    topk_res = similarity.topk(k=K+1)
    topk_ind = topk_res.indices[:, 1:]
    topk_match = torch.gather(label_match, 1, topk_ind)
    topk_acc = topk_match.float().mean(dim=-1).mean()
    print(f'K={K} acc={topk_acc}')

In [None]:
from my_util_package.graph import compute_consensus_on_features
neighborhood_size = np.arange(5, 50, 5)
similarity = all_features@all_features.T
label_match = all_labels.view(-1, 1)==all_labels.view(1, -1)
for K in neighborhood_size:
    _, _, pred_affinity = compute_consensus_on_features(all_features, k=K+1, q=0.8)
    acc = ((pred_affinity & label_match).float().sum(1)/(pred_affinity.float().sum(1)+1e-10)).mean()
    n_nn = pred_affinity.float().sum(1).mean()
    print(f'K={K} acc={acc} n_nn={n_nn}')

In [None]:
neighborhood_size = np.arange(5, 50, 5)
similarity = all_features@all_features.T
label_match = all_labels.view(-1, 1)==all_labels.view(1, -1)
for K in neighborhood_size:
    _, _, pred_affinity = compute_consensus_on_features(all_features, k=K+1, q=0.5)
    acc = ((pred_affinity & label_match).float().sum(1)/(pred_affinity.float().sum(1)+1e-10)).mean()
    n_nn = pred_affinity.float().sum(1).mean()
    print(f'K={K} acc={acc} n_nn={n_nn}')

In [None]:
""" KNN matrix output """
neighborhood_size = 315
similarity = all_features@all_features.T
label_match = (all_labels.view(-1, 1)==all_labels.view(1, -1))
K = neighborhood_size
topk_res = similarity.topk(k=K+1)
topk_ind = topk_res.indices

torch.save(topk_ind, f'./cache/{args.dataset_name}-clip-knn-{neighborhood_size}.pth')

In [None]:
weight_normalize = lambda x: x/x[:, 0].view(-1, 1)

In [None]:
idx = 3
weight_normalize(instance_weight)[(instance_pred[:, 0]==all_gt_voc)][:, idx].mean(), \
weight_normalize(instance_weight)[(instance_pred[:, idx]==all_gt_voc)][:, idx].mean()

spatial features reweighting and clustering

In [None]:
""" visual reranking computation based on spatial features """
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
instance_topk_voclabel_by_scd = all_clu_pred.topk(k=5).indices.index_select(0, record_pred_kmeans_t)
all_spatial_label_pred = []
all_label_voc = []
all_label_match_rerank = []
all_label_pred = []
all_rerank_pred_voc = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                features = model.visual(images, return_spatial=True)
                features = features/features.norm(dim=-1, keepdim=True)
                spatial_similarity = model.logit_scale.exp() * (features @ classifier[instance_topk_voclabel_by_scd[idx_img]].permute(0,2,1))
                spatial_label_pred = spatial_similarity[:, 1:, :].topk(k=10, dim=1).values.mean(dim=1).argmax(dim=-1)
                all_spatial_label_pred.append(spatial_label_pred.cpu())
                all_label_voc.append(label_voc)
                
                ### global-spatial reranking
                # global_label_attn = spatial_similarity[:, 0, :].softmax(dim=-1)
                # global_spatial_mixed_sim_after_scaling = (spatial_similarity[:, 0, :].unsqueeze(1)*spatial_similarity[:, 1:, :])/100
                # topk_spatial_sim_ind = spatial_similarity[:, 1:, :].topk(k=10, dim=1).indices
                # spatial_label_attn = torch.gather(global_spatial_mixed_sim_after_scaling , 1, topk_spatial_sim_ind ).mean(dim=1).softmax(dim=-1)
                # ind_increment = torch.arange(idx_img.size(0), device=args.device)
                # global_spatial_mixed_sim_argmax = (global_label_attn.pow(0.75)*spatial_label_attn.pow(0.25)).argmax(dim=-1)
                # GSRerank_pred_voc = instance_topk_voclabel_by_scd[idx_img][ind_increment, global_spatial_mixed_sim_argmax]
                # label_match_rerank = GSRerank_pred_voc==label_voc
                ### global-attention based spatial voting
                global_spatial_attn = features[:, 0, :].unsqueeze(1) @ features[:, 1:, :].permute(0,2,1)
                topk_spatial_ind = global_spatial_attn.topk(k=10).indices
                topk_spatial_features = torch.gather(features[:, 1:, :], 1, topk_spatial_ind)
                sim_topk_spatial_features = \
                    model.logit_scale.exp() * (topk_spatial_features @ classifier.unsqueeze(0).permute(0,2,1))
                
                
            all_label_match_rerank.append(label_match_rerank.cpu())
            all_label_pred.append(global_label_attn.argmax(dim=-1).cpu())
            all_rerank_pred_voc.append(GSRerank_pred_voc.cpu())
        pbar.update(1)

all_spatial_label_pred = torch.cat(all_spatial_label_pred, dim=0)
all_label_voc = torch.cat(all_label_voc, dim=0)
all_label_match_rerank = torch.cat(all_label_match_rerank, dim=0)
all_label_pred = torch.cat(all_label_pred, dim=0)
all_rerank_pred_voc = torch.cat(all_rerank_pred_voc, dim=0)

In [None]:
print(f'GSReranked instance Acc:: all_label_match_rerank={all_label_match_rerank.float().mean()}')
print(f'SCD:: instance_topk_voclabel_by_scd={(instance_topk_voclabel_by_scd[:, 0]==all_label_voc).float().mean()}')
print(f'GSR missing label:: N={len(set(all_gt_voc.unique().cpu().numpy()) - set(all_rerank_pred_voc.unique().cpu().numpy()))}')

In [None]:
""" visual spatial features 
- KNN difference between CLS and tokens
"""
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
instance_topk_voclabel_by_scd = all_clu_pred.topk(k=5).indices.index_select(0, record_pred_kmeans_t)

method = ['cls-spatial-voting', 
          'cls-spatial-classifier-similarity-inspect',
          'cls'][1]
all_global_spatial_features = []
all_labels_voc_gt = []
all_scdknn_classifier_features = []
all_entire_spatial_voting = []
all_all_voting_voc_ind = []
all_pred_voc_label = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                features = model.visual(images, return_spatial=True)
                features = features/features.norm(dim=-1, keepdim=True)
                
                if method == 'cls-spatial-voting':
                    # scdknn_classifier_features = classifier[instance_topk_voclabel_by_scd[idx_img]]
                    cls_spatial_sim = (features[:, 0, :].unsqueeze(1) @ features[:, 1:, :].permute(0,2,1))
                    cls_knn_token = cls_spatial_sim.topk(k=10).indices
                    token_similarity = model.logit_scale.exp() * (features @ classifier.unsqueeze(0).permute(0,2,1))
                    knn_token = token_similarity.topk(k=5).indices
                    voting_voc_ind = torch.gather(knn_token, 1, cls_knn_token.permute(0,2,1).repeat(1,1,5)).flatten(1)
                    all_voting_voc_ind = []
                    for i in range(voting_voc_ind.size(0)):
                        val, ind = voting_voc_ind[i].unique(return_counts=True)
                        all_voting_voc_ind.append(val[ind.topk(k=5).indices].cpu())
                    all_voting_voc_ind = torch.stack(all_voting_voc_ind, dim=0)
                    all_all_voting_voc_ind.append(all_voting_voc_ind)
                elif method == 'cls-spatial-classifier-similarity-inspect': ### corrected, consider projection head
                    n_similar_token = 20
                    n_vote = 5
                    token_similarity = model.logit_scale.exp() * (features @ classifier.unsqueeze(0).permute(0,2,1)) ### B x L+1 x V
                    knn_token = token_similarity.topk(k=n_vote).indices ### B x L+1 x n_vote
                    cls_spatial_sim = (features[:, 0, :].unsqueeze(1) @ features[:, 1:, :].permute(0,2,1))
                    cls_knn_token = cls_spatial_sim.topk(k=n_similar_token).indices + 1 ### B x 1 x n_similar_token
                    # knn_token.gather(1, cls_knn_token)
                    all_voting_voc_ind = []
                    for i in range(idx_img.size(0)):
                        val, count = knn_token[i, cls_knn_token[i, 0, :], :].flatten().unique(return_counts=True) ### n_similar_token x n_vote
                        all_voting_voc_ind.append(val[count.topk(k=n_vote).indices].cpu())
                    all_voting_voc_ind = torch.stack(all_voting_voc_ind, dim=0)
                    all_all_voting_voc_ind.append(all_voting_voc_ind)
                    
                    all_pred_voc_label.append((features[:, 0, :]@classifier.t()).argmax(dim=-1).cpu())
                    
                elif method == 'cls':
                    similarity = features[:, 0, :]@classifier.t()
                    all_pred_voc_label.append(similarity.argmax(dim=-1).cpu())
                # entire_spatial_voting = (features[:, 0:, :] @ classifier.unsqueeze(0).permute(0,2,1)).topk(k=5).indices
                # torch.scatter_add(p, 1, entire_spatial_voting.flatten(1), torch.ones_like(entire_spatial_voting.flatten(1)).float()).argmax(dim=-1)
#             all_global_spatial_features.append(features.cpu())
            all_labels_voc_gt.append(label_voc)
#             all_scdknn_classifier_features.append(scdknn_classifier_features.cpu().numpy())
#             # all_entire_spatial_voting.append(entire_spatial_voting.cpu())
        pbar.update(1)

# all_global_spatial_features = torch.cat(all_global_spatial_features, dim=0)
all_labels_voc_gt = torch.cat(all_labels_voc_gt, dim=0)
# all_scdknn_classifier_features = np.concatenate(all_scdknn_classifier_features)
# # all_entire_spatial_voting = torch.cat(all_entire_spatial_voting, dim=0)
all_all_voting_voc_ind = torch.cat(all_all_voting_voc_ind, dim=0)
all_pred_voc_label = torch.cat(all_pred_voc_label, dim=0)

In [None]:
""" offline clustering for visual spatial features """
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

cluster_kmeans = {}
cluster_spectral = {}
n_clusters = 10
spectral = SpectralClustering(n_clusters=n_clusters, affinity='precomputed')
kmeans = KMeans(n_clusters=n_clusters)
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                features = model.visual(images, return_spatial=True)
                features = features/features.norm(dim=-1, keepdim=True)
        batch_size = features.size(0)
        for i in range(batch_size):
            self_sim = features[i]@features[i].t()
            pred_spectral = spectral.fit_predict(self_sim.cpu().numpy())
            pred_kmeans = kmeans.fit_predict(features[i].cpu().numpy())
        cluster_kmeans[idx_img[i]] = pred_kmeans
        cluster_spectral[idx_img[i]] = pred_spectral
        
        pbar.update(1)



In [None]:
from sklearn.cluster import SpectralClustering, KMeans
from sklearn.manifold import TSNE

In [None]:
""" clustering performance comparison """
n_clusters = 10
idx = np.random.randint(low=0, high=510, size=1)
spectral = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', n_init=30)
begin = time.time()
pred_spectral = spectral.fit_predict(self_sim[idx].cpu().numpy())
end = time.time()
print(f'time={end-begin}')

kmeans = KMeans(n_clusters=n_clusters)
begin = time.time()
pred_kmeans = kmeans.fit_predict(features[idx].cpu().numpy())
end = time.time()
print(f'time={end-begin}')


tsne = TSNE(n_components=2, verbose=0)
tsne_features_tr = tsne.fit_transform(features[idx].cpu().numpy())

plt.figure(dpi=64)
plt.scatter(x=tsne_features_tr[:, 0], y=tsne_features_tr[:, 1], c=pred_spectral, s=5, alpha=0.6) ### SPATIAL
plt.scatter(x=tsne_features_tr[pred_spectral==pred_spectral[0], 0], y=tsne_features_tr[pred_spectral==pred_spectral[0], 1], c='r', s=5) ### CLS
plt.show()

plt.figure(dpi=64)
plt.scatter(x=tsne_features_tr[:, 0], y=tsne_features_tr[:, 1], c=pred_kmeans, s=5, alpha=0.6) ### SPATIAL
plt.scatter(x=tsne_features_tr[pred_kmeans==pred_kmeans[0], 0], y=tsne_features_tr[pred_kmeans==pred_kmeans[0], 1], c='r', s=5) ### CLS
plt.show()

In [None]:
self_sim.mean(dim=[1,2]).mean(), self_sim.std(dim=[1,2]).mean(), self_sim_classifier.mean(dim=[0,1]), self_sim_classifier.std(dim=[0,1])

In [None]:
p = torch.zeros([512, classifier.size(0)], device=args.device)
torch.scatter_add(p, 1, entire_spatial_voting.flatten(1), torch.ones_like(entire_spatial_voting.flatten(1)).float()).argmax(dim=-1), entire_spatial_voting.device

In [None]:
### dimensionality reduction with TSNE
from sklearn.manifold import TSNE

np.random.seed(2)
idx = np.random.randint(low=0, high=len(all_global_spatial_features), size=[1])[0]
image_features = all_global_spatial_features[idx]
knn_classifier_features = torch.from_numpy(all_scdknn_classifier_features[idx])
Nimg = image_features.size(0)
tsne_features_input = torch.cat([image_features, knn_classifier_features], dim=0)
tsne = TSNE(n_components=2, verbose=2)
tsne_features_tr = tsne.fit_transform(tsne_features_input.numpy())
image_features_tr = tsne_features_tr[:Nimg]
knn_classifier_features_tr = tsne_features_tr[Nimg:]

In [None]:
plt.figure(dpi=128)
plt.scatter(x=image_features_tr[1:, 0], y=image_features_tr[1:, 1], c='g', s=5, alpha=0.6) ### SPATIAL
plt.scatter(x=image_features_tr[0, 0], y=image_features_tr[0, 1], c='r', s=5) ### CLS
plt.scatter(x=knn_classifier_features_tr[:, 0], y=knn_classifier_features_tr[:, 1], c='b', s=5)
plt.show()

In [None]:
np.random.seed(0)
idx = np.random.randint(low=0, high=len(all_global_spatial_features), size=[1])[0]
image_features = all_global_spatial_features[idx]
knn_classifier_features = torch.from_numpy(all_scdknn_classifier_features[idx])

Nimg = image_features.size(0)
N2 = knn_classifier_features.size(0)
all_knn_classifier_features = classifier[(image_features.to(args.device) @ classifier.t()).topk(k=5).indices.flatten().unique()].cpu()
tsne_features_input = torch.cat([image_features, knn_classifier_features, all_knn_classifier_features], dim=0)
tsne = TSNE(n_components=2, verbose=2)
tsne_features_tr = tsne.fit_transform(tsne_features_input.numpy())
image_features_tr = tsne_features_tr[:Nimg]
knn_classifier_features_tr = tsne_features_tr[Nimg:Nimg+N2]
all_knn_classifier_features_tr = tsne_features_tr[Nimg+N2:]

plt.figure(dpi=128)
plt.scatter(x=image_features_tr[1:, 0], y=image_features_tr[1:, 1], c='g', s=5, alpha=0.6) ### SPATIAL
plt.scatter(x=image_features_tr[0, 0], y=image_features_tr[0, 1], c='r', s=5) ### CLS
plt.scatter(x=all_knn_classifier_features_tr[:, 0], y=all_knn_classifier_features_tr[:, 1], c='y', s=5) ### spatial KNN
plt.scatter(x=knn_classifier_features_tr[:, 0], y=knn_classifier_features_tr[:, 1], c='b', s=5) ### cluster KNN
plt.show()

In [None]:
spatial_knn, spatial_knn_counts = (image_features.to(args.device) @ classifier.t()).topk(k=5).indices[:, 0].unique(return_counts=True)
(spatial_knn==all_labels_voc_gt[idx]).nonzero(), spatial_knn_counts

SCD with shrinked vocab

In [None]:
""" SCD with shrinked vocab """
classifier_selected = None
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast

### collect variables
prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_label_clu = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier[classifier_selected].t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                ### mapping @selected to vocab ind
                B, C = prob_topk_ind.shape
                prob_topk_ind = classifier_selected[prob_topk_ind.view(-1)].view(B, C)
                all_topk_voc.append(prob_topk_ind.cpu().numpy())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
        pbar.update(1)

all_topk_voc = np.concatenate(all_topk_voc)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)

# pred_kmeans = torch.from_numpy(np.load(f'/home/sheng/OSZSL/ipynb/pred_clu-{args.dataset_name}-train-vit_dino-dino_stage1.npy'))
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-vit_dino.npy'))
pred_kmeans_t = pred_kmeans
history_set_pred = []
for t in range(3):
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
    pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, all_prob=None)
    set_pred = set(res_ass[1].tolist())
    set_gt = set(all_gt_voc.unique().numpy().tolist())
    print('missing label::', len(set_gt - set_pred))
    print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
    history_set_pred.append(set_pred)
    


In [None]:
# final_all_clu_pred = all_clu_pred
# len(set_gt - set(final_all_clu_pred.topk(k=2).indices.flatten().unique().numpy().tolist()))

# len(set_gt - set(all_clu_pred.topk(k=3).indices.flatten().unique().numpy().tolist()))

# select_correct = (cluster_ind_voc.cpu()==all_gt_voc)

# all_topk_val = torch.from_numpy(all_topk_val)#[select_correct]
# prob_all_topk_val = torch.cat([all_topk_val, 1-all_topk_val.sum(dim=-1, keepdim=True)], dim=-1)

# ent = - (prob_all_topk_val * (prob_all_topk_val+1e-30).log()).sum(dim=-1)

# # import seaborn as sns
# # sns.distplot(prob_all_topk_val[select_correct, 0], bins=100)
# # sns.distplot(prob_all_topk_val[~select_correct, 0], bins=100)
# # sns.scatterplot(x=prob_all_topk_val[:, 0], y=select_correct.float(), s=3, alpha=0.6)

collect variables

In [None]:
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_label_clu = []
all_topk_val = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                topk_res = prob.topk(k=prob_k, dim=-1)
                prob_topk_ind = topk_res.indices
                all_topk_voc.append(prob_topk_ind.cpu().numpy())
                all_topk_val.append(topk_res.values.cpu().numpy())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
        pbar.update(1)

all_topk_voc = np.concatenate(all_topk_voc)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)
all_topk_val = np.concatenate(all_topk_val)

confidence threshold

In [None]:
# classifier = get_classifier(args)
# classifier = classifier/classifier.norm(dim=-1, keepdim=True)
# args.num_voc = classifier.size(0)
# amp_autocast = torch.cuda.amp.autocast
# ### collect variables
# prob_k = 5
# all_topk_voc = []
# all_gt_voc = []
# all_label_clu = []
# with tqdm(total=len(loader_f)) as pbar:
#     if hasattr(model, 'eval'):
#         model.eval()
#     for idx_batch, batch in enumerate(loader_f):
#         images, label_voc, label_clu, idx_img = batch[:4]
#         images = images.to(args.device)
#         with amp_autocast():
#             with torch.no_grad():
#                 logits = model.visual(images)
#                 logits = logits/logits.norm(dim=-1, keepdim=True)
#                 similarity = 100 * logits @ classifier.t()
#                 prob = similarity.softmax(-1)
#                 prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
#                 all_topk_voc.append(prob_topk_ind.cpu().numpy())
#                 all_gt_voc.append(label_voc)
#                 all_label_clu.append(label_clu)
#         pbar.update(1)

# all_topk_voc = np.concatenate(all_topk_voc)
# all_gt_voc = torch.cat(all_gt_voc, dim=0)
# all_label_clu = torch.cat(all_label_clu, dim=0)

use_confidence = True
th_confidence = 0.5
pred_kmeans = torch.from_numpy(np.load(f'/home/sheng/OSZSL/ipynb/pred_clu-{args.dataset_name}-train-vit_dino.npy'))

if use_confidence:
    # ### SSL feature extraction
    # ssl_prototypes = torch.zeros([pred_kmeans.unique().size(0), 768], device=args.device, dtype=torch.float64) ### C x D
    # ssl_counter = torch.zeros(pred_kmeans.unique().size(0))
    # with tqdm(total=len(loader_f)) as pbar:
    #     modelf.eval()
    #     modelf.to(args.device)
    #     for idx_batch, batch in enumerate(loader_f):
    #         images, label_voc, label_clu, idx_img = batch[:4]
    #         images = images.to(args.device)
    #         with torch.no_grad():
    #             features = modelf(images.float())
    #             features = F.normalize(features, dim=-1)
    #             for p in range(idx_img.size(0)):
    #                 ssl_prototypes[pred_kmeans[p].long()] += features.to(torch.float64)[p]
    #             # ssl_prototypes = torch.scatter_add(ssl_prototypes, 0, pred_kmeans[idx_img.long()].to(args.device).long(), features.to(torch.float64))
    #             counter_voc_ind, counter_val = pred_kmeans[idx_img].unique(return_counts=True)
    #             ssl_counter[counter_voc_ind.long()] += counter_val
    #         pbar.update(1)
    # ssl_prototypes = ssl_prototypes/ssl_counter.to(args.device).unsqueeze(-1)
    # ssl_prototypes = F.normalize(ssl_prototypes, dim=-1)
    # ### select confident instances
    # all_prob = []
    # all_sim = []
    # with tqdm(total=len(loader_f)) as pbar:
    #     modelf.eval()
    #     modelf.to(args.device)
    #     for idx_batch, batch in enumerate(loader_f):
    #         images, label_voc, label_clu, idx_img = batch[:4]
    #         images = images.to(args.device)
    #         with torch.no_grad():
    #             features = modelf(images)
    #             features = F.normalize(features, dim=-1)
    #             sim = features@ssl_prototypes.float().t()
    #             prob = (sim/1.0).amax(dim=-1)
    #             all_prob.append(prob.cpu())
    #             all_sim.append(sim.cpu())
    #         pbar.update(1)
    # all_prob = torch.cat(all_prob, dim=0)
    # all_sim = torch.cat(all_sim, dim=0)
    ### confidence thresholding
    q = np.quantile(all_prob.numpy(), q=0.5)
    selected = (all_prob>q)
    ### computing
    pred_kmeans_t = pred_kmeans[selected]
    for t in range(3):
        all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc[selected], voc_size=args.num_voc)
        label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc[selected])
        pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, 
                                                                  all_prob=None, instance_selected=selected)
        set_pred = set(res_ass[1].tolist())
        set_gt = set(all_gt_voc.unique().numpy().tolist())
        print('missing label::', len(set_gt - set_pred))
        print('cluster acc', cluster_acc(y_true=all_label_clu[selected].numpy(), y_pred=pred_kmeans_t.numpy()))
else:
    pred_kmeans_t = pred_kmeans
    for t in range(3):
        all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
        label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
        pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, all_prob=None)
        set_pred = set(res_ass[1].tolist())
        set_gt = set(all_gt_voc.unique().numpy().tolist())
        print('missing label::', len(set_gt - set_pred))
        print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))

In [None]:
# """ inspect cluster topk assigned classes """
# topk_cluster_label = all_clu_pred.topk(k=5).indices

In [None]:
# # pred_kmeans_t = pred_kmeans
# # for t in range(5):
# #     all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t, all_topk_voc)
# #     label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
# #     pred_kmeans_t = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args, all_prob=None)
# #     set_pred = set(res_ass[1].tolist())
# #     set_gt = set(all_gt_voc.unique().numpy().tolist())
# #     print('missing label::', len(set_gt - set_pred))
# #     print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))


# """ get confident prediction """
# th = 0.5
# amp_autocast = torch.cuda.amp.autocast
# label_voc_kmeans_t = label_voc_kmeans_t.to(args.device)
# cluster_ind = []
# selected_ind = []
# with tqdm(total=len(loader_f)) as pbar:
#     if hasattr(model, 'eval'):
#         model.eval()
#     for idx_batch, batch in enumerate(loader_f):
#         images, label_voc, label_clu, idx_img = batch[:4]
#         images = images.to(device)
#         with amp_autocast():
#             with torch.no_grad():
#                 logits = model.visual(images)
#                 logits = logits/logits.norm(dim=-1, keepdim=True)
#                 similarity = 100 * logits @ classifier.t()
#                 prob = similarity[:, label_voc_kmeans_t].softmax(dim=-1)
#                 selected = (prob.amax(dim=-1)>th)
#                 selected_ind.append(selected.cpu())
#         pbar.update(1)
# selected_ind = torch.cat(selected_ind, dim=0)



# # precision = cluster_acc(y_true=all_label_clu[selected_ind].numpy(), y_pred=pred_kmeans_t[selected_ind].numpy())
# # recall = selected_ind.mean()
# print(f'confidence selection precision={precision} recall={recall}')

# # np.save(f'./pred_clu_clip-{args.dataset_name}-train-{arch}.npy', pred_kmeans_t.cpu().numpy())

In [None]:
# """ 
# 1. inverse entropy of prototype

# 2. top1 sim of proto-image

# 3. top1 sim of image-proto

# """
# # candidate_ind = res_ass[1].unique()
# # cls_proto_similarity = torch.zeros([len(dataset_f), candidate_ind.size()])
# all_sim_proto_image_pred = []
# all_sim_proto_image_gt = []
# with tqdm(total=len(loader_f)) as pbar:
#     model.eval()
#     for idx_batch, batch in enumerate(loader_f):
#         images, label_voc, label_clu, idx_img = batch
#         images = images.to(args.device)
#         label_voc = label_voc.to(args.device)
#         with amp_autocast():
#             with torch.no_grad():
#                 logits = model.visual(images)
#                 logits = logits/logits.norm(dim=-1, keepdim=True)
#                 similarity = model.logit_scale.exp() * logits @ classifier.t()
#                 prob = similarity.softmax(dim=-1)
#                 all_sim_proto_image_pred.append(similarity[:, prob.argmax(dim=-1)].cpu())
#                 all_sim_proto_image_gt.append(similarity[:, label_voc].cpu())
#         pbar.update(1)
        
# all_sim_proto_image_pred = torch.cat(all_sim_proto_image_pred, dim=0)
# all_sim_proto_image_gt = torch.cat(all_sim_proto_image_gt, dim=0)

In [None]:
# all_sim_proto_image_pred = torch.cat(all_sim_proto_image_pred, dim=0)
# all_sim_proto_image_gt = torch.cat(all_sim_proto_image_gt, dim=0)



# label_match = all_label_clu.view(-1, 1)@all_label_clu.view(1, -1)
# pred_match_init = pred_kmeans.view(-1, 1)@pred_kmeans.view(1, -1)
# pred_match = pred_kmeans_t.view(-1, 1)@pred_kmeans_t.view(1, -1)

# pred_consensus = (pred_match_init==pred_match) 
# ((pred_consensus & label_match).float().sum(dim=-1) / (pred_consensus.sum(dim=-1)+1e-20)).mean()

# (pred_consensus & label_match).float().sum(dim=-1).bool().float().mean()

# all_clu_pred

In [None]:
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_label_clu = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_topk_voc.append(prob_topk_ind.cpu().numpy())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
        pbar.update(1)

all_topk_voc = np.concatenate(all_topk_voc)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)

# ### MCMF
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-vit_dino.npy'))
all_clu_pred = agg_by_pred_cluster(args, pred_kmeans.numpy(), all_topk_voc, voc_size=args.num_voc)
class_topk_assignment = MCMF_assign_labels(all_clu_pred.cpu().numpy(), K=2)

# ### collect variables
# prob_k = 5
# all_mcmf_rerank_pred = []
# all_gt_voc = []
# all_label_clu = []
# with tqdm(total=len(loader_f)) as pbar:
#     if hasattr(model, 'eval'):
#         model.eval()
#     for idx_batch, batch in enumerate(loader_f):
#         images, label_voc, label_clu, idx_img = batch[:4]
#         images = images.to(args.device)
#         with amp_autocast():
#             with torch.no_grad():
#                 logits = model.visual(images)
#                 logits = logits/logits.norm(dim=-1, keepdim=True)
                
#                 valid_classifier_ind = class_topk_assignment[pred_kmeans[idx_img].long()].to(args.device)
#                 bb, kk = valid_classifier_ind.size()
#                 valid_classifier = classifier[valid_classifier_ind.flatten()].view(bb, kk, -1).permute(0,2,1)
                
#                 similarity = 100 * logits.unsqueeze(1) @ valid_classifier
#                 prob = similarity.softmax(-1)
                
#                 all_mcmf_rerank_pred.append(valid_classifier_ind[prob.argmax(dim=-1)].cpu().numpy())
#                 all_gt_voc.append(label_voc)
#                 all_label_clu.append(label_clu)
#         pbar.update(1)

# all_mcmf_rerank_pred = np.concatenate(all_mcmf_rerank_pred)
# all_gt_voc = torch.cat(all_gt_voc, dim=0)
# all_label_clu = torch.cat(all_label_clu, dim=0)
        
# instance_assignment_pred = torch.zeros(all_mcmf_rerank_pred.shape[0])
# for c in pred_kmeans.unique():
#     select = (pred_kmeans==c)
#     unique_ind, unique_count = torch.from_numpy(all_mcmf_rerank_pred[select]).unique(return_counts=True)
#     instance_assignment_pred[select] = unique_ind[unique_count.argsort()[-1]].item()
    

In [None]:
### class-wise assignment to instance prediction
instance_assignment_pred = torch.zeros(pred_kmeans.size(0), class_topk_assignment.size(1)).to(class_topk_assignment.device).long()
for c in pred_kmeans.unique():
    select = (pred_kmeans==c)
    instance_assignment_pred[select] = class_topk_assignment[c].view(-1, class_topk_assignment.size(1))

all_mcmf_instance_pred = []
all_gt_voc = []
all_label_clu = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                
                valid_classifier_ind = instance_assignment_pred[idx_img].long().to(args.device)
                bb, kk = valid_classifier_ind.size()
                valid_classifier = classifier[valid_classifier_ind.flatten()].view(bb, kk, -1).permute(0,2,1)
                
                similarity = 100 * logits.unsqueeze(1) @ valid_classifier
                prob = similarity.softmax(-1)
                
                all_mcmf_instance_pred.append(valid_classifier_ind[torch.arange(valid_classifier_ind.size(0)), 
                                                                   prob.argmax(dim=-1).squeeze(-1)].cpu().numpy())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
        pbar.update(1)
    
all_mcmf_instance_pred = np.concatenate(all_mcmf_instance_pred)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)

# pred_kmeans_t = pred_kmeans

# history_set_pred = []
# for t in range(3):
#     record_pred_kmeans_t = pred_kmeans_t
#     all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
#     label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
#     pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, all_prob=None)
#     set_pred = set(res_ass[1].tolist())
#     set_gt = set(all_gt_voc.unique().numpy().tolist())
#     print('missing label::', len(set_gt - set_pred))
#     print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
#     history_set_pred.append(set_pred)

In [None]:
instance_assignment_pred = torch.zeros(pred_kmeans.size(0)).to(pred_kmeans.device).long()
cluster_assignment_argmax = all_clu_pred.argmax(dim=-1)
for c in pred_kmeans.unique():
    select = (pred_kmeans==c)
    instance_assignment_pred[select] = cluster_assignment_argmax[c]

In [None]:
instance_assignment_pred = torch.zeros(pred_kmeans.size(0), class_topk_assignment.size(1)).to(class_topk_assignment.device).long()
for c in pred_kmeans.unique():
    select = (pred_kmeans==c)
    instance_assignment_pred[select] = class_topk_assignment[c].view(-1, class_topk_assignment.size(1))

In [None]:
((instance_assignment_pred[:, 0]==all_gt_voc) | (instance_assignment_pred[:, 1]==all_gt_voc)).float().mean()

In [None]:
(torch.from_numpy(all_mcmf_instance_pred)==all_gt_voc).float().mean(), (instance_assignment_pred==all_gt_voc).float().mean()

In [None]:
(instance_assignment_pred==all_gt_voc).float().mean(), len(set(all_gt_voc.unique().numpy()) - set(class_topk_assignment.unique().numpy()))

In [None]:
class_topk_assignment.flatten().unique().long(), all_gt_voc

In [None]:
np.unique(all_mcmf_rerank_pred.squeeze(1), axis=0).shape

In [None]:
torch.isin(all_gt_voc.unique(), instance_assignment_pred.unique().long()).sum()

In [None]:
unique_ind, unique_count = torch.from_numpy(all_mcmf_rerank_pred[select]).unique(return_counts=True)

In [None]:
unique_ind[unique_count.argsort()[-1]].item(), unique_count.max()

In [None]:
prob.shape

test for MCMF

In [None]:
K = 3
class_topk_assignment = MCMF_assign_labels(all_clu_pred.cpu().numpy(), K=K)


In [None]:
""" overlap with SCD linear assignment prediction 
NOTE: MCMF is not ordered prediction
"""
for i in range(K):
    overlap = (all_clu_pred.argmax(dim=-1).cpu()==class_topk_assignment[:, i]).sum()
    print(overlap.item())

In [None]:
print('missing label:', len(set_gt - set(class_topk_assignment.unique().numpy())))

In [None]:
""" reranking with voting similarity """
class_topk_assignment_ordered = torch.gather(class_topk_assignment, 1, torch.gather(all_clu_pred, 1, class_topk_assignment).argsort(descending=True))

In [None]:
(class_topk_assignment_ordered[record_pred_kmeans_t][:, 1]==all_gt_voc).float().mean() #, (class_topk_assignment_ordered[record_pred_kmeans_t][:, 0]==cluster_ind_voc.cpu()).float().mean()

In [None]:
from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]
modelf = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
arch = 'vit_dino'

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f)
if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=1)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=1)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=8, batch_size=args.batch_size, shuffle=False)
dataset_r = get_datasets_oszsl(args, vocab, is_train=True, transform=None, seed=1)

if subset == 'train':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy'))
elif subset == 'val':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-val-{arch}.npy'))
    
model, preprocess = clip.load(args.arch)
if args.clip_checkpoint:
    model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
model.to(args.device).eval()

In [None]:
model, preprocess = clip.load(args.arch)
if args.clip_checkpoint:
    model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
model.to(args.device).eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

#### collect variables

In [None]:
""" topk prediction from CLIP """
classifier = get_classifier(args)
use_norm = True
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True) if use_norm else classifier

# initial topK prediction from CLIP
prob_k = 5
all_topk_voc = []
all_gt_voc = []
all_prob = []
all_max_ind = []
all_topk_vocinds = []
all_label_clu = []
all_topk_vals = []
all_topk_inds = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True) if use_norm else logits
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                pred_topk_scattered = torch.scatter(torch.zeros([images.size(0), classifier.size(0)], 
                                                                device=args.device), 1, prob_topk_ind, 1)
                all_topk_voc.append(pred_topk_scattered.cpu())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
                all_max_ind.append(prob.argmax(dim=-1).cpu())
                all_topk_vocinds.append(prob.topk(k=10, dim=-1).indices.cpu())
                
                batch_topk_res = prob.topk(k=20, dim=-1)
                all_topk_vals.append(batch_topk_res.values.cpu())
                all_topk_inds.append(batch_topk_res.indices.cpu())
        pbar.update(1)

# all_prob = torch.cat(all_prob, dim=0)
all_topk_voc = torch.cat(all_topk_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)
all_max_ind = torch.cat(all_max_ind, dim=0)
all_topk_vocinds = torch.cat(all_topk_vocinds, dim=0)
all_topk_vals = torch.cat(all_topk_vals, dim=0)
all_topk_inds = torch.cat(all_topk_inds, dim=0)

#### text proto inspection

In [None]:
### KNN proto analysis
text_sim = classifier[all_gt_voc.unique(), :]@classifier.t()

In [None]:
text_topk = text_sim.topk(k=20)

In [None]:
pprint(np.array([vocab.mapping_idx_names[t.item()] for t in text_topk.indices[:, :].flatten().cpu()]).reshape(text_sim.size(0), -1).tolist(), compact=True)

In [None]:
import seaborn as sns 

sns.distplot(all_topk_vals[:, 0].cpu().numpy(), bins=200)
sns.distplot(all_topk_vals[:, 1].cpu().numpy(), bins=200)
sns.distplot(all_topk_vals[:, 2].cpu().numpy(), bins=200)

In [None]:
selected = (all_topk_vals[:, 0]>0.7)
'top1 acc', (all_gt_voc[selected] == all_max_ind[selected]).float().mean(), \
'selected percentile', selected.float().mean(), \
'class diversity', len(set(all_gt_voc.unique().numpy()) - set(all_max_ind[selected].unique().numpy())), \
'topk inclusion', torch.stack([all_topk_vocinds[selected, i]==all_gt_voc[selected] for i in range(all_topk_vocinds.size(1))], dim=1).float().sum(dim=-1).bool().float().mean(), \
'selected sample pred voc size', len(all_max_ind[selected].unique()), \
'average selected instance number per class', selected.sum()/len(all_gt_voc.unique()),


In [None]:
""" initial clip assignment """
list(filter(lambda x: x[1]<100, [(i.item(), (all_max_ind==i).sum().item()) for i in all_gt_voc.unique()]))

#### text to image entropy

In [None]:
amp_autocast = torch.cuda.amp.autocast
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

### get all label and predicted label
all_label_voc = []
all_pred_voc = []
all_label_clu = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                pred_voc = prob.argmax(dim=-1)
        all_label_voc.append(label_voc)
        all_pred_voc.append(pred_voc.cpu())
        all_label_clu.append(label_clu)
        pbar.update(1)
all_label_voc = torch.cat(all_label_voc, dim=0)
all_pred_voc = torch.cat(all_pred_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)

                
### compute entropy
set_all_label_voc = all_label_voc.unique()
set_all_pred_voc = all_pred_voc.unique()
selected_classifier_ind = torch.cat([set_all_label_voc, set_all_pred_voc], dim=0).unique()
all_similarity = []
all_selected_sim = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        label_voc = label_voc.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                all_selected_sim.append(similarity[:, selected_classifier_ind].cpu())
                all_similarity.append(similarity.cpu().numpy())
        pbar.update(1)
        
all_selected_sim = torch.cat(all_selected_sim, dim=0)
all_similarity = np.concatenate(all_similarity)

In [None]:
pred_kmeans_t = torch.from_numpy(np.load(f'./pred_clu_clip-{args.dataset_name}-train-{arch}.npy'))
classwise_all_selected_sim = []
for c in pred_kmeans_t.unique():
    subset = (pred_kmeans_t==c)
    classwise_all_selected_sim.append(all_selected_sim[subset, :].mean(dim=0).cpu())
classwise_all_selected_sim = torch.stack(classwise_all_selected_sim, dim=0)
p = classwise_all_selected_sim.float().softmax(dim=0)
ent = (-p*(p+1e-10).log2()).sum(dim=0)

In [None]:
from scipy.special import softmax
softmax(all_similarity, axis=0)
pred_kmeans_t = torch.from_numpy(np.load(f'./pred_clu_clip-{args.dataset_name}-train-{arch}.npy'))
classwise_all_sim = []
for c in pred_kmeans_t.unique():
    subset = (pred_kmeans_t==c)
    classwise_all_sim.append(all_selected_sim[subset, :].mean(dim=0).cpu())
classwise_all_sim = torch.stack(classwise_all_sim, dim=0)
p_all = classwise_all_sim.float().softmax(dim=0)
ent_all = (-p_all*(p_all+1e-10).log2()).sum(dim=0)

In [None]:
ent = torch.nan_to_num(ent)
ent_gt = ent[torch.isin(selected_classifier_ind, set_all_label_voc)]
ent_pred = ent[torch.isin(selected_classifier_ind, set_all_pred_voc)]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(dpi=128)
sns.distplot(ent_gt.numpy(), bins=100)
sns.distplot(ent_pred.numpy(), bins=100)

In [None]:
(ent_gt>2).sum(), (ent_gt>3).sum(), \
(ent_pred<2).sum(), (ent_pred<3).sum()

In [None]:
(ent_gt>2).sum(), (ent_gt>3).sum(), \
(ent_pred<2).sum(), (ent_pred<3).sum()

#### class-wise distribution

collect variables

In [None]:
""" collect variables """
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

# initial topK prediction from CLIP
prob_k = 5
all_gt_label_voc = []
all_gt_label_clu = []
all_inst_topk_ind_voc = []
all_inst_topk_val_voc = []
all_inst_max_pred = []
all_img_idx = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk = prob.topk(k=prob_k, dim=-1)
                
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_inst_topk_ind_voc.append(prob_topk.indices[:, :prob_k+1].cpu())
                all_inst_topk_val_voc.append(prob_topk.values[:, :prob_k+1].cpu())
                all_inst_max_pred.append(prob.argmax(dim=-1).cpu())
                all_img_idx.append(idx_img)
                
        pbar.update(1)


all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_inst_topk_ind_voc = torch.cat(all_inst_topk_ind_voc, dim=0)
all_inst_topk_val_voc = torch.cat(all_inst_topk_val_voc, dim=0)
all_inst_max_pred = torch.cat(all_inst_max_pred, dim=0)
all_img_idx = torch.cat(all_img_idx, dim=0)

# res = torch.load(f'./cache-{args.dataset_name}.pth')
# all_clu_pred = res['all_clu_pred']
# label_voc_kmeans = res['label_voc_kmeans']
# pred_kmeans_t = res['pred_kmeans_t']
# cluster_ind_voc = res['cluster_ind_voc']
# record_pred_kmeans_t = res['record_pred_kmeans_t']
# all_gt_voc = res['all_gt_voc']
# all_label_clu = res['all_label_clu']

In [None]:
args.num_voc = classifier.size(0)

pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-vit_dino.npy'))
pred_kmeans_t = pred_kmeans
history_set_pred = []
history_mapping_assignment_clu = []
for t in range(3):
    record_pred_kmeans_t = pred_kmeans_t
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_inst_topk_ind_voc, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_label_voc)
    pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, all_prob=None)
    set_pred = set(res_ass[1].tolist())
    set_gt = set(all_gt_label_voc.unique().numpy().tolist())
    print('missing label::', len(set_gt - set_pred))
    print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
    history_set_pred.append(set_pred)
    history_mapping_assignment_clu.append(pred_kmeans_t)

##### class-wise feaature space with KNN prototypes

In [None]:
# np.random.seed(1)
c = np.random.choice(all_gt_label_clu.unique().numpy())
select = (all_gt_label_clu==c.item())

all_class_features = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                features = model.visual(images)
                features = features/features.norm(dim=-1, keepdim=True)
                
                # if select[idx_img].sum().item()==0:
                #     pbar.update(1)
                #     continue
                all_class_features.append(features.cpu().numpy())
                
        pbar.update(1)

all_class_features = np.concatenate(all_class_features)

In [None]:
### randomly select a class of features 
np.random.seed(5)
c = np.random.choice(all_gt_label_clu.unique().numpy())
select = (all_gt_label_clu==c.item())
selected_all_class_features = torch.from_numpy(all_class_features)[select]

In [None]:
### compute scd confusing classifier
i = torch.arange(select.size(0))[select]
for x in history_mapping_assignment_clu[:-1]:
    i = x[i]
ind, counts = i.unique(return_counts=True)
all_confusing_classifier_ind = all_clu_pred[ind[counts.argmax()]].topk(k=3).indices

confusing_classifier = classifier[all_confusing_classifier_ind]

In [None]:
### distances among confusing classifiers
triu_confusing_classifier = (confusing_classifier@confusing_classifier.t()).triu(1)
dist_confusing_classifier = triu_confusing_classifier[triu_confusing_classifier.nonzero(as_tuple=True)]
print(dist_confusing_classifier.mean(), dist_confusing_classifier.max(), dist_confusing_classifier.min())

plt.figure()
sns.distplot(dist_confusing_classifier.cpu().numpy(), bins=10)
plt.show()

### distances among gt classifier and confusing classifiers
gt_voc_label = all_gt_label_voc[select].unique()[0].item()
classifier[gt_voc_label].view(1, -1)@confusing_classifier.t()

In [None]:
gt_voc_label in all_confusing_classifier_ind

visualization

In [None]:
c = np.random.choice(all_gt_label_clu.unique().numpy())
select = (all_gt_label_clu==c.item())
selected_all_class_features = all_class_features[select.numpy()]

### compute scd confusing classifier
i = torch.arange(select.size(0))[select]
for x in history_mapping_assignment_clu[:-1]:
    i = x[i]
ind, counts = i.unique(return_counts=True)
all_confusing_classifier_ind = list(set(all_clu_pred[ind[counts.argmax()]].topk(k=5).indices.flatten().numpy()) )
                                    # | set(all_inst_topk_ind_voc[select].unique().numpy()))

all_confusing_classifier = classifier[torch.tensor(all_confusing_classifier_ind), :].cpu().numpy()
all_features_vis = np.concatenate([selected_all_class_features, all_confusing_classifier, classifier[c].view(1, -1).cpu().numpy()])

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, 
            n_iter=1000, 
            # perplexity=10,
            method='exact',
           )
tr_all_features_vis = tsne.fit_transform(all_features_vis)

In [None]:
import matplotlib.pyplot as plt
plt.figure(dpi=128)
plt.scatter(x=tr_all_features_vis[:selected_all_class_features.shape[0], 0], y=tr_all_features_vis[:selected_all_class_features.shape[0], 1], s=5, c='b', alpha=0.6)
plt.scatter(x=tr_all_features_vis[selected_all_class_features.shape[0]:-1, 0], y=tr_all_features_vis[selected_all_class_features.shape[0]:-1, 1], s=5, c='g', alpha=0.8)
plt.scatter(x=tr_all_features_vis[-1, 0], y=tr_all_features_vis[-1, 1], s=8, c='r', alpha=1.0)
plt.show()

#### spatial region visualization

In [None]:
from wordnet_utils import *
from PIL import Image
from pprint import pprint

In [None]:
"""
1. collect variables
upper bound visualization test
"""
args.num_voc = classifier.size(0)
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-vit_dino.npy'))
all_clu_pred = agg_by_pred_cluster(args, pred_kmeans.numpy(), all_inst_topk_ind_voc, voc_size=args.num_voc)

In [None]:
### cluster to instance topk
cluster_topk_ind = all_clu_pred.topk(k=5).indices
inst_topk_ind = cluster_topk_ind[pred_kmeans.long(), ...]
### sample one cluster 
sampled_cluster_idx = torch.randint(low=0, high=all_gt_label_clu.max(), size=[1])
inst_select = all_gt_label_clu==sampled_cluster_idx.item()

sampled_img_idx = np.random.choice(inst_select.nonzero().flatten().numpy(), 10)

In [None]:
idx = 9
model.eval()
with amp_autocast():
    with torch.no_grad():
        image, label_voc, label_clu, _ = dataset_f[sampled_img_idx[idx]]
        image = image.to(args.device)
        
        candidate_voc_labels = inst_topk_ind[sampled_img_idx[idx]]
        candidate_voc_synsets = [ mapping_vocidx_to_synsets(x.item(), vocab)[0].name() for x in candidate_voc_labels ]
        gt_class_label = mapping_vocidx_to_synsets(label_voc, vocab)[0].name()
        
        features = model.visual(image.unsqueeze(0), return_spatial=True)
        features = F.normalize(features, dim=-1)
        
        spatial_sim = 100 * features[0, 1:, :]@classifier[candidate_voc_labels].t()
        spatial_softmax = spatial_sim.softmax(dim=-1)
        spatial_ind = spatial_softmax.argmax(dim=-1).reshape(14, 14)
        spatial_score = spatial_softmax.amax(dim=-1).reshape(14, 14)
        
        spatial_sim_entire = 100 * features[0, 1:, :]@classifier.t()
        spatial_softmax_entire = spatial_sim_entire.softmax(dim=-1)
        spatial_sim_entire = spatial_sim_entire[:, candidate_voc_labels]
        spatial_ind_entire = spatial_softmax_entire.argmax(dim=-1).reshape(14, 14)
        spatial_score_entire = spatial_softmax_entire.amax(dim=-1).reshape(14, 14)
        
        sim_self = 10 * features[0, 0, :]@features[0, 1:, :].t()
        sim_self = sim_self.softmax(dim=-1).reshape(14, 14)
        
        candidate_voc_topk_labels = (100 * features[0, 0, :]@classifier.t()).topk(k=10).indices
        candidate_voc_topk_synsets = [ mapping_vocidx_to_synsets(x.item(), vocab)[0].name() for x in candidate_voc_topk_labels ]
        
        features = model.visual(image.unsqueeze(0), return_spatial=True, with_proj=False)
        features = F.normalize(features, dim=-1)
        
        image, label_voc, label_clu, _ = dataset_r[sampled_img_idx[idx]]
        sim_self = 10 * features[0, 0, :]@features[0, 1:, :].t()
        sim_self = sim_self.softmax(dim=-1).reshape(14, 14)

In [None]:
spatial_ind, spatial_score, spatial_ind_entire, spatial_score_entire

In [None]:
pprint([gt_class_label, candidate_voc_synsets, candidate_voc_topk_synsets], compact=True)

fig, ax = plt.subplots(nrows=2, ncols=4, dpi=128)
ax[0,0].imshow(np.array(image.resize([256,256])))
ax[0,0].axis(False)
ax[0,1].imshow(spatial_score.cpu().numpy())
ax[0,1].axis(False)
ax[0,2].imshow(spatial_ind.cpu().numpy())
ax[0,2].axis(False)
ax[0,3].imshow(spatial_score.cpu().numpy()>0.6)
ax[0,3].axis(False)

ax[1,0].imshow(np.array(image.resize([256,256])))
ax[1,0].axis(False)
ax[1,1].imshow(spatial_score_entire.cpu().numpy())
ax[1,1].axis(False)
ax[1,2].imshow(spatial_ind_entire.cpu().numpy()==gt_voc_label)
ax[1,2].axis(False)
ax[1,3].imshow(spatial_score_entire.cpu().numpy()>0.5)
ax[1,3].axis(False)

plt.show()

plt.figure()
plt.imshow(sim_self.cpu().numpy()>sim_self.flatten().topk(k=10).values[-1].item())
plt.show()

In [None]:
[ mapping_vocidx_to_synsets(x.item(), vocab)[0].definition() for x in candidate_voc_labels ]

In [None]:
s = (features[0, 1:, :]@features[0, 1:, :].t())[(sim_self.cpu().numpy()>sim_self.flatten().topk(k=10).values[-1].item()).flatten()]
for i in range(s.size(0)):
    plt.imshow(s[i].reshape(14, 14).cpu().numpy()>s[i].topk(k=20).values[-1].item())
    plt.show()

### Linguistic Prompt

In [None]:
from wordnet_utils import *
from PIL import Image
from pprint import pprint

In [None]:
def build_mutex_prompt(args, model, pairs):
    """
    return:
        class_embedding: tensor([P x D])
    """
    with open('../templates_small_mutex.json', 'rb') as f:
        data = json.load(f)['imagenet']
        
    all_prompts = []
    for p in pairs:
        prompts = []
        for r in data:
            prompts.append(r.format(p[0], p[1]))
        all_prompts.append(prompts)
    all_prompts = np.array(all_prompts)
    n_pairs, n_templates = all_prompts.shape
    
    # model, preprocess = clip.load(args.arch)
    # model.to(args.device).eval()
    
    texts = tokenize(all_prompts.ravel()).to(args.device) # tokenize
    class_embeddings = model.encode_text(texts) # embed with text encoder
    class_embeddings = class_embeddings.view(n_pairs, n_templates, -1)
    class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
    class_embedding = class_embeddings.mean(dim=1)
    class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
    return class_embedding, all_prompts

In [None]:
model, preprocess = clip.load(args.arch)
model.to(args.device).eval()

res = torch.load(f'./cache-{args.dataset_name}.pth')
all_clu_pred = res['all_clu_pred']
label_voc_kmeans = res['label_voc_kmeans']
pred_kmeans_t = res['pred_kmeans_t']
cluster_ind_voc = res['cluster_ind_voc']
record_pred_kmeans_t = res['record_pred_kmeans_t']
all_gt_voc = res['all_gt_voc']
all_label_clu = res['all_label_clu']

In [None]:
class_topk_pred_voc_ind = all_clu_pred.topk(k=5).indices
class_topk_class_names = np.array([ mapping_vocidx_to_synsets(x, vocab)[0].name().split('.')[0] 
                                   for x in class_topk_pred_voc_ind.flatten().numpy() ]).reshape(-1, 5)
class_mutex_templates = []
for row in class_topk_class_names:
    class_mutex_templates.append([[x, y] for i, x in enumerate(row) for j, y in enumerate(row) if i!=j])

idx_class = 12
class_embedding_pair, prompts_pair = build_mutex_prompt(args, model, class_mutex_templates[idx_class])
classifier = get_classifier(args)
class_embedding_single = classifier[class_topk_pred_voc_ind[idx_class]]

In [None]:
self_sim = class_embedding_single@class_embedding_single.t()
sim_pair = class_embedding_pair.float() @ class_embedding_pair.float().t()

In [None]:
res = []
for i in range(20):
    for j in range(20):
        p1 = [i//4, i%4]
        p2 = [j//4, j%4]
        if (p1[0]==p2[1]) and (p1[1]==p2[0]) and (p1[0]!=p1[1]) and (p2[0]!=p2[1]):
            res.append(sim_pair[i, j].item())
print(np.mean(res))
print(self_sim.triu(1)[self_sim.triu(1).nonzero(as_tuple=True)].mean())

### Feature Extraction

In [None]:
from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]
modelf = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
arch = 'vit_dino'

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f)
if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=1)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=1)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=8, batch_size=args.batch_size, shuffle=False)
dataset_r = get_datasets_oszsl(args, vocab, is_train=True, transform=None, seed=1)

if subset == 'train':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy'))
elif subset == 'val':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-val-{arch}.npy'))
    
model, preprocess = clip.load(args.arch)
if args.clip_checkpoint:
    model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
model.to(args.device).eval()

In [None]:
### collect features and labels
amp_autocast = torch.cuda.amp.autocast
all_gt_label_voc = []
all_gt_label_clu = []
all_img_idx = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_img_idx.append(idx_img)
                all_features.append(logits.cpu().numpy())
                
        pbar.update(1)

all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_img_idx = torch.cat(all_img_idx, dim=0)
all_features = np.concatenate(all_features)

In [None]:
clip_store = {
    'all_gt_label_voc': all_gt_label_voc.cpu(),
    'all_gt_label_clu': all_gt_label_clu.cpu(),
    'all_img_idx': all_img_idx.cpu(),
    'all_features': all_features,
}
torch.save(clip_store, f'./cache/clip_store-{args.dataset_name}.pth')

In [None]:
args.dataset_name = 'make_entity13'


### Classifier Extraction

In [None]:
args.dataset_name = 'make_nonliving26'

In [None]:
from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]
modelf = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
arch = 'vit_dino'

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f)
if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=1)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=1)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=8, batch_size=args.batch_size, shuffle=False)
dataset_r = get_datasets_oszsl(args, vocab, is_train=True, transform=None, seed=1)

if subset == 'train':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-{arch}.npy'))
elif subset == 'val':
    pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-val-{arch}.npy'))
    
model, preprocess = clip.load(args.arch)
if args.clip_checkpoint:
    model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
model.to(args.device).eval()

classifier = get_classifier(args)
classifier = F.normalize(classifier.float(), dim=-1)

collect results

In [None]:
clip_store = torch.load(f'./cache/clip_store-{args.dataset_name}.pth')
all_gt_label_voc = clip_store['all_gt_label_voc']
all_gt_label_clu = clip_store['all_gt_label_clu']
all_img_idx = clip_store['all_img_idx']
all_features = clip_store['all_features']


### collect features and labels
all_knn_classifier_ind = []
all_knn_classifier_val = []
with tqdm(total=len(loader_f)) as pbar:
    for idx_batch, batch in enumerate(loader_f):
        _, label_voc, label_clu, idx_img = batch
        batch_features = torch.from_numpy(all_features[idx_img.numpy()])
        batch_features = F.normalize(batch_features.float(), dim=-1).to(args.device)
        sim = batch_features@classifier.t()
        sim_topk = sim.topk(k=5)
        all_knn_classifier_ind.append(sim_topk.indices.cpu())
        all_knn_classifier_val.append(sim_topk.values.cpu())
        
        pbar.update(1)

all_knn_classifier_ind = torch.cat(all_knn_classifier_ind, dim=0)
all_knn_classifier_val = torch.cat(all_knn_classifier_val, dim=0)

In [None]:
(100*all_knn_classifier_val).softmax(dim=-1)

In [None]:
torch.save(all_knn_classifier_ind, f'./cache/clip_img_knn-{args.dataset_name}.pth')

#### Classifier study

In [None]:
all_knn_classifier_ind = torch.load(f'./cache/clip_img_knn-{args.dataset_name}.pth')

In [None]:
all_knn_classifier_ind

In [None]:
ind, val = all_knn_classifier_ind.unique(return_counts=True)

In [None]:
### load SCD results
res = torch.load(f'./cache-{args.dataset_name}.pth')
all_clu_pred = res['all_clu_pred']
label_voc_kmeans = res['label_voc_kmeans']
pred_kmeans_t = res['pred_kmeans_t']
cluster_ind_voc = res['cluster_ind_voc']
record_pred_kmeans_t = res['record_pred_kmeans_t']
all_gt_voc = res['all_gt_voc']
all_label_clu = res['all_label_clu']

In [None]:
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-clip.npy'))
args.num_voc = classifier.size(0)
all_clu_pred = agg_by_pred_cluster(args, pred_kmeans.numpy(), all_knn_classifier_ind, voc_size=args.num_voc)

In [None]:
N = pred_kmeans.size(0)
K = all_knn_classifier_ind.size(1)
instance_assigned_pred = torch.zeros([N, K]).long()
for c in pred_kmeans.unique():
    select = (pred_kmeans==c)
    instance_assigned_pred[select] = all_clu_pred[c].topk(k=5).indices

print('acc', (instance_assigned_pred[:, 0]==all_gt_voc).float().mean().item())

In [None]:
class_select_idx = 2
class_select = (pred_kmeans==class_select_idx)
study_candidates = all_clu_pred.topk(k=3).indices[class_select_idx]
ind, val = all_gt_voc[class_select].unique(return_counts=True)

In [None]:
from wordnet_utils import *

print([mapping_vocidx_to_synsets(c.item(), vocab)[0].name().split('.')[0] for c in study_candidates])
print([mapping_vocidx_to_synsets(c.item(), vocab)[0].definition() for c in study_candidates])

# descriptions = \
# [
#     "A photo of an aircraft_carrier. It has a long, flat deck, with a large superstructure at the back, and a tall tower at the front. It is usually painted grey, and has multiple aircrafts parked on the deck.",
#     "A photo of a parking_meter. A metal box with a coin slot, a digital display, and a lever or button to activate the timer.",
#     "A photo of an ambulance. A vehicle typically characterized by a bright red or orange color, a siren, and a medical cross symbol.",
# ]
descriptions = \
[
    "A photo of a Winnebago. The Winnebago language is characterized by a distinct set of phonemes and a unique set of grammatical structures.",
    "A photo of an ambulance. A vehicle typically characterized by a bright red or orange color, a siren, and a medical cross symbol.",
    "A photo of police_van. A large, box-shaped vehicle with a distinctive black and white paint job and a barred window in the back."
]
descriptions = \
[
    "A photo of a gobiesox. Small, slender fish with a laterally compressed body and two separate dorsal fins.",
    "A photo of a sock. A foot covering that is typically made of cloth, reaching from the ankle to the knee.",
    "A photo of a athletic_sock. A sock typically made of a lightweight, breathable material with a reinforced heel and toe for added durability.",
]

In [None]:
study_classifier = tokenize(descriptions, truncate=True).to(args.device)
study_classifier = model.encode_text(study_classifier)
study_classifier = study_classifier/study_classifier.norm(dim=-1, keepdim=True)

In [None]:
clip_store = torch.load(f'./cache/clip_store-{args.dataset_name}.pth')
all_gt_label_voc = clip_store['all_gt_label_voc']
all_gt_label_clu = clip_store['all_gt_label_clu']
all_img_idx = clip_store['all_img_idx']
all_features = clip_store['all_features']

In [None]:
class_features = torch.from_numpy(all_features[class_select]).to(args.device)
class_label = all_gt_label_voc[class_select]

In [None]:
ind, val = class_label.unique(return_counts=True)
mapping_vocidx_to_synsets(ind[val.argmax(dim=-1)].item(), vocab)

In [None]:
(study_candidates[(class_features.float() @ study_classifier.float().t()).argmax(dim=-1)]==class_label).float().mean()

In [None]:
(class_label==study_candidates[2]).float().mean()

## Baseline

In [None]:
from wordnet_utils import *

def compute_knn_batch(tensor, k=5, exclude_self=False, batch_size=1024, device='cpu'):
    n_batch = int(np.ceil(tensor.size(0)/batch_size))
    all_topk_ind = []
    all_topk_val = []
    for b in range(n_batch):
        start = b*batch_size
        end = min((b+1)*batch_size, tensor.size(0))
        batch_tensor = tensor[start:end, :].to(device)
        batch_sim = batch_tensor@tensor.t()
        batch_sim_topk = batch_sim.topk(k=k)
        if exclude_self:
            all_topk_ind.append(batch_sim_topk.indices[:, 1:k+1].cpu())
            all_topk_val.append(batch_sim_topk.values[:, 1:k+1].cpu())
        else:
            all_topk_ind.append(batch_sim_topk.indices[:, :k].cpu())
            all_topk_val.append(batch_sim_topk.values[:, :k].cpu())
    all_topk_ind = torch.cat(all_topk_ind, dim=0)
    all_topk_val = torch.cat(all_topk_val, dim=0)
    return all_topk_ind, all_topk_val


def compute_similarity_with_augmented_classifier(features, candidate_names, 
                                                 class_name_key_mapping, all_augmented_classifier, 
                                                 method='ensemble', agg_func=None, return_type='max', **kwargs):
    """ only support single instance 
    Args:
        features: tensor([D])
        candidate_names: list([])
        class_name_key_mapping: {`class_name`: [`synsets`]}
        all_augmented_classifier: {`synset`: tensor([M x D])}
    """
    res_similarity = {}
    for c in candidate_names:
        res_similarity.setdefault(c, [])
        synsets = class_name_key_mapping[c]
        for synset in synsets:
            if method == 'ensemble':
                single_ensembled_classifier = all_augmented_classifier[synset].to(features.device).float().mean(dim=0)
                sim = 100 * features.view(1, -1) @ single_ensembled_classifier.view(1, -1).t()
                res_similarity[c].append(sim.item())
            else:
                raise NotImplementedError()
    if agg_func is not None:
        for k, v in res_similarity.items():
            res_similarity[k] = agg_func(v)
        if return_type=='max':
            max_k = max(res_similarity, key=lambda x: res_similarity[x])
            return max_k
        elif return_type=='topk':
            top_k = heapq.nlargest(kwargs['k'], res_similarity, key=res_similarity.get)
            return top_k
    return res_similarity
    
    
def agg_by_pred_cluster(args, pred_kmeans, all_topk_voc, voc_size):
    """
    Args:
        pred_kmeans: np.array([N])
        all_topk_voc: np.array([N x K])
        voc_size: int
    Returns:
        all_clu_pred: tensor([C x V])
    """
    print('agg_by_pred_cluster')
    all_clu_pred = []
    n_count = []
    for i in np.unique(pred_kmeans):
        selected = (pred_kmeans==i)
        n_count.append( selected.sum().item() )
        counter_voc_ind, counter_val = np.unique((all_topk_voc[selected]).ravel(), return_counts=True)
        # counter_val = counter_val/(n_count+1e-20) # L1 norm
        clu_pred = torch.zeros(args.num_voc) # cluster-wise prob
        clu_pred[torch.from_numpy(counter_voc_ind).long()] = torch.from_numpy(counter_val).float()
        # clu_pred = F.normalize(all_topk_voc[selected].sum(dim=0), dim=-1, p=1)
        all_clu_pred.append(clu_pred)
    all_clu_pred = torch.stack(all_clu_pred, dim=0).cpu()
    n_count = torch.tensor(n_count).cpu()
    
    # all_clu_pred = setdiff_assignment(all_clu_pred)
    
    all_clu_pred = all_clu_pred/(n_count.view(-1, 1) + 1e-20)
    
    print('is mutex assignment::', all_clu_pred.argmax(dim=-1).size(0)==all_clu_pred.argmax(dim=-1).unique().size(0))
    print('assignment collision num::', len(list(filter(lambda x: x>1, Counter(all_clu_pred.argmax(dim=-1).numpy()).values()))))
    return all_clu_pred

def linear_assign(all_clu_pred, pred_kmeans, all_gt_voc):
    print('linear_assign')
    cost_mat = all_clu_pred.cpu().numpy()
    print(f'assignment shape={cost_mat.shape}')
    res_ass = linear_assignment(cost_mat.max() - cost_mat)
    label_voc_kmeans = torch.tensor([res_ass[1][x.item()] for x in pred_kmeans])
    print('instance label acc::', (label_voc_kmeans==all_gt_voc).float().mean().item())
    return label_voc_kmeans, res_ass

def reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, device, all_prob=None, 
                             instance_selected=None, 
                             classifier_selected=None):
    """
    Args:
        classifier_selected: tensor([C2])
    """
    print('reassign_by_pred_cluster')
    amp_autocast = torch.cuda.amp.autocast
    label_voc_kmeans = label_voc_kmeans.to(device)
    if all_prob is None:
        cluster_ind = []
        with tqdm(total=len(loader_f)) as pbar:
            if hasattr(model, 'eval'):
                model.eval()
            for idx_batch, batch in enumerate(loader_f):
                images, label_voc, label_clu, idx_img = batch[:4]
                images = images.to(device)
                if (instance_selected is not None) and ((~instance_selected[idx_img]).all()):
                    continue
                with amp_autocast():
                    with torch.no_grad():
                        if (instance_selected is not None):
                            logits = model.visual(images[instance_selected[idx_img]])
                        else:
                            logits = model.visual(images)
                            
                        logits = logits/logits.norm(dim=-1, keepdim=True)
                        
                        if classifier_selected is not None:
                            similarity = 100 * logits @ classifier[classifier_selected].t()
                            prob = classifier_selected[similarity.softmax(-1)]
                            cluster_ind.append(prob.cpu().argmax(dim=-1))
                        else:
                            similarity = 100 * logits @ classifier.t()
                            prob = similarity.softmax(-1)
                            cluster_ind.append(prob[:, label_voc_kmeans].cpu().argmax(dim=-1))
                pbar.update(1)
        cluster_ind = torch.cat(cluster_ind, dim=0)
    else:
        all_prob = all_prob[:, label_voc_kmeans]
        cluster_ind = all_prob.argmax(dim=-1)
        
    if classifier_selected is not None:
        cluster_ind_voc = classifier_selected[cluster_ind]
    else:
        cluster_ind_voc = label_voc_kmeans[cluster_ind]
    mapping_ind = dict(zip(cluster_ind.unique().numpy(), torch.arange(cluster_ind.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind])
    return cluster_ind, cluster_ind_voc


def row_wise_isin(a, b):
    n, k = b.size()
    results = []
    for i in range(1, k+1):
        res = torch.zeros_like(a).bool()
        for j in range(i):
            res = res | (a==b[:, j])
        results.append(res)
    results = torch.stack(results, dim=1).cpu()
    return results

import openai
def request_gpt(prompt, model_name='text-davinci-003', max_tokens=400, temperature=0.01, best_of=1):
    openai.api_key = "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"
    while 1:
        try:
            response = openai.Completion.create(
              model=model_name,
              prompt=prompt,
              temperature=temperature,
              max_tokens=max_tokens,
              top_p=1,
                best_of=best_of,
              frequency_penalty=0,
              presence_penalty=0,
            )
            break
        except Exception as e:
            print(f'e={e}')
            continue
    return response

def get_prompt_candidate_discrimination(candidates, attributes='color and shape'):
    candidate_string = ''
    candidates = list(map(lambda word:"'"+word+"'", candidates))
    candidate_string = ', '.join(candidates[:-1]) + ', and ' + candidates[-1]
    ### 42
    prompt = f"Precisely describe discriminative visual features of each word in {candidate_string}. Describe the color and texture. In two bullet points, each uses the template \"{'name'}: {'description'}\"."
    return prompt

def get_prompt_candidate_discrimination_v2(candidates, attributes='color and shape'):
    candidate_string = ''
    candidates = list(map(lambda word:"'"+word+"'", candidates))
    candidate_string = ', '.join(candidates[:-1]) + ', and ' + candidates[-1]
    prompt = f"Precisely distinguish discriminative visual features (e.g., {attributes}) of each category in {candidate_string}. Each category is elaborated in separate sentence with template \"category_name: description\". Do not use comparative degree."
    return prompt

def get_prompt_candidate_discrimination_v3(candidates, attributes='color and shape'):
    candidate_string = ''
    candidates = list(map(lambda word:"'"+word+"'", candidates))
    candidate_string = ', '.join(candidates[:-1]) + ', and ' + candidates[-1]
    prompt = f"There are five categories: {candidate_string}. Closely and Precisely mention all discriminative visual differences between each category and others. Each category is described in a caption with template \"category_name: description\". "
    return prompt

def get_prompt_candidate_discrimination_v4(candidates, attributes='color and shape'):
    candidate_string = ''
    candidates = list(map(lambda word:"'"+word+"'", candidates))
    candidate_string = ', '.join(candidates[:-1]) + ', and ' + candidates[-1]
    ### 44.8
    # prompt = f"Please generate visual descriptions based on the following five category nouns (taken from WordNet): {candidate_string}. For each category, provide a separate sentence that highlights representative visual features that distinguish it from the others. Be precise and concise in your descriptions, using vivid language to help bring each category to life. Please make sure to emphasize the visual differences between the categories, and describe each category in a way that highlights how it differs from the others visually."
    prompt = f"Please generate visual descriptions based on the following three category nouns (taken from WordNet): {candidate_string}. For each category, sequentially provide a separate sentence that highlights representative visual features that distinguish it from the others. Be precise and concise in your descriptions, using vivid language to help bring each category to life. Please make sure to emphasize the visual differences between the categories, and describe each category in a way that highlights how it differs from the others. Write in three lines, use template \"{'name'}: {'description'}\"."
    return prompt

def get_prompt_candidate_discrimination_pair(candidates, attributes, index, enable_def=False):
    pair = [ candidates[i] for i in index ]
    candidate_string = ''
    candidate_string = '\'' + '\' and \''.join(pair) + '\''
    prompt = f"Given two category nouns (taken from WordNet): {candidate_string}. For each category, please generate a separate phrase that highlights representative visual features that distinguish it from the others. Be precise and concise in your descriptions, using vivid language to help bring each category to life. Please make sure to only emphasize the visual differences between the categories, and describe each category in a way that highlights how it differs from the others in appearance. Write in two lines, use template \"name: description\"."
    if enable_def:
        pair = [ candidates[i] for i in index ]
        candidate_string = ''
        candidate_string = '\'' + '\' and \''.join(pair) + '\''
        prompt = f"Given two category nouns (taken from WordNet): {candidate_string}. For each category, please generate a separate phrase that highlights representative visual features that distinguish it from the others. Be precise and concise in your descriptions, using vivid language to help bring each category to life. Please make sure to only emphasize the visual differences between the categories, and describe each category in a way that highlights how it differs from the others in appearance. Write in two lines, use template \"name: description\"."
    return prompt

def get_prompt_candidate_discrimination_pair_caption(candidates, attributes, index, enable_def=False):
    pair = [ candidates[i] for i in index ]
    random.shuffle(pair)
    candidate_string = ''
    candidate_string = '\'' + '\' and \''.join(pair) + '\''
    prompt = f"Given two category nouns (taken from WordNet): {candidate_string}. For each category, imagine given one photo, please generate a list of frequently co-occurred related words in image caption dataset in descending order. Write in one line for each category. Only mention the word that the other do not have."
    return prompt

def build_classifier_from_prompt_response(args, model, response):
    with open('../templates_small_description.json', 'rb') as f:
        templates_small = json.load(f)['imagenet']
    all_prompts = []
    for r in response:
        name, description = r.split(': ')[0].lower(), ': '.join(r.split(': ')[1:])
        filled_templates_small = [t.format(name, description) for t in templates_small]
        all_prompts.append(filled_templates_small)
    all_aug_classifiers = []
    for prompt in all_prompts:
        aug_classifier = tokenize(prompt, truncate=True).to(args.device)
        with torch.no_grad():
            aug_classifier = model.encode_text(aug_classifier)
        aug_classifier = aug_classifier/aug_classifier.norm(dim=-1, keepdim=True)
        aug_classifier = aug_classifier.mean(dim=0) ### ensembling
        aug_classifier = aug_classifier/aug_classifier.norm(dim=-1, keepdim=True)
        all_aug_classifiers.append(aug_classifier)
    all_aug_classifiers = torch.stack(all_aug_classifiers, dim=0).to(args.device)
    return all_aug_classifiers
    


In [None]:
from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]
modelf = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
arch = 'vit_dino'

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=1)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=1)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=8, batch_size=args.batch_size, shuffle=False)

### Baseline Classifier Build

In [None]:
""" from MUST """
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

In [None]:
### load prompts and templates
output_fpath = './cache/parsed-wn-gpt3-d-2023_02_26.json'
with open(output_fpath, 'rb') as f:
    all_parse_results = pickle.load(f)

with open('../templates_small.json', 'rb') as f:
    templates_small = json.load(f)['imagenet']

In [None]:
### synset and prompts
text_inputs = {}
for k, v in all_parse_results.items():
    text_inputs[k] = [t.format(k.split('.')[0]) + f' {v}' for t in templates_small]

In [None]:
### class name to synsets mapping
class_name_key_mapping = {}
for k in text_inputs:
    class_name_key_mapping.setdefault(k.split('.')[0], [])
    class_name_key_mapping[k.split('.')[0]].append(k)

In [None]:
### extract prompt embeddings
all_augmented_classifier = {}
with tqdm(total=len(text_inputs)) as pbar:
    for k, v in text_inputs.items():
        aug_classifier = tokenize(v, truncate=True).to(args.device)
        with torch.no_grad():
            aug_classifier = model.encode_text(aug_classifier)
        aug_classifier = aug_classifier/aug_classifier.norm(dim=-1, keepdim=True)
        all_augmented_classifier[k] = aug_classifier.cpu()
        
        pbar.update(1)

In [None]:
data_augmented_classifier = {
    'all_augmented_classifier': all_augmented_classifier,
    'class_name_key_mapping': class_name_key_mapping,
}
torch.save(data_augmented_classifier, './cache/all_aug_prompts_embed-wn-gpt3-d-2023_02_26.pth')

### classifier statistics

In [None]:
### load SCD results
res = torch.load(f'./cache-{args.dataset_name}.pth')
all_clu_pred = res['all_clu_pred']
label_voc_kmeans = res['label_voc_kmeans']
pred_kmeans_t = res['pred_kmeans_t']
cluster_ind_voc = res['cluster_ind_voc']
record_pred_kmeans_t = res['record_pred_kmeans_t']
all_gt_voc = res['all_gt_voc']
all_label_clu = res['all_label_clu']

clip_store = torch.load(f'./cache/clip_store-{args.dataset_name}.pth')
all_gt_label_voc = clip_store['all_gt_label_voc']
all_gt_label_clu = clip_store['all_gt_label_clu']
all_img_idx = clip_store['all_img_idx']
all_features = clip_store['all_features']

In [None]:
classifier = get_classifier(args)

In [None]:
classifier_all_topk_ind, classifier_all_topk_val = compute_knn_batch(classifier.to(args.device), 
                                                                     k=5, exclude_self=True, batch_size=512, 
                                                                     device=args.device)


In [None]:
classifier_all_topk_val.mean(dim=-1).mean(), classifier_all_topk_val.mean(dim=0)

In [None]:
for c in all_gt_label_voc.unique():
    select = (all_gt_label_voc==c)
    subset_features = torch.from_numpy(all_features[select]).to(args.device)
    subset_features = subset_features/subset_features.norm(dim=-1, keepdim=True)
    
    target_classifier = classifier[c].to(args.device).view(1, -1)
    target_classifier = target_classifier/target_classifier.norm(dim=-1, keepdim=True)
    
    sim_intra_img_text = subset_features@target_classifier.t()
    print(sim_intra_img_text.mean())
    
    sim_intra = subset_features@subset_features.t()
    mask = torch.ones_like(sim_intra)
    mask = torch.scatter(mask, 1, torch.arange(mask.size(0), device=args.device).view(-1, 1), 0).bool()
    intra_topk_ind, intra_topk_val = compute_knn_batch(subset_features.to(args.device), 
                                                       k=5, exclude_self=True, batch_size=512, 
                                                       device=args.device)

In [None]:
intra_topk_ind, intra_topk_val = compute_knn_batch(torch.from_numpy(all_features).to(args.device), 
                                                   k=5, exclude_self=True, batch_size=512, 
                                                   device=args.device)

In [None]:
for c in all_gt_label_voc.unique():
    select = (all_gt_label_voc==c)
    class_features = torch.from_numpy(all_features)[select].to(args.device)
    sim = class_features@class_features.t()
    mask = torch.ones_like(sim)
    mask = torch.scatter(mask, 1, torch.arange(mask.size(0), device=args.device).view(-1, 1), 0).bool()
    print(sim[mask].mean(dim=-1).mean())
    # intra_topk_ind, intra_topk_val = compute_knn_batch(torch.from_numpy(all_features)[select].to(args.device), 
    #                                                    k=5, exclude_self=True, batch_size=512, 
    #                                                    device=args.device)

In [None]:
intra_topk_val.mean(dim=-1).mean(), intra_topk_val.mean(dim=0)

### Method test

In [None]:
data_augmented_classifier = torch.load('./cache/all_aug_prompts_embed-wn-gpt3-d-2023_02_26.pth')
all_augmented_classifier = data_augmented_classifier['all_augmented_classifier']
class_name_key_mapping = data_augmented_classifier['class_name_key_mapping']


#### naive method

In [None]:
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 5
all_instance_voc_topk_ind = []
all_gt_label_voc = []
all_gt_label_clu = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_instance_voc_topk_ind.append(prob_topk_ind.cpu().numpy())
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_features.append(logits.cpu().numpy())
        pbar.update(1)

all_instance_voc_topk_ind = np.concatenate(all_instance_voc_topk_ind)
all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_features = np.concatenate(all_features)

In [None]:
pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-clip.npy'))
# all_clu_pred = agg_by_pred_cluster(args, pred_kmeans.numpy(), all_instance_voc_topk_ind, voc_size=args.num_voc)

pred_kmeans_t = pred_kmeans
history_set_pred = []
for t in range(3):
    record_pred_kmeans_t = pred_kmeans_t
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_instance_voc_topk_ind, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_label_voc)
    pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, all_prob=None)
    set_pred = set(res_ass[1].tolist())
    set_gt = set(all_gt_label_voc.unique().numpy().tolist())
    print('missing label::', len(set_gt - set_pred))
    print('cluster acc', cluster_acc(y_true=all_gt_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
    history_set_pred.append(set_pred)
    

In [None]:
for attributes in [
    # 'texture and shape',
    # 'shape and texture',
    # 'color and shape',
    # 'shape and color',
    # 'components and color',
    # 'color and components',
    # 'components and texture',
    # 'texture and components',
    # 'components and shape',
    # 'shape and components',
    # 'texture and color',
    'color and texture',
    # 'components, shape, and color',
    # 'shape, color, and components',
    # 'color, components, and shape',
    # 'components, color, and shape',
    # 'shape, components, and color',
    # 'color, shape, and components',
]:
    topK = 3
    print(attributes)
    cluster_topk_voc_ind = all_clu_pred.topk(k=topK).indices.cpu()
    class_prediction = []
    record_response = []
    all_prompt_response = []
    all_aug_classifiers = []
    with tqdm(total=len(cluster_topk_voc_ind)) as pbar:
        for idx, row in enumerate(cluster_topk_voc_ind):
            candidates = [mapping_vocidx_to_synsets(x, vocab)[0].name().split('.')[0] for x in row.numpy()[:topK]]
            # candidates_def = [mapping_vocidx_to_synsets(x, vocab)[0].name().split('.')[0] for x in row.numpy()[:topK]]
            while 1:
                try:
                    prompt = get_prompt_candidate_discrimination_v4(candidates, attributes)
                    response = request_gpt(prompt, model_name='text-davinci-003')
                    response = list(filter(lambda x: len(x), response['choices'][0]['text'].lstrip('\n').split('\n')))
                    # response = record_response[idx]
                    aug_classifiers = build_classifier_from_prompt_response(args, model, response) ### K x D
                    assert aug_classifiers.size(0)==topK
                    all_prompt_response.append(response)
                    break
                except Exception as e:
                    print(e)
                    print(response)
            
            subset_features = torch.from_numpy(all_features[record_pred_kmeans_t==idx]).to(args.device)
            sim = subset_features.float() @ aug_classifiers.float().t()
            ind, count = sim.argmax(dim=-1).unique(return_counts=True)
            class_prediction.append(cluster_topk_voc_ind[idx, ind[count.argmax()]].item())
            record_response.append(response)
            all_aug_classifiers.append(aug_classifiers)
            pbar.update(1)
    class_prediction = torch.tensor(class_prediction)
    all_aug_classifiers = torch.cat(all_aug_classifiers, dim=0)

    N = pred_kmeans.size(0)
    instance_assigned_pred = torch.zeros(N).long()
    for c in record_pred_kmeans_t.unique():
        select = (record_pred_kmeans_t==c)
        instance_assigned_pred[select] = class_prediction[c]
    print('acc', (instance_assigned_pred==all_gt_label_voc).float().mean().item())
    print('conflict', len(all_gt_label_voc.unique()) - len(instance_assigned_pred.unique()))

In [None]:
topK = 3
cluster_topk_voc_ind = all_clu_pred.topk(k=topK).indices.cpu()
class_prediction = []
record_response = []
all_prompt_response = []
all_aug_classifiers = []
with tqdm(total=len(cluster_topk_voc_ind)) as pbar:
    for idx, row in enumerate(cluster_topk_voc_ind):
        ### parse candidate class names
        candidates = [mapping_vocidx_to_synsets(x, vocab)[0].name().split('.')[0] for x in row.numpy()[:topK]]
        ### get subset features
        subset_features = torch.from_numpy(all_features[record_pred_kmeans_t==idx]).to(args.device)
        ### candidate index flag
        curr_candidate_idx_1 = 0 ### head
        curr_candidate_idx_2 = 1 ### tail
        ### record
        pair_prompts = []
        while curr_candidate_idx_2<topK:
            ### get pair prompts
            prompt = get_prompt_candidate_discrimination_pair_caption(candidates, attributes, 
                                                              index=[curr_candidate_idx_1, curr_candidate_idx_2],
                                                             )
            ### record
            pred_ind = []
            pair_repeat_prompts = []
            ### repeat
            for _ in range(3):
                while 1:
                    try:
                        response = request_gpt(prompt, model_name='text-davinci-003', max_tokens=200, temperature=0.7, best_of=1)
                        response = list(filter(lambda x: len(x), response['choices'][0]['text'].lstrip('\n').split('\n')))
                        aug_classifiers = build_classifier_from_prompt_response(args, model, response) ### K x D
                        ### constraint
                        assert aug_classifiers.size(0)==2
                        ### record
                        pair_repeat_prompts.append(response)
                        all_aug_classifiers.append(aug_classifiers)
                        break
                    except Exception as e:
                        print(e)
                        print(response)
                sim = subset_features.float() @ aug_classifiers.float().t()
                ind, count = sim.argmax(dim=-1).unique(return_counts=True)
                pred_ind.append(ind[count.argmax()].item())
            ind, count = torch.tensor(pred_ind).unique(return_counts=True)
            curr_candidate_idx_1 = ind[count.argmax()] ### winner
            curr_candidate_idx_2 = curr_candidate_idx_2 + 1
            pair_prompts.append(pair_repeat_prompts)
        ### results
        class_prediction.append(cluster_topk_voc_ind[idx, curr_candidate_idx_1].item())
        all_prompt_response.append(pair_prompts)
        pbar.update(1)

class_prediction = torch.tensor(class_prediction)
all_aug_classifiers = torch.cat(all_aug_classifiers, dim=0)

N = pred_kmeans.size(0)
instance_assigned_pred = torch.zeros(N).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred[select] = class_prediction[c]
print('acc', (instance_assigned_pred==all_gt_label_voc).float().mean().item())
print('conflict', len(all_gt_label_voc.unique()) - len(instance_assigned_pred.unique()))


In [None]:
# torch.save({'all_prompt_response': all_prompt_response}, f'./cache/all_prompt_response-{args.dataset_name}-pair.pth')

In [None]:
### upperbound performance of cluster-wise assignment
N = pred_kmeans.size(0)
instance_assigned_pred = torch.zeros(N).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    ind_gt, count_gt = all_gt_label_voc[record_pred_kmeans_t==c].unique(return_counts=True) 
    instance_assigned_pred[select] = ind_gt[count_gt.argmax()]
print('acc', (instance_assigned_pred==all_gt_label_voc).float().mean().item())
print('conflict', len(all_gt_label_voc.unique()) - len(instance_assigned_pred.unique()))
### class recall performance of SCD topK predictions
recall = []
all_gtlbl = []
for idx in range(len(cluster_topk_voc_ind)):
    ind_gt, count_gt = all_gt_label_voc[record_pred_kmeans_t==idx].unique(return_counts=True) 
    gtlbl = ind_gt[count_gt.argmax()]
    recall.append(torch.isin(gtlbl, cluster_topk_voc_ind[idx, :3]).item())
    all_gtlbl.append(gtlbl)

recall = torch.tensor(recall)
all_gtlbl = torch.tensor(all_gtlbl)
recall.float().mean()

entropy partition experiment

In [None]:
N = pred_kmeans.size(0)
instance_assigned_pred = torch.zeros(N).long()
# normalize_sum = lambda x: x/x.sum(dim=-1, keepdim=True)
entropy = lambda p, a=1: -((a*p)*((a*p)+1e-20).log()).sum()
record_true = []
record_false = []
cluster_topk_voc_val = (1 * all_clu_pred.topk(k=topK).values.cpu()).softmax(-1)
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    ind_gt, count_gt = all_gt_label_voc[record_pred_kmeans_t==c].unique(return_counts=True) 
    if (ind_gt[count_gt.argmax()]==cluster_topk_voc_ind[c][0]).item():
        record_true.append(cluster_topk_voc_val[c][0])
        # record_true.append(entropy(cluster_topk_voc_val[c]))
    else:
        record_false.append(cluster_topk_voc_val[c][0])
        # record_false.append(entropy(cluster_topk_voc_val[c]))
    # break
    
plt.figure()
sns.distplot(torch.tensor(record_true).numpy(), bins=100)
sns.distplot(torch.tensor(record_false).numpy(), bins=100)
plt.legend(['true', 'false'])
plt.show()

In [None]:
sns.distplot(torch.tensor(record_true).numpy(), bins=100)
sns.distplot(torch.tensor(record_false).numpy(), bins=100)

In [None]:
sns.distplot(cluster_topk_voc_val[:, 0].numpy())

In [None]:
# with open(f'./all_prompt_response_v4-{args.dataset_name}.pkl', 'wb') as f:
#     pickle.dump(all_prompt_response, f)

In [None]:
# with open(f'./all_prompt_response_v4-{args.dataset_name}.pkl', 'rb') as f:
#     all_prompt_response = pickle.load(f)

In [None]:
false_pred_idx_list = []
for idx in range(len(cluster_topk_voc_ind)):
    ind_gt, count_gt = all_gt_label_voc[record_pred_kmeans_t==idx].unique(return_counts=True) 
    correct_pred = (ind_gt[count_gt.argmax()] == class_prediction[idx]).item()
    if not correct_pred:
        false_pred_idx_list.append(idx)

In [None]:
# np.random.seed(1)
idx = np.random.choice(false_pred_idx_list)
# idx = idx + 1
response = all_prompt_response[idx]
aug_classifiers = build_classifier_from_prompt_response(args, model, response)
subset_features = torch.from_numpy(all_features[record_pred_kmeans_t==idx]).to(args.device)
sim = subset_features.float() @ aug_classifiers.float().t()
ind, count = sim.argmax(dim=-1).unique(return_counts=True)
ind_gt, count_gt = all_gt_label_voc[record_pred_kmeans_t==idx].unique(return_counts=True) 

print(all_prompt_response[idx])
print(f'ind={ind}, count={count}')
print(f'cand={cluster_topk_voc_ind[idx]}, prev_pred={class_prediction[idx]}')
print(f'pred={cluster_topk_voc_ind[idx,ind[count.argmax()]]}, gt={ind_gt[count_gt.argmax()]}')
print('synset=', mapping_vocidx_to_synsets(ind_gt[count_gt.argmax()].item(), vocab))

In [None]:
attributes = [
    'components, shape, and color',
    'shape, color, and components',
    'color, components, and shape',
    'components, color, and shape',
    'shape, components, and color',
    'color, shape, and components',
]
print(attributes)
cluster_topk_voc_ind = all_clu_pred.topk(k=5).indices.cpu()
class_prediction = []
record_response = []
all_prompt_response = []
with tqdm(total=len(cluster_topk_voc_ind)) as pbar:
    for idx, row in enumerate(cluster_topk_voc_ind):
        candidates = [mapping_vocidx_to_synsets(x, vocab)[0].name().split('.')[0] for x in row.numpy()[:5]]
        ensembled_response = []
        ensembled_classifier = []
        for a in attributes:
            prompt = get_prompt_candidate_discrimination(candidates, attributes)
            response = request_gpt(prompt, model_name='text-davinci-003')
            response = response['choices'][0]['text'].lstrip('\n\n').split('\n\n')
            aug_classifiers = build_classifier_from_prompt_response(args, model, response) ### K x D
            ensembled_classifier.append(aug_classifiers)
            ensembled_response.append(response)
        all_prompt_response.append(ensembled_response)
        
        ### similarity average ensemble
        subset_features = torch.from_numpy(all_features[record_pred_kmeans_t==idx]).float()#.to(args.device).float()
        a_c = aug_classifiers.float().t().cpu()
        ensembled_sim = []
        for aug_classifiers in ensembled_classifier:
            sim = 100 * subset_features @ a_c
            ensembled_sim.append(sim)
        ensembled_sim = torch.stack(ensembled_sim, dim=0).mean(dim=0) ### average
        ind, count = ensembled_sim.argmax(dim=-1).unique(return_counts=True)
        class_prediction.append(cluster_topk_voc_ind[idx, ind[count.argmax()]].item())
        record_response.append(response)
        pbar.update(1)
class_prediction = torch.tensor(class_prediction)

N = pred_kmeans.size(0)
instance_assigned_pred = torch.zeros(N).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred[select] = class_prediction[c]
print('acc', (instance_assigned_pred==all_gt_label_voc).float().mean().item())
print('conflict', len(all_gt_label_voc.unique()) - len(instance_assigned_pred.unique()))

In [None]:
with open(f./cache/request/equest/ensmbled_prompts-{args.dataset_name}.pkl', 'wb') as f:
    pickle.dump(all_prompt_response, f)

In [None]:
all_prompt_response

#### iterative

#### misc

In [None]:
N = pred_kmeans.size(0)
instance_assigned_pred = torch.zeros(N).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred[select] = class_prediction[c]
print('acc', (instance_assigned_pred==all_gt_label_voc).float().mean().item())

In [None]:
N = pred_kmeans.size(0)
instance_assigned_pred_scd = torch.zeros(N).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred_scd[select] = all_clu_pred[c].argmax(dim=-1)
print('acc', (instance_assigned_pred_scd==all_gt_label_voc).float().mean().item())

In [None]:
cond_fn = (instance_assigned_pred!=all_gt_label_voc) & (instance_assigned_pred_scd==all_gt_label_voc)


In [None]:
print('flip rate', (class_prediction == all_clu_pred.argmax(dim=-1)).float().mean())
scd_names = np.array([mapping_vocidx_to_synsets(x.item(), vocab)[0].name().split('.')[0] for x in all_clu_pred.argmax(dim=-1)])
updated_names = np.array([mapping_vocidx_to_synsets(x.item(), vocab)[0].name().split('.')[0] for x in class_prediction])

np.array(record_response)[class_prediction!=all_clu_pred.argmax(dim=-1)].tolist(), \
updated_names[(scd_names!=updated_names)], scd_names[(scd_names!=updated_names)]

In [None]:
print('acc', (all_clu_pred.argmax(dim=-1)==all_gt_label_voc).float().mean().item())

In [None]:
N = pred_kmeans.size(0)
K = prob_k
instance_assigned_pred = torch.zeros([N, K]).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred[select] = all_clu_pred[c].topk(k=5).indices

print('acc', (instance_assigned_pred[:, 0]==all_gt_label_voc).float().mean().item())
print('acc instance topk', (torch.from_numpy(all_instance_voc_topk_ind)[:, 0]==all_gt_label_voc).float().mean().item())
retrieved_labels = instance_assigned_pred[instance_assigned_pred[:, 0]==all_gt_label_voc][:, 0].unique().numpy()
pred_labels = instance_assigned_pred[:, 0].unique().numpy()
gt_labels = all_gt_label_voc.unique().numpy()
print(f'missing label of retrieval:: {len(set(gt_labels) - set(retrieved_labels))}')
print(f'missing label of predict:: {len(set(gt_labels) - set(pred_labels))}')
for k in range(1, K):
    retrieved_labels_topk = instance_assigned_pred[instance_assigned_pred[:, 0]==all_gt_label_voc][:, :k].flatten().unique().numpy()
    print(f'missing label of retieval at k={k}:: {len(set(gt_labels) - set(retrieved_labels_topk))}')

In [None]:
isin_instance_topk = row_wise_isin(all_gt_label_voc, torch.from_numpy(all_instance_voc_topk_ind))

isin_instance_topk.float().mean(dim=0)

In [None]:
isin_instance_topk = row_wise_isin(all_gt_label_voc, instance_assigned_pred)
isin_instance_topk.float().mean(dim=0)

instance based

In [None]:
all_instance_voc_topk_ind

candidate_names = 
compute_similarity_with_augmented_classifier(features, candidate_names, 
                                             class_name_key_mapping, all_augmented_classifier, 
                                             method='ensemble', agg_func=max)

cluster based

In [None]:

N = record_pred_kmeans_t.size(0)
K = prob_k
instance_assigned_pred = torch.zeros([N, K]).long()
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    instance_assigned_pred[select] = all_clu_pred[c].topk(k=10).indices

In [None]:

all_sample_ind = []
with tqdm(total=all_features.shape[0]) as pbar:
    for i in range(all_features.shape[0]):
        img_feature = torch.from_numpy(all_features[i])#.to(args.device)
        candidate_names = [mapping_vocidx_to_synsets(x.item(), vocab)[0].name().split('.')[0] 
                           for x in instance_assigned_pred[i, :3]]
        max_k = \
            compute_similarity_with_augmented_classifier(img_feature, candidate_names, 
                                                         class_name_key_mapping, all_augmented_classifier, 
                                                         method='ensemble', agg_func=np.mean, return_indices=True)
        idx_max = candidate_names.index(max_k)
        
        all_sample_ind.append(idx_max)
        
        pbar.update(1)

all_sample_ind = torch.tensor(all_sample_ind)
baseline_instance_pred = instance_assigned_pred.gather(1, all_sample_ind.view(-1, 1))

In [None]:
(baseline_instance_pred.flatten() == all_gt_label_voc).float().mean()

In [None]:
(instance_assigned_pred[:, 0] == all_gt_label_voc).float().mean()

In [None]:
candidate_names, max_k, mapping_vocidx_to_synsets(all_gt_label_voc[i].item(), vocab), \
baseline_instance_pred[-1], all_gt_label_voc[-1]

In [None]:
baseline_instance_pred = baseline_instance_pred.flatten()
baseline_instance_pred_clu = torch.zeros_like(baseline_instance_pred)
for c in record_pred_kmeans_t.unique():
    val, count = baseline_instance_pred[record_pred_kmeans_t==c].unique(return_counts=True)
    baseline_instance_pred_clu[record_pred_kmeans_t==c] = val[count.argmax(dim=-1)].item()

In [None]:
(baseline_instance_pred_clu == all_gt_label_voc).float().mean()

#### reranked KNN

In [None]:
""" compute top-20 predictions """
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 5
all_instance_voc_topk_ind = []
all_gt_label_voc = []
all_gt_label_clu = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_instance_voc_topk_ind.append(prob_topk_ind.cpu().numpy())
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_features.append(logits.cpu().numpy())
        pbar.update(1)

all_instance_voc_topk_ind = np.concatenate(all_instance_voc_topk_ind)
all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_features = np.concatenate(all_features)

In [None]:
K = 3
N = all_instance_voc_topk_ind.shape[0]
all_instance_voc_topk_ind_rerank = torch.zeros(N, K).long()
with tqdm(total=N) as pbar:
    for i in range(all_features.shape[0]):
        feature = torch.from_numpy(all_features[i])
        candidate_names = [ mapping_vocidx_to_synsets(x.item(), vocab)[0].name().split('.')[0] for x in all_instance_voc_topk_ind[i, :3] ]
        topk_candidate_voc_ind = \
        compute_similarity_with_augmented_classifier(feature, candidate_names, 
                                                     class_name_key_mapping, all_augmented_classifier, 
                                                     method='ensemble', agg_func=max, return_type='topk', k=K)
        all_instance_voc_topk_ind_rerank[i] = torch.tensor([ vocab.mapping_names_idx[x] for x in topk_candidate_voc_ind ])
        pbar.update(1)
    

In [None]:
i = np.random.permutation(N)[0]
feature = torch.from_numpy(all_features[i])
candidate_names = [ mapping_vocidx_to_synsets(x.item(), vocab)[0].name().split('.')[0] for x in all_instance_voc_topk_ind[i] ]
topk_candidate_voc_ind = \
compute_similarity_with_augmented_classifier(feature, candidate_names, 
                                             class_name_key_mapping, all_augmented_classifier, 
                                             method='ensemble', agg_func=max, return_type='topk', k=5)
all_instance_voc_topk_ind_rerank[i] = torch.tensor([ vocab.mapping_names_idx[x] for x in topk_candidate_voc_ind ])

In [None]:
(all_instance_voc_topk_ind_rerank[:, 0]==all_gt_label_voc).float().mean(), \
(torch.from_numpy(all_instance_voc_topk_ind)[:, 0]==all_gt_label_voc).float().mean()

In [None]:
q, d = compute_similarity_with_augmented_classifier(torch.rand(512), ['cat', 'dog', 'frog', 'shirt', 'man', 'swarm', 'liquid'], 
                                                 class_name_key_mapping, all_augmented_classifier, 
                                                 method='ensemble', agg_func=max, return_type='topk', k=5)

In [None]:
output_fpath = './cache/parsed-wn-gpt3-d-2023_02_26.json'
with open(output_fpath, 'rb') as f:
    all_parse_results = pickle.load(f)

In [None]:
mapping_vocidx_to_synsets(all_gt_label_voc[i].item(), vocab)[0].name(), \
topk_candidate_voc_ind, \
[mapping_vocidx_to_synsets(x, vocab)[0].name() for x in all_instance_voc_topk_ind[i][:5]], \
[ all_parse_results[mapping_vocidx_to_synsets(x, vocab)[0].name()] for x in all_instance_voc_topk_ind[i][:5] ]

#### basic observations

In [None]:
""" compute top-20 predictions """
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 5
all_instance_voc_topk_ind = []
all_gt_label_voc = []
all_gt_label_clu = []
all_features = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_instance_voc_topk_ind.append(prob_topk_ind.cpu().numpy())
                all_gt_label_voc.append(label_voc)
                all_gt_label_clu.append(label_clu)
                all_features.append(logits.cpu().numpy())
        pbar.update(1)

all_instance_voc_topk_ind = np.concatenate(all_instance_voc_topk_ind)
all_gt_label_voc = torch.cat(all_gt_label_voc, dim=0)
all_gt_label_clu = torch.cat(all_gt_label_clu, dim=0)
all_features = np.concatenate(all_features)