In [114]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [330]:
def get_result(log_path, early_stop=False, use_train=False, expected_epochs=None):
    try:
        df = pd.read_csv(f"../logs/{ds_name}/{log_path}/train_eval.csv")
    except:
        print("missing " + f"../logs/{ds_name}/{log_path}/train_eval.csv")
        return None
    train_res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
    if expected_epochs is not None:
        if len(train_res) != expected_epochs:
            print(log_path, len(train_res))
    if use_train:
        if early_stop:
            train_res = sorted(train_res, key=lambda x: x[2])
        best_epoch_no = train_res[-1][0]

    df = pd.read_csv(f"../logs/{ds_name}/{log_path}/val_eval.csv")
    val_res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
    if not use_train:
        if early_stop:
            val_res = sorted(val_res, key=lambda x: x[2])
        best_epoch_no = val_res[-1][0]

    df = pd.read_csv(f"../logs/{ds_name}/{log_path}/test_eval.csv")
    res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
    for ent in res:
        if ent[0] == best_epoch_no:
            if use_train:
                return np.array([[train_res[-1][1], train_res[-1][2], ent[1], ent[2]]])
            else:
                return np.array([[val_res[-1][1], val_res[-1][2], ent[1], ent[2]]])

In [439]:
ds_name = "celebA"

algos = {
    "ERM": [
        #"erm-resnet50",
        "erm-resnet50_wd0.001",
        "erm-resnet50_wd0.01",
        "erm-resnet50_wd0.1",
        "erm-resnet50_wd1.0",
    ],
    "ISERM": [
        #"erm_reweight-resnet50-lr1e-3",
        "erm_reweight-resnet50-lr1e-3_wd0.001",
        "erm_reweight-resnet50-lr1e-3_wd0.01",
        "erm_reweight-resnet50-lr1e-3_wd0.1",
        "erm_reweight-resnet50-lr1e-3_wd1.0",
    ],
    "IWERM": [
        #"iwerm-resnet50",
        "iwerm-resnet50_wd0.001",
        "iwerm-resnet50_wd0.01",
        "iwerm-resnet50_wd0.1",
        "iwerm-resnet50_wd1.0",
    ],
    "gDRO": [
        #"groupDRO-resnet50-lr1e-3",
        "groupDRO-resnet50-lr1e-3-wd0.001",
        "groupDRO-resnet50-lr1e-3-wd0.01",
        "groupDRO-resnet50-lr1e-3-wd0.1",
        "groupDRO-resnet50-lr1e-3-wd1.0",
    ],
    "DP IW": [
        "weightederm-dp_resnet50-dpsgd_1e-5_1.0_1.0_0.0001",
        "weightederm-dp_resnet50-dpsgd_1e-5_0.1_1.0_0.0001",
        "weightederm-dp_resnet50-dpsgd_1e-5_0.01_1.0_0.0001",
        "weightederm-dp_resnet50-dpsgd_1e-5_0.001_1.0_0.0001",
        
        #"weightederm-dp_resnet50-dpsgd_1e-5_0.01_0.1_0.0001",
        #"weightederm-dp_resnet50-dpsgd_1e-5_0.1_0.1_0.0001",
        #"weightederm-dp_resnet50-dpsgd_1e-5_1.0_0.1_0.0001",
        #"weightederm-dp_resnet50-dpsgd_1e-5_10.0_0.1_0.0001",
    ],
    "ISERM-n": [
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.001",
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.01",
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.1",
        "erm_reweight-resnet50-lr1e-3-noisesgd_1.0",
    ],
    "IWERM-n": [
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.001",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.01",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.1",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_1.0",
    ],
    "gDRO-n": [
        "groupDRO-resnet50-lr1e-3-noisesgd_0.001",
        "groupDRO-resnet50-lr1e-3-noisesgd_0.01",
        "groupDRO-resnet50-lr1e-3-noisesgd_0.1",
        "groupDRO-resnet50-lr1e-3-noisesgd_1.0",
    ],
}
cols = ["acc_avg", "acc_wg", "acc_y:notblond_male:0", "acc_y:notblond_male:1", "acc_y:blond_male:0", "acc_y:blond_male:1"]
col_names = ["acc", "acc_wg", "notblond_male:0", "notblond_male:1", "blond_male:0", "blond_male:1"]

