In [1]:
import sys
sys.path.append('/home/sheng/OSZSL/')
sys.path.append('/home/sheng/OSZSL/CLIP/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce, partial
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import matplotlib.pyplot as plt
import heapq
from wordnet_utils import *

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from ipynb_utils import get_hier_datasets, get_classifier, MCMF_assign_labels
import clip
from data.datasets import get_datasets_oszsl, build_transform, get_hier_datasets, Vocab


In [10]:
class Config:
    device = 'cuda:1'
    arch = 'ViT-B/16'
    dataset_name = 'make_nonliving26'
    n_sampled_classes = 100
    input_size = 224
    
    batch_size = 512
    use_def = False
    clip_checkpoint = None
    f_classifier = './cache/wordnet_classifier_in21k_word.pth'
    templates_name = 'templates_small'
    
args = Config()

In [11]:
def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

def get_vocab():
    """
    Args:
        vocab: {`names`: list, `ids`: synset ids, `parents`: [{synset ids}]}
    """
    with open('/home/sheng/dataset/wordnet_nouns_with_synset_4.pkl', 'rb') as f:
        vocab = pickle.load(f)
    return vocab

def get_subsample_vocab(sample_synset_id: set):
    vocab = get_vocab()
    index = np.array([ i for i in range(len(vocab['synsets'])) if vocab['synsets'][i] in sample_synset_id ]).astype(np.int32)
    for k in vocab.keys():
        vocab[k] = np.array(vocab[k])[index].tolist()
    return vocab

def read_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
    return data

templates = load_templates(args)
vocab = get_vocab()
nouns = [ wn.synset(s) for s in vocab['synsets'] ]
classnames = vocab['names']
parents = vocab['parents']
defs = vocab['def']


In [None]:
# """ candidate subclasses """
# from data.datasets import read_breeds_class_hierarchy
# superclass = get_hier_datasets(args.dataset_name)[0][0]
# ### get all superclass synset
# parents_breeds = read_breeds_class_hierarchy()
# parents_breeds = { line[1]: line[0] for line in parents_breeds }
# superclass_synset = []
# for s in superclass:
#     if s[:len('dummy')]=='dummy':
#         superclass_synset.append( wn.synset_from_pos_and_offset('n', int(parents_breeds[s][1:])) )
#     else:
#         sup = wn.synset_from_pos_and_offset('n', int(s[1:]))
#         superclass_synset.append(sup)
# ### get all subclass synset id
# closeset_synsets = set()
# for s in superclass_synset:
#     closeset_synsets |= successor_set(G, source=s.name())

# print('candidate number', len(closeset_synsets))

# from robustness.tools.breeds_helpers import ClassHierarchy
# info_dir = '/home/sheng/OSZSL/breeds_hier/modified'
# dataset_hier, hier = get_hier_datasets(args.dataset_name)

In [4]:
""" build entire wn-graph """
from nxgraph_model import *

with open('/home/sheng/dataset/wordnet_nouns_with_synset.pkl', 'rb') as f:
    entire_vocab = pickle.load(f)
    
G = create_graph([wn.synset(x) for x in entire_vocab['synsets']], entire_vocab['ids'], entire_vocab['names'], entire_vocab['def'])

In [13]:
""" prepare dataset and load CLIP """
classes = read_imagenet21k_classes() + os.listdir('/home/sheng/dataset/imagenet-img/')
classes = [wn.synset_from_pos_and_offset('n', int(x[1:])).name() for x in classes]
classes = set(classes)
vocab = get_subsample_vocab(classes)
vocab = Vocab(vocab=vocab)

transform_val = build_transform(is_train=False, args=args, train_config=None)
dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_val, seed=1)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)
print('dataset size', len(dataset))

model, preprocess = load_clip(args)

dataset size 132765
Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408


#### build classifier

In [63]:
""" from MUST """
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

