In [1]:
import sys, os

env_root = '/N/project/baby_vision_curriculum/pythonenvs/hfenv/lib/python3.10/site-packages/'
sys.path.insert(0, env_root)

In [2]:
import numpy as np
# import torch, torchvision
# from torchvision import transforms as tr
from tqdm import tqdm
from pathlib import Path
# import math
import argparse
import pandas as pd
import warnings

from copy import deepcopy

from sklearn import svm, preprocessing
from sklearn.svm import LinearSVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split

In [3]:
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
import json 
import logging
logger = logging.getLogger(__name__)

In [4]:
logger



In [5]:
def topk_retrieval(feature_dir, cfg):
    """Extract features from test split and search on train split features."""
    logger.info('Loading local .npy files...')
    fold = cfg.get_int('dataset.fold')

    X_train = np.load(os.path.join(feature_dir, f'train_fold{fold}_feats.npy'))
    y_train = np.load(os.path.join(feature_dir, f'train_fold{fold}_labels.npy'))

    X_test = np.load(os.path.join(feature_dir, f'test_fold{fold}_feats.npy'))
    y_test = np.load(os.path.join(feature_dir, f'test_fold{fold}_labels.npy'))

    ks = [1, 5, 10, 20, 50]
    topk_correct = {k:0 for k in ks}

    distances = cosine_distances(X_test, X_train)
    indices = np.argsort(distances)

    for k in ks:
        # print(k)
        top_k_indices = indices[:, :k]
        # print(top_k_indices.shape, y_test.shape)
        for ind, test_label in zip(top_k_indices, y_test):
            labels = y_train[ind]
            if test_label in labels:
                # print(test_label, labels)
                topk_correct[k] += 1

    for k in ks:
        correct = topk_correct[k]
        total = len(X_test)
        logger.info('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}'.format(k, correct, total, correct/total))

    with open(os.path.join(feature_dir, f'topk_correct_fold{fold}.json'), 'w') as fp:
        json.dump(topk_correct, fp)
        
def get_separability_score_old(df, label, method='sgd', ret_preds=False):
    # method: sgd or svm
    le = preprocessing.LabelEncoder()
    y = le.fit_transform(df[label])
    
    X_cols = ['dim'+str(i)
              for i in range(768)]
    X = df[X_cols]
    
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    
    if method=='svm':
        clf = make_pipeline(StandardScaler(),
                            LinearSVC(random_state=0, tol=1e-4))
    elif method=='sgd':
        clf = make_pipeline(StandardScaler(),
                            SGDClassifier(max_iter=5000, tol=1e-4, n_jobs=20))#, loss='log_loss'))
    else:
        raise ValueError()
    clf.fit(X_train, y_train)
    train_score = clf.score(X_train,y_train)
    test_score = clf.score(X_test,y_test)
    if ret_preds:
        preds = clf.predict(X_test)
        return train_score, test_score, preds, y_test
    return train_score, test_score

In [6]:
def get_nn_score(df_train, df_test, label, metric='cosine',
                 savedir=None, run_id=None):
#     Get nearest neighbor score
    
    le = preprocessing.LabelEncoder()
    y_train = le.fit_transform(df_train[label])
    
    X_cols = [col for col in df_train.columns if 'dim' in col]
    x_train = df_train[X_cols]
    
    if df_test is not None:
        x_test = df_test[X_cols]
        y_test = le.transform(df_test[label])
    
    ks = [1, 5, 10, 20, 50]
    topk_correct = {k:0 for k in ks}

    if metric=='cosine':
        distances = cosine_distances(x_test, x_train)
    else:
        distances = euclidean_distances(x_test, x_train)
    indices = np.argsort(distances)
    
    for k in ks:
        # print(k)
        top_k_indices = indices[:, :k]
        # print(top_k_indices.shape, y_test.shape)
        for ind, test_label in zip(top_k_indices, y_test):
            labels = y_train[ind]
            if test_label in labels:
                # print(test_label, labels)
                topk_correct[k] += 1
        topk_correct[k] = topk_correct[k]/len(y_test)
    
    for k in ks:
        correct = topk_correct[k]
        total = len(x_test)
        logger.info('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}'.format(k, correct, total, correct/total))
        print('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}'.format(k, correct, total, correct/total))

    if savedir is not None:
        if run_id is None:
            raise ValueError
            # run_id = ''
        with open(os.path.join(savedir, 
                               run_id+'_topk_correct.json'), 'w') as fp:
            json.dump(topk_correct, fp)

    return topk_correct#[1], topk_correct[5], topk_correct[10]

