In [1]:
import sys
sys.path.append('/home/sheng/sssa/')
sys.path.append('/home/sheng/sssa/')
# sys.path.append('/home/sheng/sssa/CLIP/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce, partial
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import matplotlib.pyplot as plt
import heapq
from wordnet_utils import *
import scipy.io
from PIL import Image

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.datasets import ImageFolder

from ipynb_utils import get_hier_datasets, get_classifier, MCMF_assign_labels
# import clip
import model as clip
from data.datasets import build_transform, get_hier_datasets
from data.vocab import get_vocab_with_classnames
# from data.imagenet_datasets import get_datasets_oszsl
from data.imagenet_datasets_namevocab import get_datasets_oszsl




In [2]:
class Config:
    exp = 'classifier_3d'
    vocabname = 'concat2' ### ['in21k', 'concat3', 'concat3+lvis']
    
    device = 'cuda:3'
    arch = 'ViT-B/16'
    
    dataset = 'make_entity30'
    n_sampled_classes = 100
    input_size = 224
    estimate_k = -1
    
    batch_size = 512
    use_def = False
    clip_checkpoint = None
    # f_classifier = './cache/wordnet_classifier_in21k_word.pth'
    f_classifier = './cache/classifier_3d-concat2.pth'
    templates_name = 'templates_small'
    seed = 0
    
args = Config()

In [3]:
def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

templates = load_templates(args)

def load_clip(args):
    model, preprocess = clip.load(args.arch)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model_ema'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model, preprocess

def load_clip2(args):
    model = clip.load(args.arch)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model_ema'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model

def load_mixture_clip(args, decay=1.0):
    model1 = clip.load(args.arch)
    if args.clip_checkpoint:
        model1.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model_ema'].items()}, strict=False)
    model1.to(args.device).eval()
    model2 = clip.load(args.arch)
    model2.to(args.device).eval()
    with torch.no_grad():
        msd = model1.state_dict()
        for k, ema_v in model2.state_dict().items():
            # if needs_module:
            #     k = 'module.' + k
            model_v = msd[k].detach()
            ema_v.copy_(ema_v * decay + (1. - decay) * model_v)
    return model2

def topk_acc(all_pred_voc_topk, all_gt_voc):
    acc = []
    ### topK accuracy
    for i in range(all_pred_voc_topk.size(1)):
        vec = torch.zeros(all_pred_voc_topk.size(0)).bool()
        for j in range(i+1):
            vec |= (all_pred_voc_topk[:, j]==all_gt_voc)
        print(f'k={i} acc={vec.float().mean()}')
        acc.append(vec.float().mean().item())
    return acc

def semantic_acc(y_pred, y_true, metrics={}):
    """ compute soft semantic acc for @y_pred and @y_true """
    assert len(metrics)>0
    assert y_pred.size(0)==y_true.size(0)
    scores = {m:[] for m in metrics.keys()}
    with tqdm(total=y_pred.size(0)) as pbar:
        for i in range(y_pred.size(0)):
            syn_pred = mapping_vocidx_to_synsets(y_pred[i].item(), vocab)
            syn_true = mapping_vocidx_to_synsets(y_true[i].item(), vocab)
            pairs = list(itertools.product(range(len(syn_pred)), range(len(syn_true))))
            for m_name, m in metrics.items():
                scores[m_name].append( max([ m(syn_pred[p[0]], syn_true[p[1]]) for p in pairs ]) )
            pbar.update(1)
    for m_name in metrics.keys():
        scores[m_name] = np.array(scores[m_name]).mean()
    return scores
    
""" from MUST """
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

In [4]:
""" prepare dataset and load CLIP """
vocab = get_vocab_with_classnames(args.vocabname)

transform_val = build_transform(is_train=False, args=args, train_config=None)
print('get dataset', args.dataset)
dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_val, seed=0)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=args.batch_size, shuffle=False)
print('dataset size', len(dataset))

# model, preprocess = load_clip(args)
model = load_clip2(args)

mapping_vocidx_to_synsets = lambda anchor, vocab: list(filter(lambda x: (x.name().split('.')[1]=='n') and (x.name().split('.')[0]==vocab.mapping_idx_names[anchor]), wn.synsets( vocab.mapping_idx_names[anchor] )))
# mapping_vocidx_to_synsets = lambda anchor, vocab: list(filter(lambda x: (x.name().split('.')[1]=='n') and (x.name().split('.')[0] in vocab.mapping_names_idx.keys()), wn.synsets( vocab.mapping_idx_names[anchor] )))

get_vocab concat2
get dataset make_entity30
dataset size 307835
missing keys:
['visual.projection_head.0.weight', 'visual.projection_head.0.bias', 'visual.projection_head.2.weight', 'visual.projection_head.2.bias']
Model parameters: 150,408,193
Input resolution: 224
Context length: 77
Vocab size: 49408


#### build classifier

In [6]:
""" from MUST """
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)
    return result


def build_classifier(args, model, templates, vocab_classnames, parent_classnames=None):
    batch_size = 64
    with torch.no_grad():
        zeroshot_weights = []
        assert parent_classnames is None
        with tqdm(total=len(vocab_classnames)//batch_size) as pbar:
            for classname_set in np.array_split(vocab_classnames, len(vocab_classnames)//batch_size):
                texts = [template.format(classname) for classname in classname_set for template in templates] #format with class
                texts = tokenize(texts).to(args.device) #tokenize
                class_embeddings = model.encode_text(texts).float() #embed with text encoder
                class_embeddings = class_embeddings.view(-1, len(templates), class_embeddings.size(-1))
                class_embeddings = F.normalize(class_embeddings, dim=-1)
                class_embedding = class_embeddings.mean(dim=1)
                class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
                zeroshot_weights.append(class_embedding.cpu())
                pbar.update(1)
        # else:
        #     with tqdm(total=len(vocab_classnames)//batch_size) as pbar:
        #         for classname_set, parentname_set in zip(
        #             np.array_split(vocab_classnames, len(vocab_classnames)//batch_size),
        #             np.array_split(parent_classnames, len(parent_classnames)//batch_size),
        #         ):
        #             texts = [template.format(classname)+f' A type of {pname}.' for classname, pname in zip(classname_set, parentname_set) for template in templates] #format with class
        #             texts = tokenize(texts).to(args.device) #tokenize
        #             class_embeddings = model.encode_text(texts).float() #embed with text encoder
        #             class_embeddings = class_embeddings.view(-1, len(templates), class_embeddings.size(-1))
        #             class_embeddings = F.normalize(class_embeddings, dim=-1)
        #             class_embedding = class_embeddings.mean(dim=1)
        #             class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
        #             zeroshot_weights.append(class_embedding.cpu())
        #             pbar.update(1)
    classifier = torch.cat(zeroshot_weights, dim=0)
    return classifier

In [7]:
classifier = build_classifier(args, model, templates, vocab.classnames)
torch.save(classifier, f'./cache/{args.exp}-{args.vocabname}.pth')

100%|██████████| 316/316 [01:01<00:00,  5.12it/s]


### performance test

#### naive inference

In [None]:
classifier = get_classifier(args)
amp_autocast = torch.cuda.amp.autocast
classifier = classifier/classifier.norm(dim=-1, keepdim=True)

all_pred_voc = []
all_gt_voc = []
all_pred_voc_topk = []
all_vfeatures = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual.extract_features(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = model.logit_scale.exp() * logits @ classifier.t()
                prob = similarity.softmax(-1)
                all_pred_voc.append(prob.argmax(dim=-1).cpu())
                all_gt_voc.append(label_voc)
                all_pred_voc_topk.append(prob.topk(k=5, dim=-1).indices.cpu())
                all_vfeatures.append(logits.cpu().numpy())
        pbar.update(1)

all_pred_voc = torch.cat(all_pred_voc, dim=0)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_pred_voc_topk = torch.cat(all_pred_voc_topk, dim=0)
all_vfeatures = np.concatenate(all_vfeatures)

In [10]:
print(f'acc={(all_pred_voc == all_gt_voc).float().mean()}')
n_missing = len(set(all_gt_voc.unique().numpy()) - set(all_pred_voc.unique().numpy()))
print(f'n_missing={n_missing}')

acc=0.33346137404441833
n_missing=2


#### SCD

In [5]:
from sklearn.cluster import KMeans
from my_util_package_oszsl.evaluation import cluster_acc
from scipy.optimize import linear_sum_assignment as linear_assignment

subset = ['train', 'val'][0]
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(args.input_size, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=torch.tensor(mean),
        std=torch.tensor(std))
])

# dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f)
if subset == 'train':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=0)
elif subset == 'val':
    dataset_f = get_datasets_oszsl(args, vocab, is_train=False, transform=transform_f, seed=0)
args.nb_classes = dataset_f.num_classes
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=4, batch_size=args.batch_size, shuffle=False)

In [6]:
def agg_by_pred_cluster(args, pred_kmeans, all_topk_voc, voc_size):
    """
    Args:
        pred_kmeans: np.array([N])
        all_topk_voc: np.array([N x K])
        voc_size: int
    Returns:
        all_clu_pred: tensor([C x V])
    """
    print('agg_by_pred_cluster')
    all_clu_pred = []
    n_count = []
    for i in np.unique(pred_kmeans):
        selected = (pred_kmeans==i)
        n_count.append( selected.sum().item() )
        counter_voc_ind, counter_val = np.unique((all_topk_voc[selected, :]).ravel(), return_counts=True)
        # counter_val = counter_val/(n_count+1e-20) # L1 norm
        clu_pred = torch.zeros(args.num_voc) # cluster-wise prob
        clu_pred[torch.from_numpy(counter_voc_ind).long()] = torch.from_numpy(counter_val).float()
        # clu_pred = F.normalize(all_topk_voc[selected].sum(dim=0), dim=-1, p=1)
        all_clu_pred.append(clu_pred)
    all_clu_pred = torch.stack(all_clu_pred, dim=0).cpu()
    n_count = torch.tensor(n_count).cpu()
    
    # all_clu_pred = setdiff_assignment(all_clu_pred)
    
    all_clu_pred = all_clu_pred/(n_count.view(-1, 1) + 1e-20)
    
    print('is mutex assignment::', all_clu_pred.argmax(dim=-1).size(0)==all_clu_pred.argmax(dim=-1).unique().size(0))
    print('assignment collision num::', len(list(filter(lambda x: x>1, Counter(all_clu_pred.argmax(dim=-1).numpy()).values()))))
    return all_clu_pred

def linear_assign(all_clu_pred, pred_kmeans, all_gt_voc, return_results=False):
    print('linear_assign')
    cost_mat = all_clu_pred.cpu().numpy()
    print(f'assignment shape={cost_mat.shape}')
    res_ass = linear_assignment(cost_mat.max() - cost_mat)
    label_voc_kmeans = torch.tensor([res_ass[1][x.item()] for x in pred_kmeans])
    inst_acc = (label_voc_kmeans==all_gt_voc).float().mean().item()
    print('instance label acc::', inst_acc)
    if return_results:
        return label_voc_kmeans, res_ass, inst_acc
    return label_voc_kmeans, res_ass

