In [2]:
import glob
import numpy as np
import matplotlib.pyplot as plt

import csv
import os
import re

In [3]:
%matplotlib inline

In [4]:
def read_to_dict(file_name):
    res_dict = {}
    keym = {}
    with open(file_name, newline='') as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=',')
        for i, row in enumerate(csv_reader):
            if i == 0:
                for j, rr in enumerate(row):
                    res_dict[rr] = []
                    keym[j] = rr
            else:
                for j, rr in enumerate(row):
                    if j not in keym:
                        print('error')
                        break
                    res_dict[keym[j]].append(float(rr))
    return res_dict

In [5]:
def early_stop(result_dict, key1, key2, max_ep):
    if max_ep is None:
        best_ep_0 = np.argmax(np.array(result_dict[key1]))
        best_ep_1 = np.argmax(np.array(result_dict[key2]))
    else:
        best_ep_0 = np.argmax(np.array(result_dict[key1][0:max_ep]))
        best_ep_1 = np.argmax(np.array(result_dict[key2][0:max_ep]))
    return best_ep_0, best_ep_1

In [29]:
def collect_results(entries, res_keys, which, max_ep):
    map_k = {0:"test", 1:"val", 2:"id_test", 3: "train"}
    res = {}
    for i in range(2):
        res[i] = {"test":{}, "val":{}, "id_test":{}, "train":{}}
        for dic in ["test", "val", "id_test", "train"]:
            for res_k in res_keys:
                res[i][dic][res_k] = []
    for entry in entries:
        e1,e2 = early_stop(entries[entry][which],res_keys[0], res_keys[1], max_ep)
        for i, e_v in enumerate([e1, e2]):
            for res_k in res_keys:
                for j_id in range(4):
                    if entries[entry][j_id] is None:
                        continue
                    if e_v >= len(entries[entry][j_id][res_k]):
                        print("Error", res_k, len(entries[entry][j_id][res_k]), map_k[j_id])
                        res[i][map_k[j_id]][res_k].append(entries[entry][j_id][res_k][-1])
                    else:
                        res[i][map_k[j_id]][res_k].append(entries[entry][j_id][res_k][e_v])

    return res
def print_row(row, res_keys):
    final_str = ""
    for res_k in res_keys:
        final_str += " {:04.2f}".format(100*np.mean(row[res_k]))
        final_str += " {:04.2f}".format(100*np.std(row[res_k]))
    return final_str

In [7]:
def parse_config(log_di, my_keys):
    config = open(log_di+'log.txt').read().split('\n')
    dat_conf = {}
    for ll in config:
        if 'Epoch [0]' in ll:
            break
        if 'Dataset kwargs' in ll and '{}' not in ll:
            dat_conf[ll.split(':')[0]] = ll.split(':')[2]
        elif ll.split(':')[0] in my_keys:
            dat_conf[ll.split(':')[0]] = ll.split(':')[1]
    return dat_conf

def parse_config_orig(log_di, my_keys):
    config = open(log_di+'log.txt').read().split('\n')
    dat_conf = {}
    for ll in config:
        if 'Epoch [0]' in ll:
            break
        if ll.split(':')[0] in my_keys:
            dat_conf[ll.split(':')[0]] = ll.split(':')[1]
    return dat_conf

In [8]:
np.exp(np.linspace(np.log(0.0001),np.log(10),7))

array([1.00000000e-04, 6.81292069e-04, 4.64158883e-03, 3.16227766e-02,
       2.15443469e-01, 1.46779927e+00, 1.00000000e+01])

# iWildCam