def get_separability_score(df_train, df_test, label, 
                           method='sgd', ret_preds=False,
                          n_jobs=80):
    # method: sgd or svm
    le = preprocessing.LabelEncoder()
    y_train = le.fit_transform(df_train[label])
    
    X_cols = [col for col in df_train.columns if 'dim' in col]
    x_train = df_train[X_cols]
    
    if df_test is not None:
        x_test = df_test[X_cols]
        y_test = le.transform(df_test[label])
    else:
        x_train, x_test, y_train, y_test = train_test_split(
            x_train, y_train, test_size=0.33, random_state=42)
    
    if method=='svm':
        clf = make_pipeline(StandardScaler(),
                            LinearSVC(random_state=0, tol=1e-4))
    elif method=='sgd':
        clf = make_pipeline(StandardScaler(),
                            SGDClassifier(max_iter=5000, tol=1e-4, n_jobs=n_jobs))#, loss='log_loss'))
    else:
        raise ValueError()
    clf.fit(x_train, y_train)
    train_score = clf.score(x_train,y_train)
    test_score = clf.score(x_test,y_test)
    if ret_preds:
        preds = clf.predict(x_test)
        return train_score, test_score, preds, y_test
    return train_score, test_score

In [7]:
class SSv2Eval():
    
    def __init__(self, label_paths=None):
        if label_paths is None:
            label_paths={'train':'/N/project/baby_vision_curriculum/benchmarks/ssv2/easy_labels/train_easy10.csv',
                         'test':'/N/project/baby_vision_curriculum/benchmarks/ssv2/easy_labels/val_easy10.csv'}        
        self.labels_df = {x: pd.read_csv(label_paths[x])
                         for x in ['train','test']}
        for phase in ['train', 'test']:
            self.labels_df[phase].set_index('fname', inplace=True)

    def get_categorylabel(self, fname, phase):
        return self.labels_df[phase].loc[str(fname)+'.webm', 'label']

    def add_labels_to_df(self, df, labels, phase):
        if 'category' in labels:
            df['category'] = df['fnames'].apply(
                lambda fname: self.get_categorylabel(fname, phase))
        return df
    
    def proc_train_test(self, data_fpaths, score_type, eval_type='linear', n_jobs=80):
        if score_type!='category':
            raise ValueError
        method='sgd'
#         data_fpaths = {'train':fp_train, }
        data_df = {}
        for phase in ['train', 'test']:
            data_df[phase] = pd.read_csv(
                data_fpaths[phase])
            data_df[phase] = self.add_labels_to_df(data_df[phase], ['category'], phase)
        
        if eval_type=='linear':
            train_score, test_score, preds, targets = get_separability_score(
                data_df['train'], data_df['test'], score_type, method=method, 
                ret_preds=True, n_jobs=n_jobs)
        else:
            test_score = get_nn_score(
                data_df['train'], data_df['test'], score_type, metric='cosine',
                 savedir=None, run_id=None)
        
        return test_score

In [22]:
class UCF101Eval():
    
    def __init__(self):
        pass
    
    def add_labels_to_df(self, df):
        df['category'] = df['fnames']#.apply(lambda fname: self.get_categorylabel(fname))
        return df
    
    def proc_train_test(self, data_fpaths, score_type, eval_type='linear', n_jobs=80):
        if score_type!='category':
            raise ValueError
        method='sgd'
#         data_fpaths = {'train':fp_train, }
        data_df = {}
        for phase in ['train', 'test']:
            data_df[phase] = pd.read_csv(
                data_fpaths[phase])
            data_df[phase] = self.add_labels_to_df(data_df[phase])
        
        if eval_type=='linear':
            train_score, test_score, preds, targets = get_separability_score(
                data_df['train'], data_df['test'], score_type, method=method, 
                ret_preds=True, n_jobs=n_jobs)
        else:
            test_score = get_nn_score(
                data_df['train'], data_df['test'], score_type, metric='cosine',
                 savedir=None, run_id=None)
        
        return test_score

In [23]:
def get_traingroups(curr, stage):
    if curr=='dev':
        return 'g0g1g2'[:2*stage]
    elif curr=='adev':
        return 'g2g1g0'[:2*stage]
    else:
        return 'na'
    
def parse_fname(fp):
# embeddings_adev_1_g2_default_0_246.csv'
    parts = Path(fp).stem.split('_')
    prefix, curr, stage, current_gr, cond, fold, seed = parts
    stage=int(stage)
    train_gr = get_traingroups(curr, stage)
    
    tag_dict = {
        'Curriculum':curr,
        'Stage': stage,
        'Condition':cond,
        'Seed': seed,
        'Train Groups': train_gr,
        'data_id': '_'.join([curr, seed, cond])
    }
    return tag_dict

In [24]:
def proc_result_folder(emb_root, ds_task, iter_per_stage, eval_type, 
                       n_jobs=80):
    if ds_task=='ssv2':
        evaluator = SSv2Eval()
        ds_task = 'category'
    elif ds_task=='tb_cat':
        evaluator = ToyBoxEval()
        ds_task = 'category'
    elif ds_task=='tb_trans':
        evaluator = ToyBoxEval()
        ds_task = 'transformation'
    elif ds_task=='cifar10':
        evaluator = Cifar10Eval()
        ds_task = 'category'
    elif ds_task=='ucf101':
        evaluator = UCF101Eval()
        ds_task = 'category'
        n_jobs_external=n_jobs
        n_jobs_internal=1
    else:
        raise ValueError
    record_list = []
    
    train_test_fp_list = []
    for fname in os.listdir(emb_root):
        if Path(emb_root+fname).suffix!='.csv':
            continue
        train_fp = emb_root+fname
        test_fp = str(Path(emb_root, 'test/', fname))
        
        if not os.path.exists(test_fp):
            print(test_fp, 'does not exist')
            continue
        train_test_fp_list.append(
            {'train':train_fp,
             'test':test_fp})
    

    for fp_dict in tqdm(train_test_fp_list):