def reassign_by_pred_cluster(label_voc_kmeans, model, classifier, device, 
                             all_prob=None, 
                             instance_selected=None, 
                             classifier_selected=None):
    """
    Args:
        classifier_selected: tensor([C2])
    """
    print('reassign_by_pred_cluster')
    amp_autocast = torch.cuda.amp.autocast
    label_voc_kmeans = label_voc_kmeans.to(device)
    if all_prob is None:
        cluster_ind = []
        with tqdm(total=len(loader_f)) as pbar:
            if hasattr(model, 'eval'):
                model.eval()
            for idx_batch, batch in enumerate(loader_f):
                images, label_voc, label_clu, idx_img = batch[:4]
                images = images.to(device)
                if (instance_selected is not None) and ((~instance_selected[idx_img]).all()):
                    continue
                with amp_autocast():
                    with torch.no_grad():
                        if (instance_selected is not None):
                            logits = model.visual(images[instance_selected[idx_img]])
                        else:
                            logits = model.visual(images)
                            
                        logits = logits/logits.norm(dim=-1, keepdim=True)
                        if classifier_selected is not None:
                            similarity = 100 * logits @ classifier[classifier_selected].t()
                            prob = classifier_selected[similarity.softmax(-1)]
                            cluster_ind.append(prob.cpu().argmax(dim=-1))
                        else:
                            similarity = 100 * logits @ classifier.t()
                            prob = similarity.softmax(-1)
                            cluster_ind.append(prob[:, label_voc_kmeans].cpu().argmax(dim=-1))
                pbar.update(1)
        cluster_ind = torch.cat(cluster_ind, dim=0)
    else:
        all_prob = all_prob[:, label_voc_kmeans]
        cluster_ind = all_prob.argmax(dim=-1)
        
    if classifier_selected is not None:
        cluster_ind_voc = classifier_selected[cluster_ind]
    else:
        cluster_ind_voc = label_voc_kmeans[cluster_ind]
    mapping_ind = dict(zip(cluster_ind.unique().numpy(), torch.arange(cluster_ind.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind])
    return cluster_ind, cluster_ind_voc


def reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, device, 
                             preextracted_vfeatures=None):
    """ given vocab label set @label_voc_kmeans, 
    Args:
        label_voc_kmeans: cluster-assigned label on vocab
        ...
        preextracted_vfeatures: np.array([N x D])
    Returns:
        cluster_ind: tensor([N]): re-ordered cluster assignment
        cluster_ind_voc: tensor([N]): cluster assignment indiced by vocab
    """
    print('reassign_by_pred_cluster')
    amp_autocast = torch.cuda.amp.autocast
    label_voc_kmeans = label_voc_kmeans.to(device).unique()
    cluster_ind = []
    with tqdm(total=len(loader_f)) as pbar:
        if hasattr(model, 'eval'):
            model.eval()
        if preextracted_vfeatures is not None:
            N = len(loader_f.dataset)
            batch_size = min(10000, N)
            indices = np.array_split(np.arange(N), N//batch_size)
            with torch.no_grad():
                for group in indices:
                    logits = torch.from_numpy(preextracted_vfeatures[group]).float()
                    logits = logits/logits.norm(dim=-1, keepdim=True)
                    similarity = 100 * logits@classifier.t().cpu()
                    prob = similarity.softmax(-1)
                    cluster_ind.append(prob[:, label_voc_kmeans.cpu()].argmax(dim=-1))
        else:
            for idx_batch, batch in enumerate(loader_f):
                images, label_voc, label_clu, idx_img = batch[:4]
                images = images.to(device)
                with amp_autocast():
                    with torch.no_grad():
                        if preextracted_vfeatures is not None:
                            logits = torch.from_numpy(preextracted_vfeatures[idx_img.cpu().numpy()]).float().to(device)
                        else:
                            logits = model.ema.extract_vfeatures(images)
                        logits = logits/logits.norm(dim=-1, keepdim=True)
                        similarity = 100 * logits @ classifier.t()
                        prob = similarity.softmax(-1)
                        cluster_ind.append(prob[:, label_voc_kmeans].cpu().argmax(dim=-1))
                pbar.update(1)
    cluster_ind = torch.cat(cluster_ind, dim=0)
    cluster_ind_voc = label_voc_kmeans[cluster_ind]
    mapping_ind = dict(zip(cluster_ind.unique().numpy(), torch.arange(cluster_ind.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind])
    return cluster_ind, cluster_ind_voc


@torch.no_grad()
def computation_reassign_by_pred_cluster(row, idx, args, model, classifier, candidate_classifier_ind):
    """
    candidate_classifier_ind = label_voc_kmeans.unique().to(args.device)
    """
    images, label_voc, label_clu, idx_img = row[:4]
    images = images.to(args.device)
    with amp_autocast():
        vfeatures = model.visual(images).float()
        # vfeatures = vfeatures/vfeatures.norm(dim=-1, keepdim=True)
    vfeatures = F.normalize(vfeatures, dim=-1)
    batch_sim = 100*vfeatures@classifier[candidate_classifier_ind].t()
    cluster_ind = batch_sim.argmax(dim=-1)
    cluster_ind_voc = candidate_classifier_ind[cluster_ind].cpu()
    return cluster_ind_voc

def aggregation_reassign_by_pred_cluster(r, candidate_classifier_ind):
    cluster_ind_voc = torch.cat(r, dim=0)
    mapping_ind = dict(zip(cluster_ind_voc.unique().numpy(), torch.arange(cluster_ind_voc.unique().size(0)).numpy()))
    cluster_ind = torch.tensor([mapping_ind[x.item()] for x in cluster_ind_voc])
    return cluster_ind, cluster_ind_voc


@torch.no_grad()
def extract_vfeatures(model, data_loader, device):
    amp_autocast = torch.cuda.amp.autocast
    all_vfeatures = []
    with tqdm(total=len(data_loader)) as pbar:
        if hasattr(model, 'eval'):
            model.eval()
        for idx_batch, batch in enumerate(data_loader):
            images, label_voc, label_clu, idx_img = batch[:4]
            images = images.to(device)
            with amp_autocast():
                vfeatures = model.visual(images).float()
            vfeatures = vfeatures/vfeatures.norm(dim=-1, keepdim=True)
            all_vfeatures.append(vfeatures.cpu().numpy())
            pbar.update(1)
    all_vfeatures = np.concatenate(all_vfeatures)
    return all_vfeatures


@torch.no_grad()
def loop_row_collect_results_nograd(obj_iter, computations={}, aggregations={}):
    """ compute and aggregate results, looping over @obj_iter 
    func_computation(@row, @index_row)
    aggregations(list(@results_computation))
    """
    assert set(list(computations.keys())) == set(list(aggregations.keys()))
    collector = { k:[] for k in computations }
    with tqdm(total=len(obj_iter)) as pbar:
        for i, row in enumerate(obj_iter):
            ### apply computations
            for k, func in computations.items():
                collector[k].append(func(row, i))
            pbar.update(1)
    ### aggregate results
    results = {}
    for k, func_agg in aggregations.items():
        results[k] = func_agg(collector[k])
    return results

In [7]:
loader_f = torch.utils.data.DataLoader(dataset_f, num_workers=4, batch_size=args.batch_size, shuffle=False)
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
amp_autocast = torch.cuda.amp.autocast
### collect variables
prob_k = 1
all_topk_voc = []
all_gt_voc = []
all_label_clu = []
all_vfeatures = []
with tqdm(total=len(loader_f)) as pbar:
    if hasattr(model, 'eval'):
        model.eval()
    for idx_batch, batch in enumerate(loader_f):
        images, label_voc, label_clu, idx_img = batch[:4]
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual.extract_features(images)
                # logits = model.extract_vfeatures(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                similarity = 100 * logits @ classifier.t()
                prob = similarity.softmax(-1)
                prob_topk_ind = prob.topk(k=prob_k, dim=-1).indices
                all_topk_voc.append(prob_topk_ind.cpu().numpy())
                all_gt_voc.append(label_voc)
                all_label_clu.append(label_clu)
                all_vfeatures.append(logits.cpu().numpy())
        pbar.update(1)

all_topk_voc = np.concatenate(all_topk_voc)
all_gt_voc = torch.cat(all_gt_voc, dim=0)
all_label_clu = torch.cat(all_label_clu, dim=0)
all_vfeatures = np.concatenate(all_vfeatures)

100%|██████████| 602/602 [09:04<00:00,  1.11it/s]


In [8]:
# pred_kmeans = torch.from_numpy(np.load(f'./pred_clu-{args.dataset}-train-clip.npy'))
pred_kmeans = torch.from_numpy(np.load(f'./cache/cluster/kmeans-{args.dataset}.npy'))
# pred_kmeans = torch.from_numpy(np.load('/home/sheng/MUST-output/make_nonliving26/baseline-04_22_1/pred_kmeans_t.npy'))
# pred_kmeans = torch.from_numpy(np.load('/home/sheng/MUST-output/make_nonliving26/chatgpt_init-warmup=2/pred_kmeans_t.npy'))
pred_kmeans_t = pred_kmeans
history_set_pred = []
for t in range(3):
    record_pred_kmeans_t = pred_kmeans_t
    all_clu_pred = agg_by_pred_cluster(args, pred_kmeans_t.numpy(), all_topk_voc, voc_size=args.num_voc)
    label_voc_kmeans, res_ass = linear_assign(all_clu_pred, pred_kmeans_t, all_gt_voc)
    pred_kmeans_t, cluster_ind_voc = reassign_by_pred_cluster(label_voc_kmeans, loader_f, model, classifier, args.device, preextracted_vfeatures=all_vfeatures)
    set_pred = set(res_ass[1].tolist())
    set_gt = set(all_gt_voc.unique().numpy().tolist())
    print('missing label::', len(set_gt - set_pred))
    print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=pred_kmeans_t.numpy()))
    history_set_pred.append(set_pred)

agg_by_pred_cluster
is mutex assignment:: False
assignment collision num:: 19
linear_assign
assignment shape=(240, 20079)
instance label acc:: 0.36364611983299255
reassign_by_pred_cluster


  0%|          | 0/602 [00:10<?, ?it/s]


missing label:: 106
cluster acc 0.7095944255851349
agg_by_pred_cluster
is mutex assignment:: True
assignment collision num:: 0
linear_assign
assignment shape=(240, 20079)
instance label acc:: 0.46046096086502075
reassign_by_pred_cluster


  0%|          | 0/602 [00:10<?, ?it/s]


missing label:: 103
cluster acc 0.7106501859762535
agg_by_pred_cluster
is mutex assignment:: True
assignment collision num:: 0
linear_assign
assignment shape=(240, 20079)
instance label acc:: 0.4646807610988617
reassign_by_pred_cluster


  0%|          | 0/602 [00:10<?, ?it/s]


missing label:: 103
cluster acc 0.7106501859762535


In [9]:
torch.save({
    'all_clu_pred': all_clu_pred,
    'label_voc_kmeans': label_voc_kmeans,
    'pred_kmeans_t': pred_kmeans_t,
    'record_pred_kmeans_t': record_pred_kmeans_t,
    'all_gt_voc': all_gt_voc,
    'all_label_clu': all_label_clu,
    'all_topk_voc': all_topk_voc,
    'cluster_ind_voc': cluster_ind_voc,
    'all_vfeatures': torch.from_numpy(all_vfeatures),
}, f'./cache/scd/{args.exp}-{args.vocabname}-{args.dataset}-scd.pth')

In [12]:
np.save(f'/home/sheng/sssa/ipynb/cache/cluster/topk=1-cache-inov-{args.dataset}-clip-scd.pth', pred_kmeans_t.cpu().numpy())

### Multi Agent Game

In [9]:
import openai
def openai_chatgpt_post(content, parameters={'temperature': 0.7}, verbose=False):
    openai.api_key = "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"
    completion = openai.ChatCompletion.create(
      model="gpt-3.5-turbo",
      messages=[
        {"role": "user", "content": content},
      ],
    **parameters,
    )
    if verbose:
        print(completion)
    result = completion['choices'][0]['message']['content']
    return result

def openai_chatgpt_post_multirounds(content, parameters={'temperature': 0.7}):
    openai.api_key = "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"
    completion = openai.ChatCompletion.create(
      model="gpt-3.5-turbo",
      messages=content,
    **parameters,
    )
    return completion

def save_results(res, fpath='test.pkl'):
    with open(f'./cache/openai/MAG/{fpath}', 'wb') as f:
        pickle.dump(res, f)
    return 

def load_results(fpath='test.pkl'):
    with open(f'./cache/openai/MAG/{fpath}', 'rb') as f:
        res = pickle.load(f)
    return res

In [10]:
k_1 = 3

