1. specify `sys.path.append` as your project directory path

2. change `Config.dataset` as different dataset names

3. set `Config.f_classifier` as the precomputed vocabulary classifier file path

3. set `Config.arch='ViT-L/14'` and `f_classifier='./cache/vocabulary_classifier_L.pth'` for ViT-L architecture

In [None]:
import sys
sys.path.append('/home/sheng/sheng-eatamath/S3A/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import scipy.io
from nltk.corpus import wordnet as wn

from sklearn.cluster import KMeans, MiniBatchKMeans
from my_util_package.evaluate import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.datasets import ImageFolder

import model as clip
from model import tokenize
from data.build_dataset import build_transform
from data.imagenet_datasets import get_datasets_rzsc
from data.vocab import get_vocab, Vocab, get_classifier

    
class Config:
    device = 'cuda:1'
    arch = 'ViT-B/16'
    vocab_name = 'in21k' ### in21k for ViT-B/16, in21k-L for ViT-H/14
    ### dataset name
    dataset = 'make_living17'
    n_sampled_classes = 100 ### set num of sampled classes for ImageNet-100
    input_size = 224
    batch_size = 256
    clip_checkpoint = None ### whether to use clip checkpoint
    f_classifier = './cache/vocabulary_classifier.pth' ### precomputed 21k CLIP vocabulary classifier
    templates_name = 'templates' ### CLIP template file name
    n_repeat = 3 ### number of prompt repeatition
    seed = 0
    image_mean = (0.48145466, 0.4578275, 0.40821073)
    image_std = (0.26862954, 0.26130258, 0.27577711)
    
    
args = Config()

def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

templates = load_templates(args)
vocab = get_vocab()
mapping_ids_synset = lambda x: wn.synset_from_pos_and_offset('n', int(x[1:]))
mapping_vocidx_to_synsets = lambda anchor, vocab: [mapping_ids_synset(vocab.mapping_global_idx_ids[t]) for t in vocab.mapping_idx_global_idx[anchor]]

In [None]:
def load_clip2(args):
    model = clip.load(args.arch, device=args.device)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model_ema'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model

transform_val = build_transform(is_train=False, args=args, train_config=None)
dataset = get_datasets_rzsc(args, vocab, is_train=True, transform=transform_val, seed=0)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)
print('dataset size', len(dataset))
model = load_clip2(args)

    

#### CVPR method

##### clustering & voting

In [None]:
subset = ['train', 'val'][0]
transform_f = build_transform(is_train=False, args=args, train_config=None)
if subset == 'train':
    dataset_f = get_datasets_rzsc(args, vocab, is_train=True, transform=transform_f, seed=0)
elif subset == 'val':
    dataset_f = get_datasets_rzsc(args, vocab, is_train=False, transform=transform_f, seed=0)
args.nb_classes = dataset_f.num_classes


In [None]:
def agg_by_pred_cluster(args, pred_kmeans, all_topk_voc, voc_size):
    """ aggregate topk predictions of each sample within each partition
    Args:
        pred_kmeans: np.array([N]): previous partition indice of each sample
        all_topk_voc: np.array([N x K]): topk prediction indice of each sample
        voc_size: int: total vocabulary size
    Returns:
        all_clu_pred: tensor([C x V]): class-vocabulary probability matrix
    """
    print('agg_by_pred_cluster')
    all_clu_pred = []
    n_count = []
    for i in np.unique(pred_kmeans):
        selected = (pred_kmeans==i)
        n_count.append( selected.sum().item() )
        counter_voc_ind, counter_val = np.unique((all_topk_voc[selected, :]).ravel(), return_counts=True)
        clu_pred = torch.zeros(args.num_voc) # cluster-wise prob
        clu_pred[torch.from_numpy(counter_voc_ind).long()] = torch.from_numpy(counter_val).float()
        all_clu_pred.append(clu_pred)
    all_clu_pred = torch.stack(all_clu_pred, dim=0).cpu()
    n_count = torch.tensor(n_count).cpu()
    all_clu_pred = all_clu_pred/(n_count.view(-1, 1) + 1e-20)
    print('is mutex assignment::', all_clu_pred.argmax(dim=-1).size(0)==all_clu_pred.argmax(dim=-1).unique().size(0))
    print('assignment collision num::', len(list(filter(lambda x: x>1, Counter(all_clu_pred.argmax(dim=-1).numpy()).values()))))
    return all_clu_pred