#             print(fp)
        record = parse_fname(fp_dict['train'])
#         if '141' in fp:
#             continue
        if eval_type=='linear':
            record[ds_task] = evaluator.proc_train_test(fp_dict, ds_task,
                                                       eval_type=eval_type, 
                                                        n_jobs=n_jobs)
        else:
            topkcorrect = evaluator.proc_train_test(fp_dict, ds_task,
                                                       eval_type=eval_type, 
                                                        n_jobs=n_jobs)
            record['Top1'] = topkcorrect[1]
            record['Top5'] = topkcorrect[5]
            record['Top10'] = topkcorrect[10]
        record['Iteration']=iter_per_stage*record['Stage']
        record_list.append(deepcopy(record))
        
    df = pd.DataFrame.from_records(record_list)         
    return df

In [14]:
learner = 'generative/v3/'#'contrastive/v1/'#'predictive/v1/'#
ds_task = 'ssv2'
date = 'aug11'#'aug1'#'jul315'#
num_ep,iter_per_ep = 20, 1500 #5, 2000#2, 5000

eval_type='nn'

emb_root = '/N/project/baby_vision_curriculum/trained_models/'+learner+date+'/benchmarks/'+ds_task+'/'
# emb_root = '/N/project/baby_vision_curriculum/trained_models/generative/v3/jul28dev/benchmarks/ssv2/'

iter_per_stage = num_ep*iter_per_ep

df_ss = proc_result_folder(emb_root, ds_task, iter_per_stage, eval_type, n_jobs=23)
# df_ss.to_csv(date+'_'+ds_task+'_'+eval_type+'_score.csv', index=False)

/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_adult_0_na_default_0_112.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_adev_0_na_default_0_111.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_adult_0_na_default_0_111.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_adev_0_na_default_0_113.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_dev_0_na_default_0_112.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_rnd_0_na_default_0_112.csv does not exist
/N/project/baby_vision_curriculum/trained_models/generative/v3/aug11/benchmarks/ssv2/test/embeddings_dev_0_na_default_0_113.csv does not

  3%|█▏                                          | 1/36 [00:02<01:24,  2.41s/it]

Top-1, correct = 0.24, total = 1000, acc = 0.000
Top-5, correct = 0.59, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


  6%|██▍                                         | 2/36 [00:04<01:17,  2.28s/it]

Top-1, correct = 0.24, total = 1000, acc = 0.000
Top-5, correct = 0.61, total = 1000, acc = 0.001
Top-10, correct = 0.77, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


  8%|███▋                                        | 3/36 [00:07<01:17,  2.35s/it]

Top-1, correct = 0.21, total = 1000, acc = 0.000
Top-5, correct = 0.55, total = 1000, acc = 0.001
Top-10, correct = 0.74, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 11%|████▉                                       | 4/36 [00:09<01:16,  2.38s/it]

