In [1]:
import os
import gzip
import itertools
import sys
import json
import glob
from tqdm.auto import tqdm
import numpy as np
import pycocotools.mask as mask_util
import io
import contextlib
from multiprocessing import Pool
import pickle
from scipy.stats import gmean
from scipy.special import softmax
from collections import defaultdict

In [2]:
PATH_TO_ANNS = "../data/val_annot.json"
PATH_TO_PREDS = "/scratch/shared/beegfs/prannay/ego4d_data/ckpt/results/traces_v2_all/vq_stats_val_{}.json.gz"
PATH_TO_DINO = "/scratch/shared/beegfs/prannay/ego4d_data/ckpt/results/traces_v2_all/vq_stats_val_{}_dino_scores_all.pkl"
assert os.path.exists(PATH_TO_ANNS)

In [3]:
with open(PATH_TO_ANNS, "r") as f:
    anns = json.load(f)

In [4]:
PRED_IDXS = range(0, 200)
path_to_preds = [PATH_TO_PREDS.format(i) for i in PRED_IDXS]
assert all([os.path.exists(p) for p in path_to_preds])
path_to_dinos = [PATH_TO_DINO.format(i) for i in PRED_IDXS]
assert all([os.path.exists(p) for p in path_to_dinos])

In [5]:
def preprocess_data(path):
    with gzip.open(path, "r") as f:
        data = json.load(f)
    data = data['predictions']
    # convert generic Dict[List] to List[Dict]
    out_list = []
    keys = sorted(data.keys())
    for idx in range(len(data[keys[0]])):
        save_dict = {k: data[k][idx] for k in keys}
        save_dict['dataset_uid'] = save_dict['dataset_uids']
        del save_dict['dataset_uids'] 
        out_list.append(save_dict)
    return out_list

In [6]:
def preprocess_dino(path):
    with open(path, "rb") as f:
        dino_data = pickle.load(f)
    return dict(dino_data)

In [11]:
def run_inner_episode(inputs):
    pred, gt, dino = inputs
    gt_fnos = [a['frame_number'] for a in gt['response_track']]
    pred_fnos = [a['frame_number'] for a in pred['groundtruth_response_tracks']]
    # print(len(pred['groundtruth_response_tracks']))
    assert gt_fnos == pred_fnos
    assert gt_fnos == list(range(min(gt_fnos), max(gt_fnos) + 1))
    assert (gt['dataset_uid'] == pred['dataset_uid'])
    # dino is a list
    import pdb; pdb.set_trace()
    gt_fnos = set(gt_fnos)
    gt_list = defaultdict(dict)
    assert len(dino) == len(pred['predicted_bboxes']) == len(pred['predicted_scores'])
    for d, bb, sc in zip(dino, pred['predicted_bboxes'], pred['predicted_scores']):
        if d['frame_number'] in gt_fnos:
            gt_list.append((d['frame_number'], bb, sc, d['dino_scores']))

In [12]:
def run_outer_episode(i, duid2gt=None):
    pred_path = PATH_TO_PREDS.format(i)
    dino_path = PATH_TO_DINO.format(i)
    save_path = pred_path.replace(".json.gz", "_sc_dino_gmean_gt_peak_tracker_init.json")
    if os.path.exists(save_path):
        return
    pred = preprocess_data(pred_path)
    dino = preprocess_dino(dino_path)
    duid2pred = {p['dataset_uid']: p for p in pred}
    valid_duids = set([ann['dataset_uid'] for ann in duid2gt.values() if ann['clip_uid'] is not None])
    duid2pred = {k: v for k, v in duid2pred.items() if k in valid_duids}
    duid2gt = {k: v for k, v in duid2gt.items() if k in duid2pred.keys()}
    duid2dino = {duid: dino[duid] for duid in duid2pred}
    assert len(duid2pred) == len(duid2gt) == len(duid2dino)
    inputs = [(duid2pred[k], duid2gt[k], duid2dino[k]) for k in sorted(duid2pred.keys())]
    outputs = [run_inner_episode(inp) for inp in inputs]
    with open(save_path, "w") as f:
        json.dump(outputs, f)

In [13]:
duid2gt = {g['dataset_uid']: g for g in anns if g['clip_uid'] is not None}
run_outer_episode(0, duid2gt=duid2gt)

--Return--
None
> [0;32m/tmp/ipykernel_66517/704524619.py[0m(10)[0;36mrun_inner_episode[0;34m()[0m
[0;32m      6 [0;31m    [0;32massert[0m [0mgt_fnos[0m [0;34m==[0m [0mpred_fnos[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0;32massert[0m [0mgt_fnos[0m [0;34m==[0m [0mlist[0m[0;34m([0m[0mrange[0m[0;34m([0m[0mmin[0m[0;34m([0m[0mgt_fnos[0m[0;34m)[0m[0;34m,[0m [0mmax[0m[0;34m([0m[0mgt_fnos[0m[0;34m)[0m [0;34m+[0m [0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m    [0;32massert[0m [0;34m([0m[0mgt[0m[0;34m[[0m[0;34m'dataset_uid'[0m[0;34m][0m [0;34m==[0m [0mpred[0m[0;34m[[0m[0;34m'dataset_uid'[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m    [0;31m# dino is a list[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[