In [64]:
batch_size = 64
with torch.no_grad():
    zeroshot_weights = []
    with tqdm(total=len(vocab.classnames)//batch_size) as pbar:
        for classname_set in np.array_split(vocab.classnames, len(vocab.classnames)//batch_size):
            texts = [template.format(classname) for classname in classname_set for template in templates] #format with class
            texts = tokenize(texts).to(args.device) #tokenize
            class_embeddings = model.encode_text(texts).float() #embed with text encoder
            class_embeddings = class_embeddings.view(-1, len(templates), class_embeddings.size(-1))
            class_embeddings = F.normalize(class_embeddings, dim=-1)
            class_embedding = class_embeddings.mean(dim=1)
            class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
            zeroshot_weights.append(class_embedding.cpu())

            pbar.update(1)

classifier = torch.cat(zeroshot_weights, dim=0)
torch.save(classifier, './cache/wordnet_classifier_in21k_word.pth')

100%|██████████| 313/313 [00:41<00:00,  7.54it/s]


### performance test

#### utils

In [12]:
def load_clip(args):
    model, preprocess = clip.load(args.arch)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model, preprocess

def topk_acc(all_pred_voc_topk, all_gt_voc):
    acc = []
    ### topK accuracy
    for i in range(all_pred_voc_topk.size(1)):
        vec = torch.zeros(all_pred_voc_topk.size(0)).bool()
        for j in range(i+1):
            vec |= (all_pred_voc_topk[:, j]==all_gt_voc)
        print(f'k={i} acc={vec.float().mean()}')
        acc.append(vec.float().mean().item())
    return acc

def semantic_acc(y_pred, y_true, metrics={}):
    """ compute soft semantic acc for @y_pred and @y_true """
    assert len(metrics)>0
    assert y_pred.size(0)==y_true.size(0)
    scores = {m:[] for m in metrics.keys()}
    with tqdm(total=y_pred.size(0)) as pbar:
        for i in range(y_pred.size(0)):
            syn_pred = mapping_vocidx_to_synsets(y_pred[i].item(), vocab)
            syn_true = mapping_vocidx_to_synsets(y_true[i].item(), vocab)
            pairs = list(itertools.product(range(len(syn_pred)), range(len(syn_true))))
            for m_name, m in metrics.items():
                scores[m_name].append( max([ m(syn_pred[p[0]], syn_true[p[1]]) for p in pairs ]) )
            pbar.update(1)
    for m_name in metrics.keys():
        scores[m_name] = np.array(scores[m_name]).mean()
    return scores
    

#### naive inference

In [11]:
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

all_pred_voc = []
all_gt_voc = []
all_pred_voc_topk = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                all_pred_voc.append(prob.argmax(dim=-1).cpu())
                all_gt_voc.append(label_voc)
                all_pred_voc_topk.append(prob.topk(k=5, dim=-1).indices.cpu())
        pbar.update(1)

all_pred_voc = torch.cat(all_pred_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_pred_voc_topk = torch.cat(all_pred_voc_topk, dim=0)

100%|██████████| 260/260 [02:35<00:00,  1.68it/s]


In [12]:
print(f'acc={(all_pred_voc == all_gt_voc).float().mean()}')
n_missing = len(set(all_gt_voc.unique().numpy()) - set(all_pred_voc.unique().numpy()))
print(f'n_missing={n_missing}')

acc=0.33346137404441833
n_missing=2


In [13]:
score_baseline = semantic_acc(all_pred_voc, all_gt_voc, 
                              metrics={'wup_similarity': wn.wup_similarity, 'lch_similarity': wn.lch_similarity})
print(score_baseline)

100%|██████████| 132765/132765 [01:03<00:00, 2096.07it/s]

{'wup_similarity': 0.751388662643735, 'lch_similarity': 2.3188224544809772}





In [14]:
topk_acc(all_pred_voc_topk, all_gt_voc)

k=0 acc=0.33335593342781067
k=1 acc=0.45885586738586426
k=2 acc=0.5238654613494873
k=3 acc=0.5691183805465698
k=4 acc=0.6037208437919617


[0.33335593342781067,
 0.45885586738586426,
 0.5238654613494873,
 0.5691183805465698,
 0.6037208437919617]

#### SCD

In [15]:
from sklearn.cluster import KMeans
from my_util_package.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f)
if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=1)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=1)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=4, batch_size=args.batch_size, shuffle=False)

