In [1]:
import glob
import numpy as np
import matplotlib.pyplot as plt

import csv
import os
import re

In [2]:
%matplotlib inline

In [3]:
def read_to_dict(file_name):
    res_dict = {}
    keym = {}
    with open(file_name, newline='') as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=',')
        for i, row in enumerate(csv_reader):
            if i == 0:
                for j, rr in enumerate(row):
                    res_dict[rr] = []
                    keym[j] = rr
            else:
                for j, rr in enumerate(row):
                    if j not in keym:
                        print('error')
                        break
                    res_dict[keym[j]].append(float(rr))
    return res_dict

In [4]:
def early_stop(result_dict, key1, key2, max_ep):
    if max_ep is None:
        best_ep_0 = np.argmax(np.array(result_dict[key1]))
        best_ep_1 = np.argmax(np.array(result_dict[key2]))
    else:
        best_ep_0 = np.argmax(np.array(result_dict[key1][0:max_ep]))
        best_ep_1 = np.argmax(np.array(result_dict[key2][0:max_ep]))
    return best_ep_0, best_ep_1

In [5]:
def collect_results(entries, res_keys, which, max_ep):
    map_k = {0:"test", 1:"val", 2:"id_test", 3: "train"}
    res = {}
    for i in range(2):
        res[i] = {"test":{}, "val":{}, "id_test":{}, "train":{}}
        for dic in ["test", "val", "id_test", "train"]:
            for res_k in res_keys:
                res[i][dic][res_k] = []
    for entry in entries:
        e1,e2 = early_stop(entries[entry][which],res_keys[0], res_keys[1], max_ep)
        for i, e_v in enumerate([e1, e2]):
            for res_k in res_keys:
                for j_id in range(4):
                    if entries[entry][j_id] is None:
                        continue
                    if e_v >= len(entries[entry][j_id][res_k]):
                        print("Error", res_k, len(entries[entry][j_id][res_k]), map_k[j_id])
                        res[i][map_k[j_id]][res_k].append(entries[entry][j_id][res_k][-1])
                    else:
                        res[i][map_k[j_id]][res_k].append(entries[entry][j_id][res_k][e_v])

    return res
def print_row(row, res_keys):
    final_str = ""
    for res_k in res_keys:
        final_str += " {:04.2f}".format(100*np.mean(row[res_k]))
        final_str += " {:04.2f}".format(100*np.std(row[res_k]))
    return final_str

In [6]:
def parse_config(log_di, my_keys):
    config = open(log_di+'log.txt').read().split('\n')
    dat_conf = {}
    for ll in config:
        if 'Epoch [0]' in ll:
            break
        if 'Dataset kwargs' in ll and '{}' not in ll:
            dat_conf[ll.split(':')[0]] = ll.split(':')[2]
        elif ll.split(':')[0] in my_keys:
            dat_conf[ll.split(':')[0]] = ll.split(':')[1]
    return dat_conf

def parse_config_orig(log_di, my_keys):
    config = open(log_di+'log.txt').read().split('\n')
    dat_conf = {}
    for ll in config:
        if 'Epoch [0]' in ll:
            break
        if ll.split(':')[0] in my_keys:
            dat_conf[ll.split(':')[0]] = ll.split(':')[1]
    return dat_conf

In [7]:
np.exp(np.linspace(np.log(0.0001),np.log(10),7))

array([1.00000000e-04, 6.81292069e-04, 4.64158883e-03, 3.16227766e-02,
       2.15443469e-01, 1.46779927e+00, 1.00000000e+01])

# iWildCam

In [8]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [10]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

exps = {}
for log_di in log_dicts:
    dat_conf = parse_config_orig(log_di, my_keys)
    if not dat_conf['Dataset'] == ' iwildcam':
        continue
    exp_name = "{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], dat_conf['Lr'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'))
    if '3e-05' not in exp_name:
        continue
    if '2_' not in exp_name or '-1_ ' in exp_name:
        continue
    print(log_di, exp_name)

In [12]:
for exp_id in exps:
    #if '3e-05' not in exp_id:
    #    continue
    #if '2_' not in exp_id or '-1_ ' in exp_id:
    #    continue
    resO = collect_results(exps[exp_id], ('acc_avg', 'F1-macro_all'),0 , None)
    print(exp_id)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'F1-macro_all'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'F1-macro_all'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'F1-macro_all'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

 -1_ 42_ 3e-05
Te	 ES_0  38.12 2.15 2.05 0.68
ID 	 ES_0  34.19 2.37 2.30 1.08
Te	 ES_1  37.94 1.89 2.08 0.72
ID 	 ES_1  33.51 1.53 2.36 1.16




# Py150

In [13]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [14]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [15]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config_orig(log_di, my_keys)
    if not dat_conf['Dataset'] == ' py150':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./AMASK_27/  8e-05
./AMASK_28/  8e-05
./AMASK_29/  8e-05