def generate_concepts(record_pred_kmeans_t, all_gt_voc, k_1=3):
    all_clu_gt_voc = []
    for c in record_pred_kmeans_t.unique():
        select = (record_pred_kmeans_t==c)
        all_clu_gt_voc.append(all_gt_voc[select].mode().values)

    all_clu_gt_voc = torch.tensor(all_clu_gt_voc)
    topk_all_clu_pred = all_clu_pred.topk(k=k_1).indices
    cluster_is_correct = torch.zeros(topk_all_clu_pred.size(0)).bool()
    for i in range(k_1):
        cluster_is_correct |= (topk_all_clu_pred[:, i]==all_clu_gt_voc)

    print(f'recall@{k_1} = {cluster_is_correct.float().mean()}')

    """ gather concepts """
    to_name = lambda x: [ s.name() + ': ' + s.definition() for s in x ]
    cluster_row_synsets = []
    for row in topk_all_clu_pred:
        row_synsets = [to_name(mapping_vocidx_to_synsets(voc_idx.item(), vocab)) for voc_idx in row]
        cluster_row_synsets.append(row_synsets)
    return cluster_row_synsets, topk_all_clu_pred

cluster_row_synsets, topk_all_clu_pred = generate_concepts(record_pred_kmeans_t, all_gt_voc, k_1=k_1)

recall@3 = 0.7041666507720947


In [11]:
""" generate concept requests """
def format_concept_request_with_def(cluster_row_synsets):
    concept_request = []
    for row in cluster_row_synsets:
        ccpts = reduce(lambda x, y: x+y, row)
        names = list(map(lambda x: "'"+x.split(':')[0]+"'", ccpts))
        ccpts = list(map(lambda x: "'"+x+".'", ccpts))
        ccpts = ', '.join(ccpts)
        concept_request.append((', '.join(names), ccpts))
    return concept_request

def format_concept_request(cluster_row_synsets):
    concept_request = []
    for row in cluster_row_synsets:
        row_names = []
        row_names = list(map(lambda x: x[0].split('.')[0], row))
        concept_request.append((', '.join(row_names), None))
    return concept_request


def clean_round_1(all_chatgpt_res):
    invalid_inds = []
    clean_all_chatgpt_res = [[] for _ in range(len(all_chatgpt_res))]
    for i in range(len(all_chatgpt_res)):
        for j, row in enumerate(all_chatgpt_res[i]):
            lines = row.split('\n')
            if len(lines)<10:
                invalid_inds.append((i,j))
            for l in lines[:10]:
                re_match_res = re.match('[0-9]{1,2}\..*', l)
                if re_match_res is None:
                    invalid_inds.append((i,j))
            clean_all_chatgpt_res[i].append(lines[:10])
    invalid_inds = list(set(invalid_inds))
    return clean_all_chatgpt_res, invalid_inds


In [12]:
# concept_request = format_concept_request_with_def(cluster_row_synsets)
concept_request = format_concept_request(cluster_row_synsets)
n_repeat = 3

#### round 1

In [13]:
template_round_1_with_def = lambda concepts, concepts_with_def: "Let's play a game. You are given three category names (" + concepts_with_def + "). GOAL: to visually discriminate " + concepts + ". Please ask ten questions to distinguish which category is presented in an imaginary image. Rule: you can only ask about their visual appearance, visual features, or visual characteristics. Please ask all questions at once and list each in a row sequentially."
template_round_1 = lambda concepts, concepts_with_def: "Let's play a game. You are given three category names (" + concepts + "). GOAL: to visually discriminate " + concepts + ". Please ask ten questions to distinguish which category is presented in an imaginary image. Rule: you can only ask about their visual appearance, visual features, or visual characteristics. Please ask all questions at once and list each in a row sequentially."

template_in_use = template_round_1
concept_templates = []
for row in concept_request:
    concept_templates.append(template_in_use(*row))

In [14]:
""" collect chatgpt res """
all_chatgpt_res = [[] for _ in range(n_repeat)]
with tqdm(total=len(concept_templates)*n_repeat) as pbar:
    for i in range(n_repeat):
        for row in concept_templates:
            while 1:
                try:
                    all_chatgpt_res[i].append(openai_chatgpt_post(row))
                    break
                except Exception as e:
                    print(e)

            pbar.update(1)

  3%|▎         | 20/720 [02:21<1:18:12,  6.70s/it]

Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 14 Jun 2023 07:57:50 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7d70fea1bbc6d64e-CDG', 'alt-svc': 'h3=":443"; ma=86400'}


 12%|█▏        | 88/720 [15:36<1:12:09,  6.85s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8cb693618010271902c2a6bebf7d3be0 in your message.)


 16%|█▋        | 118/720 [19:36<1:08:21,  6.81s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 75d8efcacab8ca6667ee088101462ffd in your message.)


 27%|██▋       | 193/720 [28:59<1:00:52,  6.93s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID afdd426a1138be4cd2f8fa3d705786a3 in your message.)


 33%|███▎      | 237/720 [34:32<59:28,  7.39s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5498717e44b536a0468bf906bd1f078b in your message.)


 57%|█████▋    | 407/720 [54:45<33:57,  6.51s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 47a6b2995cd182d4d430a23746f13277 in your message.)


 71%|███████   | 511/720 [1:07:20<23:49,  6.84s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e2ade8a28af5e242798c03d8981e991d in your message.)


 76%|███████▌  | 545/720 [1:11:49<22:53,  7.85s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 74dc6c1ba85b467b7670002c3d220dec in your message.)


100%|██████████| 720/720 [1:32:52<00:00,  7.74s/it]


In [15]:
save_results(all_chatgpt_res, fpath=f'{args.dataset}-round=1-no_def.pkl')

In [23]:
# all_chatgpt_res = load_results(f'{args.dataset}-round=1-no_def.pkl')

In [16]:
### repair_r1
while 1:
    all_chatgpt_res_clean, invalid_inds = clean_round_1(all_chatgpt_res)
    if len(invalid_inds)==0:
        break
    else:
        for item in invalid_inds:
            while 1:
                try:
                    all_chatgpt_res[item[0]][item[1]] = openai_chatgpt_post(concept_templates[item[1]])
                    break
                except Exception as e:
                    print(e)
                    
all_chatgpt_res = all_chatgpt_res_clean

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c45cb20ea6f9b7884dfba052a37504dc in your message.)
That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 85ddff2e09a9c6a6516c3b8cc0072afb in your message.)


In [17]:
save_results(all_chatgpt_res, fpath=f'{args.dataset}-round=1-no_def.pkl')

#### round 2

In [18]:
template_round_2_with_def = lambda concepts, concepts_with_def: "Let's play a game. GOAL: to visually discriminate " + concepts + ". You are given three category names with definitions (" + concepts_with_def + "). I will give you a number of questions. Please answer these questions concisely and accurately for each category. Imagine you are given an imagenery image. For each category name, please answer all questions at once and list each in a row sequentially. I will give you the questions now."
template_round_2 = lambda concepts, concepts_with_def: "Let's play a game. GOAL: to visually discriminate " + concepts + ". You are given three category names (" + concepts + "). I will give you a number of questions. Please answer these questions concisely and accurately for each category. Imagine you are given an imagenery image. For each category name, please answer all questions at once and list each in a row sequentially. I will give you the questions now."

template_in_use_r2 = template_round_2
concept_templates_r2 = [[] for _ in range(n_repeat)]
all_chatgpt_res_r2 = [[] for _ in range(n_repeat)]
with tqdm(total=n_repeat*len(concept_request)) as pbar:
    for i in range(n_repeat):
        for j, row in enumerate(concept_request):
            ### prepare template
            content = \
                [
                    {'role': 'user', 'content': template_in_use_r2(*row)},
                    {'role': 'system', 'content': "Sure, I'm ready to play the game. Please go ahead and provide me with the questions"},
                    {'role': 'user', 'content': '\n'.join(all_chatgpt_res[i][j]) + 'Please mention the category name before your listed answers.'}
                ]
            concept_templates_r2[i].append(content)
            ### make request
            while 1:
                try:
                    ### collect result
                    all_chatgpt_res_r2[i].append(openai_chatgpt_post_multirounds(content)["choices"][0].message.content)
                    break
                except Exception as e:
                    print(e)
            pbar.update(1)
            


 40%|███▉      | 285/720 [1:20:09<1:54:57, 15.86s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 041e32ee368f27742264c2ad41dc3567 in your message.)


 45%|████▌     | 326/720 [1:32:50<1:48:20, 16.50s/it]

Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 14 Jun 2023 11:06:38 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7d72132788333cfe-CDG', 'alt-svc': 'h3=":443"; ma=86400'}


 62%|██████▏   | 447/720 [2:13:26<1:11:41, 15.76s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a15c1129c58ba77c70f92e82a1255b71 in your message.)


 64%|██████▍   | 464/720 [2:18:19<1:01:35, 14.43s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d8e74cee6e5eff41acf117d639bdad19 in your message.)


 69%|██████▉   | 498/720 [2:28:57<1:09:07, 18.68s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5e078fed8db2871ff60f8ca41fd38249 in your message.)


 94%|█████████▍| 676/720 [3:19:25<12:03, 16.44s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5d1072283686c2e3f11e8faac7b4a1ea in your message.)


100%|██████████| 720/720 [3:31:24<00:00, 17.62s/it]


In [19]:
save_results(all_chatgpt_res_r2, fpath=f'{args.dataset}-round=2-no_def.pkl')

In [20]:
all_chatgpt_res = np.array(all_chatgpt_res).reshape(3, -1, 10).tolist()

for i in range(n_repeat):
    for j, row in enumerate(concept_request):
        all_chatgpt_res[i][j] = '\n'.join(all_chatgpt_res[i][j])

In [21]:
""" check: missing concepts """
while 1:
    ### check validity
    invalid_inds = []
    for i in range(n_repeat):
        for j, row in enumerate(concept_request):
            concepts = list(map(lambda x: x.strip('\''), concept_request[j][0].split(', ')))
            answers = all_chatgpt_res_r2[i][j].split('\n\n')[-len(concepts):]
            if len(answers) not in [len(concepts), len(concepts)+1]:
                invalid_inds.append((i,j))
            # for k in range(len(concepts)):
                # concepts[k] == answers[k][:len(concepts[k])]
    ### request
    if len(invalid_inds)==0:
        break
    with tqdm(total=len(invalid_inds)) as pbar:
        for ind in invalid_inds:
            row = concept_request[ind[1]]
            ### prepare template
            content = \
                [
                    {'role': 'user', 'content': template_in_use_r2(*row)},
                    {'role': 'system', 'content': "Sure, I'm ready to play the game. Please go ahead and provide me with the questions"},
                    {'role': 'user', 'content': all_chatgpt_res[ind[0]][ind[1]] + 'Please mention the category name before your listed answers.'}
                ]
            concept_templates_r2[ind[0]][ind[1]] = content
            ### make request
            while 1:
                try:
                    ### collect result
                    all_chatgpt_res_r2[ind[0]][ind[1]] = openai_chatgpt_post_multirounds(content)["choices"][0].message.content
                    break
                except Exception as e:
                    print(e)

  0%|          | 0/2 [00:23<?, ?it/s]