In [16]:
def agg_by_pred_cluster(args, pred_kmeans, all_topk_voc, voc_size):
    """
    Args:
        pred_kmeans: np.array([N])
        all_topk_voc: np.array([N x K])
        voc_size: int
    Returns:
        all_clu_pred: tensor([C x V])
    """
    print('agg_by_pred_cluster')
    all_clu_pred = []
    n_count = []
    for i in np.unique(pred_kmeans):
        selected = (pred_kmeans==i)
        n_count.append( selected.sum().item() )
        counter_voc_ind, counter_val = np.unique((all_topk_voc[selected]).ravel(), return_counts=True)
        # counter_val = counter_val/(n_count+1e-20) # L1 norm
        clu_pred = torch.zeros(args.num_voc) # cluster-wise prob
        clu_pred[torch.from_numpy(counter_voc_ind).long()] = torch.from_numpy(counter_val).float()
        # clu_pred = F.normalize(all_topk_voc[selected].sum(dim=0), dim=-1, p=1)
        all_clu_pred.append(clu_pred)
    all_clu_pred = torch.stack(all_clu_pred, dim=0).cpu()
    n_count = torch.tensor(n_count).cpu()
    
    # all_clu_pred = setdiff_assignment(all_clu_pred)
    
    all_clu_pred = all_clu_pred/(n_count.view(-1, 1) + 1e-20)
    
    print('is mutex assignment::', all_clu_pred.argmax(dim=-1).size(0)==all_clu_pred.argmax(dim=-1).unique().size(0))
    print('assignment collision num::', len(list(filter(lambda x: x>1, Counter(all_clu_pred.argmax(dim=-1).numpy()).values()))))
    return all_clu_pred

def linear_assign(all_clu_pred, pred_kmeans, all_gt_voc, return_results=False):
    print('linear_assign')
    cost_mat = all_clu_pred.cpu().numpy()
    print(f'assignment shape={cost_mat.shape}')
    res_ass = linear_assignment(cost_mat.max() - cost_mat)
    label_voc_kmeans = torch.tensor([res_ass[1][x.item()] for x in pred_kmeans])
    inst_acc = (label_voc_kmeans==all_gt_voc).float().mean().item()
    print('instance label acc::', inst_acc)
    if return_results:
        return label_voc_kmeans, res_ass, inst_acc
    return label_voc_kmeans, res_ass