best_test_results = {}
for name, log_paths in algos.items():
    for log_path in log_paths:
        #print(best_epoch_no)
        #test_res = [val_res[-1][1], val_res[-1][2], res[-1][1], res[-1][2], train_res[-1][1], train_res[-1][2]]
        all_res = []
        for i in range(1, 3):
            res = get_result(log_path + f"_sp{i}", early_stop=True, expected_epochs=50)
            if res is not None:
                all_res.append(res)
        all_res.append(get_result(log_path, early_stop=True))
        #print(all_res)
        if len(all_res) == 1:
            all_res = all_res[0].reshape(-1)
            all_res = np.concatenate([all_res, np.zeros_like(all_res)])
            best_test_results.setdefault(name, list()).append(all_res)
        else:
            all_res = np.concatenate(all_res, axis=0)
            all_res = np.concatenate((all_res.mean(0), all_res.std(0)))
            best_test_results.setdefault(name, list()).append(all_res)
        #print(name, log_path, all_res)
    #break

for name in algos.keys(): 
    res = np.array(best_test_results[name])
    #print(name, res[res[:, 1].argmax()])
    print(name, res[res[:, 1].argmax()][3], res[res[:, 1].argmax()][7] / np.sqrt(3))

ERM 0.7300747831662496 0.021666751073132276
ISERM 0.8240715463956197 0.0045184053176029046
IWERM 0.8901897271474203 0.009103023475341636
gDRO 0.8449074228604635 0.008186285338262732
DP IW 0.8595568736394247 0.008080739898989674
ISERM-n 0.849392036596934 0.009701063993641876
IWERM-n 0.8854014078776041 0.003549978178748661
gDRO-n 0.8333333532015482 0.004536087790310597


In [417]:
best_test_results["ERM"]

[array([0.94612507, 0.65042241, 0.94668202, 0.64376527, 0.00444627,
        0.03816489, 0.00271142, 0.0186927 ]),
 array([0.91003507, 0.74469958, 0.91330192, 0.72649195, 0.03024529,
        0.07874851, 0.02311135, 0.11103518]),
 array([0.8528045 , 0.51947274, 0.8555088 , 0.4911859 , 0.00480203,
        0.36773473, 0.00818055, 0.34921376]),
 array([0.85023743, 0.        , 0.85592628, 0.        , 0.00315797,
        0.        , 0.00780757, 0.        ])]

In [315]:
best_test_results