In [22]:
# template_round_2_with_def = lambda concepts, concepts_with_def: "Let's play a game. GOAL: to visually discriminate " + concepts + ". You are given three category names with definitions (" + concepts_with_def + "). I will give you a number of questions. Please answer all these questions concisely and accurately for each category based on your knowledge. Imagine you are given an imagenery image. For each category name, please answer all questions at once and list each in a row sequentially. I will give you the questions now."
""" check: missing answer """
i_iter = 0
while 1:
    all_qa_pairs = [[] for _ in range(n_repeat)] ### N x R x C x P
    invalid_inds = []
    for i in range(n_repeat):
        for j, row in enumerate(concept_request):
            concepts = list(map(lambda x: x.strip('\''), concept_request[j][0].split(', ')))
            answers = all_chatgpt_res_r2[i][j].split('\n\n')[-len(concepts):]
            answers = answers[-len(concepts): ]
            names = [item.strip("'") for item in row[0].split(', ')]
            names_def = [item.strip("'") for item in row[1].split(', ')] if row[1] is not None else [None]*len(row[0].split(', '))

            qa_pairs = []
            q = [' '.join(item.split(' ')[1:]) for item in all_chatgpt_res[i][j].split('\n')]
            for k in range(len(concepts)):
                extract_lines = lambda x: list(filter(lambda y: len(y), x.split('\n')))
                extract_ans = lambda x: ' '.join(x.split(' ')[1:])
                try:
                    a = [extract_ans(item) for item in extract_lines(answers[k])[1:]]
                    qa_pairs.append([names[k], names_def[k], q, a])
                    if len(q)!=len(a):
                        invalid_inds.append((i, j))
                except Exception as e:
                    print(e)
                    invalid_inds.append((i, j))
            all_qa_pairs[i].append(qa_pairs)
    invalid_inds = list(set(invalid_inds))
    if len(invalid_inds) == 0:
        break

    with tqdm(total=len(invalid_inds)) as pbar:
        for ind in invalid_inds:
            row = concept_request[ind[1]]
            ### prepare template
            content = \
                [
                    {'role': 'user', 'content': template_in_use_r2(*row)},
                    {'role': 'system', 'content': "Sure, I'm ready to play the game. Please go ahead and provide me with the questions"},
                    {'role': 'user', 'content': all_chatgpt_res[ind[0]][ind[1]] + 'Please mention the category name before your listed answers.'}
                ]
            ### update template
            concept_templates_r2[ind[0]][ind[1]] = content
            ### make request
            while 1:
                try:
                    ### update result
                    all_chatgpt_res_r2[ind[0]][ind[1]] = openai_chatgpt_post_multirounds(content)["choices"][0].message.content
                    break
                except Exception as e:
                    print(e)
            pbar.update(1)
            
    if i_iter>50:
        for ind in invalid_inds:
            key1 = random.choice(range(n_repeat))
            all_chatgpt_res_r2[ind[0]][ind[1]] = all_chatgpt_res_r2[key1][ind[1]]
            
    i_iter += 1

100%|██████████| 23/23 [07:23<00:00, 19.30s/it]
100%|██████████| 3/3 [00:52<00:00, 17.54s/it]
100%|██████████| 1/1 [00:15<00:00, 15.38s/it]


In [23]:
save_results(all_chatgpt_res_r2, fpath=f'{args.dataset}-round=2-no_def.pkl')

#### round 3

In [24]:
template_round_3_with_def = lambda concepts, concepts_with_def, qa, query: "GOAL: to visually discriminate " + concepts + ". Their definitions are given as (" + concepts_with_def + "). Please generate a consise descriptive image caption for " + query + " only based on the information of this Q&A " + qa + ". Please answer in template \"caption: {caption}\"."
template_round_3 = lambda concepts, concepts_with_def, qa, query: "GOAL: to visually discriminate " + concepts + ". Please generate a consise descriptive image caption for " + query + " only based on the information of this Q&A " + qa + ". Please answer in template \"caption: {caption}\"."
synthesize_qa = lambda q, a: [item_q + ' ' + item_a for item_q, item_a in zip(q, a)]

template_in_use_r3 = template_round_3
concept_templates_r3 = [[[] for _ in range(len(concept_request))] for _ in range(n_repeat)]
all_chatgpt_res_r3 = [[[] for _ in range(len(concept_request))] for _ in range(n_repeat)]
with tqdm(total=n_repeat*len(concept_request)*10*len(concepts)) as pbar:
    for i in range(n_repeat):
        for j, row in enumerate(concept_request):
            concepts = list(map(lambda x: x.strip("'"), concept_request[j][0].split(', ')))
            concept_templates_r3[i][j] = [ [] for _ in range(len(concepts)) ]
            all_chatgpt_res_r3[i][j] = [ [] for _ in range(len(concepts)) ]
            for k in range(len(concepts)):
                qas = synthesize_qa(*all_qa_pairs[i][j][k][-2:])
                for n in range(10):
                    ### prepare template
                    content = template_in_use_r3(row[0], row[1], qas[n], all_qa_pairs[i][j][k][0])
                    concept_templates_r3[i][j][k].append(content)
                    ### make request
                    while 1:
                        try:
                            ### collect result
                            all_chatgpt_res_r3[i][j][k].append(openai_chatgpt_post(content, verbose=False))
                            break
                        except Exception as e:
                            print(e)
                    pbar.update(1)
            


  0%|          | 17/21600 [00:23<8:43:59,  1.46s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 868bf3c76c38750d67b1f687489b2fb6 in your message.)


  0%|          | 26/21600 [01:03<10:07:19,  1.69s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ac82b64f2fe113a829abca9548519c98 in your message.)


  0%|          | 32/21600 [01:41<15:44:28,  2.63s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 6c09d2799c3b981144467b13cc9056f3 in your message.)


  0%|          | 46/21600 [02:31<7:41:41,  1.29s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 09327bdd4e619ffc47c9b072df66c598 in your message.)


  0%|          | 73/21600 [03:41<8:40:31,  1.45s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 855343ed1cf5a37874b1f9c3a2fa244c in your message.)


  1%|          | 188/21600 [07:06<6:29:09,  1.09s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 238a2dafd2808dcd1c0f8a1707d64527 in your message.)


  1%|          | 221/21600 [08:24<7:29:49,  1.26s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8f766056e412053457063daef34e2a1a in your message.)


  1%|          | 265/21600 [09:56<7:21:31,  1.24s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c9928094a1ce6af56ba779dcad30a291 in your message.)


  3%|▎         | 584/21600 [17:23<6:06:43,  1.05s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 76a89d0e20ff1d384f7561658830311c in your message.)


  3%|▎         | 627/21600 [18:52<9:01:45,  1.55s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3d79866305e8e8368587faca16ff67cf in your message.)


  3%|▎         | 739/21600 [21:35<7:21:06,  1.27s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f4c17d89a480ac3df4374d1049081f9d in your message.)


  4%|▎         | 802/21600 [23:28<7:49:49,  1.36s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 7e13a40d914c0fefbd78ef2aa0ad4faa in your message.)


  4%|▍         | 822/21600 [24:28<7:51:25,  1.36s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1495fef2fb0b455820583655ea47543e in your message.)


  4%|▍         | 870/21600 [26:02<7:06:00,  1.23s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID da7c77ee982da875b7492838cfa53458 in your message.)


  5%|▍         | 1026/21600 [29:50<6:05:18,  1.07s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 236dc18e1eec18ae639e4ba6efb8214a in your message.)


  5%|▌         | 1106/21600 [32:06<7:35:14,  1.33s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5f1cc42034b476692f41fee6cba43cb9 in your message.)


  6%|▌         | 1209/21600 [34:47<7:04:59,  1.25s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 005365435e0f7260564247079b79a5b5 in your message.)


  6%|▌         | 1313/21600 [37:13<6:10:25,  1.10s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3f1ddb8aca89bb30f98a6b699e60860b in your message.)


  7%|▋         | 1527/21600 [42:34<9:40:36,  1.74s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ff3ff9d606c64bf262677ef45e34d54e in your message.)


  7%|▋         | 1544/21600 [43:24<6:39:32,  1.20s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 243b4162b818e79105b980bc8e8b039c in your message.)


  8%|▊         | 1703/21600 [47:03<6:45:25,  1.22s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b6ef1e97a77f318b2413988a34575b4f in your message.)


  9%|▉         | 2019/21600 [54:47<5:42:06,  1.05s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ca2971e94f8cafa573fb877d991aba93 in your message.)


  9%|▉         | 2032/21600 [55:32<6:38:36,  1.22s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f20e48a1e774de993147b68c79f0f187 in your message.)


 10%|▉         | 2095/21600 [57:23<6:38:01,  1.22s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID fc0afdc30a8c0d7bcee71514700ba1d6 in your message.)


 10%|▉         | 2130/21600 [58:38<8:27:14,  1.56s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 300889de2cf7f775095388348272eb77 in your message.)