def reassign_by_pred_cluster(label_voc_kmeans, model, classifier, device, 
                             all_prob=None, 
                             instance_selected=None, 
                             classifier_selected=None):
    """
    Args:
        classifier_selected: tensor([C2])
    """
    print('reassign_by_pred_cluster')
    amp_autocast = torch.cuda.amp.autocast
    label_voc_kmeans = label_voc_kmeans.to(device)
    if all_prob is None:
        cluster_ind = []
        with tqdm(total=len(loader_f)) as pbar:
            if hasattr(model, 'eval'):
                model.eval()
            for idx_batch, batch in enumerate(loader_f):
                images, label_voc, label_clu, idx_img = batch[:4]
                images = images.to(device)
                if (instance_selected is not None) and ((~instance_selected[idx_img]).all()):
                    continue
                with amp_autocast():
                    with torch.no_grad():
                        if (instance_selected is not None):
                            logits = model.visual(images[instance_selected[idx_img]])
                        else:
                            logits = model.visual(images)
                            
                        logits = logits/logits.norm(dim=-1, keepdim=True)
                        if classifier_selected is not None:
                            similarity = 100 * logits @ classifier[classifier_selected].t()
                            prob = classifier_selected[similarity.softmax(-1)]
                            cluster_ind.append(prob.cpu().argmax(dim=-1))
                        else:
                            similarity = 100 * logits @ classifier.t()
                            prob = similarity.softmax(-1)
                            cluster_ind.append(prob[:, label_voc_kmeans].cpu().argmax(dim=-1))
                pbar.update(1)
        cluster_ind = torch.cat(cluster_ind, dim=0)
    else:
        all_prob = all_prob[:, label_voc_kmeans]
        cluster_ind = all_prob.argmax(dim=-1)
        
    if classifier_selected is not None:
        cluster_ind_voc = classifier_selected[cluster_ind]
    else:
        cluster_ind_voc = label_voc_kmeans[cluster_ind]
    mapping_ind = dict(zip(cluster_ind.unique().numpy(), torch.arange(cluster_ind.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind])
    return cluster_ind, cluster_ind_voc


@torch.no_grad()
def computation_reassign_by_pred_cluster(row, idx, args, model, classifier, candidate_classifier_ind):
    """
    candidate_classifier_ind = label_voc_kmeans.unique().to(args.device)
    """
    images, label_voc, label_clu, idx_img = row[:4]
    images = images.to(args.device)
    with amp_autocast():
        vfeatures = model.visual(images).float()
        # vfeatures = vfeatures/vfeatures.norm(dim=-1, keepdim=True)
    vfeatures = F.normalize(vfeatures, dim=-1)
    batch_sim = 100*vfeatures@classifier[candidate_classifier_ind].t()
    cluster_ind = batch_sim.argmax(dim=-1)
    cluster_ind_voc = candidate_classifier_ind[cluster_ind].cpu()
    return cluster_ind_voc

def aggregation_reassign_by_pred_cluster(r, candidate_classifier_ind):
    cluster_ind_voc = torch.cat(r, dim=0)
    mapping_ind = dict(zip(cluster_ind_voc.unique().numpy(), torch.arange(cluster_ind_voc.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind_voc])
    return cluster_ind, cluster_ind_voc


@torch.no_grad()
def extract_vfeatures(model, data_loader, device):
    amp_autocast = torch.cuda.amp.autocast
    all_vfeatures = []
    with tqdm(total=len(data_loader)) as pbar:
        if hasattr(model, 'eval'):
            model.eval()
        for idx_batch, batch in enumerate(data_loader):
            images, label_voc, label_clu, idx_img = batch[:4]
            images = images.to(device)
            with amp_autocast():
                vfeatures = model.visual(images).float()
            vfeatures = vfeatures/vfeatures.norm(dim=-1, keepdim=True)
            all_vfeatures.append(vfeatures.cpu().numpy())
            pbar.update(1)
    all_vfeatures = np.concatenate(all_vfeatures)
    return all_vfeatures


@torch.no_grad()
def loop_row_collect_results_nograd(obj_iter, computations={}, aggregations={}):
    """ compute and aggregate results, looping over @obj_iter 
    func_computation(@row, @index_row)
    aggregations(list(@results_computation))
    """
    assert set(list(computations.keys())) == set(list(aggregations.keys()))
    collector = { k:[] for k in computations }
    with tqdm(total=len(obj_iter)) as pbar:
        for i, row in enumerate(obj_iter):
            ### apply computations
            for k, func in computations.items():
                collector[k].append(func(row, i))
            pbar.update(1)
    ### aggregate results
    results = {}
    for k, func_agg in aggregations.items():
        results[k] = func_agg(collector[k])
    return results

In [None]:
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=1, batch_size=args.batch_size, shuffle=False)
# classifier = get_classifier(args)
# classifier = classifier/classifier.norm(dim=-1, keepdim=True)
# args.num_voc = classifier.size(0)
# amp_autocast = torch.cuda.amp.autocast
# ### collect variables
# prob_k = 5
# all_topk_voc = []
# all_gt_voc = []
# all_label_clu = []
# with tqdm(total=len(loader_f)) as pbar:
#     if hasattr(model, 'eval'):
#         model.eval()
#     for idx_batch, batch in enumerate(loader_f):
#         images, label_voc, label_clu, idx_img = batch[:4]
#         images = images.to(args.device)
#         with amp_autocast():
#             with torch.no_grad():
#                 logits = model.visual(images)
#                 logits = logits/logits.norm(dim=-1, keepdim=True)
#                 similarity = 100 * logits @ classifier.t()
#                 prob = similarity.softmax(-1)
#                 prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
#                 all_topk_voc.append(prob_topk_ind.cpu().numpy())
#                 all_gt_voc.append(label_voc)
#                 all_label_clu.append(label_clu)
#         pbar.update(1)

# all_topk_voc = np.concatenate(all_topk_voc)
# all_gt_voc = torch.cat(all_gt_voc, dim=0)
# all_label_clu = torch.cat(all_label_clu, dim=0)

pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset_name}-train-clip.npy'))
pred_kmeans_t = pred_kmeans
history_set_pred = []
for t in range(3):
    record_pred_kmeans_t = pred_kmeans_t
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
    # pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, model, classifier, args.device)
    results = \
    loop_row_collect_results_nograd(loader_f, 
                                    computations={
                                        'reassign_by_pred_cluster': partial(
                                            computation_reassign_by_pred_cluster, 
                                            args=args, 
                                            model=model, 
                                            classifier=classifier, 
                                            candidate_classifier_ind=label_voc_kmeans.unique().to(args.device),
                                        ),
                                    }, 
                                    aggregations={
                                        'reassign_by_pred_cluster': partial(
                                            aggregation_reassign_by_pred_cluster,
                                            candidate_classifier_ind=label_voc_kmeans.unique().to(args.device),
                                        ),
                                    },
                                   )
    pred_kmeans_t, cluster_ind_voc = results['reassign_by_pred_cluster']
    set_pred = set(res_ass[1].tolist())
    set_gt = set(all_gt_voc.unique().numpy().tolist())
    print('missing label::', len(set_gt - set_pred))
    print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
    history_set_pred.append(set_pred)

    