In [17]:
for exp_id in exps:
    print(exp_id)
    #if not ('3_ 0_ 1e-05_ 2_ 6' in exp_id or '3_ 1_ 1e-05_ 2_ 6' in exp_id):
    #    continue
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    resO = collect_results(exps[exp_id], ('Acc (Class-Method)', 'Acc (Overall)'),1, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('Acc (Class-Method)', 'Acc (Overall)'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('Acc (Class-Method)', 'Acc (Overall)'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('Acc (Class-Method)', 'Acc (Overall)'))

        print("Te\t", res_test_str)
        #print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

 -1_ 42_ 8e-05_ 2_ 6
7.0
./AMASK_27/
6.0
./AMASK_28/
3.0
./AMASK_29/
Error Acc (Class-Method) 6 test
Error Acc (Class-Method) 6 id_test
Error Acc (Overall) 6 test
Error Acc (Overall) 6 id_test
Te	 ES_0  62.63 0.30 65.82 0.16
Te	 ES_1  62.37 0.07 65.79 0.07




# Mol

In [18]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [19]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [29]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    if not dat_conf['Dataset'] == ' ogb-molpcba':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Batch size'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        None,
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./AMASK_13/  0.001
./AMASK_14/  0.001
./AMASK_15/  0.001


In [31]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])

    resO = collect_results(exps[exp_id], ('ap', 'ap'), 0, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('ap', 'ap'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('ap', 'ap'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('ap', 'ap'))

        print("Te\t", res_test_str)
        #print("ID \t", res_id_t_str)
        print("Tr \t", res_train_str)
        break
    print('\n')

' -1_ 42_ 32_ 4_ 32'
4.0
./AMASK_13/
4.0
./AMASK_14/
4.0
./AMASK_15/
Te	 ES_0  4.87 0.10 4.87 0.10
Tr 	 ES_0  3.20 0.15 3.20 0.15




In [30]:
exps.keys()

dict_keys([])

# FMOW

In [32]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [33]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [34]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    if not dat_conf['Dataset'] == ' fmow':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Batch size'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./AMASK_16/  0.0005
./AMASK_17/  0.0005
./AMASK_18/  0.0005


In [36]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])

    resO = collect_results(exps[exp_id], ('acc_avg', 'acc_worst_region'), 0, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'acc_worst_region'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'acc_worst_region'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'acc_worst_region'))

        print("Te\t", res_test_str)
        print("ID \t", res_id_t_str)
        print("Tr \t", res_train_str)
    print('\n')

' -1_ 42_ 64_ 8_ 64'
25.0
./AMASK_16/
29.0
./AMASK_17/
29.0
./AMASK_18/
Te	 ES_0  26.35 1.36 15.67 0.42
ID 	 ES_0  31.62 1.42 28.04 1.53
Tr 	 ES_0  39.47 1.63 29.27 1.64
Te	 ES_1  25.46 0.66 17.70 0.75
ID 	 ES_1  30.81 1.15 27.85 1.77
Tr 	 ES_1  37.92 1.77 27.85 2.63




# Cam

In [37]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [38]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed']

In [42]:
exps = {}
for log_di in log_dicts:
    dat_conf = parse_config(log_di, my_keys)
    if not dat_conf['Dataset'] == ' camelyon17':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Seed']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_val_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)

./AMASK_3/  0.005
./AMASK_4/  0.005
./AMASK_5/  0.005
./AMASK_6/  0.005
./AMASK_7/  0.005
./AMASK_8/  0.005
./AMASK_9/  0.005
./AMASK_10/  0.005
./AMASK_11/  0.005
./AMASK_12/  0.005


In [47]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    skip_exp = False
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    if skip_exp:
        continue
        
    resO = collect_results(exps[exp_id], ('acc_avg', 'acc_wg'), 1, None)

    for alg in resO:
        res_test_str = "ES_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', 'acc_wg'))
        res_id_t_str = "ES_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', 'acc_wg'))
        res_train_str = "ES_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', 'acc_wg'))

        print("Te\t", res_test_str)
        #print("ID \t", res_id_t_str)
        #print("Tr \t", res_train_str)
    print('\n')

' -1_ 42_ 0.005_ 3_ 30'
4.0
./AMASK_3/
4.0
./AMASK_4/
4.0
./AMASK_5/
4.0
./AMASK_6/
4.0
./AMASK_7/
4.0
./AMASK_8/
4.0
./AMASK_9/
4.0
./AMASK_10/
4.0
./AMASK_11/
4.0
./AMASK_12/
Te	 ES_0  70.24 10.92 48.90 12.66
Te	 ES_1  69.05 9.90 47.99 11.41




# Poverty

In [48]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [49]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed', 'Dataset kwargs']