That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 38f6b797f7169e9c9b6f32e1c53ffc97 in your message.)


 11%|█▏        | 2482/21600 [1:06:37<5:51:30,  1.10s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ffa44b76e3f752c7ea3909924482114d in your message.)


 12%|█▏        | 2500/21600 [1:07:27<5:44:25,  1.08s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d70f43335b8c498fe3c4c443dba60176 in your message.)


 12%|█▏        | 2577/21600 [1:09:27<5:57:30,  1.13s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ff9ccec4dd6c1233259898f71cb5591c in your message.)


 12%|█▏        | 2614/21600 [1:10:44<6:31:49,  1.24s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 7888ebdb958161f0f635641477e25d62 in your message.)


 13%|█▎        | 2715/21600 [1:13:25<6:53:37,  1.31s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f2892fcc3fcb90ac08626eb256001a42 in your message.)


 13%|█▎        | 2823/21600 [1:16:02<6:24:26,  1.23s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID cc3670fd718b33d4d3e1be558faa6029 in your message.)


 14%|█▍        | 3059/21600 [1:21:23<5:45:57,  1.12s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8d1f239d09d5bd35ac866b8551dd6281 in your message.)


 14%|█▍        | 3109/21600 [1:22:53<6:11:33,  1.21s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 133d758e5c86fee33471b99a82c73400 in your message.)


 16%|█▌        | 3460/21600 [1:30:25<5:35:13,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 926e294525cb272cb93c0f57c84144cc in your message.)


 16%|█▋        | 3548/21600 [1:32:41<5:50:38,  1.17s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID febfcf5eefb947eca4cd994e42442ad4 in your message.)


 17%|█▋        | 3655/21600 [1:35:25<5:43:14,  1.15s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 23b31e39e80ce8347038b9c502950f21 in your message.)


 17%|█▋        | 3761/21600 [1:38:09<6:26:58,  1.30s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID df8b25b66142b1e2f74b3ff68d52c8a4 in your message.)


 18%|█▊        | 3835/21600 [1:40:12<4:57:25,  1.00s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a7c91d67bd54ad923babf771640769d8 in your message.)


 18%|█▊        | 3857/21600 [1:41:08<5:22:41,  1.09s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 6939828acebf5d2e7e48e182d2a3e73b in your message.)


 18%|█▊        | 3951/21600 [1:43:49<6:49:39,  1.39s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d1451c3e70af2b46dd0cfce2cc1321d2 in your message.)


 19%|█▉        | 4167/21600 [1:48:43<6:24:08,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 730c624803c2413f0a6b35fb62e27a25 in your message.)


 19%|█▉        | 4210/21600 [1:50:15<6:08:05,  1.27s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 74ef73adfe58198c278f6f57f351dc6e in your message.)


 20%|██        | 4322/21600 [1:53:10<4:58:55,  1.04s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ac90f1fbcd651a93ce24b48b7e64b6f8 in your message.)


 20%|██        | 4328/21600 [1:53:48<12:51:11,  2.68s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1cbd188c8ef1a4e2d3329e3a210ad8b8 in your message.)


 20%|██        | 4378/21600 [1:55:15<5:32:38,  1.16s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d4f0bb68b2a956b99bf805fd76207737 in your message.)


 21%|██        | 4432/21600 [1:56:52<4:47:37,  1.01s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 0d77694791d66832e8d5fba4a4d37ffc in your message.)


 21%|██        | 4449/21600 [1:57:41<5:49:04,  1.22s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 66bf25f2bb515f8f2026dd91f6e510fe in your message.)


 21%|██        | 4458/21600 [1:58:21<7:58:25,  1.67s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 41632355b8d56f16ce1d97cdf8e53d25 in your message.)


 21%|██        | 4579/21600 [2:01:10<4:20:44,  1.09it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d439fe8016f8ce753707a1a3f0823aa5 in your message.)


 21%|██▏       | 4637/21600 [2:02:47<4:27:17,  1.06it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 070c4547db4a40660d90fddaa39736ef in your message.)


 23%|██▎       | 4910/21600 [2:08:50<6:25:55,  1.39s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 66b0138c6226d5a0b5de216f1385bbe0 in your message.)


 23%|██▎       | 5008/21600 [2:11:19<4:49:26,  1.05s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 6a3701e2a4bd0b8f8afc680a26666ff0 in your message.)


 24%|██▎       | 5104/21600 [2:13:49<5:33:54,  1.21s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 88864bcd69bd2d772b6cfad632968f4b in your message.)


 24%|██▍       | 5177/21600 [2:15:43<4:39:59,  1.02s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2856f63e88f43e041b0bcb0787085726 in your message.)


 26%|██▌       | 5521/21600 [2:23:19<6:03:15,  1.36s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c256ea3a8ecd3300cc6d67dcdfa28ec7 in your message.)


 26%|██▌       | 5623/21600 [2:25:52<5:07:59,  1.16s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID cac5f94bb5e5227952e169b67a95192b in your message.)


 27%|██▋       | 5803/21600 [2:31:51<7:58:29,  1.82s/it]  

Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 14 Jun 2023 15:45:54 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7d73ac62696fd2a3-CDG', 'alt-svc': 'h3=":443"; ma=86400'}


 27%|██▋       | 5844/21600 [2:37:44<5:02:14,  1.15s/it]  

Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 14 Jun 2023 15:51:48 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7d73b504ac82d2a3-CDG', 'alt-svc': 'h3=":443"; ma=86400'}


 27%|██▋       | 5867/21600 [2:43:23<4:46:51,  1.09s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ed51eaab7783442254779fb20e393701 in your message.)


 27%|██▋       | 5921/21600 [2:44:52<5:09:44,  1.19s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID efad41c7c3fd8d92a914282ccfa075d4 in your message.)


 28%|██▊       | 5949/21600 [2:45:52<4:30:11,  1.04s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9ed3170f70cdc7bd9524bfdbc6df7eea in your message.)


 28%|██▊       | 5995/21600 [2:47:17<5:24:47,  1.25s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3a56222b4e8d61026848304c5c073dcb in your message.)


 28%|██▊       | 6004/21600 [2:47:58<7:17:58,  1.68s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8d465496378e2950699acff8d5061192 in your message.)


 28%|██▊       | 6007/21600 [2:48:32<25:08:35,  5.80s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f90de5135bb2cd49973f184cbd54041e in your message.)


 28%|██▊       | 6129/21600 [2:51:26<5:00:50,  1.17s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a7b1622ba66c1ee81b504d772aee7fef in your message.)


 28%|██▊       | 6144/21600 [2:52:14<5:55:12,  1.38s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ac3ccc6bf79ee7fffb00c948c4319e0c in your message.)


 29%|██▉       | 6257/21600 [2:54:59<4:40:28,  1.10s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 65c1c4db9cfac93aa228cdc0684e7d5b in your message.)


 30%|██▉       | 6450/21600 [2:59:07<6:07:38,  1.46s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3147110ba9ea00a6160ca82cbd812353 in your message.)


 30%|███       | 6549/21600 [3:01:35<4:51:38,  1.16s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5063c06fb6aa66666cea515f00cb467a in your message.)


 31%|███       | 6630/21600 [3:03:31<4:10:15,  1.00s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f7c76dfe82dadb79f78c3e425c988298 in your message.)


 31%|███       | 6728/21600 [3:05:55<5:17:17,  1.28s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID fdbb71d09b5f4589c526bc8a47d57aca in your message.)


 31%|███▏      | 6753/21600 [3:06:58<5:36:35,  1.36s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e9c5b4c082e7958112716915fcc6eaac in your message.)


 31%|███▏      | 6771/21600 [3:07:58<6:57:05,  1.69s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ce92955bb7e43720fa451011116a639e in your message.)


 32%|███▏      | 6872/21600 [3:10:25<4:47:09,  1.17s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a13861e3745ef44b91166629c833849e in your message.)


 33%|███▎      | 7030/21600 [3:13:54<4:40:45,  1.16s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f7d4330bb501128f7aae362505a21ee3 in your message.)


 34%|███▎      | 7237/21600 [3:18:11<4:20:21,  1.09s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 387deb11c249cbbec83f153d97dd4e62 in your message.)


 34%|███▎      | 7286/21600 [3:19:42<4:34:42,  1.15s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 677a35e4cc667bff1e53a3a95c5b1409 in your message.)


 34%|███▍      | 7324/21600 [3:20:53<4:37:27,  1.17s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 23d12470c09dc5e4e1e04bb1a2e3d2ab in your message.)


 35%|███▌      | 7569/21600 [3:26:19<4:18:26,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b73efc8654eb6eb0d8b7518465fac954 in your message.)


 36%|███▌      | 7705/21600 [3:29:23<4:01:40,  1.04s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2b9a1685c7a2df97f784b63cdb4f54da in your message.)


 36%|███▌      | 7719/21600 [3:30:10<4:51:20,  1.26s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2ffe0d8b04c167c46f18fe5e7d11649f in your message.)


 36%|███▌      | 7769/21600 [3:31:52<5:44:14,  1.49s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 0ef57f69e69f71c52e509062538b32c3 in your message.)


 36%|███▋      | 7869/21600 [3:34:29<4:06:25,  1.08s/it] 

Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 14 Jun 2023 16:48:32 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7d7408230e0099c8-CDG', 'alt-svc': 'h3=":443"; ma=86400'}


 37%|███▋      | 7886/21600 [3:40:00<6:04:32,  1.59s/it]  

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a6b510cff1c6b865250c77c256b09eaf in your message.)


 37%|███▋      | 7969/21600 [3:42:09<5:11:07,  1.37s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 168af7babd6264e4a41a941d044749b6 in your message.)


 38%|███▊      | 8105/21600 [3:45:15<4:43:51,  1.26s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 51799ddf2073327d58bf1f5caa435278 in your message.)


 38%|███▊      | 8224/21600 [3:47:53<4:21:19,  1.17s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2f772fd5ea553a8014bf70d0cb940273 in your message.)
That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ad10d6aa507d20fef9f57ccf202629df in your message.)


 38%|███▊      | 8247/21600 [3:49:18<4:04:12,  1.10s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 00ff3c517b60d74b67a74441a5e76d0b in your message.)


 38%|███▊      | 8250/21600 [3:49:52<20:39:06,  5.57s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a67ea1a4a2ada0313677b087432c688d in your message.)


 38%|███▊      | 8261/21600 [3:50:36<5:44:53,  1.55s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 33cbfcd96a4c84f45640577ac9599693 in your message.)


 39%|███▉      | 8440/21600 [3:54:26<4:09:04,  1.14s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3fbcb80ede54f784b79d8f31e9530eea in your message.)


 39%|███▉      | 8444/21600 [3:55:01<15:43:24,  4.30s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e3107181ae268ed693f6bcbca14027dc in your message.)


 39%|███▉      | 8470/21600 [3:55:58<3:42:57,  1.02s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b58ffa72a88907ccb2a7d2c3004e572e in your message.)


 41%|████      | 8882/21600 [4:04:15<3:22:52,  1.04it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 85973e9c3aa597a34e9dd4d40aeebfac in your message.)


 41%|████      | 8895/21600 [4:04:59<4:22:27,  1.24s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a3e8a3042bb6c8d11b0d51bfd4c9f734 in your message.)


 41%|████▏     | 8948/21600 [4:06:29<4:12:24,  1.20s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 957728b93116fd8851acfd829c92d0a7 in your message.)


 42%|████▏     | 9101/21600 [4:09:52<4:09:53,  1.20s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1c61c16bdda0a77930337a4e0bb66721 in your message.)


 42%|████▏     | 9103/21600 [4:10:28<28:39:19,  8.25s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3c7e12027e86dd2fb5b0b4207e8f613c in your message.)


 43%|████▎     | 9267/21600 [4:14:09<4:31:18,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID ab3207f4d081afd16dccf7c1f87aa38e in your message.)


 44%|████▍     | 9460/21600 [4:18:09<3:26:52,  1.02s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f671665cdcfc415c825359398e491fec in your message.)


 45%|████▍     | 9636/21600 [4:21:53<3:11:00,  1.04it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d64e2e8a3f05aeb58589bcc7fea4fe7b in your message.)


 45%|████▍     | 9689/21600 [4:23:20<3:38:31,  1.10s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 406135a4cf608c2d2c976367d6124365 in your message.)


 45%|████▌     | 9823/21600 [4:26:17<4:09:20,  1.27s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID fe315f250a64fb52c62bdf18c7a58a67 in your message.)


 48%|████▊     | 10261/21600 [4:35:28<3:40:06,  1.16s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b8dac2e6e792a4e78e69a2a4b77220ec in your message.)


 48%|████▊     | 10359/21600 [4:37:52<3:03:35,  1.02it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 4d2b9c8bba549b50e56dd2a350f9a037 in your message.)


 49%|████▊     | 10484/21600 [4:40:42<3:26:12,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c0a1aa33d396c7819feab4fbdd7ae8d7 in your message.)


 49%|████▉     | 10599/21600 [4:43:25<3:07:42,  1.02s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 08d0405a7f64f79b6c2764e598e049d1 in your message.)


 49%|████▉     | 10629/21600 [4:44:31<3:12:11,  1.05s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3a6d9e99a119903fce4fd98dce89836e in your message.)


 50%|████▉     | 10706/21600 [4:46:23<2:59:05,  1.01it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 7f195eb93290d32c45fceda0fa310b61 in your message.)


 50%|████▉     | 10757/21600 [4:47:51<3:23:05,  1.12s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a511454d593191885a7c444b6f6cbbc4 in your message.)


 51%|█████     | 10916/21600 [4:51:32<3:30:56,  1.18s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1188b3bd00ac8455e9b3fdfc76020364 in your message.)


 52%|█████▏    | 11237/21600 [4:58:20<3:06:14,  1.08s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e94dfcf68eef6bbae2ab73eeaa479ad8 in your message.)


 52%|█████▏    | 11242/21600 [4:58:56<9:32:40,  3.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 235ed7333506ca84ce950ca10e2819ae in your message.)


 52%|█████▏    | 11254/21600 [4:59:39<3:30:05,  1.22s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2b6c99bba283de67c73fa8727b7dc059 in your message.)


 52%|█████▏    | 11295/21600 [5:01:00<3:46:57,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8218bd97822e17877821b76281dc32d5 in your message.)


 53%|█████▎    | 11453/21600 [5:04:34<3:08:29,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a479984dbf208e1c706e53ea1c2305f8 in your message.)


 53%|█████▎    | 11539/21600 [5:06:37<2:34:51,  1.08it/s] 

