In [1]:
import os
import re
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from scipy import stats
from tensorboard.backend.event_processing import event_accumulator
import pickle

In [2]:
def parse_tensorboard(path, scalars = None):
    """returns a dictionary of pandas dataframes for each requested scalar"""
    ea = event_accumulator.EventAccumulator(
        path,
        size_guidance={event_accumulator.SCALARS: 0},
    )
    _absorb_print = ea.Reload()
    # make sure the scalars are in the event accumulator tags
    # assert all(
    #     s in ea.Tags()["scalars"] for s in scalars
    # ), "some scalars were not found in the event accumulator"
    res = {}
    if scalars is None:
        scalars = ea.Tags()["scalars"]
    for s in scalars:
        try:
            res[s] = pd.DataFrame(ea.Scalars(s))
        except:
            continue
    return res

In [3]:
# scalars = [
#     "val/acc_adv_attack",
#     "val/balanced_acc_adv_attack",
#     "val/acc_task",
#     "val/balanced_acc_task",
#     "train/zero_ratio_adv",
#     "train/zero_ratio_task",
#     "val/acc_task_debiased",
#     "val/acc_protected",
#     "val/balanced_acc_task_debiased",
#     "val/balanced_acc_protected"
# ]
# old suffixes for adv attack: adv_attack_unbiased, adv_attack_biased
scalars = [
    "val/acc_task",
    "val/loss_task",
    "val/balanced_acc_adv_attack",
    "val/loss_adv_attack"
]

In [23]:
folder = "/share/home/lukash/pan16/bertl4/logs_merged_masks"
experiment_names = set([re.sub(r"(?<=seed)[\d+]", "{}", n) for n in os.listdir(folder)]) # remove seed suffix
results = {}
for n in experiment_names:
    results[n] = []
    for i in range(5):
        n_seed = n.format(i)
        filepath = os.path.join(folder, n_seed)
        df = parse_tensorboard(filepath)
        results[n].append(df)

In [5]:
# with open("results.pkl", "wb") as f:
#     pickle.dump(results, f)

In [6]:
# with open("results.pkl", "rb") as f:
#     results = pickle.load(f)

In [7]:
#results['adverserial-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}'][0]

In [25]:
results_prep = {}
for exp, data in results.items():
    results_prep[exp] = {}
    for res_seed in data:
        for k, v in res_seed.items():
            for s in scalars:
                if s in k:
                    try:
                        results_prep[exp][k].append(v)
                    except:
                        results_prep[exp][k] = [v]

In [17]:
# remove cp_init
results_prep = {k.replace("-cp_init", ""):v for k,v in results_prep.items()}

In [27]:
def get_scalar_stats(
    exp_results,
    strategy = "max", # max, min, last, argmin, argmax
    idx_results = None,
    return_stats = True
):

    if strategy == "argmin" or strategy == "argmax":
        
        assert idx_results is not None
        get_fn_idx = lambda x: (x.argmax() if strategy == "argmax" else x.argmin())
        idx_list = [get_fn_idx(x["value"]) for x in idx_results]
        res = np.array([x["value"].iloc[i] for i, x in zip(idx_list, exp_results)])

    else:

        if strategy == "max":
            get_fn = lambda x: x.max()
        elif strategy == "min":
            get_fn = lambda x: x.max()
        elif strategy == "last":
            get_fn = lambda x: x.iloc[-1]
        else:
            raise Exception(f"strategy {strategy} undefined")
        res = np.array([get_fn(x["value"]) for x in exp_results])

    if return_stats:
        return np.array([res.mean(), res.std()])
    else:
        return res


def get_scalar_stats_wrapper_max(results, exp_name, scalar):
    exp_results = results[exp_name][scalar]
    return get_scalar_stats(exp_results, strategy="max")   


def get_scalar_stats_wrapper_argmin(results, exp_name, scalar, scalar_idx):
    exp_results = results[exp_name][scalar]
    idx_results = results[exp_name][scalar_idx]
    return get_scalar_stats(exp_results, strategy="argmin", idx_results=idx_results)    

In [28]:
sorted(list(results_prep.keys()))