In [50]:
log_dicts = sorted(glob.glob('./iWildCamResults/logs_iwild_f*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [52]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

exps = {}
for log_di in log_dicts:
    dat_conf = parse_config_orig(log_di, my_keys)
    if not dat_conf['Dataset'] == ' iwildcam':
        continue
    exp_name = "{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], dat_conf['Lr'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'))
    if '3e-05' not in exp_name:
        continue
    if '2_' not in exp_name or '-1_ ' in exp_name:
        continue
    print(log_di, exp_name)

./iWildCamResults/logs_iwild_final_36/  2_ 0_ 3e-05
./iWildCamResults/logs_iwild_final_37/  2_ 0_ 3e-05
./iWildCamResults/logs_iwild_final_38/  2_ 0_ 3e-05
./iWildCamResults/logs_iwild_final_39/  2_ 1_ 3e-05
./iWildCamResults/logs_iwild_final_40/  2_ 1_ 3e-05
./iWildCamResults/logs_iwild_final_41/  2_ 1_ 3e-05
./iWildCamResults/logs_iwild_final_42/  2_ 2_ 3e-05
./iWildCamResults/logs_iwild_final_43/  2_ 2_ 3e-05
./iWildCamResults/logs_iwild_final_44/  2_ 2_ 3e-05
./iWildCamResults/logs_iwild_final_45/  2_ 3_ 3e-05
./iWildCamResults/logs_iwild_final_46/  2_ 3_ 3e-05
./iWildCamResults/logs_iwild_final_47/  2_ 3_ 3e-05


In [12]:
for exp_id in exps:
    if '3e-05' not in exp_id:
        continue
    #if '2_' not in exp_id or '-1_ ' in exp_id:
    #    continue
    resO = collect_results(exps[exp_id], ('acc_avg', 'F1-macro_all'),2 , None)
    print(exp_id)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'F1-macro_all'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'F1-macro_all'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'F1-macro_all'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

 -1_ 0_ 3e-05
Te	 ES_0  69.74 3.41 29.77 2.79
ID 	 ES_0  74.49 0.42 44.67 0.48
Te	 ES_1  71.43 3.24 29.14 3.14
ID 	 ES_1  74.16 0.65 45.00 0.65


 -1_ 1_ 3e-05
Te	 ES_0  63.94 1.80 22.96 1.18
ID 	 ES_0  71.90 1.45 42.63 1.28
Te	 ES_1  64.01 1.68 24.52 2.24
ID 	 ES_1  71.33 0.95 44.65 1.40


 -1_ 2_ 3e-05
Te	 ES_0  67.54 0.75 26.13 1.85
ID 	 ES_0  72.42 0.31 38.64 1.48
Te	 ES_1  66.97 1.22 26.69 1.71
ID 	 ES_1  71.34 1.40 40.26 0.40


 -1_ 3_ 3e-05
Te	 ES_0  69.44 1.43 28.45 0.83
ID 	 ES_0  74.12 0.31 46.55 3.21
Te	 ES_1  68.28 1.31 27.99 1.26
ID 	 ES_1  73.55 0.54 47.41 2.61


 2_ 0_ 3e-05
Te	 ES_0  73.81 2.87 33.85 1.42
ID 	 ES_0  76.88 0.53 48.18 2.24
Te	 ES_1  73.26 2.60 31.32 1.88
ID 	 ES_1  75.01 1.70 48.92 1.57


 2_ 1_ 3e-05
Te	 ES_0  76.08 0.81 34.31 0.91
ID 	 ES_0  76.54 0.62 47.29 3.33
Te	 ES_1  72.65 4.48 32.52 3.19
ID 	 ES_1  76.11 1.08 48.73 1.68


 2_ 2_ 3e-05
Te	 ES_0  76.00 0.58 34.46 1.00
ID 	 ES_0  76.73 0.49 49.04 1.49
Te	 ES_1  75.59 0.74 34.36 0.87
ID 	 ES_1  76.69

# Py150

In [53]:
log_dicts = sorted(glob.glob('./py150Results/rd_p_all_seeds*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [54]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [55]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config_orig(log_di, my_keys)
    if not dat_conf['Dataset'] == ' py150':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./py150Results/rd_p_all_seeds_0/  1e-05
./py150Results/rd_p_all_seeds_1/  1e-05
./py150Results/rd_p_all_seeds_2/  1e-05
./py150Results/rd_p_all_seeds_3/  1e-05
./py150Results/rd_p_all_seeds_4/  1e-05
./py150Results/rd_p_all_seeds_5/  1e-05
./py150Results/rd_p_all_seeds_6/  1e-05
./py150Results/rd_p_all_seeds_7/  1e-05
./py150Results/rd_p_all_seeds_8/  1e-05
./py150Results/rd_p_all_seeds_9/  8e-05
./py150Results/rd_p_all_seeds_10/  8e-05
./py150Results/rd_p_all_seeds_11/  8e-05
./py150Results/rd_p_all_seeds_12/  8e-05
./py150Results/rd_p_all_seeds_13/  8e-05
./py150Results/rd_p_all_seeds_14/  8e-05
./py150Results/rd_p_all_seeds_15/  8e-05
./py150Results/rd_p_all_seeds_16/  8e-05
./py150Results/rd_p_all_seeds_17/  8e-05
./py150Results/rd_p_all_seeds_18/  1e-05
./py150Results/rd_p_all_seeds_19/  1e-05
./py150Results/rd_p_all_seeds_20/  1e-05
./py150Results/rd_p_all_seeds_21/  1e-05
./py150Results/rd_p_all_seeds_22/  1e-05
./py150Results/rd_p_all_seeds_23/  1e-05
./py150Results/rd_p_all_se

In [48]:
for exp_id in exps:
    print(exp_id)
    #if not ('3_ 0_ 1e-05_ 2_ 6' in exp_id or '3_ 1_ 1e-05_ 2_ 6' in exp_id):
    #    continue
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    resO = collect_results(exps[exp_id], ('Acc (Class-Method)', 'Acc (Overall)'),0, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('Acc (Class-Method)', 'Acc (Overall)'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('Acc (Class-Method)', 'Acc (Overall)'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('Acc (Class-Method)', 'Acc (Overall)'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

 -1_ 0_ 1e-05_ 2_ 6
11.0
./py150Results/rd_p_all_seeds_0/
11.0
./py150Results/rd_p_all_seeds_1/
5.0
./py150Results/rd_p_all_seeds_2/
Te	 ES_0  66.38 0.11 68.14 0.04
ID 	 ES_0  68.58 0.48 69.43 0.25
Te	 ES_1  66.25 0.11 68.18 0.01
ID 	 ES_1  68.48 0.17 69.49 0.11


 -1_ 1_ 1e-05_ 2_ 6
12.0
./py150Results/rd_p_all_seeds_3/
11.0
./py150Results/rd_p_all_seeds_4/
5.0
./py150Results/rd_p_all_seeds_5/
Te	 ES_0  66.44 0.09 68.12 0.04
ID 	 ES_0  69.33 0.81 69.91 0.46
Te	 ES_1  66.21 0.09 68.18 0.02
ID 	 ES_1  68.89 0.52 69.87 0.39


 -1_ 3_ 1e-05_ 2_ 6
11.0
./py150Results/rd_p_all_seeds_6/
9.0
./py150Results/rd_p_all_seeds_7/
5.0
./py150Results/rd_p_all_seeds_8/
Te	 ES_0  66.36 0.09 68.13 0.05
ID 	 ES_0  68.83 0.24 69.65 0.08
Te	 ES_1  66.31 0.04 68.16 0.02
ID 	 ES_1  68.64 0.19 69.57 0.12


 -1_ 0_ 8e-05_ 2_ 6
3.0
./py150Results/rd_p_all_seeds_9/
7.0
./py150Results/rd_p_all_seeds_10/
3.0
./py150Results/rd_p_all_seeds_11/
Te	 ES_0  64.34 0.34 66.55 0.07
ID 	 ES_0  68.18 0.24 69.10 0.18
Te	 ES_1

# Mol

In [56]:
log_dicts = sorted(glob.glob('./RD_mol*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [57]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [59]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Batch size'], dat_conf['N groups per batch'],
                                 dat_conf['Lr'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        None,
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./RD_mol_0/  0.001
./RD_mols_0/  0.001
./RD_mols_1/  0.001
./RD_mol_1/  0.001
./RD_mol_2/  0.001
./RD_mol_3/  0.001
./RD_mols_3/  0.001
./RD_mol_4/  0.001
./RD_mols_4/  0.001
./RD_mols_5/  0.001
./RD_mol_5/  0.001
./RD_mol_6/  0.001
./RD_mols_7/  0.001
./RD_mol_7/  0.001
./RD_mol_8/  0.0005
./RD_mol_9/  0.0005
./RD_mol_10/  0.0005
./RD_mol_11/  0.0005
./RD_mol_12/  0.0005
./RD_mol_13/  0.0005
./RD_mol_14/  0.0005
./RD_mol_15/  0.0005
./RD_mol_16/  0.001
./RD_mols_16/  0.001
./RD_mol_17/  0.001
./RD_mol_18/  0.001
./RD_mols_18/  0.001
./RD_mol_19/  0.001
./RD_mols_19/  0.001
./RD_mol_20/  0.001
./RD_mols_20/  0.001
./RD_mol_21/  0.001
./RD_mols_22/  0.001
./RD_mol_22/  0.001
./RD_mol_23/  0.001
./RD_mols_23/  0.001
./RD_mols_24/  0.001
./RD_mol_24/  0.001
./RD_mol_25/  0.001
./RD_mols_26/  0.001
./RD_mol_26/  0.001
./RD_mols_27/  0.001
./RD_mol_27/  0.001
./RD_mol_28/  0.001
./RD_mols_28/  0.001
./RD_mol_29/  0.001
./RD_mol_30/  0.001
./RD_mols_30/  0.001
./RD_mol_31/  0.001
./RD_mols_3

In [60]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])

    resO = collect_results(exps[exp_id], ('ap', 'ap'), 1, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('ap', 'ap'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['val'], ('ap', 'ap'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('ap', 'ap'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        print("Tr \t", res_train_str)
        break
    print('\n')

' -1_ 0_ 32_ 32_ 0.001'
10.0
./RD_mol_0/
8.0
./RD_mols_0/
Te	 ES_0  15.44 0.34 15.44 0.34
ID 	 ES_0  15.48 0.33 15.48 0.33
Tr 	 ES_0  13.99 0.31 13.99 0.31


' -1_ 1_ 32_ 32_ 0.001'
7.0
./RD_mols_1/
14.0
./RD_mol_1/
Te	 ES_0  16.34 1.37 16.34 1.37
ID 	 ES_0  16.17 1.36 16.17 1.36
Tr 	 ES_0  14.57 1.22 14.57 1.22


' -1_ 2_ 32_ 32_ 0.001'
2.0
./RD_mol_2/
Te	 ES_0  11.61 0.00 11.61 0.00
ID 	 ES_0  10.96 0.00 10.96 0.00
Tr 	 ES_0  9.19 0.00 9.19 0.00


' -1_ 3_ 32_ 32_ 0.001'
14.0
./RD_mol_3/
7.0
./RD_mols_3/
Te	 ES_0  15.72 0.92 15.72 0.92
ID 	 ES_0  15.76 0.92 15.76 0.92
Tr 	 ES_0  14.53 1.56 14.53 1.56


' 40_ 0_ 32_ 32_ 0.001'
55.0
./RD_mol_4/
47.0
./RD_mols_4/
Te	 ES_0  24.62 0.12 24.62 0.12
ID 	 ES_0  24.99 0.05 24.99 0.05
Tr 	 ES_0  28.74 0.22 28.74 0.22


' 40_ 1_ 32_ 32_ 0.001'
48.0
./RD_mols_5/
55.0
./RD_mol_5/
Te	 ES_0  24.43 0.09 24.43 0.09
ID 	 ES_0  25.19 0.04 25.19 0.04
Tr 	 ES_0  28.72 0.57 28.72 0.57


' 40_ 2_ 32_ 32_ 0.001'
43.0
./RD_mol_6/
Te	 ES_0  24.32 0.00 24.32 0.

In [30]:
exps.keys()

dict_keys([])

# FMOW

In [61]:
log_dicts = sorted(glob.glob('./FowResults/WBN*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [62]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [63]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    if not dat_conf['Dataset'] == ' fmow':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Batch size'], dat_conf['N groups per batch'],
                                 dat_conf['Lr'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./FowResults/WBN_rd_fow_0/  0.0001
./FowResults/WBN_rd_fow_1/  0.0001
./FowResults/WBN_rd_fow_2/  0.0001
./FowResults/WBN_rd_fow_3/  0.0001
./FowResults/WBN_rd_fow_4/  0.0001
./FowResults/WBN_rd_fow_5/  0.0001
./FowResults/WBN_rd_fow_6/  0.0001
./FowResults/WBN_rd_fow_7/  0.0001
./FowResults/WBN_rd_fow_8/  0.0005
./FowResults/WBN_rd_fow_9/  0.0005
./FowResults/WBN_rd_fow_10/  0.0005
./FowResults/WBN_rd_fow_11/  0.0005
./FowResults/WBN_rd_fow_12/  0.0005
./FowResults/WBN_rd_fow_13/  0.0005
./FowResults/WBN_rd_fow_14/  0.0005
./FowResults/WBN_rd_fow_15/  0.0005
./FowResults/WBN_rd_fow_16/  0.0005
./FowResults/WBN_rd_fow_17/  0.0005
./FowResults/WBN_rd_fow_18/  0.0005
./FowResults/WBN_rd_fow_19/  0.0005
./FowResults/WBN_rd_fow_20/  0.0005
./FowResults/WBN_rd_fow_21/  0.0005
./FowResults/WBN_rd_fow_22/  0.0005
./FowResults/WBN_rd_fow_23/  0.0005
./FowResults/WBN_rd_fow_50/  0.0001
./FowResults/WBN_rd_fow_51/  0.0001
./FowResults/WBN_rd_fow_52/  0.0001
./FowResults/WBN_rd_fow_53/  0.0001
./

In [70]:
for exp_id in exps:
    if ' '.join(exp_id.split(' ')[:-1]) not in [' 3_ 0_ 32_ 4_', ' 3_ 2_ 32_ 4_', ' 3_ 0_ 64_ 4_']:
        continue
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])

    resO = collect_results(exps[exp_id], ('acc_avg', 'acc_worst_region'), 0, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'acc_worst_region'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'acc_worst_region'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'acc_worst_region'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        print("Tr \t", res_train_str)
    print('\n')

' 3_ 0_ 32_ 4_ 0.0001'
46.0
./FowResults/WBN_rd_fow_4/
46.0
./FowResults/WBN_rd_fow_54/
46.0
./FowResults/WBN_rd_fow_154/
Te	 ES_0  52.77 0.29 34.44 1.07
ID 	 ES_0  59.19 0.08 57.68 0.40
Tr 	 ES_0  99.77 0.21 99.71 0.19
Te	 ES_1  52.27 0.64 35.36 0.41
ID 	 ES_1  58.96 0.52 57.63 0.19
Tr 	 ES_1  99.21 1.01 98.93 1.35


' 3_ 2_ 32_ 4_ 0.0001'
46.0
./FowResults/WBN_rd_fow_6/
46.0
./FowResults/WBN_rd_fow_56/
49.0
./FowResults/WBN_rd_fow_156/
Te	 ES_0  53.00 0.22 32.66 0.72
ID 	 ES_0  59.59 0.51 57.60 0.81
Tr 	 ES_0  99.96 0.04 99.94 0.05
Te	 ES_1  52.28 0.80 33.08 0.62
ID 	 ES_1  58.89 0.54 56.85 0.93
Tr 	 ES_1  99.26 0.98 99.02 1.27


' 3_ 0_ 64_ 4_ 0.0005'
49.0
./FowResults/WBN_rd_fow_16/
49.0
./FowResults/WBN_rd_fow_62/
49.0
./FowResults/WBN_rd_fow_162/
Te	 ES_0  52.51 0.17 33.19 0.56
ID 	 ES_0  59.31 0.23 57.61 0.87
Tr 	 ES_0  99.92 0.05 99.86 0.09
Te	 ES_1  51.83 0.60 33.72 0.24
ID 	 ES_1  58.44 0.63 56.64 0.52
Tr 	 ES_1  99.54 0.46 99.45 0.48




# Cam

In [9]:
log_dicts = sorted(glob.glob('./CamelyonResults/rd_cam*/'), key= lambda x:int(re.sub('s','',x).split('_')[-1].split('/')[0]))

In [10]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [11]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_val_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./CamelyonResults/rd_cam_0/  0.005
./CamelyonResults/rd_cam_sss1/  0.005
./CamelyonResults/rd_cam_s1/  0.005
./CamelyonResults/rd_cam_1/  0.005
./CamelyonResults/rd_cam_ss1/  0.005
./CamelyonResults/rd_cam_2/  0.005
./CamelyonResults/rd_cam_s3/  0.005
./CamelyonResults/rd_cam_3/  0.005
./CamelyonResults/rd_cam_ss3/  0.005
./CamelyonResults/rd_cam_4/  0.005
./CamelyonResults/rd_cam_ss5/  0.005
./CamelyonResults/rd_cam_s5/  0.005
./CamelyonResults/rd_cam_5/  0.005
./CamelyonResults/rd_cam_6/  0.005
./CamelyonResults/rd_cam_7/  0.005
./CamelyonResults/rd_cam_s7/  0.005
./CamelyonResults/rd_cam_ss7/  0.005
./CamelyonResults/rd_cam_8/  0.005
./CamelyonResults/rd_cam_s9/  0.005
./CamelyonResults/rd_cam_9/  0.005
./CamelyonResults/rd_cam_ss9/  0.005
./CamelyonResults/rd_cam_10/  0.005
./CamelyonResults/rd_cam_11/  0.005
./CamelyonResults/rd_cam_ss11/  0.005
./CamelyonResults/rd_cam_12/  0.01
./CamelyonResults/rd_cam_s13/  0.01
./CamelyonResults/rd_cam_13/  0.01
./CamelyonResults/rd_cam_14/  0

In [23]:
exps[exp_id][' 9'][0].keys()

dict_keys(['epoch', 'acc_avg', 'acc_slide:0', 'count_slide:0', 'acc_slide:1', 'count_slide:1', 'acc_slide:2', 'count_slide:2', 'acc_slide:3', 'count_slide:3', 'acc_slide:4', 'count_slide:4', 'acc_slide:5', 'count_slide:5', 'acc_slide:6', 'count_slide:6', 'acc_slide:7', 'count_slide:7', 'acc_slide:8', 'count_slide:8', 'acc_slide:9', 'count_slide:9', 'acc_slide:10', 'count_slide:10', 'acc_slide:11', 'count_slide:11', 'acc_slide:12', 'count_slide:12', 'acc_slide:13', 'count_slide:13', 'acc_slide:14', 'count_slide:14', 'acc_slide:15', 'count_slide:15', 'acc_slide:16', 'count_slide:16', 'acc_slide:17', 'count_slide:17', 'acc_slide:18', 'count_slide:18', 'acc_slide:19', 'count_slide:19', 'acc_slide:20', 'count_slide:20', 'acc_slide:21', 'count_slide:21', 'acc_slide:22', 'count_slide:22', 'acc_slide:23', 'count_slide:23', 'acc_slide:24', 'count_slide:24', 'acc_slide:25', 'count_slide:25', 'acc_slide:26', 'count_slide:26', 'acc_slide:27', 'count_slide:27', 'acc_slide:28', 'count_slide:28', 'ac

In [30]:
for exp_id in exps:
    if exp_id not in [' -1_ 1_ 0.005_ 3_ 30', ' -1_ 1_ 0.005_ 3_ 60', ' -1_ 3_ 0.005_ 3_ 60' ,' -1_ 3_ 0.01_ 3_ 60', ' 1_ 2_ 0.01_ 3_ 60']:
        continue
    print("'{}'".format(exp_id))
    skip_exp = False
    for kk in exps[exp_id]:
        if 'epoch' not in exps[exp_id][kk][0].keys():
            #pint("Error", exps[exp_id][kk][4])
            skip_exp = True
            break
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    if skip_exp:
        continue
        
    resO = collect_results(exps[exp_id], ('acc_avg', 'acc_avg'), 1, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'acc_avg'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'acc_avg'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'acc_avg'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

' -1_ 1_ 0.005_ 3_ 30'
2.0
./CamelyonResults/rd_cam_sss1/
7.0
./CamelyonResults/rd_cam_s1/
7.0
./CamelyonResults/rd_cam_1/
1.0
./CamelyonResults/rd_cam_ss1/
1.0
./CamelyonResults/rd_cam_ss5/
7.0
./CamelyonResults/rd_cam_s5/
7.0
./CamelyonResults/rd_cam_5/
4.0
./CamelyonResults/rd_cam_s9/
5.0
./CamelyonResults/rd_cam_9/
' -1_ 1_ 0.005_ 3_ 60'
6.0
./CamelyonResults/rd_cam_sss25/
7.0
./CamelyonResults/rd_cam_ss25/
7.0
./CamelyonResults/rd_cam_s25/
7.0
./CamelyonResults/rd_cam_25/
7.0
./CamelyonResults/rd_cam_29/
7.0
./CamelyonResults/rd_cam_ss29/
7.0
./CamelyonResults/rd_cam_s29/
7.0
./CamelyonResults/rd_cam_ss33/
7.0
./CamelyonResults/rd_cam_s33/
7.0
./CamelyonResults/rd_cam_33/
Te	 ES_0  73.95 9.37 73.95 9.37
ID 	 ES_0  94.83 1.94 94.83 1.94
Te	 ES_1  73.95 9.37 73.95 9.37
ID 	 ES_1  94.83 1.94 94.83 1.94


' -1_ 3_ 0.005_ 3_ 60'
7.0
./CamelyonResults/rd_cam_s27/
7.0
./CamelyonResults/rd_cam_sss27/
7.0
./CamelyonResults/rd_cam_ss27/
7.0
./CamelyonResults/rd_cam_27/
7.0
./CamelyonResults

# Poverty

In [72]:
log_dicts = sorted(glob.glob('./PovertyResults/POV_RD*/'), key= lambda x:int(re.sub('s','',x).split('_')[-1].split('/')[0]))

In [73]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed', 'Dataset kwargs']

In [74]:
exps = {}
for log_di in log_dicts:
    #if 'RDE' not in log_di:
    #    continue
    dat_conf = parse_config(log_di, my_keys)
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['N groups per batch'], dat_conf['Lr'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Dataset kwargs']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'id_val_eval.csv'),
                                        log_di)

./PovertyResults/POV_RDC_0/  0.0001
./PovertyResults/POV_RDA_0/  0.0001
./PovertyResults/POV_RDE_0/  0.0001
./PovertyResults/POV_RD_0/  0.0001
./PovertyResults/POV_RDD_0/  0.0001
./PovertyResults/POV_RDA_1/  0.0001
./PovertyResults/POV_RDC_1/  0.0001
./PovertyResults/POV_RDE_1/  0.0001
./PovertyResults/POV_RDD_1/  0.0001
./PovertyResults/POV_RD_1/  0.0001
./PovertyResults/POV_RDE_2/  0.0001
./PovertyResults/POV_RD_2/  0.0001
./PovertyResults/POV_RDA_2/  0.0001
./PovertyResults/POV_RDD_2/  0.0001
./PovertyResults/POV_RDC_2/  0.0001
./PovertyResults/POV_RDD_3/  0.0001
./PovertyResults/POV_RD_3/  0.0001
./PovertyResults/POV_RDA_3/  0.0001
./PovertyResults/POV_RDC_3/  0.0001
./PovertyResults/POV_RDE_3/  0.0001
./PovertyResults/POV_RDA_4/  0.0001
./PovertyResults/POV_RDE_4/  0.0001
./PovertyResults/POV_RDC_4/  0.0001
./PovertyResults/POV_RDD_4/  0.0001
./PovertyResults/POV_RD_4/  0.0001
./PovertyResults/POV_RDD_5/  0.0001
./PovertyResults/POV_RDA_5/  0.0001
./PovertyResults/POV_RDC_5/  0.00

In [75]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
        
    resO = collect_results(exps[exp_id], ('r_all', 'r_wg'), 0, None)

    for alg in resO:
        res_test_str = "OOD_{} ".format(alg) +print_row(resO[alg]['test'], ('r_all', 'r_wg'))
        res_id_t_str = "ID_{} ".format(alg) +print_row(resO[alg]['id_test'], ('r_all', 'r_wg'))
        res_train_str = "TR_{} ".format(alg) +print_row(resO[alg]['train'], ('r_all', 'r_wg'))
        res_val_str = "VAL_{} ".format(alg) +print_row(resO[alg]['val'], ('r_all', 'r_wg'))

        print(res_test_str)
        print(res_id_t_str)
        print(res_train_str)
        print(res_val_str)
    print('\n')

' -1_ 0_ 4_ 0.0001_ 32'
199.0
./PovertyResults/POV_RDC_0/
199.0
./PovertyResults/POV_RDA_0/
199.0
./PovertyResults/POV_RDE_0/
199.0
./PovertyResults/POV_RD_0/
199.0
./PovertyResults/POV_RDD_0/
OOD_0  79.66 3.63 47.61 4.59
ID_0  81.90 1.26 55.09 5.52
TR_0  80.69 1.96 53.70 3.85
VAL_0  77.23 3.44 39.20 6.99
OOD_1  79.28 3.26 48.45 3.75
ID_1  81.75 1.11 55.43 5.71
TR_1  80.55 1.92 53.64 4.43
VAL_1  76.91 3.38 41.07 4.25


' -1_ 1_ 4_ 0.0001_ 32'
199.0
./PovertyResults/POV_RDA_1/
199.0
./PovertyResults/POV_RDC_1/
199.0
./PovertyResults/POV_RDE_1/
199.0
./PovertyResults/POV_RDD_1/
199.0
./PovertyResults/POV_RD_1/
OOD_0  79.19 3.66 46.62 3.14
ID_0  81.81 1.94 54.46 6.37
TR_0  80.64 1.39 51.94 3.46
VAL_0  77.31 3.93 41.02 9.44
OOD_1  78.65 3.62 47.44 2.74
ID_1  82.33 1.53 55.41 5.58
TR_1  81.42 2.02 55.16 5.60
VAL_1  77.46 2.34 43.72 6.46


' -1_ 2_ 4_ 0.0001_ 32'
199.0
./PovertyResults/POV_RDE_2/
199.0
./PovertyResults/POV_RD_2/
199.0
./PovertyResults/POV_RDA_2/
199.0
./PovertyResults/POV_RD

# Amazon

In [40]:
log_dicts = sorted(glob.glob('./AmazonResults/AMA_*/'), key= lambda x:int(re.sub('s','',x).split('_')[-1].split('/')[0]))

In [41]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed', 'Dataset kwargs']

In [42]:
exps = {}
for log_di in log_dicts:
    #if 'RDE' not in log_di:
    #    continue
    dat_conf = parse_config_orig(log_di, my_keys)
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    test_res = read_to_dict(log_di+'test_eval.csv')
    if 'epoch' not in test_res:
        continue
    exps[exp_name][dat_conf['Seed']] = (test_res, 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)


./AmazonResults/AMA_RD_0/  5e-06
./AmazonResults/AMA_RD2_0/  5e-06
./AmazonResults/AMA_RD1_0/  5e-06
./AmazonResults/AMA_RD_1/  5e-06
./AmazonResults/AMA_RD1_1/  5e-06
./AmazonResults/AMA_RD2_1/  5e-06
./AmazonResults/AMA_RD1_2/  5e-06
./AmazonResults/AMA_RD_2/  5e-06
./AmazonResults/AMA_RD2_2/  5e-06
./AmazonResults/AMA_RD_3/  5e-06
./AmazonResults/AMA_RD1_3/  5e-06
./AmazonResults/AMA_RD2_3/  5e-06
./AmazonResults/AMA_RD_4/  5e-06
./AmazonResults/AMA_RD1_4/  5e-06
./AmazonResults/AMA_RD2_4/  5e-06
./AmazonResults/AMA_RD1_5/  5e-06
./AmazonResults/AMA_RD_5/  5e-06
./AmazonResults/AMA_RD2_5/  5e-06
./AmazonResults/AMA_RD2_6/  5e-06
./AmazonResults/AMA_RD1_6/  5e-06
./AmazonResults/AMA_RD_6/  5e-06
./AmazonResults/AMA_RD_7/  5e-06
./AmazonResults/AMA_RD2_7/  5e-06
./AmazonResults/AMA_RD1_7/  5e-06
./AmazonResults/AMA_RD2_8/  1e-05
./AmazonResults/AMA_RD1_8/  1e-05
./AmazonResults/AMA_RD_8/  1e-05
./AmazonResults/AMA_RD_9/  1e-05
./AmazonResults/AMA_RD2_9/  1e-05
./AmazonResults/AMA_RD1_

In [44]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    breakl = False
    for kk in exps[exp_id]:
        if 'epoch' not in exps[exp_id][kk][0]:
            breakl = True
            continue
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    #if breakl:
    #    continue
    resO = collect_results(exps[exp_id], ('acc_avg', '10th_percentile_acc'), 0, None)

    for alg in resO:
        res_test_str = "OOD_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', '10th_percentile_acc'))
        res_id_t_str = "ID_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', '10th_percentile_acc'))
        res_train_str = "TR_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', '10th_percentile_acc'))
        res_val_str = "VAL_{} ".format(alg) +print_row(resO[alg]['val'], ('acc_avg', '10th_percentile_acc'))

        print(res_test_str)
        print(res_id_t_str)
        #print(res_train_str)
        #print(res_val_str)
    print('\n')

' -1_ 0_ 5e-06_ 2_ 8'
2.0
./AmazonResults/AMA_RD_0/
1.0
./AmazonResults/AMA_RD2_0/
2.0
./AmazonResults/AMA_RD1_0/
OOD_0  70.26 0.11 52.89 0.63
ID_0  72.29 0.06 53.78 0.63
OOD_1  70.13 0.27 52.89 0.63
ID_1  71.95 0.42 53.33 1.09


' -1_ 1_ 5e-06_ 2_ 8'
2.0
./AmazonResults/AMA_RD_1/
2.0
./AmazonResults/AMA_RD1_1/
1.0
./AmazonResults/AMA_RD2_1/
OOD_0  71.29 0.17 53.33 1.09
ID_0  73.55 0.14 56.00 0.00
OOD_1  71.05 0.34 53.78 0.63
ID_1  73.04 0.47 55.56 0.63


' -1_ 2_ 5e-06_ 2_ 8'
2.0
./AmazonResults/AMA_RD1_2/
2.0
./AmazonResults/AMA_RD_2/
2.0
./AmazonResults/AMA_RD2_2/
OOD_0  70.71 0.12 53.02 0.44
ID_0  73.09 0.13 55.11 0.63
OOD_1  70.35 0.27 53.02 0.44
ID_1  72.61 0.36 55.56 0.63


' -1_ 3_ 5e-06_ 2_ 8'
2.0
./AmazonResults/AMA_RD_3/
2.0
./AmazonResults/AMA_RD1_3/
2.0
./AmazonResults/AMA_RD2_3/
OOD_0  71.18 0.09 52.00 1.09
ID_0  73.07 0.05 53.78 0.63
OOD_1  70.75 0.31 53.78 0.63
ID_1  72.98 0.30 55.56 0.63


' 1_ 0_ 5e-06_ 2_ 8'
2.0
./AmazonResults/AMA_RD_4/
2.0
./AmazonResults/AMA_RD1_4