Top-1, correct = 0.12, total = 1000, acc = 0.000
Top-5, correct = 0.43, total = 1000, acc = 0.000
Top-10, correct = 0.66, total = 1000, acc = 0.001
Top-20, correct = 0.84, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 14%|██████                                      | 5/36 [00:12<01:17,  2.49s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.57, total = 1000, acc = 0.001
Top-10, correct = 0.76, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 17%|███████▎                                    | 6/36 [00:14<01:14,  2.47s/it]

Top-1, correct = 0.29, total = 1000, acc = 0.000
Top-5, correct = 0.67, total = 1000, acc = 0.001
Top-10, correct = 0.80, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 19%|████████▌                                   | 7/36 [00:17<01:11,  2.47s/it]

Top-1, correct = 0.26, total = 1000, acc = 0.000
Top-5, correct = 0.63, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 22%|█████████▊                                  | 8/36 [00:19<01:09,  2.50s/it]

Top-1, correct = 0.28, total = 1000, acc = 0.000
Top-5, correct = 0.65, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 25%|███████████                                 | 9/36 [00:22<01:07,  2.49s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.58, total = 1000, acc = 0.001
Top-10, correct = 0.76, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 28%|███████████▉                               | 10/36 [00:24<01:04,  2.49s/it]

Top-1, correct = 0.25, total = 1000, acc = 0.000
Top-5, correct = 0.63, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 31%|█████████████▏                             | 11/36 [00:27<01:03,  2.54s/it]

Top-1, correct = 0.20, total = 1000, acc = 0.000
Top-5, correct = 0.54, total = 1000, acc = 0.001
Top-10, correct = 0.73, total = 1000, acc = 0.001
Top-20, correct = 0.87, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 33%|██████████████▎                            | 12/36 [00:29<01:01,  2.56s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.60, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 36%|███████████████▌                           | 13/36 [00:32<00:59,  2.59s/it]

Top-1, correct = 0.17, total = 1000, acc = 0.000
Top-5, correct = 0.52, total = 1000, acc = 0.001
Top-10, correct = 0.73, total = 1000, acc = 0.001
Top-20, correct = 0.87, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 39%|████████████████▋                          | 14/36 [00:35<00:56,  2.57s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.58, total = 1000, acc = 0.001
Top-10, correct = 0.75, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 42%|█████████████████▉                         | 15/36 [00:37<00:52,  2.49s/it]

Top-1, correct = 0.26, total = 1000, acc = 0.000
Top-5, correct = 0.62, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 44%|███████████████████                        | 16/36 [00:39<00:48,  2.45s/it]

Top-1, correct = 0.16, total = 1000, acc = 0.000
Top-5, correct = 0.50, total = 1000, acc = 0.000
Top-10, correct = 0.69, total = 1000, acc = 0.001
Top-20, correct = 0.86, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 47%|████████████████████▎                      | 17/36 [00:41<00:45,  2.40s/it]

Top-1, correct = 0.31, total = 1000, acc = 0.000
Top-5, correct = 0.65, total = 1000, acc = 0.001
Top-10, correct = 0.80, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 50%|█████████████████████▌                     | 18/36 [00:44<00:43,  2.40s/it]

Top-1, correct = 0.24, total = 1000, acc = 0.000
Top-5, correct = 0.59, total = 1000, acc = 0.001
Top-10, correct = 0.77, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 53%|██████████████████████▋                    | 19/36 [00:46<00:40,  2.36s/it]

Top-1, correct = 0.18, total = 1000, acc = 0.000
Top-5, correct = 0.52, total = 1000, acc = 0.001
Top-10, correct = 0.71, total = 1000, acc = 0.001
Top-20, correct = 0.88, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 56%|███████████████████████▉                   | 20/36 [00:48<00:37,  2.36s/it]

Top-1, correct = 0.21, total = 1000, acc = 0.000
Top-5, correct = 0.55, total = 1000, acc = 0.001
Top-10, correct = 0.75, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 58%|█████████████████████████                  | 21/36 [00:51<00:35,  2.34s/it]

Top-1, correct = 0.25, total = 1000, acc = 0.000
Top-5, correct = 0.60, total = 1000, acc = 0.001
Top-10, correct = 0.80, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 61%|██████████████████████████▎                | 22/36 [00:53<00:33,  2.38s/it]

Top-1, correct = 0.27, total = 1000, acc = 0.000
Top-5, correct = 0.64, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 64%|███████████████████████████▍               | 23/36 [00:56<00:30,  2.37s/it]

Top-1, correct = 0.28, total = 1000, acc = 0.000
Top-5, correct = 0.63, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 67%|████████████████████████████▋              | 24/36 [00:58<00:28,  2.33s/it]

Top-1, correct = 0.12, total = 1000, acc = 0.000
Top-5, correct = 0.44, total = 1000, acc = 0.000
Top-10, correct = 0.65, total = 1000, acc = 0.001
Top-20, correct = 0.86, total = 1000, acc = 0.001
Top-50, correct = 0.99, total = 1000, acc = 0.001


 69%|█████████████████████████████▊             | 25/36 [01:00<00:25,  2.35s/it]

Top-1, correct = 0.24, total = 1000, acc = 0.000
Top-5, correct = 0.62, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 72%|███████████████████████████████            | 26/36 [01:02<00:23,  2.32s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.60, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 75%|████████████████████████████████▎          | 27/36 [01:05<00:20,  2.30s/it]

Top-1, correct = 0.25, total = 1000, acc = 0.000
Top-5, correct = 0.61, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.96, total = 1000, acc = 0.001


 78%|█████████████████████████████████▍         | 28/36 [01:07<00:18,  2.32s/it]

Top-1, correct = 0.11, total = 1000, acc = 0.000
Top-5, correct = 0.41, total = 1000, acc = 0.000
Top-10, correct = 0.64, total = 1000, acc = 0.001
Top-20, correct = 0.85, total = 1000, acc = 0.001
Top-50, correct = 0.99, total = 1000, acc = 0.001


 81%|██████████████████████████████████▋        | 29/36 [01:09<00:16,  2.31s/it]

Top-1, correct = 0.23, total = 1000, acc = 0.000
Top-5, correct = 0.60, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 83%|███████████████████████████████████▊       | 30/36 [01:12<00:13,  2.32s/it]

Top-1, correct = 0.21, total = 1000, acc = 0.000
Top-5, correct = 0.60, total = 1000, acc = 0.001
Top-10, correct = 0.76, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 86%|█████████████████████████████████████      | 31/36 [01:14<00:11,  2.32s/it]

Top-1, correct = 0.24, total = 1000, acc = 0.000
Top-5, correct = 0.63, total = 1000, acc = 0.001
Top-10, correct = 0.78, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 89%|██████████████████████████████████████▏    | 32/36 [01:16<00:09,  2.33s/it]

Top-1, correct = 0.20, total = 1000, acc = 0.000
Top-5, correct = 0.59, total = 1000, acc = 0.001
Top-10, correct = 0.77, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001


 92%|███████████████████████████████████████▍   | 33/36 [01:19<00:06,  2.32s/it]

Top-1, correct = 0.23, total = 1000, acc = 0.000
Top-5, correct = 0.59, total = 1000, acc = 0.001
Top-10, correct = 0.77, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 94%|████████████████████████████████████████▌  | 34/36 [01:21<00:04,  2.37s/it]

Top-1, correct = 0.23, total = 1000, acc = 0.000
Top-5, correct = 0.58, total = 1000, acc = 0.001
Top-10, correct = 0.76, total = 1000, acc = 0.001
Top-20, correct = 0.89, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


 97%|█████████████████████████████████████████▊ | 35/36 [01:24<00:02,  2.37s/it]

Top-1, correct = 0.22, total = 1000, acc = 0.000
Top-5, correct = 0.58, total = 1000, acc = 0.001
Top-10, correct = 0.77, total = 1000, acc = 0.001
Top-20, correct = 0.90, total = 1000, acc = 0.001
Top-50, correct = 0.97, total = 1000, acc = 0.001


100%|███████████████████████████████████████████| 36/36 [01:26<00:00,  2.40s/it]

Top-1, correct = 0.25, total = 1000, acc = 0.000
Top-5, correct = 0.61, total = 1000, acc = 0.001
Top-10, correct = 0.79, total = 1000, acc = 0.001
Top-20, correct = 0.91, total = 1000, acc = 0.001
Top-50, correct = 0.98, total = 1000, acc = 0.001





In [16]:
df_ss.groupby(['Stage', 'Condition', 'Curriculum']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Top1,Top5,Top10,Iteration
Stage,Condition,Curriculum,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,default,adev,0.116667,0.427,0.65,30000.0
1,default,adult,0.19,0.535,0.731333,30000.0
1,default,dev,0.217333,0.579,0.756,30000.0
1,default,rnd,0.190667,0.530333,0.724333,30000.0
2,default,adev,0.224333,0.602333,0.773667,60000.0
2,default,adult,0.222,0.580333,0.760333,60000.0
2,default,dev,0.253667,0.625,0.779667,60000.0
2,default,rnd,0.235,0.594,0.778667,60000.0
3,default,adev,0.241333,0.609,0.783333,90000.0
3,default,adult,0.245333,0.610333,0.79,90000.0


In [17]:
df_ss.to_csv(date+'_'+ds_task+'_'+eval_type+'_score.csv', index=False)

In [25]:
learner = 'generative/v3/'#'contrastive/v1/'#'predictive/v1/'#
ds_task = 'ucf101'
date = 'aug11'#'aug1'#'jul315'#
num_ep,iter_per_ep = 20, 1500 #5, 2000#2, 5000

eval_type='nn'

emb_root = '/N/project/baby_vision_curriculum/trained_models/'+learner+date+'/benchmarks/'+ds_task+'/'
# emb_root = '/N/project/baby_vision_curriculum/trained_models/generative/v3/jul28dev/benchmarks/ssv2/'

iter_per_stage = num_ep*iter_per_ep

df_ucf = proc_result_folder(emb_root, ds_task, iter_per_stage, eval_type, n_jobs=23)
# df_ucf.to_csv(date+'_'+ds_task+'_'+eval_type+'_score.csv', index=False)

  3%|█▏                                          | 1/36 [00:04<02:45,  4.73s/it]

Top-1, correct = 0.28, total = 3747, acc = 0.000
Top-5, correct = 0.43, total = 3747, acc = 0.000
Top-10, correct = 0.51, total = 3747, acc = 0.000
Top-20, correct = 0.59, total = 3747, acc = 0.000
Top-50, correct = 0.73, total = 3747, acc = 0.000


  6%|██▍                                         | 2/36 [00:09<02:39,  4.69s/it]

Top-1, correct = 0.25, total = 3758, acc = 0.000
Top-5, correct = 0.42, total = 3758, acc = 0.000
Top-10, correct = 0.51, total = 3758, acc = 0.000
Top-20, correct = 0.60, total = 3758, acc = 0.000
Top-50, correct = 0.73, total = 3758, acc = 0.000


  8%|███▋                                        | 3/36 [00:13<02:33,  4.64s/it]

Top-1, correct = 0.26, total = 3755, acc = 0.000
Top-5, correct = 0.38, total = 3755, acc = 0.000
Top-10, correct = 0.46, total = 3755, acc = 0.000
Top-20, correct = 0.56, total = 3755, acc = 0.000
Top-50, correct = 0.70, total = 3755, acc = 0.000


 11%|████▉                                       | 4/36 [00:18<02:29,  4.67s/it]

Top-1, correct = 0.17, total = 3747, acc = 0.000
Top-5, correct = 0.29, total = 3747, acc = 0.000
Top-10, correct = 0.36, total = 3747, acc = 0.000
Top-20, correct = 0.46, total = 3747, acc = 0.000
Top-50, correct = 0.62, total = 3747, acc = 0.000


 14%|██████                                      | 5/36 [00:23<02:22,  4.58s/it]

Top-1, correct = 0.25, total = 3750, acc = 0.000
Top-5, correct = 0.39, total = 3750, acc = 0.000
Top-10, correct = 0.48, total = 3750, acc = 0.000
Top-20, correct = 0.58, total = 3750, acc = 0.000
Top-50, correct = 0.71, total = 3750, acc = 0.000


 17%|███████▎                                    | 6/36 [00:27<02:15,  4.53s/it]

Top-1, correct = 0.28, total = 3751, acc = 0.000
Top-5, correct = 0.45, total = 3751, acc = 0.000
Top-10, correct = 0.54, total = 3751, acc = 0.000
Top-20, correct = 0.64, total = 3751, acc = 0.000
Top-50, correct = 0.77, total = 3751, acc = 0.000


 19%|████████▌                                   | 7/36 [00:32<02:11,  4.52s/it]

Top-1, correct = 0.26, total = 3757, acc = 0.000
Top-5, correct = 0.42, total = 3757, acc = 0.000
Top-10, correct = 0.52, total = 3757, acc = 0.000
Top-20, correct = 0.62, total = 3757, acc = 0.000
Top-50, correct = 0.75, total = 3757, acc = 0.000


 22%|█████████▊                                  | 8/36 [00:36<02:05,  4.46s/it]

Top-1, correct = 0.28, total = 3749, acc = 0.000
Top-5, correct = 0.44, total = 3749, acc = 0.000
Top-10, correct = 0.54, total = 3749, acc = 0.000
Top-20, correct = 0.64, total = 3749, acc = 0.000
Top-50, correct = 0.77, total = 3749, acc = 0.000


 25%|███████████                                 | 9/36 [00:40<02:00,  4.45s/it]

Top-1, correct = 0.27, total = 3744, acc = 0.000
Top-5, correct = 0.43, total = 3744, acc = 0.000
Top-10, correct = 0.52, total = 3744, acc = 0.000
Top-20, correct = 0.62, total = 3744, acc = 0.000
Top-50, correct = 0.74, total = 3744, acc = 0.000


 28%|███████████▉                               | 10/36 [00:45<01:55,  4.43s/it]

Top-1, correct = 0.28, total = 3749, acc = 0.000
Top-5, correct = 0.44, total = 3749, acc = 0.000
Top-10, correct = 0.53, total = 3749, acc = 0.000
Top-20, correct = 0.62, total = 3749, acc = 0.000
Top-50, correct = 0.75, total = 3749, acc = 0.000


 31%|█████████████▏                             | 11/36 [00:49<01:50,  4.41s/it]

Top-1, correct = 0.21, total = 3754, acc = 0.000
Top-5, correct = 0.34, total = 3754, acc = 0.000
Top-10, correct = 0.42, total = 3754, acc = 0.000
Top-20, correct = 0.52, total = 3754, acc = 0.000
Top-50, correct = 0.66, total = 3754, acc = 0.000


 33%|██████████████▎                            | 12/36 [00:53<01:45,  4.38s/it]

Top-1, correct = 0.27, total = 3756, acc = 0.000
Top-5, correct = 0.42, total = 3756, acc = 0.000
Top-10, correct = 0.51, total = 3756, acc = 0.000
Top-20, correct = 0.61, total = 3756, acc = 0.000
Top-50, correct = 0.73, total = 3756, acc = 0.000


 36%|███████████████▌                           | 13/36 [00:58<01:43,  4.50s/it]

Top-1, correct = 0.20, total = 3760, acc = 0.000
Top-5, correct = 0.33, total = 3760, acc = 0.000
Top-10, correct = 0.41, total = 3760, acc = 0.000
Top-20, correct = 0.51, total = 3760, acc = 0.000
Top-50, correct = 0.66, total = 3760, acc = 0.000


 39%|████████████████▋                          | 14/36 [01:02<01:37,  4.44s/it]

Top-1, correct = 0.26, total = 3741, acc = 0.000
Top-5, correct = 0.40, total = 3741, acc = 0.000
Top-10, correct = 0.49, total = 3741, acc = 0.000
Top-20, correct = 0.60, total = 3741, acc = 0.000
Top-50, correct = 0.73, total = 3741, acc = 0.000


 42%|█████████████████▉                         | 15/36 [01:07<01:32,  4.43s/it]

Top-1, correct = 0.27, total = 3748, acc = 0.000
Top-5, correct = 0.42, total = 3748, acc = 0.000
Top-10, correct = 0.50, total = 3748, acc = 0.000
Top-20, correct = 0.60, total = 3748, acc = 0.000
Top-50, correct = 0.73, total = 3748, acc = 0.000


 44%|███████████████████                        | 16/36 [01:11<01:27,  4.38s/it]

Top-1, correct = 0.19, total = 3750, acc = 0.000
Top-5, correct = 0.31, total = 3750, acc = 0.000
Top-10, correct = 0.39, total = 3750, acc = 0.000
Top-20, correct = 0.49, total = 3750, acc = 0.000
Top-50, correct = 0.64, total = 3750, acc = 0.000


 47%|████████████████████▎                      | 17/36 [01:15<01:22,  4.35s/it]

Top-1, correct = 0.31, total = 3754, acc = 0.000
Top-5, correct = 0.46, total = 3754, acc = 0.000
Top-10, correct = 0.56, total = 3754, acc = 0.000
Top-20, correct = 0.65, total = 3754, acc = 0.000
Top-50, correct = 0.78, total = 3754, acc = 0.000


 50%|█████████████████████▌                     | 18/36 [01:20<01:17,  4.33s/it]

Top-1, correct = 0.26, total = 3744, acc = 0.000
Top-5, correct = 0.41, total = 3744, acc = 0.000
Top-10, correct = 0.50, total = 3744, acc = 0.000
Top-20, correct = 0.60, total = 3744, acc = 0.000
Top-50, correct = 0.73, total = 3744, acc = 0.000


 53%|██████████████████████▋                    | 19/36 [01:24<01:13,  4.31s/it]

Top-1, correct = 0.20, total = 3756, acc = 0.000
Top-5, correct = 0.32, total = 3756, acc = 0.000
Top-10, correct = 0.41, total = 3756, acc = 0.000
Top-20, correct = 0.50, total = 3756, acc = 0.000
Top-50, correct = 0.64, total = 3756, acc = 0.000


 56%|███████████████████████▉                   | 20/36 [01:28<01:09,  4.32s/it]

Top-1, correct = 0.22, total = 3755, acc = 0.000
Top-5, correct = 0.34, total = 3755, acc = 0.000
Top-10, correct = 0.42, total = 3755, acc = 0.000
Top-20, correct = 0.52, total = 3755, acc = 0.000
Top-50, correct = 0.67, total = 3755, acc = 0.000


 58%|█████████████████████████                  | 21/36 [01:33<01:04,  4.30s/it]

Top-1, correct = 0.27, total = 3755, acc = 0.000
Top-5, correct = 0.41, total = 3755, acc = 0.000
Top-10, correct = 0.52, total = 3755, acc = 0.000
Top-20, correct = 0.61, total = 3755, acc = 0.000
Top-50, correct = 0.74, total = 3755, acc = 0.000


 61%|██████████████████████████▎                | 22/36 [01:37<01:00,  4.30s/it]

Top-1, correct = 0.27, total = 3749, acc = 0.000
Top-5, correct = 0.42, total = 3749, acc = 0.000
Top-10, correct = 0.51, total = 3749, acc = 0.000
Top-20, correct = 0.60, total = 3749, acc = 0.000
Top-50, correct = 0.73, total = 3749, acc = 0.000


 64%|███████████████████████████▍               | 23/36 [01:41<00:55,  4.29s/it]

Top-1, correct = 0.26, total = 3754, acc = 0.000
Top-5, correct = 0.42, total = 3754, acc = 0.000
Top-10, correct = 0.51, total = 3754, acc = 0.000
Top-20, correct = 0.60, total = 3754, acc = 0.000
Top-50, correct = 0.74, total = 3754, acc = 0.000


 67%|████████████████████████████▋              | 24/36 [01:45<00:51,  4.26s/it]

Top-1, correct = 0.11, total = 3735, acc = 0.000
Top-5, correct = 0.20, total = 3735, acc = 0.000
Top-10, correct = 0.26, total = 3735, acc = 0.000
Top-20, correct = 0.36, total = 3735, acc = 0.000
Top-50, correct = 0.51, total = 3735, acc = 0.000


 69%|█████████████████████████████▊             | 25/36 [01:50<00:46,  4.27s/it]

Top-1, correct = 0.27, total = 3760, acc = 0.000
Top-5, correct = 0.43, total = 3760, acc = 0.000
Top-10, correct = 0.52, total = 3760, acc = 0.000
Top-20, correct = 0.61, total = 3760, acc = 0.000
Top-50, correct = 0.75, total = 3760, acc = 0.000


 72%|███████████████████████████████            | 26/36 [01:54<00:42,  4.25s/it]

Top-1, correct = 0.25, total = 3758, acc = 0.000
Top-5, correct = 0.41, total = 3758, acc = 0.000
Top-10, correct = 0.50, total = 3758, acc = 0.000
Top-20, correct = 0.60, total = 3758, acc = 0.000
Top-50, correct = 0.73, total = 3758, acc = 0.000


 75%|████████████████████████████████▎          | 27/36 [01:58<00:38,  4.24s/it]

Top-1, correct = 0.27, total = 3753, acc = 0.000
Top-5, correct = 0.42, total = 3753, acc = 0.000
Top-10, correct = 0.51, total = 3753, acc = 0.000
Top-20, correct = 0.61, total = 3753, acc = 0.000
Top-50, correct = 0.72, total = 3753, acc = 0.000


 78%|█████████████████████████████████▍         | 28/36 [02:02<00:34,  4.27s/it]

Top-1, correct = 0.12, total = 3749, acc = 0.000
Top-5, correct = 0.21, total = 3749, acc = 0.000
Top-10, correct = 0.28, total = 3749, acc = 0.000
Top-20, correct = 0.38, total = 3749, acc = 0.000
Top-50, correct = 0.54, total = 3749, acc = 0.000


 81%|██████████████████████████████████▋        | 29/36 [02:07<00:30,  4.29s/it]

Top-1, correct = 0.28, total = 3752, acc = 0.000
Top-5, correct = 0.43, total = 3752, acc = 0.000
Top-10, correct = 0.52, total = 3752, acc = 0.000
Top-20, correct = 0.62, total = 3752, acc = 0.000
Top-50, correct = 0.75, total = 3752, acc = 0.000


 83%|███████████████████████████████████▊       | 30/36 [02:11<00:25,  4.30s/it]

Top-1, correct = 0.26, total = 3756, acc = 0.000
Top-5, correct = 0.40, total = 3756, acc = 0.000
Top-10, correct = 0.48, total = 3756, acc = 0.000
Top-20, correct = 0.57, total = 3756, acc = 0.000
Top-50, correct = 0.71, total = 3756, acc = 0.000


 86%|█████████████████████████████████████      | 31/36 [02:15<00:21,  4.29s/it]

Top-1, correct = 0.26, total = 3743, acc = 0.000
Top-5, correct = 0.41, total = 3743, acc = 0.000
Top-10, correct = 0.50, total = 3743, acc = 0.000
Top-20, correct = 0.59, total = 3743, acc = 0.000
Top-50, correct = 0.72, total = 3743, acc = 0.000


 89%|██████████████████████████████████████▏    | 32/36 [02:20<00:17,  4.28s/it]

Top-1, correct = 0.25, total = 3750, acc = 0.000
Top-5, correct = 0.39, total = 3750, acc = 0.000
Top-10, correct = 0.49, total = 3750, acc = 0.000
Top-20, correct = 0.59, total = 3750, acc = 0.000
Top-50, correct = 0.73, total = 3750, acc = 0.000


 92%|███████████████████████████████████████▍   | 33/36 [02:24<00:12,  4.27s/it]

Top-1, correct = 0.23, total = 3752, acc = 0.000
Top-5, correct = 0.37, total = 3752, acc = 0.000
Top-10, correct = 0.46, total = 3752, acc = 0.000
Top-20, correct = 0.55, total = 3752, acc = 0.000
Top-50, correct = 0.69, total = 3752, acc = 0.000


 94%|████████████████████████████████████████▌  | 34/36 [02:28<00:08,  4.26s/it]

Top-1, correct = 0.27, total = 3760, acc = 0.000
Top-5, correct = 0.40, total = 3760, acc = 0.000
Top-10, correct = 0.48, total = 3760, acc = 0.000
Top-20, correct = 0.57, total = 3760, acc = 0.000
Top-50, correct = 0.70, total = 3760, acc = 0.000


 97%|█████████████████████████████████████████▊ | 35/36 [02:32<00:04,  4.27s/it]

Top-1, correct = 0.26, total = 3758, acc = 0.000
Top-5, correct = 0.40, total = 3758, acc = 0.000
Top-10, correct = 0.49, total = 3758, acc = 0.000
Top-20, correct = 0.58, total = 3758, acc = 0.000
Top-50, correct = 0.72, total = 3758, acc = 0.000


100%|███████████████████████████████████████████| 36/36 [02:37<00:00,  4.36s/it]

Top-1, correct = 0.25, total = 3744, acc = 0.000
Top-5, correct = 0.42, total = 3744, acc = 0.000
Top-10, correct = 0.51, total = 3744, acc = 0.000
Top-20, correct = 0.62, total = 3744, acc = 0.000
Top-50, correct = 0.75, total = 3744, acc = 0.000





In [26]:
df_ucf.to_csv(date+'_'+ds_task+'_'+eval_type+'_score.csv', index=False)

In [27]:
df_ucf.groupby(['Stage', 'Condition', 'Curriculum']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Top1,Top5,Top10,Iteration
Stage,Condition,Curriculum,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,default,adev,0.132735,0.233246,0.301628,30000.0
1,default,adult,0.219533,0.34816,0.431955,30000.0
1,default,dev,0.262619,0.393484,0.47502,30000.0
1,default,rnd,0.204273,0.328796,0.41042,30000.0
2,default,adev,0.247142,0.391388,0.483072,60000.0
2,default,adult,0.264342,0.410753,0.498722,60000.0
2,default,dev,0.26538,0.427513,0.519373,60000.0
2,default,rnd,0.264032,0.414881,0.504221,60000.0
3,default,adev,0.269254,0.41956,0.503772,90000.0
3,default,adult,0.268915,0.423831,0.518645,90000.0


In [None]:
# 30s for linear evaluation
# 2s for nn