In [1]:
import joblib
import os
import sys
from visualization import *

%matplotlib inline

In [2]:
# Aggregate run results
!{sys.executable} agg_stats.py

In [3]:
nm_d = {
    
    'caltech': 'Caltech (101 classes)',
    'nuswide': 'NUS-WIDE (81 labels)',
    
    'mnist': 'MNIST (10 classes)',
    'mnistreg': 'MNIST-REG (24 tasks)',
}

st_d = {
    'cb': 'CatBoost (multioutput)',
    'gbdtmo': 'GBDTMO-Full',
    'gbdtso': 'GBDTMO-Sparse',
    'pb': 'SketchBoost Full',

    'random': 'Random Sampling',
    'proj': 'Random Projection',
}


HYPERPARAMS = 'baselines_and_params.pkl'

dataset_order = ['caltech', 'nuswide', 'mnist', 'mnistreg']

strategy_order =  ['random', 'proj',]
baselines = ['pb', 'cb', 'gbdtmo', 'gbdtso']
k_limit = 20

OUTPUT_DIR = 'output/OUTPUT_GBDTMO'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [4]:
scores = get_summary_table(dataset_order, strategy_order, 'final_test_score', 
                           st_d, nm_d, baselines=baselines, sorter='sorter_def',output_dir=OUTPUT_DIR)
stds = get_summary_table(dataset_order, strategy_order, 'std_final_test_score', 
                         st_d, nm_d, baselines=baselines, sorter='sorter_def',output_dir=OUTPUT_DIR)


scores = scores.set_index('Dataset').astype(str)


total = (scores + '+-' + stds.set_index('Dataset').astype(str)
         ).replace('-+--', '-').reset_index()
total.to_csv(os.path.join(OUTPUT_DIR, 'summary_' + 'final_score_w_std' + '.csv'), index=False)

total

Unnamed: 0,Dataset,SketchBoost Full,CatBoost (multioutput),GBDTMO-Full,GBDTMO-Sparse,Random Sampling,Random Projection
0,Caltech (101 classes),0.5549+-0.008,0.5049+-0.0167,0.4469+-0.059,0.4796+-0.0375,0.571+-0.0209,0.5623+-0.0159
1,NUS-WIDE (81 labels),0.9893+-0.0002,0.9893+-0.0001,0.9891+-0.0002,0.9892+-0.0006,0.9893+-0.0004,0.9897+-0.0003
2,MNIST (10 classes),0.973+-0.0028,0.9684+-0.004,0.976+-0.004,0.9758+-0.0048,0.9755+-0.0042,0.974+-0.0032
3,MNIST-REG (24 tasks),0.266+-0.0019,0.2708+-0.0023,0.2723+-0.0026,0.2736+-0.0017,0.2661+-0.0019,0.2654+-0.0012


In [5]:
get_summary_table(dataset_order, strategy_order, 'train_time', 
                           st_d, nm_d, baselines=baselines, sorter='sorter_def', output_dir=OUTPUT_DIR)

Unnamed: 0,Dataset,SketchBoost Full,CatBoost (multioutput),GBDTMO-Full,GBDTMO-Sparse,Random Sampling,Random Projection
0,Caltech (101 classes),13.8488,136.435,776.9003,1312.5386,12.488,16.0879
1,NUS-WIDE (81 labels),87.5762,13857.0014,2606.9869,3660.2744,45.2191,72.6631
2,MNIST (10 classes),46.6257,156.2955,362.6884,399.7232,102.8937,66.8702
3,MNIST-REG (24 tasks),90.3452,964.1152,210.9613,163.1199,120.8742,45.9758


In [6]:
get_summary_table(dataset_order, strategy_order, 'best_iter', 
                           st_d, nm_d, baselines=baselines, sorter='sorter_def', output_dir=OUTPUT_DIR)

Unnamed: 0,Dataset,SketchBoost Full,CatBoost (multioutput),GBDTMO-Full,GBDTMO-Sparse,Random Sampling,Random Projection
0,Caltech (101 classes),128.2,219.6,329.6,451.8,257.2,185.8
1,NUS-WIDE (81 labels),441.0,1103.8,789.2,878.2,458.4,450.4
2,MNIST (10 classes),249.6,299.4,424.2,457.6,199.2,323.6
3,MNIST-REG (24 tasks),847.8,3026.2,994.8,871.8,666.8,963.0