agg_by_pred_cluster
is mutex assignment:: False
assignment collision num:: 6
linear_assign
assignment shape=(104, 20071)
instance label acc:: 0.3845817744731903


 91%|█████████ | 237/260 [13:17<01:20,  3.48s/it]

#### CHATGPT request

In [14]:
import openai
def openai_chatgpt_post(content, parameters={'temperature': 0.7}):
    openai.api_key = "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"
    completion = openai.ChatCompletion.create(
      model="gpt-3.5-turbo",
      messages=[
        {"role": "user", "content": content},
      ],
    **parameters,
    )
    result = completion['choices'][0]['message']['content']
    # completion = openai.Completion.create(
    #     model="text-davinci-003",
    #     prompt=content,  
    #     temperature=0.7,
    #     max_tokens=256,
    #     top_p=1,
    #     frequency_penalty=0,
    #     presence_penalty=0,
    # )
    # result = completion['choices'][0]['text']
    return result

In [15]:
k_1 = 3
classifier = get_classifier(args).cpu()
sim_classifier = classifier@classifier.t()
topk_all_clu_pred = sim_classifier.topk(k=k_1+1).indices[:, 1:]

In [23]:
""" gather concepts """
to_name = lambda x: [ s.name() + ': ' + s.definition() for s in x ]
cluster_row_synsets = []
for row in topk_all_clu_pred:
    row_synsets = [to_name(mapping_vocidx_to_synsets(voc_idx.item(), vocab)) for voc_idx in row]
    cluster_row_synsets.append(row_synsets)

    
""" generate concept requests """
concept_request = []
for row in cluster_row_synsets:
    ccpts = reduce(lambda x, y: x+y, row)
    ccpts = list(map(lambda x: "'"+x+".'", ccpts))
    ccpts = ', '.join(ccpts)
    concept_request.append(ccpts)
    
""" generate concept templates """
template_1 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all alternative concept names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
with open('/home/sheng/OSZSL/templates_chatgpt.json', 'r') as f:
    template_chatgpt = json.load(f)
template_2 = lambda concept_list: template_chatgpt['pictionary-long'].format(concept_list)
template_3 = lambda concept_list: template_chatgpt['pictionary-short'].format(concept_list)
template_4 = lambda concept_list: template_chatgpt['direct'].format(concept_list)
    
template_in_use = template_1
concept_templates = []
for row in concept_request:
    concept_templates.append(template_in_use(row))

In [24]:
concept_templates = concept_templates[7000:14000]

In [None]:
""" collect chatgpt res """
n_repeat = 3
all_chatgpt_res = [[] for _ in range(n_repeat)]
with tqdm(total=len(concept_templates)*n_repeat) as pbar:
    for i in range(n_repeat):
        for row in concept_templates:
            while 1:
                try:
                    all_chatgpt_res[i].append(openai_chatgpt_post(row))
                    break
                except Exception as e:
                    print(e)

            pbar.update(1)

In [26]:
with open(f'./cache/openai/visual-inov-template=1-k_1={k_1}-repeat={n_repeat}-part=2.pkl', 'wb') as f:
    pickle.dump(all_chatgpt_res, f)

In [None]:
all_chatgpt_res