Request failed due to server shutdown {
  "error": {
    "message": "Request failed due to server shutdown",
    "type": "server_error",
    "param": null,
    "code": null
  }
}
 500 {'error': {'message': 'Request failed due to server shutdown', 'type': 'server_error', 'param': None, 'code': None}} {'Date': 'Wed, 14 Jun 2023 18:15:38 GMT', 'Content-Type': 'application/json', 'Content-Length': '141', 'Connection': 'keep-alive', 'access-control-allow-origin': '*', 'openai-model': 'gpt-3.5-turbo-0301', 'openai-organization': 'mbzuai-2', 'openai-processing-ms': '9171', 'openai-version': '2020-10-01', 'strict-transport-security': 'max-age=15724800; includeSubDomains', 'x-ratelimit-limit-requests': '3500', 'x-ratelimit-limit-tokens': '90000', 'x-ratelimit-remaining-requests': '3499', 'x-ratelimit-remaining-tokens': '89901', 'x-ratelimit-reset-requests': '17ms', 'x-ratelimit-reset-tokens': '66ms', 'x-request-id': 'ea94045b78c9d3fb62b4eee51d51f633', 'CF-Cache-Status': 'DYNAMIC', 'Server': 'cl

 56%|█████▌    | 12048/21600 [5:16:26<2:57:03,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID bdc86e48cf40f08df03bebdfd4b4a001 in your message.)
That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 11349523cf0560a51df7338b80da4674 in your message.)


 56%|█████▌    | 12099/21600 [5:18:24<2:49:49,  1.07s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 4d526da3bca6e0073e015333fdcc9503 in your message.)


 56%|█████▌    | 12103/21600 [5:18:59<11:20:32,  4.30s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5cf531c74b4d7c8109d5af7c95d10311 in your message.)


 57%|█████▋    | 12307/21600 [5:23:15<2:57:45,  1.15s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5caaa3cf4cb16c835f9fd05f968acaf8 in your message.)


 57%|█████▋    | 12314/21600 [5:23:55<6:03:37,  2.35s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 10cd2b0be501107dff9c5dd72bcd336f in your message.)


 57%|█████▋    | 12333/21600 [5:24:53<3:23:34,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 28a164a9b195884bba8f328034d90de5 in your message.)


 58%|█████▊    | 12495/21600 [5:28:26<2:26:17,  1.04it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 847c5888d9fde4fe2069674211b587a6 in your message.)


 59%|█████▉    | 12722/21600 [5:33:11<2:45:40,  1.12s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 5283a0e6f7dc39f4ef7567136f52ba93 in your message.)


 60%|█████▉    | 12900/21600 [5:36:59<2:30:22,  1.04s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 57b1cc6b008bea6a361232c84cac9f8e in your message.)


 61%|██████    | 13099/21600 [5:41:10<2:16:20,  1.04it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 75ea274daaabe81c9515470e1ddea3b9 in your message.)


 62%|██████▏   | 13373/21600 [5:46:31<2:01:26,  1.13it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b91e7122197955a0b6bfd4743e594643 in your message.)


 63%|██████▎   | 13644/21600 [5:51:49<2:54:38,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c024b99c9e098a14a78ec91c10cee1c0 in your message.)


 63%|██████▎   | 13654/21600 [5:52:29<2:53:08,  1.31s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a9f04827aa14bbae7151d9e082472874 in your message.)


 64%|██████▍   | 13821/21600 [5:56:09<2:27:14,  1.14s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 66880878ce53daca368b58a750dc4ba3 in your message.)


 64%|██████▍   | 13833/21600 [5:56:54<2:23:27,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1b7c358daa162e831f9d7496b7aeab04 in your message.)


 66%|██████▋   | 14328/21600 [6:06:49<2:26:58,  1.21s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b8b555dc7893f28f1ad17927224cc522 in your message.)


 66%|██████▋   | 14363/21600 [6:07:59<2:27:49,  1.23s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 6b1c43ebcb5ea79e70c00e40df0834cf in your message.)


 68%|██████▊   | 14582/21600 [6:12:25<1:58:24,  1.01s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9f604ee376b16ba50f870acd2fc00712 in your message.)


 68%|██████▊   | 14588/21600 [6:13:00<4:35:42,  2.36s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID d75859b073b12ca522200f031082e80f in your message.)


 68%|██████▊   | 14601/21600 [6:13:44<2:51:34,  1.47s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f8010109ea89b4b5d333f7473159278d in your message.)


 68%|██████▊   | 14712/21600 [6:16:22<2:20:52,  1.23s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 762b0902c077b1c957578fc13c97c55e in your message.)


 69%|██████▊   | 14799/21600 [6:18:29<1:53:38,  1.00s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a891fc59df10b8c28cfec495cf5e5f02 in your message.)


 69%|██████▉   | 14968/21600 [6:22:12<2:11:42,  1.19s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID b623734f3ab8b02154c7a1c377afabd8 in your message.)


 70%|██████▉   | 15023/21600 [6:23:45<2:24:53,  1.32s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3a093512bf770d6ee780c7020ef85a76 in your message.)


 70%|███████   | 15126/21600 [6:26:08<2:00:04,  1.11s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a5d5ef4db35084ebfd0c5d7545acd328 in your message.)


 71%|███████   | 15256/21600 [6:29:11<1:51:50,  1.06s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9597fc96e35d4594273b5d323e8a1b28 in your message.)


 71%|███████   | 15276/21600 [6:30:04<1:52:46,  1.07s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e742303b6c34327da8ef9b6a6d046939 in your message.)


 72%|███████▏  | 15596/21600 [6:36:27<1:38:16,  1.02it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 1ceb9e4bc9dea9d67f01c1ce0a30190c in your message.)


 72%|███████▏  | 15631/21600 [6:37:35<1:54:19,  1.15s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 66f6dda6899c99db41fee29cf215f689 in your message.)


 73%|███████▎  | 15800/21600 [6:41:16<1:34:58,  1.02it/s] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 39e7f9d5e6c49aaddc142f59872670f2 in your message.)


 73%|███████▎  | 15840/21600 [6:42:30<1:52:57,  1.18s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 3b7f47e6024762345f30d6e95fa51b90 in your message.)


 74%|███████▍  | 16019/21600 [6:46:26<1:52:24,  1.21s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9b0bd1c3b15be3b95fe9145aff1857ef in your message.)


 75%|███████▌  | 16253/21600 [6:51:27<1:53:09,  1.27s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID e0b28628f9236f7799d36a812ee00a5a in your message.)


 76%|███████▌  | 16441/21600 [6:55:30<1:37:01,  1.13s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID aec2cefc30ca1de9c9f446a03480eea2 in your message.)


 77%|███████▋  | 16551/21600 [6:58:02<1:37:40,  1.16s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 31a1ea3cb1d797005a0d45f3038d5e25 in your message.)


 77%|███████▋  | 16618/21600 [6:59:41<1:24:22,  1.02s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID a03908e9c2d72b373d5f010ba72d9761 in your message.)


 77%|███████▋  | 16692/21600 [7:01:40<2:04:28,  1.52s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8bdec41e8cd9dc11007524175e28f581 in your message.)


 78%|███████▊  | 16895/21600 [7:05:50<1:23:25,  1.06s/it] 

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 4fbd30a604c9b6907f78dc86d2182abd in your message.)


 78%|███████▊  | 16902/21600 [7:06:28<2:44:06,  2.10s/it] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [25]:
save_results(all_chatgpt_res_r3, fpath=f'{args.dataset}-round=3-no_def.pkl')

In [26]:
### validate round 3
all_qa_captions = deepcopy(all_chatgpt_res_r3)
# all_qa_captions = np.empty([len(all_chatgpt_res_r3[0]), len(all_chatgpt_res_r3[0][0]), len(all_chatgpt_res_r3), len(all_chatgpt_res_r3[0][0][0][0])]).tolist()
while 1:
    invalid_inds = []
    for i in range(n_repeat):
        for j in range(len(concept_request)):
            concepts = list(map(lambda x: x.strip("'"), concept_request[j][0].split(', ')))
            for k in range(len(concepts)):
                for i_c, cap in enumerate(all_chatgpt_res_r3[i][j][k]):
                    if (not cap[:len('caption:')].lower() == 'caption:') or ('sorry' in cap):
                        invalid_inds.append((i,j,k,i_c))
                    else:
                        extract_caption = lambda x: x.lower().split('caption: ')[-1].strip('{}\"')
                        all_qa_captions[i][j][k][i_c] = extract_caption(all_chatgpt_res_r3[i][j][k][i_c])
                        # all_qa_captions[j][k][i][i_c] = extract_caption(all_chatgpt_res_r3[i][j][k][i_c])
    if len(invalid_inds)==0:
        break
        
    with tqdm(total=len(invalid_inds)) as pbar:
        for row in invalid_inds:
            i, j, k, i_cap = row
            qas = synthesize_qa(*all_qa_pairs[i][j][k][-2:])
            content = template_in_use_r3(concept_request[j][0], concept_request[j][1], qas[i_cap], all_qa_pairs[i][j][k][0])
            concept_templates_r3[i][j][k][i_cap] = content
            ### make request
            while 1:
                try:
                    ### collect result
                    all_chatgpt_res_r3[i][j][k][i_cap] = openai_chatgpt_post(content, verbose=False)
                    break
                except Exception as e:
                    print(e)
            pbar.update(1)

100%|██████████| 3/3 [00:03<00:00,  1.31s/it]
100%|██████████| 1/1 [00:00<00:00,  1.24it/s]
100%|██████████| 1/1 [00:02<00:00,  2.08s/it]
100%|██████████| 1/1 [00:01<00:00,  1.99s/it]
100%|██████████| 1/1 [00:02<00:00,  2.09s/it]
100%|██████████| 1/1 [00:01<00:00,  1.74s/it]
100%|██████████| 1/1 [00:02<00:00,  2.19s/it]
100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
100%|██████████| 1/1 [00:01<00:00,  1.51s/it]
100%|██████████| 1/1 [00:01<00:00,  1.51s/it]
100%|██████████| 1/1 [00:02<00:00,  2.74s/it]
100%|██████████| 1/1 [00:01<00:00,  1.45s/it]
100%|██████████| 1/1 [00:01<00:00,  1.15s/it]


In [28]:
save_results(all_chatgpt_res_r3, fpath=f'{args.dataset}-round=3-no_def.pkl')

In [None]:
all_qa_captions = np.transpose(np.array(all_qa_captions), (1,2,0,3))
all_qa_captions = all_qa_captions.reshape(all_qa_captions.shape[0], all_qa_captions.shape[1], np.prod(all_qa_captions.shape[-2:])).tolist() ### row x [concept x repeat x caption]

In [None]:
@torch.no_grad()
def build_classifier_qa_captions(all_qa_captions, model, all_row_key_name=None):
    row_classifier = []
    with tqdm(total=len(all_qa_captions)) as pbar:
        for idx, row in enumerate(all_qa_captions):
            shape_row = np.array(row).shape ### 3 x 30
            row = np.array(row).ravel().tolist()
            if all_row_key_name is not None:
                pass
            row_t = tokenize(row).to(args.device)
            features = model.encode_text(row_t)
            features = features/features.norm(dim=-1, keepdim=True)
            row_classifier.append(features.cpu())
            
            pbar.update(1)
    return row_classifier
    

In [None]:
qa_cap_classifiers = build_classifier_qa_captions(all_qa_captions, model)

In [None]:
candidate_inds_qa_cap = [torch.arange(90).int().div(30, rounding_mode='floor') for _ in range(len(qa_cap_classifiers))]

#### naive ensembling

In [91]:
vfeatures = all_vfeatures
k_2 = 1
instance_pred_voc = torch.zeros_like(record_pred_kmeans_t)
all_clu_pred_qa_cap = torch.zeros_like(all_clu_pred[:, 0])
topk_all_clu_pred = all_clu_pred.topk(k=k_1).indices
for c in range(len(qa_cap_classifiers)):
    ### selection 
    select = (record_pred_kmeans_t==c)
    row_classifier = qa_cap_classifiers[c]
    ### prediction 
    sim = torch.from_numpy(vfeatures[select, ...]).to(args.device)@row_classifier.to(args.device).t()
    sim_topk = sim.topk(k=k_2)
    ind, val = sim_topk.indices.flatten().cpu().unique(return_counts=True)
    ### counting
    count_names = torch.zeros(row_classifier.size(0)).long()
    count_names[ind] = val ### count of each ind
    count_smask = []
    smask = np.array(candidate_inds_qa_cap[c]) ### partition mask
    for s in np.unique(smask):
        count_smask.append(count_names[smask==s].sum().item())
    prediction = torch.tensor(count_smask).argmax(dim=-1)
    instance_pred_voc[select] = topk_all_clu_pred[c, prediction]
    all_clu_pred_qa_cap[c] = topk_all_clu_pred[c, prediction]

In [92]:
(instance_pred_voc == all_gt_voc).float().mean()

tensor(0.4842)

#### logical ensembling

In [93]:
# row x [(concept x repeat x caption)] -> row x [caption x (concept x repeat)]
dim_concept, dim_repeat, dim_caption, dim_feature = 3, 3, 10, 512
qa_cap_classifiers = [x.view(dim_concept, dim_repeat, dim_caption, dim_feature).permute(2,0,1,3).view(dim_caption, -1, dim_feature) for x in qa_cap_classifiers]

In [94]:
candidate_inds_qa_cap = [torch.arange(9).int().div(3, rounding_mode='floor') for _ in range(len(qa_cap_classifiers))] ### row x 