def linear_assign(all_clu_pred, pred_kmeans, all_gt_voc, return_results=False):
    """ align each cluster with a category name by solving linear assignment problem
    Args:
        all_clu_pred: tensor([C x V]): class-vocabulary probability matrix
        pred_kmeans: np.array([N]): previous partition indice of each sample
        all_gt_voc: tensor([N]): ground-truth label of each sample
    Returns:
        label_voc_kmeans: tensor([N]): prediction of vocabulary indices of each sample 
        res_ass: np.array([C, 2]): assignment from cluster indices to vocabulary indices
        inst_acc: float: instance-wise accuracy after assignment
    """
    print('linear_assign')
    cost_mat = all_clu_pred.cpu().numpy()
    print(f'assignment shape={cost_mat.shape}')
    res_ass = linear_assignment(cost_mat.max() - cost_mat)
    label_voc_kmeans = torch.tensor([res_ass[1][x.item()] for x in pred_kmeans])
    inst_acc = (label_voc_kmeans==all_gt_voc).float().mean().item()
    print('instance label acc::', inst_acc)
    if return_results:
        return label_voc_kmeans, res_ass, inst_acc
    return label_voc_kmeans, res_ass

def reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, device, 
                             preextracted_vfeatures=None):
    """ given assigned vocab label set @set(label_voc_kmeans), reassign each sample to its nearest category prototype
    Args:
        label_voc_kmeans: cluster-assigned label on vocab
        ...
        preextracted_vfeatures: np.array([N x D]): if the visual features are pre-computed to save time for multiple inference
    Returns:
        cluster_ind: tensor([N]): re-ordered cluster assignment (1, ..., |C|)
        cluster_ind_voc: tensor([N]): cluster assignment indiced by vocab (1, ..., |V|)
    """
    print('reassign_by_pred_cluster')
    amp_autocast = torch.cuda.amp.autocast
    label_voc_kmeans = label_voc_kmeans.to(device).unique()
    cluster_ind = []
    with tqdm(total=len(loader_f)) as pbar:
        if hasattr(model, 'eval'):
            model.eval()
        if preextracted_vfeatures is not None: ### speed-up with pre-extracted visual features by avoiding looping iteratively
            N = len(loader_f.dataset)
            batch_size = min(N, 10000)
            indices = np.array_split(np.arange(N), N//batch_size)
            with torch.no_grad():
                for group in indices:
                    logits = torch.from_numpy(preextracted_vfeatures[group]).float()
                    logits = logits/logits.norm(dim=-1, keepdim=True)
                    similarity = 100 * logits@classifier.t().cpu()
                    prob = similarity.softmax(-1)
                    cluster_ind.append(deepcopy(prob[:, label_voc_kmeans.cpu()].argmax(dim=-1)))
        else:
            for idx_batch, batch in enumerate(loader_f):
                images, label_voc, label_clu, idx_img = batch[:4]
                images = images.to(device)
                with amp_autocast():
                    with torch.no_grad():
                        logits = model.ema.extract_vfeatures(images)
                        logits = logits/logits.norm(dim=-1, keepdim=True)
                        similarity = 100 * logits @ classifier.t()
                        prob = similarity.softmax(-1)
                        cluster_ind.append(deepcopy(prob[:, label_voc_kmeans].cpu().argmax(dim=-1)))
                pbar.update(1)
    cluster_ind = torch.cat(cluster_ind, dim=0)
    cluster_ind_voc = label_voc_kmeans[cluster_ind]
    mapping_ind = dict(zip(cluster_ind.unique().numpy(), torch.arange(cluster_ind.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind])
    return cluster_ind, cluster_ind_voc


In [None]:
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=4, batch_size=args.batch_size, shuffle=False)
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 1
all_topk_voc = []
all_gt_voc = []
all_label_clu = []
all_vfeatures = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual.extract_features(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_topk_voc.append(deepcopy(prob_topk_ind.cpu().numpy()))
                all_gt_voc.append(deepcopy(label_voc))
                all_label_clu.append(deepcopy(label_clu))
                all_vfeatures.append(deepcopy(logits.cpu().numpy()))
        pbar.update(1)

all_topk_voc = np.concatenate(all_topk_voc)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)
all_vfeatures = np.concatenate(all_vfeatures)

In [None]:
pred_kmeans = torch.from_numpy(np.load(f'./cache/cluster/kmeans-{args.dataset}.npy'))
pred_kmeans_t = pred_kmeans
history_set_pred = []
for t in range(3):
    record_pred_kmeans_t = pred_kmeans_t
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
    pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, preextracted_vfeatures=all_vfeatures)
    
    ### evaluation
    print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))