['adv_0.05_seed{}',
 'adv_0.1_seed{}',
 'modular_0.05_seed{}',
 'modular_0.1_seed{}']

In [31]:
set([k for d in results_prep.values() for k in d.keys()])

{'val/acc_task_eval',
 'val/balanced_acc_adv_attack_age',
 'val/balanced_acc_adv_attack_gender',
 'val/loss_adv_attack_age',
 'val/loss_adv_attack_gender',
 'val/loss_task_eval'}

In [32]:
key_map_pan16 = {
    'task-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender'],
        ['val/balanced_acc_adv_attack_task_emb_target_key_age', 'val/loss_adv_attack_task_emb_target_key_age']
    ],
    'task-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender'],
        ['val/balanced_acc_adv_attack_task_emb_target_key_age', 'val/loss_adv_attack_task_emb_target_key_age']
    ],
    'task-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender'],
        ['val/balanced_acc_adv_attack_task_emb_target_key_age', 'val/loss_adv_attack_task_emb_target_key_age']
    ],
    'adverserial-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_gender', 'val/loss_adv_attack_adv_emb_age_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_age', 'val/loss_adv_attack_adv_emb_age_target_key_age']
    ],
    'adverserial-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_gender', 'val/loss_adv_attack_adv_emb_gender_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_age', 'val/loss_adv_attack_adv_emb_gender_target_key_age']
    ],
    'adverserial-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender_age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_age', 'val/loss_adv_attack_adv_emb_all_target_key_age']
    ],
    'adverserial-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_gender', 'val/loss_adv_attack_adv_emb_age_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_age', 'val/loss_adv_attack_adv_emb_age_target_key_age']
    ],
    'adverserial-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_gender', 'val/loss_adv_attack_adv_emb_gender_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_age', 'val/loss_adv_attack_adv_emb_gender_target_key_age']
    ],
    'adverserial-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender_age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_age', 'val/loss_adv_attack_adv_emb_all_target_key_age']
    ],
    'adverserial-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_gender', 'val/loss_adv_attack_adv_emb_age_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_age_target_key_age', 'val/loss_adv_attack_adv_emb_age_target_key_age']
    ],
    'adverserial-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_gender', 'val/loss_adv_attack_adv_emb_gender_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_age', 'val/loss_adv_attack_adv_emb_gender_target_key_age']
    ],
    'adverserial-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender_age-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender'],
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_age', 'val/loss_adv_attack_adv_emb_all_target_key_age']
    ],
}

key_map_bios = {
    'task-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender']
    ],
    'task-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender']
    ],
    'task-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-seed{}': [
        'val/acc_task',
        ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender']
    ],
    'adverserial-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender']
    ],
    'adverserial-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender']
    ],
    'adverserial-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}': [
        'val/acc_task_debiased',
        ['val/balanced_acc_adv_attack_adv_emb_all_target_key_gender', 'val/loss_adv_attack_adv_emb_all_target_key_gender']
    ]
}


keys_merged_masks = [
    'val/acc_task_eval',
    ['val/balanced_acc_adv_attack_gender', 'val/loss_adv_attack_gender'],
    ['val/balanced_acc_adv_attack_age', 'val/loss_adv_attack_age']
]


keys_modular = [
    'val/acc_task',
    'val/acc_task_debiased_gender',
    'val/acc_task_debiased_age',
    ['val/balanced_acc_adv_attack_task_emb_target_key_gender', 'val/loss_adv_attack_task_emb_target_key_gender'],
    ['val/balanced_acc_adv_attack_task_emb_target_key_age', 'val/loss_adv_attack_task_emb_target_key_age'],
    ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_gender', 'val/loss_adv_attack_adv_emb_gender_target_key_gender'],
    ['val/balanced_acc_adv_attack_adv_emb_gender_target_key_age', 'val/loss_adv_attack_adv_emb_gender_target_key_age'],
    ['val/balanced_acc_adv_attack_adv_emb_age_target_key_gender', 'val/loss_adv_attack_adv_emb_age_target_key_gender'],
    ['val/balanced_acc_adv_attack_adv_emb_age_target_key_age', 'val/loss_adv_attack_adv_emb_age_target_key_age']
]