In [95]:
vfeatures = all_vfeatures
k_2 = 2
N = record_pred_kmeans_t.shape[0]
R = all_clu_pred.shape[0]
instance_pred_voc = torch.zeros(dim_caption, N)
all_clu_pred_qa_cap = torch.zeros(dim_caption, R)
topk_all_clu_pred = all_clu_pred.topk(k=k_1).indices
for i_cap in range(dim_caption):
    for c in range(len(qa_cap_classifiers)):
        ### selection 
        select = (record_pred_kmeans_t==c)
        row_classifier = qa_cap_classifiers[c][i_cap]
        ### prediction 
        sim = torch.from_numpy(vfeatures[select, ...]).to(args.device)@row_classifier.to(args.device).t()
        sim_topk = sim.topk(k=k_2)
        ind, val = sim_topk.indices.flatten().cpu().unique(return_counts=True)
        ### counting
        count_names = torch.zeros(row_classifier.size(0)).long()
        count_names[ind] = val ### count of each ind
        count_smask = []
        smask = np.array(candidate_inds_qa_cap[c]) ### partition mask
        for s in np.unique(smask):
            count_smask.append(count_names[smask==s].sum().item())
        prediction = torch.tensor(count_smask).argmax(dim=-1)
        instance_pred_voc[i_cap, select] = topk_all_clu_pred[c, prediction]
        all_clu_pred_qa_cap[i_cap, c] = topk_all_clu_pred[c, prediction]

In [96]:
instance_pred_voc = instance_pred_voc.mode(dim=0).values.int()

In [97]:
all_clu_pred_qa_cap = all_clu_pred_qa_cap.mode(dim=0).values.int()

In [98]:
(instance_pred_voc == all_gt_voc).float().mean()

tensor(0.4870)

torch.Size([10, 100])

### CHATGPT request

In [10]:
import openai
def openai_chatgpt_post(content, parameters={'temperature': 0.7}):
    openai.api_key = "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"
    completion = openai.ChatCompletion.create(
      model="gpt-3.5-turbo",
      messages=[
        {"role": "user", "content": content},
      ],
    **parameters,
    )
    result = completion['choices'][0]['message']['content']
    # completion = openai.Completion.create(
    #     model="text-davinci-003",
    #     prompt=content,  
    #     temperature=0.7,
    #     max_tokens=256,
    #     top_p=1,
    #     frequency_penalty=0,
    #     presence_penalty=0,
    # )
    # result = completion['choices'][0]['text']
    return result

In [11]:
all_clu_gt_voc = []
for c in record_pred_kmeans_t.unique():
    select = (record_pred_kmeans_t==c)
    all_clu_gt_voc.append(all_gt_voc[select].mode().values)

all_clu_gt_voc = torch.tensor(all_clu_gt_voc)
k_1 = 3
topk_all_clu_pred = all_clu_pred.topk(k=k_1).indices
cluster_is_correct = torch.zeros(topk_all_clu_pred.size(0)).bool()
for i in range(k_1):
    cluster_is_correct |= (topk_all_clu_pred[:, i]==all_clu_gt_voc)

print(f'recall@{k_1} = {cluster_is_correct.float().mean()}')

recall@3 = 0.7041666507720947


In [12]:
""" gather concepts """
to_name = lambda x: [ s.name() + ': ' + s.definition() for s in x ]
cluster_row_synsets = []
for row in topk_all_clu_pred:
    row_synsets = [to_name(mapping_vocidx_to_synsets(voc_idx.item(), vocab)) for voc_idx in row]
    cluster_row_synsets.append(row_synsets)


In [14]:
""" generate concept requests """
concept_request = []
for row in cluster_row_synsets:
    ccpts = reduce(lambda x, y: x+y, row)
    ccpts = list(map(lambda x: "'"+x+".'", ccpts))
    ccpts = ', '.join(ccpts)
    concept_request.append(ccpts)
    
""" generate concept templates """
template_1 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all alternative concept names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
with open('/home/sheng/OSZSL/templates_chatgpt.json', 'r') as f:
    template_chatgpt = json.load(f)
template_2 = lambda concept_list: template_chatgpt['pictionary-long'].format(concept_list)
template_3 = lambda concept_list: template_chatgpt['pictionary-short'].format(concept_list)
template_4 = lambda concept_list: template_chatgpt['direct'].format(concept_list)
template_5 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all synonym concept names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_6 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all category names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_7 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all parent-type category names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_8 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible descriptive phrases of image captions for each visual concept. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
template_9 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible visiual descriptive phrases for each visual concept without duplication. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
# template_10 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible visiual descriptive phrases for each visual concept without duplication. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
template_13 = lambda concept_list: "Given visual concepts: "+ concept_list + "Please list all possible adjective phrases of visual descriptions for each visual concept without duplication. Please list in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
template_9_1 = lambda concept_list: "Given visual concepts: "+ concept_list + "Please list all possible visual descriptive phrases for each visual concept without duplication. Please list in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."

    
template_in_use = template_9_1
concept_templates = []
for row in concept_request:
    concept_templates.append(template_in_use(row))
    
n_repeat = 3

In [None]:
""" collect chatgpt res """
all_chatgpt_res = [[] for _ in range(n_repeat)]
with tqdm(total=len(concept_templates)*n_repeat) as pbar:
    for i in range(n_repeat):
        for row in concept_templates:
            while 1:
                try:
                    all_chatgpt_res[i].append(openai_chatgpt_post(row))
                    break
                except Exception as e:
                    print(e)

            pbar.update(1)

  4%|▍         | 32/720 [02:50<1:04:38,  5.64s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID dcf2dc940fa4dacc4cd59d592b5c73be in your message.)


  8%|▊         | 60/720 [06:08<1:13:10,  6.65s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 82a510114f9717a1fe0dd2a0bcb8ca72 in your message.)


 12%|█▏        | 83/720 [08:46<1:05:42,  6.19s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8a12edd13d868cc6781e9e0053d362e4 in your message.)


 16%|█▌        | 112/720 [11:52<1:05:15,  6.44s/it]

In [30]:
with open(f'./cache/openai/topk=1-visual-inov-template=9_1-k_1={k_1}-repeat={n_repeat}-data={args.dataset}.pkl', 'wb') as f:
    pickle.dump(all_chatgpt_res, f)
    
# with open(f'./cache/openai/visual-inov-template=9-k_1={k_1}-repeat={n_repeat}-data={args.dataset}-iter=1.pkl', 'wb') as f:
#     pickle.dump(all_chatgpt_res, f)

# with open(f'./cache/openai/visual-inov-template=9-k_1={k_1}-repeat={n_repeat}-vocab.pkl', 'wb') as f:
#     pickle.dump(data, f)

# with open(f'./cache/openai/visual-inov-template=5-k_1={k_1}-repeat={n_repeat}.pkl', 'wb') as f:
#     pickle.dump(all_chatgpt_res, f)

# with open(f'./cache/openai/visual-inov-template=5-k_1={k_1}-repeat={n_repeat}.pkl', 'rb') as f:
#     all_chatgpt_res = pickle.load(f)

In [47]:
with open(f'./cache/openai/topk=1-visual-inov-template=9_1-k_1={k_1}-repeat={n_repeat}-data={args.dataset}-uk{args.estimate_k}.pkl', 'rb') as f:
    all_chatgpt_res = pickle.load(f)

In [15]:
with open(f'/home/sheng/sssa/ipynb/cache/openai/VDE/{args.exp}-{args.vocabname}-template=9_1-k_1={k_1}-repeat={n_repeat}-data={args.dataset}.pkl', 'rb') as f:
    all_chatgpt_res = pickle.load(f)

In [None]:
while 1:
    """ integrity check """
    while 1:
        invalid_res = []
        for i in range(n_repeat):
            for j, row in enumerate(all_chatgpt_res[i]):
                extract_synsetid = lambda r: list(map(lambda x: x.split(': ')[0], r))
                remove_space = lambda r: list(filter(lambda x: len(x), r))
                synsets = extract_synsetid(remove_space(row.lower().replace('\n\n', '\n').split('\n')))
                gt_synsets = extract_synsetid(reduce(lambda x,y: x+y, cluster_row_synsets[j]))
                try:
                    start_idx = [ synsets[k].find(s) for k, s in enumerate(gt_synsets) ]
                    synsets = [ synsets[k][start_idx[k]:start_idx[k]+len(gt_synsets[k])] for k, s in enumerate(synsets) ]
                    assert set(synsets)==set(gt_synsets)
                except Exception as e:
                    print(i, j)
                    print(synsets, gt_synsets)
                    invalid_res.append((i,j))

        if len(invalid_res)==0:
            break
        else:
            for i,j in invalid_res:
                print(f'repair {(i,j)}')
                content = concept_templates[j]
                while 1:
                    try:
                        res = openai_chatgpt_post(content)
                        break
                    except Exception as e:
                        print(e)
                all_chatgpt_res[i][j] = res



    """ extract key-value-list from @chatgpt-res """
    extracted_chatgpt_res = []
    for j, row in enumerate(all_chatgpt_res[0]):
        # all_chatgpt_res[0][j] = 
        chatgpt_row_res = {}
        extract_synsetid = lambda r: list(map(lambda x: x.split(': ')[0], r))
        remove_space = lambda r: list(filter(lambda x: len(x), r))
        extract_synnames = lambda r: list(map(lambda x: x.split(': ')[1].split('; '), r))
        for i in range(n_repeat):
            row = all_chatgpt_res[i][j]
            row_data = remove_space(row.lower().replace('\n\n', '\n').split('\n'))
            synsets = extract_synsetid(row_data)
            synnames = extract_synnames(row_data)
            gt_synsets = extract_synsetid(reduce(lambda x,y: x+y, cluster_row_synsets[j]))
            start_idx = [ synsets[k].find(s) for k, s in enumerate(gt_synsets) ]
            synsets = [ synsets[k][start_idx[k]:start_idx[k]+len(gt_synsets[k])] for k, s in enumerate(synsets) ]
            for idx_s, s in enumerate(synsets):
                chatgpt_row_res.setdefault(s, [])
                chatgpt_row_res[s].append( remove_space(synnames[idx_s]) )
        extracted_chatgpt_res.append(chatgpt_row_res)

    """ deduplication """
    use_dedup = True
    all_candidates = []
    all_candidates_set = []
    for i, row in enumerate(extracted_chatgpt_res):
        ### flatten multiple results
        row_all_synset_names = list(map(lambda x: x.split('.')[0], row.keys()))
        row_candidates = {}
        row_candidates_set = {}
        for k, v in row.items():
            candidates = list(reduce(lambda x, y: x+y, v))
            candidates = [c for c in candidates if c not in row_all_synset_names] ### remove competing synset names
            set_candidates = set(candidates)
            k = k.split('.')[0] ### key synset name
            row_candidates.setdefault(k, [])
            row_candidates_set.setdefault(k, set([]))
            row_candidates[k].extend(candidates)
            row_candidates_set[k] |= set_candidates
        ### collect duplicates
        duplicates = set()
        for k1, v1 in row.items():
            k1 = k1.split('.')[0]
            for k2, v2 in row.items():
                k2 = k2.split('.')[0]
                if k1!=k2:
                    duplicates |= row_candidates_set[k1]&row_candidates_set[k2]
        ### remove duplication with synset-names (keys)
        row_candidates_update = {}
        row_candidates_set_update = {}
        for k1, v1 in row.items():
            k1 = k1.split('.')[0]
            for k2, v2 in row.items():
                k2 = k2.split('.')[0]
            row_candidates_set_update[k1] = row_candidates_set[k1] - duplicates if use_dedup else row_candidates_set[k1]
            row_candidates_update[k1] = [item for item in row_candidates[k1] if item not in duplicates ] if row_candidates_set[k1] else row_candidates[k1]

        all_candidates.append(row_candidates_update)
        all_candidates_set.append(row_candidates_set_update)


    ### check non-empty
    empty_list = []
    for i, line in enumerate(all_candidates_set):
        for k, v in line.items():
            if len(v)==0:
                for j in range(n_repeat):
                    empty_list.append(j)
                    print(f'repair {i} {j}')
                    while 1:
                        try:
                            res = openai_chatgpt_post(concept_templates[i])
                            break
                        except Exception as e:
                            print(e)
                    all_chatgpt_res[j][i] = res

    if len(empty_list)==0:
        break