##### prompting

In [None]:
import openai
def openai_chatgpt_post(content, parameters={'temperature': 0.7}):
    openai.api_key = "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"
    while 1:
        try:
            completion = openai.ChatCompletion.create(
              model="gpt-3.5-turbo-0301",
              messages=[
                {"role": "user", "content": content},
              ],
            **parameters,
            )
            result = completion['choices'][0]['message']['content']
            break
        except Exception as e:
            print(e)
    return result

""" topK candidate generation """
all_clu_gt_voc = []
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    all_clu_gt_voc.append(all_gt_voc[select].mode().values)

all_clu_gt_voc = torch.tensor(all_clu_gt_voc)
k_1 = 3
topk_all_clu_pred = all_clu_pred.topk(k=k_1).indices
cluster_is_correct = torch.zeros(topk_all_clu_pred.size(0)).bool()
for i in range(k_1):
    cluster_is_correct |= (topk_all_clu_pred[:, i]==all_clu_gt_voc)
print(f'recall@{k_1} = {cluster_is_correct.float().mean()}')

""" gather candidate concepts (synsets) with name and definitions """
to_name = lambda x: [ s.name() + ': ' + s.definition() for s in x ]
cluster_row_synsets = []
for row in topk_all_clu_pred:
    row_synsets = [to_name(mapping_vocidx_to_synsets(voc_idx.item(), vocab)) for voc_idx in row]
    cluster_row_synsets.append(row_synsets)

""" generate concept lists with candidate concepts """
concept_request = []
for row in cluster_row_synsets:
    ccpts = reduce(lambda x, y: x+y, row)
    ccpts = list(map(lambda x: "'"+x+".'", ccpts))
    ccpts = ', '.join(ccpts)
    concept_request.append(ccpts)
    
""" generate prompts with concept lists """
# template = lambda concept_list: "Given visual concepts: "+ concept_list + "Please list all possible visual descriptive phrases for each visual concept without duplication. Please list in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
template2 = lambda concept_list: """Q: 'Given visual concepts: \'african_grey.n.01: commonly domesticated grey parrot with red-and-black tail and white face; native to equatorial Africa.\', \'parrot.n.01: usually brightly colored zygodactyl tropical birds with short hooked beaks and the ability to mimic sounds.\', \'parrot.n.02: a copycat who does not understand the words or acts being imitated.\', \'cockatoo.n.01: white or light-colored crested parrot of the Australian region; often kept as cage birds.\' Goal: to discriminate these visual concepts in a photo. Please list all possible visual descriptive phrases for each visual concept in bullet points. Please list in the format "{concept name}: {all phrases deliminated by semicolons}." for each concept. No duplication.'\nA:african_grey.n.01: grey parrot;red-and-black tail; white face; pale yellow eye; black beak; zygodactyl feet; intelligent and expressive eyes; predominantly grey plumage; native to equatorial Africa\nparrot.n.01: usually brightly colored; zygodactyl; tropical birds; short hooked beaks; broad range of colors and patterns\nparrot.n.01: copycat; imitation of words or actions without understanding; non-visual concept; metaphorical representation of mimicking behavior\ncockatoo.n.01: white or light-colored; often with yellow, pink, or red highlights in the crest or tail; crested parrot; strong, curved beak; zygodactyl feet; native to the Australian region; often kept as cage birds; expressive crest that can be raised or lowered; large, broad wings; known for loud and raucous calls.\n\nGiven visual concepts: """+ concept_list + "Goal: to discriminate these visual concepts in a photo. Please list all possible visual descriptive phrases for each visual concept in bullet points. Please list in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept."
    
template_in_use = template2
concept_templates = []
for row in concept_request:
    concept_templates.append(template_in_use(row))
    