In [21]:
for exp, keys in key_map_pan16.items():
    if exp in results_prep:
        acc_task = get_scalar_stats_wrapper_max(results_prep, exp, keys[0])
        acc_g = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys[1])
        acc_a = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys[2])
        print(exp)
        print(f"acc task: {acc_task[0]:.3f} +- {acc_task[1]:.3f}")
        print(f"bacc attack gender: {acc_g[0]:.3f} +- {acc_g[1]:.3f}")
        print(f"bacc attack age: {acc_a[0]:.3f} +- {acc_a[1]:.3f}")
        print("\n")
    

In [22]:
for exp in results_prep.keys():
    acc_task = get_scalar_stats_wrapper_max(results_prep, exp, keys_modular[0])
    acc_task_g = get_scalar_stats_wrapper_max(results_prep, exp, keys_modular[1])
    acc_task_a = get_scalar_stats_wrapper_max(results_prep, exp, keys_modular[2])
    bacc_g_task_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[3])
    bacc_a_task_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[4])
    bacc_g_gender_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[5])
    bacc_a_gender_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[6])
    bacc_g_age_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[7])
    bacc_a_age_emb = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_modular[8])
    print(exp)
    print(f"acc task: {acc_task[0]:.3f} +- {acc_task[1]:.3f}")
    print(f"acc task debiased gender: {acc_task_g[0]:.3f} +- {acc_task_g[1]:.3f}")
    print(f"acc task debiased age: {acc_task_a[0]:.3f} +- {acc_task_a[1]:.3f}")
    print(f"bacc attack gender - task emb: {bacc_g_task_emb[0]:.3f} +- {bacc_g_task_emb[1]:.3f}")
    print(f"bacc attack age - task emb: {bacc_a_task_emb[0]:.3f} +- {bacc_a_task_emb[1]:.3f}")
    print(f"bacc attack gender - g emb: {bacc_g_gender_emb[0]:.3f} +- {bacc_g_gender_emb[1]:.3f}")
    print(f"bacc attack age - g emb: {bacc_a_gender_emb[0]:.3f} +- {bacc_a_gender_emb[1]:.3f}")
    print(f"bacc attack gender - a emb: {bacc_g_age_emb[0]:.3f} +- {bacc_g_age_emb[1]:.3f}")
    print(f"bacc attack age - a emb: {bacc_a_age_emb[0]:.3f} +- {bacc_a_age_emb[1]:.3f}")
    print("\n")

modular-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender_age-seed{}
acc task: 0.919 +- 0.001
acc task debiased gender: 0.919 +- 0.001
acc task debiased age: 0.917 +- 0.001
bacc attack gender - task emb: 0.612 +- 0.004
bacc attack age - task emb: 0.421 +- 0.003
bacc attack gender - g emb: 0.560 +- 0.004
bacc attack age - g emb: 0.406 +- 0.011
bacc attack gender - a emb: 0.611 +- 0.002
bacc attack age - a emb: 0.370 +- 0.016


modular-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender_age-seed{}
acc task: 0.919 +- 0.001
acc task debiased gender: 0.919 +- 0.001
acc task debiased age: 0.917 +- 0.001
bacc attack gender - task emb: 0.613 +- 0.004
bacc attack age - task emb: 0.423 +- 0.004
bacc attack gender - g emb: 0.562 +- 0.010
bacc attack age - g emb: 0.410 +- 0.009
bacc attack gender - a emb: 0.612 +- 0.003
bacc attack age - a emb: 0.362 +- 0.007




In [36]:
for exp in results_prep.keys():
    acc_task = get_scalar_stats_wrapper_max(results_prep, exp, keys_merged_masks[0])
    bacc_gender = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_merged_masks[1])
    bacc_age = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys_merged_masks[2])
    print(exp)
    print(f"acc task: {acc_task[0]:.3f} +- {acc_task[1]:.3f}")
    print(f"bacc attack gender: {bacc_gender[0]:.3f} +- {bacc_gender[1]:.3f}")
    print(f"bacc attack age: {bacc_age[0]:.3f} +- {bacc_age[1]:.3f}")
    print("\n")