{'ERM': [array([0.9470982 , 0.58791208, 0.9498046 , 0.59444445, 0.        ,
         0.        , 0.        , 0.        ]),
  array([0.94050437, 0.69230771, 0.94389337, 0.66666669, 0.        ,
         0.        , 0.        , 0.        ]),
  array([0.91003507, 0.74469958, 0.91330192, 0.72649195, 0.03024529,
         0.07874851, 0.02311135, 0.11103518]),
  array([0.8528045 , 0.51947274, 0.8555088 , 0.4911859 , 0.00480203,
         0.36773473, 0.00818055, 0.34921376]),
  array([0.85023743, 0.        , 0.85592628, 0.        , 0.00315797,
         0.        , 0.00780757, 0.        ])],
 'ISERM': [array([0.94397748, 0.65934068, 0.94760042, 0.62222224, 0.        ,
         0.        , 0.        , 0.        ]),
  array([0.94125935, 0.73403701, 0.94086266, 0.71251526, 0.00543615,
         0.00772122, 0.00107706, 0.0347375 ]),
  array([0.92223284, 0.83486566, 0.9197976 , 0.7976535 , 0.00377509,
         0.04365686, 0.00886685, 0.06987572]),
  array([0.85780442, 0.84628004, 0.86995292, 0.8277778 

In [442]:
ds_name = "utkface"

algos = {
    "ERM": [
        #"erm-resnet50",
        "erm-resnet50_wd0.001",
        "erm-resnet50_wd0.01",
        "erm-resnet50_wd0.1",
        "erm-resnet50_wd1.0",
    ],
    "ISERM": [
        #"erm_reweight-resnet50",
        "erm_reweight-resnet50_wd0.001",
        "erm_reweight-resnet50_wd0.01",
        "erm_reweight-resnet50_wd0.1",
        "erm_reweight-resnet50_wd1.0",
    ],
    "IWERM": [
        #"iwerm-resnet50",
        "iwerm-resnet50_wd0.001",
        "iwerm-resnet50_wd0.01",
        "iwerm-resnet50_wd0.1",
        "iwerm-resnet50_wd1.0",
    ],
    "gDRO": [
        #"groupDRO-resnet50",
        "groupDRO-resnet50_wd0.001",
        "groupDRO-resnet50_wd0.01",
        "groupDRO-resnet50_wd0.1",
        "groupDRO-resnet50_wd1.0",
    ],
    "ISERM-n": [
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.001",
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.01",
        "erm_reweight-resnet50-lr1e-3-noisesgd_0.1",
        "erm_reweight-resnet50-lr1e-3-noisesgd_1.0",
    ],
    "gDRO-n": [
        "groupDRO-resnet50-lr1e-3-noisesgd_0.001",
        "groupDRO-resnet50-lr1e-3-noisesgd_0.01",
        "groupDRO-resnet50-lr1e-3-noisesgd_0.1",
        "groupDRO-resnet50-lr1e-3-noisesgd_1.0",
    ],
    "IWERM-n": [
        #"iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.0001__",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.001__",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.01__",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_0.1__",
        "iwerm-resnet50-lr1e-3-noisesgd_1e-5_1.0__",
        #"iwerm-resnet50-lr1e-3-noisesgd_1e-5_10.0__",
    ],
    "DP IW": [
        #"weightederm-cw0.001-dp_resnet50-lr1e-3-dpsgd_1e-5_0.01_1.0_0.001",
        #"weightederm-cw0.001-dp_resnet50-lr1e-3-dpsgd_1e-5_0.1_1.0_0.001",
        #"weightederm-cw0.001-dp_resnet50-lr1e-3-dpsgd_1e-5_1.0_1.0_0.001",
        #
        #"weightederm-cw0.002-dp_resnet50-lr1e-3-dpsgd_1e-5_0.01_1.0_0.001",
        #"weightederm-cw0.002-dp_resnet50-lr1e-3-dpsgd_1e-5_0.1_1.0_0.001",
        #"weightederm-cw0.002-dp_resnet50-lr1e-3-dpsgd_1e-5_1.0_1.0_0.001",
        #
        #"weightederm-dp_resnet50-lr1e-4-dpsgd_1e-5_0.001_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-4-dpsgd_1e-5_0.01_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-4-dpsgd_1e-5_0.1_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-4-dpsgd_1e-5_1.0_1.0_0.001",
        #
        #"weightederm-dp_resnet50-lr1e-2-dpsgd_1e-5_0.001_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-2-dpsgd_1e-5_0.01_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-2-dpsgd_1e-5_0.1_1.0_0.001",
        #"weightederm-dp_resnet50-lr1e-2-dpsgd_1e-5_1.0_1.0_0.001",
        
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.0001_1.0_0.001",
        "weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.001_1.0_0.001",
        "weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.01_1.0_0.001",
        "weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.1_1.0_0.001",
        "weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_1.0_1.0_0.001",
        
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.0001_10.0_0.001",
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.001_10.0_0.001",
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.01_10.0_0.001",
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_0.1_10.0_0.001",
        #"weightederm-dp_resnet50-lr1e-3-dpsgd_1e-5_1.0_10.0_0.001",
    ],
    "DP gDRO": [
        "groupdro-dp_resnet50-dpsgd_1e-5_0.01_1.0_0.001",
    ],
}

best_test_results = {}
for name, log_paths in algos.items():
    for log_path in log_paths:
        #try:
        #    df = pd.read_csv(f"../logs/{ds_name}/{log_path}/train_eval.csv")
        #except:
        #    print(f"{log_path} not run yet.")
        #    continue
        #    
        #train_res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
        #
        #df = pd.read_csv(f"../logs/{ds_name}/{log_path}/val_eval.csv")
        #val_res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
        #val_res = sorted(val_res, key=lambda x: x[2])
        #best_epoch_no = val_res[-1][0]

        #df = pd.read_csv(f"../logs/{ds_name}/{log_path}/test_eval.csv")
        #res = df[["epoch", "acc_avg", "acc_wg"]].values.tolist()
        #for ent in res:
        #    if ent[0] == best_epoch_no:
        #        test_res = [val_res[-1][1], val_res[-1][2], ent[1], ent[2]]
        #        break

        res = get_result(log_path, True, expected_epochs=100)
        #test_res = [val_res[-1][1], val_res[-1][2], res[-1][1], res[-1][2], train_res[-1][1], train_res[-1][2]]
        if res is not None:
            best_test_results.setdefault(name, list()).append(res[0])
        
for name in algos.keys(): 
    res = np.array(best_test_results[name])
    print(name, res[res[:, 1].argmax()])

ERM [0.91949999 0.88976377 0.92502123 0.86301368]
ISERM [0.93400002 0.90163934 0.91291422 0.85808581]
IWERM [0.94300002 0.88524592 0.9267205  0.86468649]
gDRO [0.91649997 0.85245901 0.92502123 0.85215056]
ISERM-n [0.92799997 0.89873415 0.90314358 0.8547855 ]
gDRO-n [0.91850001 0.89051098 0.92459643 0.875     ]
IWERM-n [0.91900003 0.89204544 0.91694987 0.88499999]
DP IW [0.89649999 0.84810126 0.87829226 0.82456142]
DP gDRO [0.89249998 0.85245901 0.88423961 0.76897693]


In [443]:
ds_name = "inaturalist"

algos = {
    "ERM": [
        #"erm-resnet18-lr1e-3",
        "erm-resnet18-lr1e-3_wd0.001",
        "erm-resnet18-lr1e-3_wd0.01",
        "erm-resnet18-lr1e-3_wd0.1",
        "erm-resnet18-lr1e-3_wd1.0",
    ],
    "ISERM": [
        #"erm_reweight-resnet18-lr1e-3",
        "erm_reweight-resnet18-lr1e-3_wd0.001",
        "erm_reweight-resnet18-lr1e-3_wd0.01",
        "erm_reweight-resnet18-lr1e-3_wd0.1",
        "erm_reweight-resnet18-lr1e-3_wd1.0",
    ],
    "IWERM": [
        #"iwerm-resnet18-lr1e-3",
        "iwerm-resnet18-lr1e-3_wd0.001",
        "iwerm-resnet18-lr1e-3_wd0.01",
        "iwerm-resnet18-lr1e-3_wd0.1",
        "iwerm-resnet18-lr1e-3_wd1.0",
    ],
    "gDRO": [
        #"groupDRO-resnet18-lr1e-3",
        "groupDRO-resnet18-lr1e-3_wd0.001",
        "groupDRO-resnet18-lr1e-3_wd0.01",
        "groupDRO-resnet18-lr1e-3_wd0.1",
        "groupDRO-resnet18-lr1e-3_wd1.0",
    ],
    "DP IW": [
        "weightederm-dp_resnet18-dpsgd_1e-5_0.0000001_10.0_0.0001",
        "weightederm-dp_resnet18-dpsgd_1e-5_0.000001_10.0_0.0001",
        "weightederm-dp_resnet18-dpsgd_1e-5_0.00001_10.0_0.0001",
        "weightederm-dp_resnet18-dpsgd_1e-5_0.0001_10.0_0.0001",
        
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.00001_1.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.0001_1.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.001_1.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.01_1.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.1_1.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001",
        
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.00000001_100.0_0.0001",
        #"weightederm-dp_resnet18-dpsgd_1e-5_0.0000001_100.0_0.0001",
    ],
    "ISERM-n": [
        "erm_reweight-resnet18-lr1e-3-noisesgd_0.001",
        "erm_reweight-resnet18-lr1e-3-noisesgd_0.01",
        "erm_reweight-resnet18-lr1e-3-noisesgd_0.1",
        "erm_reweight-resnet18-lr1e-3-noisesgd_1.0",
    ],
    "IWERM-n": [
        "iwerm-resnet18-lr1e-3-noisesgd_1e-5_0.001",
        "iwerm-resnet18-lr1e-3-noisesgd_1e-5_0.01",
        "iwerm-resnet18-lr1e-3-noisesgd_1e-5_0.1",
        "iwerm-resnet18-lr1e-3-noisesgd_1e-5_1.0",
    ],
    "gDRO-n": [
        "groupDRO-resnet18-lr1e-3-noisesgd_0.001__",
        "groupDRO-resnet18-lr1e-3-noisesgd_0.01__",
        "groupDRO-resnet18-lr1e-3-noisesgd_0.1__",
        "groupDRO-resnet18-lr1e-3-noisesgd_1.0__",
    ],
}
cols = ["acc_avg", "acc_wg", "acc_y:notblond_male:0", "acc_y:notblond_male:1", "acc_y:blond_male:0", "acc_y:blond_male:1"]
col_names = ["acc", "acc_wg", "notblond_male:0", "notblond_male:1", "blond_male:0", "blond_male:1"]

best_test_results = {}
for name, log_paths in algos.items():
    for log_path in log_paths:
        res = get_result(log_path, True, expected_epochs=20)
        if res is not None:
            best_test_results.setdefault(name, list()).append(res[0])

for name in algos.keys(): 
    res = np.array(best_test_results[name])
    print(name, res[res[:, 1].argmax()])

ERM [0.90399998 0.55769229 0.90554744 0.41818181]
ISERM [0.84227997 0.74134421 0.84210455 0.71099049]
IWERM [0.87085998 0.72708756 0.87208992 0.67567569]
gDRO [0.83797997 0.65714288 0.83750165 0.67272729]
DP IW [0.74822003 0.60000002 0.74943459 0.51351351]
ISERM-n [0.84289998 0.71428573 0.84276974 0.70963365]
IWERM-n [0.84460002 0.73076922 0.84235734 0.70909089]
gDRO-n [0.88792002 0.60000002 0.88942397 0.56363636]


In [311]:
best_test_results

{'ERM': [array([0.90215999, 0.55769229, 0.90355194, 0.41818181]),
  array([0.90399998, 0.55769229, 0.90554744, 0.41818181]),
  array([0.85795999, 0.        , 0.86150062, 0.        ]),
  array([0.69995999, 0.        , 0.7025941 , 0.        ]),
  array([0.34994   , 0.        , 0.34922177, 0.        ])],
 'gDRO': [array([0.88472003, 0.59615386, 0.88539308, 0.56756759]),
  array([0.83797997, 0.65714288, 0.83750165, 0.67272729]),
  array([0.83921999, 0.60000002, 0.8431555 , 0.61818182]),
  array([0.63998002, 0.48571429, 0.63738191, 0.51351351]),
  array([0.01386   , 0.        , 0.01406146, 0.        ])],
 'ISERM': [array([0.84276003, 0.74285716, 0.84308898, 0.70556307]),
  array([0.84227997, 0.74134421, 0.84210455, 0.71099049]),
  array([0.74934   , 0.37244385, 0.75246775, 0.39252338]),
  array([0.00922  , 0.       , 0.0087801, 0.       ])],
 'ISERM-n': [array([0.84289998, 0.71428573, 0.84276974, 0.70963365]),
  array([0.87128001, 0.63461536, 0.87157112, 0.61818182]),
  array([0.58771998, 0