n_repeat = args.n_repeat

In [None]:
""" collect chatgpt results """
all_chatgpt_res = [[] for _ in range(n_repeat)]
with tqdm(total=len(concept_templates)*n_repeat) as pbar:
    for i in range(n_repeat):
        for row in concept_templates:
            all_chatgpt_res[i].append(openai_chatgpt_post(row))

            pbar.update(1)

# with open(f'./cache/openai/prompting-data={args.dataset}.pkl', 'rb') as f:
#     all_chatgpt_res = pickle.load(f)

In [None]:
while 1:
    """ response integrity check """
    while 1:
        invalid_res = []
        for i in range(n_repeat):
            for j, row in enumerate(all_chatgpt_res[i]):
                extract_synsetid = lambda r: list(map(lambda x: x.split(': ')[0], r))
                remove_space = lambda r: list(filter(lambda x: len(x), r))
                synsets = extract_synsetid(remove_space(row.lower().replace('\n\n', '\n').split('\n')))
                gt_synsets = extract_synsetid(reduce(lambda x,y: x+y, cluster_row_synsets[j]))
                try:
                    start_idx = [ synsets[k].find(s) for k, s in enumerate(gt_synsets) ]
                    synsets = [ synsets[k][start_idx[k]:start_idx[k]+len(gt_synsets[k])] for k, s in enumerate(synsets) ]
                    assert set(synsets)==set(gt_synsets)
                except Exception as e: ### missing information, to re-prompt
                    print(i, j)
                    print(synsets, gt_synsets)
                    invalid_res.append((i,j))

        if len(invalid_res)==0:
            break
        else:
            ### re-prompt
            for i,j in invalid_res:
                print(f'repair {(i,j)}')
                content = concept_templates[j]
                while 1:
                    try:
                        res = openai_chatgpt_post(content)
                        break
                    except Exception as e:
                        print(e)
                all_chatgpt_res[i][j] = res



    """ extract key-value-list from @chatgpt-res """
    extracted_chatgpt_res = []
    for j, row in enumerate(all_chatgpt_res[0]):
        chatgpt_row_res = {}
        extract_synsetid = lambda r: list(map(lambda x: x.split(': ')[0], r))
        remove_space = lambda r: list(filter(lambda x: len(x), r))
        extract_synnames = lambda r: list(map(lambda x: x.split(': ')[1].split('; '), r))
        for i in range(n_repeat):
            row = all_chatgpt_res[i][j]
            row_data = remove_space(row.lower().replace('\n\n', '\n').split('\n'))
            synsets = extract_synsetid(row_data)
            synnames = extract_synnames(row_data)
            gt_synsets = extract_synsetid(reduce(lambda x,y: x+y, cluster_row_synsets[j]))
            start_idx = [ synsets[k].find(s) for k, s in enumerate(gt_synsets) ]
            synsets = [ synsets[k][start_idx[k]:start_idx[k]+len(gt_synsets[k])] for k, s in enumerate(synsets) ]
            for idx_s, s in enumerate(synsets):
                chatgpt_row_res.setdefault(s, [])
                chatgpt_row_res[s].append( remove_space(synnames[idx_s]) )
        extracted_chatgpt_res.append(chatgpt_row_res)

    """ deduplication """
    use_dedup = True
    all_candidates = []
    all_candidates_set = []
    for i, row in enumerate(extracted_chatgpt_res):
        ### flatten multiple results
        row_all_synset_names = list(map(lambda x: x.split('.')[0], row.keys()))
        row_candidates = {}
        row_candidates_set = {}
        for k, v in row.items():
            candidates = list(reduce(lambda x, y: x+y, v))
            candidates = [c for c in candidates if c not in row_all_synset_names] ### remove competing synset names
            set_candidates = set(candidates)
            k = k.split('.')[0] ### key synset name
            row_candidates.setdefault(k, [])
            row_candidates_set.setdefault(k, set([]))
            row_candidates[k].extend(candidates)
            row_candidates_set[k] |= set_candidates
        ### collect duplicates
        duplicates = set()
        for k1, v1 in row.items():
            k1 = k1.split('.')[0]
            for k2, v2 in row.items():
                k2 = k2.split('.')[0]
                if k1!=k2:
                    duplicates |= row_candidates_set[k1]&row_candidates_set[k2]
        ### remove duplication with synset-names (keys)
        row_candidates_update = {}
        row_candidates_set_update = {}
        for k1, v1 in row.items():
            k1 = k1.split('.')[0]
            for k2, v2 in row.items():
                k2 = k2.split('.')[0]
            row_candidates_set_update[k1] = row_candidates_set[k1] - duplicates if use_dedup else row_candidates_set[k1]
            row_candidates_update[k1] = [item for item in row_candidates[k1] if item not in duplicates ] if row_candidates_set[k1] else row_candidates[k1]

        all_candidates.append(row_candidates_update)
        all_candidates_set.append(row_candidates_set_update)


    ### check non-empty
    empty_list = []
    for i, line in enumerate(all_candidates_set):
        for k, v in line.items():
            if len(v)==0:
                for j in range(n_repeat):
                    empty_list.append(j)
                    print(f'repair {i} {j}')
                    while 1:
                        try:
                            res = openai_chatgpt_post(concept_templates[i])
                            break
                        except Exception as e:
                            print(e)
                    all_chatgpt_res[j][i] = res

    if len(empty_list)==0:
        break