adv_0.05_seed{}
acc task: 0.917 +- 0.002
bacc attack gender: 0.598 +- 0.007
bacc attack age: 0.390 +- 0.012


adv_0.1_seed{}
acc task: 0.917 +- 0.001
bacc attack gender: 0.596 +- 0.007
bacc attack age: 0.384 +- 0.004


modular_0.05_seed{}
acc task: 0.916 +- 0.000
bacc attack gender: 0.583 +- 0.009
bacc attack age: 0.367 +- 0.017


modular_0.1_seed{}
acc task: 0.917 +- 0.001
bacc attack gender: 0.583 +- 0.006
bacc attack age: 0.360 +- 0.009




In [93]:
for exp, keys in key_map_pan16.items():
    if exp in results_prep:
        acc_task = get_scalar_stats_wrapper_max(results_prep, exp, keys[0])
        acc_g = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys[1])
        # acc_a = get_scalar_stats_wrapper_argmin(results_prep, exp, *keys[2])
        print(exp)
        print(f"acc task: {acc_task[0]:.3f} +- {acc_task[1]:.3f}")
        print(f"bacc attack gender: {acc_g[0]:.3f} +- {acc_g[1]:.3f}")
        print(f"bacc attack age: {acc_a[0]:.3f} +- {acc_a[1]:.3f}")
        print("\n")
    

'adverserial-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender-seed{}'

In [47]:
get_scalar_stats_wrapper_max(results_prep, exp, 'val/acc_task_debiased')

array([0.91883999, 0.00183696])

In [48]:
get_scalar_stats_wrapper_argmin(results_prep, exp, 'val/balanced_acc_adv_attack_adv_emb_gender_target_key_gender', 'val/loss_adv_attack_adv_emb_gender_target_key_gender')

array([0.54591186, 0.00614799])

In [49]:
get_scalar_stats_wrapper_argmin(results_prep, exp, 'val/balanced_acc_adv_attack_adv_emb_gender_target_key_age', 'val/loss_adv_attack_adv_emb_gender_target_key_age')

array([0.42163634, 0.00561892])

In [26]:
experiment_names

{'only_adv_attack_task-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}',
 'only_adv_attack_task-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}',
 'only_adv_attack_task-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}'}

In [27]:
for n in sorted(experiment_names):
    scalars_to_print = ['val/acc_task', 'val/balanced_acc_adv_attack', 'val/acc_task_debiased', 'val/balanced_acc_adv_attack_debiased']
    format_stats = lambda arr: f"{arr[0]:.3f} +- {arr[1]:.3f}"
    res = {k:[_v for _k,_v in v.items() if _k == n][0] for k,v in scalar_stats.items() if k in scalars_to_print}
    res_format = {k:format_stats(v) for k,v in res.items() if not np.isnan(v).any()}
    print(n)
    print(res_format)
    print("\n")

only_adv_attack_task-baseline-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}
{'val/balanced_acc_adv_attack': '0.463 +- 0.018'}


only_adv_attack_task-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}
{'val/balanced_acc_adv_attack': '0.497 +- 0.011'}


only_adv_attack_task-diff_pruning_0.1-bert_uncased_L-4_H-256_A-4-64-2e-05-seed{}-64-0.0001-weighted_loss_prot-prot_idx_1-seed{}
{'val/balanced_acc_adv_attack': '0.495 +- 0.007'}




In [None]:
n = 'adverserial-diff_pruning_0.05-bert_uncased_L-4_H-256_A-4-64-2e-05-weighted_loss_prot-gender'
# scalars_to_print = ['val/acc_task', 'val/balanced_acc_adv_attack']
scalars_to_print = ['val/acc_task_debiased', 'val/balanced_acc_adv_attack_debiased']
format_stats = lambda arr: f"{{${arr[0]*100:.1f} \pm {arr[1]*100:.1f}$}}"

res = {k:[_v for _k,_v in v.items() if _k == n][0] for k,v in scalar_stats.items()}
s = " & ".join([format_stats(res[k]) for k in scalars_to_print])
print(s)