In [7]:
scores = get_total_table(dataset_order, strategy_order, 'final_test_score', 
                         st_d, nm_d, r=4, baselines=baselines, output_dir=OUTPUT_DIR)
stds = get_total_table(dataset_order, strategy_order, 'std_final_test_score', 
                         st_d, nm_d, r=4, baselines=baselines, output_dir=OUTPUT_DIR)


scores = scores.set_index(['Dataset', 'K']).astype(str)
stds = stds.set_index(['Dataset', 'K']).astype(str)

total = (scores + '+-' + stds).replace('-+--', '-').reset_index()
total.to_csv(os.path.join(OUTPUT_DIR, 'detailed_' + 'final_test_score_w_std' + '.csv'), index=False)
total

Unnamed: 0,Dataset,K,Caltech (101 classes),NUS-WIDE (81 labels),MNIST (10 classes),MNIST-REG (24 tasks)
0,SketchBoost Full,-,0.5549+-0.008,0.9893+-0.0002,0.973+-0.0028,0.266+-0.0019
1,GBDTMO-Full,-,0.4469+-0.059,0.9891+-0.0002,0.976+-0.004,0.2723+-0.0026
2,GBDTMO-Sparse,-,0.4796+-0.0375,0.9892+-0.0006,0.9758+-0.0048,0.2736+-0.0017
3,CatBoost (multioutput),-,0.5049+-0.0167,0.9893+-0.0001,0.9684+-0.004,0.2708+-0.0023
4,Random Sampling,1,0.5377+-0.0158,0.9892+-0.0003,0.973+-0.0045,0.2671+-0.0011
5,Random Sampling,2,0.5704+-0.0174,0.9891+-0.0003,0.975+-0.0034,0.2678+-0.0015
6,Random Sampling,5,0.5599+-0.0146,0.9887+-0.0002,0.9755+-0.0042,0.2671+-0.0012
7,Random Sampling,10,0.571+-0.0209,0.9891+-0.0002,0.9753+-0.0007,0.2661+-0.0019
8,Random Sampling,20,0.5691+-0.0127,0.9893+-0.0004,-,0.2667+-0.001
9,Random Projection,1,0.5623+-0.0159,0.9897+-0.0003,0.9737+-0.0023,0.2657+-0.0018


In [8]:
get_total_table(dataset_order, strategy_order, 'train_time', 
                         st_d, nm_d, r=4, baselines=baselines, output_dir=OUTPUT_DIR)

Unnamed: 0,Dataset,K,Caltech (101 classes),NUS-WIDE (81 labels),MNIST (10 classes),MNIST-REG (24 tasks)
0,SketchBoost Full,-,13.8488,87.5762,46.6257,90.3452
1,GBDTMO-Full,-,776.9003,2606.9869,362.688,210.9613
2,GBDTMO-Sparse,-,1312.5386,3660.2744,399.723,163.1199
3,CatBoost (multioutput),-,136.435,13857.0014,156.296,964.1152
4,Random Sampling,1,14.0101,36.4078,66.8061,110.8892
5,Random Sampling,2,42.9496,145.0915,99.5403,85.4452
6,Random Sampling,5,40.407,148.7295,102.894,98.5611
7,Random Sampling,10,12.488,40.0959,88.9745,120.8742
8,Random Sampling,20,40.8565,45.2191,-,112.0635
9,Random Projection,1,16.0879,72.6631,45.5606,38.3287


In [9]:
!rm {os.path.join(OUTPUT_DIR, 'experiments_gbdtmo.tar')}
!tar --totals -cvf {os.path.join(OUTPUT_DIR, 'experiments_gbdtmo.tar')} {OUTPUT_DIR}

output/OUTPUT_GBDTMO/
output/OUTPUT_GBDTMO/summary_final_test_score.csv
output/OUTPUT_GBDTMO/summary_std_final_test_score.csv
output/OUTPUT_GBDTMO/summary_final_score_w_std.csv
output/OUTPUT_GBDTMO/summary_train_time.csv
output/OUTPUT_GBDTMO/summary_best_iter.csv
output/OUTPUT_GBDTMO/detailed_final_test_score.csv
output/OUTPUT_GBDTMO/detailed_std_final_test_score.csv
output/OUTPUT_GBDTMO/detailed_final_test_score_w_std.csv
output/OUTPUT_GBDTMO/detailed_train_time.csv
tar: output/OUTPUT_GBDTMO/experiments_gbdtmo.tar: file is the archive; not dumped
Total bytes written: 20480 (20KiB, 6.0MiB/s)