In [None]:
with open(f'./cache/openai/prompting-data={args.dataset}.pkl', 'wb') as f:
    pickle.dump(all_chatgpt_res, f)

In [None]:
data = \
{
    'all_candidates': all_candidates,
    'all_candidates_set': all_candidates_set,
}

""" counter sorting """
all_candidates = data['all_candidates']
all_counter_candidates = []
all_number_candidates = []
for row in all_candidates:
    row_counter = {}
    total_num = 0
    for k, v in row.items():
        ct = Counter(v)
        row_counter[k] = OrderedDict(sorted(ct.items())) ### order key
        total_num += sum(ct.values())
    all_counter_candidates.append(OrderedDict(sorted(row_counter.items()))) ### order key
    all_number_candidates.append(total_num)

### flatten
all_row_mapping_idx_synset_name = []
all_row_chatgpt_names = []
all_row_i_syn = []
all_row_weight = []
all_row_key_name = []
for i in range(len(all_counter_candidates)):
    row_synset_names = all_counter_candidates[i].keys()
    row_mapping_idx_synset_name = dict(zip(range(len(row_synset_names)), row_synset_names))
    row_i_syn = []
    row_chatgpt_names = []
    row_weight = []
    for i_syn, syn in enumerate(row_synset_names):
        row_i_syn.extend([i_syn for _ in range(len(all_counter_candidates[i][syn]))])
        row_chatgpt_names.extend(list(all_counter_candidates[i][syn]))
        row_weight.extend(list(all_counter_candidates[i][syn].values()))
    
    all_row_mapping_idx_synset_name.append(row_mapping_idx_synset_name)
    all_row_chatgpt_names.append(row_chatgpt_names)
    all_row_i_syn.append(row_i_syn)
    all_row_weight.append(row_weight)
    all_row_key_name.append(list(map(lambda x: row_mapping_idx_synset_name[x], row_i_syn)))

In [None]:
@torch.no_grad()
def build_classifier_chatgpt(all_row_chatgpt_names, model, all_row_key_name=None):
    """ build classifier for chatgpt
    Args:
        all_row_chatgpt_names: [[names]]
    """
    if all_row_key_name is None: ### single name
        with open('../templates_small.json', 'rb') as f: ### template 1
            templates = json.load(f)['imagenet']
    else:
        with open('../templates_small.json', 'rb') as f: ### template 2
            templates = json.load(f)[f'{args.dataset}']
            
    len_t = len(templates)
    row_classifier = []
    with tqdm(total=len(all_row_chatgpt_names)) as pbar:
        for idx, row in enumerate(all_row_chatgpt_names):
            len_row = len(row)
            if all_row_key_name is None:
                row_t = [ t.format(name) for name in row for t in templates ]
            else:
                row_t = [ t.format(pname, name) for pname, name in zip(all_row_key_name[idx], row) for t in templates ]
            row_t = tokenize(row_t).to(args.device)
            features = model.encode_text(row_t)
            features = features.view(len_row, len_t, -1).float()
            features = features/features.norm(dim=-1, keepdim=True)
            features = features.mean(dim=1)
            features = features/features.norm(dim=-1, keepdim=True)
            row_classifier.append(features.cpu())
            
            pbar.update(1)
    return row_classifier
    