In [53]:
exps = {}
for log_di in log_dicts:
    #if 'RDE' not in log_di:
    #    continue
    dat_conf = parse_config(log_di, my_keys)
    if not dat_conf['Dataset'] == ' poverty':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['Lr'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    exps[exp_name][dat_conf['Dataset kwargs']] = (read_to_dict(log_di+'test_eval.csv'), 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        None,
                                        read_to_dict(log_di+'id_val_eval.csv'),
                                        log_di)

./AMASK_22/  0.0005
./AMASK_23/  0.0005
./AMASK_24/  0.0005
./AMASK_25/  0.0005
./AMASK_26/  0.0005


In [54]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    for kk in exps[exp_id]:
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
        
    resO = collect_results(exps[exp_id], ('r_all', 'r_wg'), 0, None)

    for alg in resO:
        res_test_str = "OOD_{} ".format(alg) +print_row(resO[alg]['test'], ('r_all', 'r_wg'))
        res_id_t_str = "ID_{} ".format(alg) +print_row(resO[alg]['id_test'], ('r_all', 'r_wg'))
        res_train_str = "TR_{} ".format(alg) +print_row(resO[alg]['train'], ('r_all', 'r_wg'))
        res_val_str = "VAL_{} ".format(alg) +print_row(resO[alg]['val'], ('r_all', 'r_wg'))

        print(res_test_str)
        print(res_id_t_str)
        print(res_train_str)
        print(res_val_str)
    print('\n')

' -1_ 42_ 0.0005_ 0.0005_ 64'
164.0
./AMASK_22/
166.0
./AMASK_23/
166.0
./AMASK_24/
173.0
./AMASK_25/
171.0
./AMASK_26/
OOD_0  80.86 3.85 50.44 4.11
ID_0  0nan 0nan 0nan 0nan
TR_0  81.19 1.57 54.04 2.09
VAL_0  79.59 2.88 46.97 4.94
OOD_1  80.33 3.94 51.45 4.03
ID_1  0nan 0nan 0nan 0nan
TR_1  81.34 1.29 55.09 1.65
VAL_1  78.53 3.68 44.52 6.03




# Amazon

In [55]:
log_dicts = sorted(glob.glob('./AMASK_*/'), key= lambda x:int(x.split('_')[-1].split('/')[0]))

In [56]:
my_keys = ['Dataset', 'Algorithm', 'Uniform over groups', 'Distinct groups', 'N groups per batch']
my_keys += ['Batch size', 'Rd type', 'Warm start epoch', 'Control only direction', 'Only inconsistent']
my_keys += ['Without sampling', 'Lr', 'Weight decay', 'Seed', 'Dataset kwargs']

In [59]:
exps = {}
for log_di in log_dicts:
    #if 'RDE' not in log_di:
    #    continue
    dat_conf = parse_config_orig(log_di, my_keys)
    if not dat_conf['Dataset'] == ' amazon':
        continue
    print(log_di, dat_conf['Lr'])
    exp_name = "{}_{}_{}_{}_{}".format(dat_conf['Warm start epoch'], dat_conf['Rd type'], 
                                 dat_conf['Lr'], dat_conf['N groups per batch'],
                                 dat_conf['Batch size'])
    if exp_name not in exps:
        exps[exp_name] = {}
    test_res = read_to_dict(log_di+'test_eval.csv')
    if 'epoch' not in test_res:
        continue
    exps[exp_name][dat_conf['Seed']] = (test_res, 
                                        read_to_dict(log_di+'val_eval.csv'), 
                                        read_to_dict(log_di+'id_test_eval.csv'),
                                        read_to_dict(log_di+'train_eval.csv'),
                                        log_di)


./AMASK_19/  1e-05
./AMASK_20/  1e-05
./AMASK_21/  1e-05


In [60]:
for exp_id in exps:
    print("'{}'".format(exp_id))
    breakl = False
    for kk in exps[exp_id]:
        if 'epoch' not in exps[exp_id][kk][0]:
            breakl = True
            continue
        print(np.max(exps[exp_id][kk][0]['epoch']))
        print(exps[exp_id][kk][4])
    #if breakl:
    #    continue
    resO = collect_results(exps[exp_id], ('acc_avg', '10th_percentile_acc'), 0, None)

    for alg in resO:
        res_test_str = "OOD_{} ".format(alg) +print_row(resO[alg]['test'], ('acc_avg', '10th_percentile_acc'))
        res_id_t_str = "ID_{} ".format(alg) +print_row(resO[alg]['id_test'], ('acc_avg', '10th_percentile_acc'))
        res_train_str = "TR_{} ".format(alg) +print_row(resO[alg]['train'], ('acc_avg', '10th_percentile_acc'))
        res_val_str = "VAL_{} ".format(alg) +print_row(resO[alg]['val'], ('acc_avg', '10th_percentile_acc'))

        print(res_test_str)
        #print(res_id_t_str)
        #print(res_train_str)
        #print(res_val_str)
    print('\n')

' -1_ 42_ 1e-05_ 2_ 8'
2.0
./AMASK_19/
2.0
./AMASK_20/
2.0
./AMASK_21/
OOD_0  70.30 0.39 51.56 0.63
OOD_1  69.44 0.83 52.00 0.00