In [None]:
data = \
{
    'all_candidates': all_candidates,
    'all_candidates_set': all_candidates_set,
}

In [None]:
""" counter sorting """
all_candidates = data['all_candidates']
all_counter_candidates = []
all_number_candidates = []
for row in all_candidates:
    row_counter = {}
    total_num = 0
    for k, v in row.items():
        ct = Counter(v)
        row_counter[k] = OrderedDict(sorted(ct.items())) ### order key
        total_num += sum(ct.values())
    all_counter_candidates.append(OrderedDict(sorted(row_counter.items()))) ### order key
    all_number_candidates.append(total_num)

In [None]:
### flatten
all_row_mapping_idx_synset_name = []
all_row_chatgpt_names = []
all_row_i_syn = []
all_row_weight = []
all_row_key_name = []
for i in range(len(all_counter_candidates)):
    row_synset_names = all_counter_candidates[i].keys()
    row_mapping_idx_synset_name = dict(zip(range(len(row_synset_names)), row_synset_names))
    row_i_syn = []
    row_chatgpt_names = []
    row_weight = []
    for i_syn, syn in enumerate(row_synset_names):
        row_i_syn.extend([i_syn for _ in range(len(all_counter_candidates[i][syn]))])
        row_chatgpt_names.extend(list(all_counter_candidates[i][syn]))
        row_weight.extend(list(all_counter_candidates[i][syn].values()))
    
    all_row_mapping_idx_synset_name.append(row_mapping_idx_synset_name)
    all_row_chatgpt_names.append(row_chatgpt_names)
    all_row_i_syn.append(row_i_syn)
    all_row_weight.append(row_weight)
    all_row_key_name.append(list(map(lambda x: row_mapping_idx_synset_name[x], row_i_syn)))

In [None]:
@torch.no_grad()
def build_classifier_chatgpt(all_row_chatgpt_names, model, all_row_key_name=None):
    """ build classifier for chatgpt
    Args:
        all_row_chatgpt_names: [[names]]
    """
    if all_row_key_name is None: ### single name
        with open('../templates_small.json', 'rb') as f: ### template 1
            templates = json.load(f)['imagenet']
    else:
        with open('../templates_small.json', 'rb') as f: ### template 2
            templates = json.load(f)[f'{args.dataset}-parent-3']
            
    len_t = len(templates)
    row_classifier = []
    with tqdm(total=len(all_row_chatgpt_names)) as pbar:
        for idx, row in enumerate(all_row_chatgpt_names):
            len_row = len(row)
            if all_row_key_name is None:
                row_t = [ t.format(name) for name in row for t in templates ]
            else:
                row_t = [ t.format(pname, name) for pname, name in zip(all_row_key_name[idx], row) for t in templates ]
            row_t = tokenize(row_t).to(args.device)
            features = model.encode_text(row_t)
            features = features.view(len_row, len_t, -1).float()
            features = features/features.norm(dim=-1, keepdim=True)
            features = features.mean(dim=1)
            features = features/features.norm(dim=-1, keepdim=True)
            row_classifier.append(features.cpu())
            
            pbar.update(1)
    return row_classifier
    

In [None]:
all_row_classifier = build_classifier_chatgpt(all_row_chatgpt_names, model, all_row_key_name=all_row_key_name)

In [None]:
# vfeatures = np.load(f'./cache/features/vfeatures-{args.dataset}.npy')
vfeatures = all_vfeatures
all_clu_pred_chatgpt = torch.zeros_like(all_clu_pred)
is_correct = []
k_2 = 3
enable_weight = True
instance_pred_voc = torch.zeros_like(record_pred_kmeans_t)
for c in range(len(all_row_classifier)):
    select = (record_pred_kmeans_t==c)
    row_classifier = all_row_classifier[c]
    sim = torch.from_numpy(vfeatures[select, ...]).to(args.device)@row_classifier.to(args.device).t()
    sim_topk = sim.topk(k=k_2)
    ind, val = sim_topk.indices.flatten().cpu().unique(return_counts=True)
    count_names = torch.zeros(row_classifier.size(0)).long()
    count_names[ind] = val ### count of each name
    count_smask = []
    smask = np.array(all_row_i_syn[c]) ### partition mask
    for s in np.unique(smask):
        if enable_weight:
            row_weight = torch.tensor(all_row_weight[c]).float()
            row_weight[smask==s] = row_weight[smask==s] / row_weight[(smask==s)].sum()
            row_weight /= row_weight.sum()
            count_smask.append((row_weight[smask==s]*count_names[smask==s]).sum().item())
        else:
            count_smask.append(count_names[smask==s].sum())
    name_pred = all_row_mapping_idx_synset_name[c][np.argmax(count_smask)]
    name_gt = all_gt_voc[select].mode().values
    name_gt = vocab.mapping_idx_names[name_gt.item()]
    is_correct.append(name_pred==name_gt)
    instance_pred_voc[select] = vocab.mapping_names_idx[name_pred]
    
    val_count = torch.tensor(count_smask)
    ind_count = [ all_row_mapping_idx_synset_name[c][ii] for ii in range(val_count.shape[0]) ]
    ind_count = torch.tensor([vocab.mapping_names_idx[xx] for xx in ind_count])
    all_clu_pred_chatgpt[c, ind_count] = val_count

In [None]:
name_acc = np.array(is_correct).mean().item()
instance_acc = (instance_pred_voc==all_gt_voc).float().mean().item()
missing = all_gt_voc.unique().size(0) - all_gt_voc[(instance_pred_voc==all_gt_voc)].unique().size(0)

print(f'name_acc={name_acc}, instance_acc={instance_acc}, missing={missing}')
instance_pred_voc.unique().shape, all_gt_voc.unique().shape

In [56]:
# mapping_voc_clu = dict(zip(instance_pred_voc.unique().numpy().tolist(), range(len(instance_pred_voc))))
# r_pred_kmeans_t = np.array([mapping_voc_clu[item.item()] for item in instance_pred_voc])

# np.save(f'/home/sheng/sssa/ipynb/cache/cluster/topk=1-cache-inov-{args.dataset}-clip-chatgpt-uk{args.estimate_k}.pth', r_pred_kmeans_t)

In [38]:
# a, b = linear_assign(all_clu_pred_chatgpt, record_pred_kmeans_t, all_gt_voc)

# # with open(f'./cache/openai/inov-cluster_visual_chatgpt-repeat={n_repeat}-k_1={k_1}-dataset={args.dataset}.pkl', 'wb') as f:
# #     pickle.dump(all_chatgpt_res, f)

# instance_acc = ((instance_pred_voc==all_gt_voc) | (cluster_ind_voc.cpu()==all_gt_voc)).float().mean().item()
# print(instance_acc)

In [None]:
classifier = get_classifier(args)
classifier = classifier/classifier.norm(dim=-1, keepdim=True)
args.num_voc = classifier.size(0)
a, res_ass = linear_assign(all_clu_pred_chatgpt, record_pred_kmeans_t, all_gt_voc)
r_pred_kmeans_t, r_cluster_ind_voc = reassign_by_pred_cluster(a, loader_f, model, classifier, args.device, preextracted_vfeatures=all_vfeatures)

In [31]:
set_pred = set(res_ass[1].tolist())
set_gt = set(all_gt_voc.unique().numpy().tolist())
n_inter = all_gt_voc[cluster_ind_voc.cpu()==all_gt_voc].unique().shape[0]
n_union = torch.cat([cluster_ind_voc.cpu(), all_gt_voc]).unique().shape[0]
iou_voc = n_inter/n_union
n_missing_label = all_gt_voc.unique().shape[0] - n_inter
print('missing label::', n_missing_label)
print('iou voc::', iou_voc)
print('cluster acc', cluster_acc(y_true=all_label_clu.numpy(), y_pred=r_pred_kmeans_t.numpy()))

missing label:: 31
iou voc:: 0.37373737373737376
cluster acc 0.6826953531514411


In [26]:
with open(f'/home/sheng/sssa/ipynb/cache/openai/VDE/{args.exp}-{args.vocabname}-template=9_1-k_1={k_1}-repeat={n_repeat}-data={args.dataset}.pkl', 'wb') as f:
    pickle.dump(all_chatgpt_res, f)

In [60]:
np.save(f'/home/sheng/sssa/ipynb/cache/cluster/topk=1-cache-inov-{args.dataset}-clip-chatgpt-uk206.pth', r_pred_kmeans_t.cpu().numpy())

In [29]:
np.save(f'/home/sheng/sssa/ipynb/cache/cluster/topk=1-cache-inov-{args.dataset}-clip-chatgpt.pth', r_pred_kmeans_t.cpu().numpy())
# np.save(f'/home/sheng/sssa/ipynb/cache/cluster/cache-inov-{args.dataset}-clip-chatgpt-iter=1.pth', r_pred_kmeans_t.cpu().numpy())

In [30]:
np.load(f'/home/sheng/sssa/ipynb/cache/cluster/cache-inov-{args.dataset}-clip-chatgpt.pth.npy')

array([167, 167, 167, ..., 232, 232, 232])

In [53]:
k_1 = 3
topk_all_clu_pred = (classifier@classifier.t()).topk(k=k_1).indices

In [54]:
""" gather concepts """
to_name = lambda x: [ s.name() + ': ' + s.definition() for s in x ]
cluster_row_synsets = []
for row in topk_all_clu_pred:
    row_synsets = [to_name(mapping_vocidx_to_synsets(voc_idx.item(), vocab)) for voc_idx in row]
    cluster_row_synsets.append(row_synsets)

In [58]:
""" generate concept requests """
concept_request = []
for row in cluster_row_synsets:
    ccpts = reduce(lambda x, y: x+y, row)
    ccpts = list(map(lambda x: "'"+x+".'", ccpts))
    ccpts = ', '.join(ccpts)
    concept_request.append(ccpts)
    
""" generate concept templates """
template_1 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all alternative concept names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
with open('/home/sheng/sssa/templates_chatgpt.json', 'r') as f:
    template_chatgpt = json.load(f)
template_2 = lambda concept_list: template_chatgpt['pictionary-long'].format(concept_list)
template_3 = lambda concept_list: template_chatgpt['pictionary-short'].format(concept_list)
template_4 = lambda concept_list: template_chatgpt['direct'].format(concept_list)
template_5 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all synonym concept names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_6 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all category names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_7 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all parent-type category names for each visual concept. List in the format \"{concept name}: {list of names separated by ';'}.\""
template_8 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible descriptive phrases of image captions for each visual concept. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
template_9 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible visiual descriptive phrases for each visual concept without duplication. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."
# template_10 = lambda concept_list: "Given visual concepts: "+ concept_list + "List all possible visiual descriptive phrases for each visual concept without duplication. List in the format \"{concept name}: {all phrases deliminated by semicolons}.\" for each concept. No duplication."

    
template_in_use = template_9
concept_templates = []
for row in concept_request:
    concept_templates.append(template_in_use(row))
    
n_repeat = 1

In [None]:
""" collect chatgpt res """
all_chatgpt_res = [[] for _ in range(n_repeat)]
with tqdm(total=len(concept_templates)*n_repeat) as pbar:
    for i in range(n_repeat):
        for row in concept_templates:
            while 1:
                try:
                    all_chatgpt_res[i].append(openai_chatgpt_post(row))
                    break
                except Exception as e:
                    print(e)

            pbar.update(1)

  0%|          | 60/20071 [06:52<38:59:28,  7.01s/it]

That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 8c6d3b0b746aa518c8842e4e54019d28 in your message.)


  1%|          | 126/20071 [15:12<47:16:45,  8.53s/it]

0