In [None]:
all_row_classifier = build_classifier_chatgpt(all_row_chatgpt_names, model, all_row_key_name=all_row_key_name)

##### realignment

In [None]:
vfeatures = all_vfeatures
all_clu_pred_chatgpt = torch.zeros_like(all_clu_pred)
is_correct = []
k_2 = 3
enable_weight = True ### enable weighting for imbalanced partitions
instance_pred_voc = torch.zeros_like(record_pred_kmeans_t)
for c in range(len(all_row_classifier)):
    select = (record_pred_kmeans_t==c)
    row_classifier = all_row_classifier[c]
    sim = torch.from_numpy(vfeatures[select, ...]).to(args.device)@row_classifier.to(args.device).t()
    sim_topk = sim.topk(k=k_2)
    ind, val = sim_topk.indices.flatten().cpu().unique(return_counts=True)
    count_names = torch.zeros(row_classifier.size(0)).long()
    count_names[ind] = val ### count of each name
    count_smask = []
    smask = np.array(all_row_i_syn[c]) ### partition mask
    for s in np.unique(smask):
        if enable_weight:
            row_weight = torch.tensor(all_row_weight[c]).float()
            row_weight[smask==s] = row_weight[smask==s] / row_weight[(smask==s)].sum()
            row_weight /= row_weight.sum()
            count_smask.append((row_weight[smask==s]*count_names[smask==s]).sum().item())
        else:
            count_smask.append(count_names[smask==s].sum())
    name_pred = all_row_mapping_idx_synset_name[c][np.argmax(count_smask)]
    name_gt = all_gt_voc[select].mode().values
    name_gt = vocab.mapping_idx_names[name_gt.item()]
    is_correct.append(name_pred==name_gt)
    instance_pred_voc[select] = vocab.mapping_names_idx[name_pred]
    
    val_count = torch.tensor(count_smask)
    ind_count = [ all_row_mapping_idx_synset_name[c][ii] for ii in range(k_1) ]
    ind_count = torch.tensor([vocab.mapping_names_idx[xx] for xx in ind_count])
    all_clu_pred_chatgpt[c, ind_count] = val_count
    
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
a, res_ass = linear_assign(all_clu_pred_chatgpt, record_pred_kmeans_t, all_gt_voc)
r_pred_kmeans_t, r_cluster_ind_voc = reassign_by_pred_cluster(a, loader_f, model, classifier, args.device, preextracted_vfeatures=all_vfeatures)
set_pred = set(res_ass[1].tolist())
set_gt = set(all_gt_voc.unique().numpy().tolist())
n_inter = all_gt_voc[r_cluster_ind_voc.cpu()==all_gt_voc].unique().shape[0]
n_union = torch.cat([r_cluster_ind_voc.cpu(), all_gt_voc]).unique().shape[0]
iou_voc = n_inter/n_union
n_missing_label = all_gt_voc.unique().shape[0] - n_inter
print('missing label::', n_missing_label)
print('iou voc::', iou_voc)
print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=r_pred_kmeans_t.numpy()))
n_inter = all_gt_voc[cluster_ind_voc.cpu()==all_gt_voc].unique().shape[0]
n_union = torch.cat([cluster_ind_voc.cpu(), all_gt_voc]).unique().shape[0]
iou_voc = n_inter/n_union
n_missing_label = all_gt_voc.unique().shape[0] - n_inter
print('missing label::', n_missing_label)

In [None]:
result_data = {
    'r_pred_kmeans_t': r_pred_kmeans_t.cpu(),
    'r_cluster_ind_voc': r_cluster_ind_voc.cpu(),
}
if args.arch == 'ViT-B/16':
    torch.save(result_data, f'./cache/training/cvpr_result-data={args.dataset}-clip.pth')
elif args.arch == 'ViT-L/14':
    torch.save(result_data, f'./cache/training/cvpr_result-data={args.dataset}-clip-L.pth')
else:
    raise NotImplementedError()

Finally, the computed `cvpr_results` are saved and used for semantic alignment self-traning.