## Imports

In [None]:
import json
import os
import re
import sys
from os.path import join
from pprint import pprint

import numpy as np
import pandas as pd
sys.path.append('../bias_probing')
import bias_probing.config as project_config

import pickle
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import display, HTML

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Notebook Setup

In [None]:
import matplotlib.font_manager as fm
# Collect all the font names available to matplotlib
font_names = [f.name for f in fm.fontManager.ttflist]
print([f for f in font_names if 'serif' in f.lower()])

In [None]:
# Notebook settings, do not modify
mpl.rcParams['figure.dpi'] = 300
# mpl.rcParams['font.family'] = 'Microsoft Sans Serif'
plt.rcParams['font.size'] = 30
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['figure.figsize'] = (4.7747, 3.5)
plt.style.use('ggplot')
mpl.use('pgf')
mpl.rcParams.update({
    'pgf.texsystem': "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
    'pgf.preamble': [
        '\DeclareUnicodeCharacter{2212}{-}'
    ]
})

pd.set_option('display.max_colwidth', 40)

%matplotlib inline
# %matplotlib widget
%load_ext autoreload
%autoreload 2

figures_dir = project_config.FIGURES_DIR

## Setup

In [None]:
# Group by all models of the same type
def results_for_task_all_seeds(results, task_config, probe_type='linear'):
    return results[(results['task_config_file'] == task_config) & (results['probe_type'] == probe_type)]\
        [['compression', 'online_cdl', 'eval_accuracy']]\
        .groupby(results['model_name_or_path'].str.replace('_seed_.*', ''))\
        .agg(['std', 'mean', 'count'])\
        .round(3)\
        .sort_values([('online_cdl',  'mean')], ascending=False)

#         .drop('name.1', axis=1, errors='ignore')\
#         .drop('output_dir', axis=1, errors='ignore')\
#         .drop('type', axis=1, errors='ignore')\
#         .drop('seed', axis=1, errors='ignore')\
#         .drop('embedding_size', axis=1, errors='ignore')\

def load_experiment_results(output_dir):
    p = os.path.join(output_dir, f'results.pkl')
    with open(p, 'rb') as f:
        exps = pickle.load(f)
    return exps


def plot_mean_with_std(results, name_pretty_dict=None, replace='_[0-9]+', title=None):
    df = pd.DataFrame()
    df.index.name = 'name'
    fractions = None
    for out_dir, name in zip(list(results.output_dir), list(results.name)):
        exp = load_experiment_results(out_dir)
        if fractions is None:
            fractions = exp.fractions
        loss_list = pd.Series(list(map(lambda r: r['eval_loss'], exp.report['online_coding_list'])), \
                              name=name)
        if df is None:
            df = loss_list.to_frame()
        else:
            df = df.append(loss_list)

    df_mean = df.groupby(df.index.str.replace(replace, '')).agg(['mean', 'std'])
    fig, ax = plt.subplots() 
    for name in filter(lambda k: k in name_pretty_dict, df_mean.index):
        x = []
        y = []
        e = []
        for col in df_mean.columns:
            idx, typ = col
            if typ == 'mean':
                x.append(idx)
                y.append(df_mean.loc[name, col])
            elif typ == 'std':
                e.append(df_mean.loc[name, col])
        plt.errorbar(x, y, e, linestyle='-', marker='^', label=name if name_pretty_dict is None else name_pretty_dict[name])
    
    ax.set_xticks(x)
    ax.set_xticklabels([str(i) for i in fractions[:-1]])
    ax.set_xlabel('% of training data')
    ax.set_ylabel('Test error')
    # plt.title(title)
    L = ax.legend()
    plt.setp(L.texts, family='sans-serif')
    plt.legend(borderpad=1, loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=len(df_mean))
    if title is not None:
        plt.savefig(join(figures_dir, title.replace(' ', '_') + '.png'), bbox_inches='tight', dpi=300)
    plt.show()
    

def show_results(df: pd.DataFrame, tasks):
    for task_config in tasks:
        print(f'{task_config}')
        display(results_for_task_all_seeds(df, task_config, probe_type='linear'))
        plot_mean_with_std(df[df['task_config_file'] == task_config], name_pretty_dict)

    
    CSS = """
    .output {
    }
    """

    return HTML('<style>{}</style>'.format(CSS))


## Experiment Results

In [None]:
tasks_hypo = [
    "mnli_neg_words_config.json",
    "mnli_neg_words_hypo_only_config.json",
    "snli_neg_words_config.json",
    "snli_neg_words_hypo_only_config.json",
    "fever_neg_words_config.json",
    "fever_neg_words_hypo_only_config.json",
]

tasks_hans = [
    "mnli_hclass_lex_config.json",
    "mnli_hclass_sub_config.json",
    "snli_hclass_lex_config.json",
    "snli_hclass_sub_config.json",
]

name_pretty_dict = {
    'baseline': 'Base',
    'dfl': 'DFL',
    'pretrained': 'Pretrained',
    'random': 'Random',
    'bert_confreg': 'ConfReg',
    'bert_implicit': 'Implicit',
    'bert_tiny': 'TinyBERT'
}

In [None]:
# results.groupby('task_config_file').count().mean(axis=1).rename('Number of experiments').to_frame()

In [None]:
# results_hans = pd.read_csv(join(project_config.TEMP_DIR, 'results_hans.csv'))
# len(results_hans)

In [None]:
# show_results(results_hans, tasks_hans)

In [None]:
# results_hypo = pd.read_csv(join(project_config.TEMP_DIR, 'results_hypo.csv'))
# len(results_hypo)

In [None]:
# show_results(results_hypo, tasks_hypo)

In [None]:
# results_hypo_bigrams = pd.read_csv(join(project_config.TEMP_DIR, 'results_hypo_bigrams.csv'))
# display(results_for_task_all_seeds(results_hypo_bigrams, 'mnli_neg_bigrams_hypo_only_config.json', probe_type='linear'))
# plot_mean_with_std(results_hypo_bigrams, name_pretty_dict=name_pretty_dict)

### Analysis

In [None]:
# results_hypo_dfl_gamma = pd.read_csv(join(project_config.TEMP_DIR, 'results_hypo_dfl_gamma.csv'))
# results_hypo_dfl_gamma[:5]

In [None]:
# tasks = tasks_mnli
# for task_config in tasks:
#     print(f'{task_config}')
#     plot_dfl_gamma_experiment_results(task_config, results_hypo_dfl_gamma)

In [None]:
# results_hans_dfl_gamma = pd.read_csv(join(project_config.TEMP_DIR, 'results_hans_dfl_gamma_new.csv'))
# len(results_hans_dfl_gamma)

In [None]:
# tasks = tasks_hans
# for task_config in tasks:
#     print(f'{task_config}')
#     plot_dfl_gamma_experiment_results(task_config, results_hans_dfl_gamma)

## LMI Analysis

In [None]:
# from experiments.calculate_lmi import *
# task = jiant_create_task_from_config_path(os.path.join(project_config.DATA_DIR, 'datasets/configs/fever_nli_config.json'))
# dataset = task.get_train_examples()

In [None]:
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# def tokenize(sent):
#     return tokenizer.convert_ids_to_tokens(tokenizer(sent, add_special_tokens=False)['input_ids'])

In [None]:
# dataset_size = len(dataset)
# print(f'Dataset size: {dataset_size}')

In [None]:
# def print_top_refutes(ngram_size, top_n=30):
#     ngram_lmi_dict = get_ngram_label_lmi(dataset, n=ngram_size)
#     display(highest_lmi_words('REFUTES', ngram_lmi_dict, n=30))
#     return ngram_lmi_dict

In [None]:
# unigrams = print_top_refutes(1)

In [None]:
# bigrams = print_top_refutes(2)

In [None]:
# trigrams = print_top_refutes(3)

## W&B Results

### 11/02/2021 (9:16)

In [None]:
# results_hans = pd.read_csv(join(project_config.TEMP_DIR, 'wandb_export_2021-02-11T09_16_04.782+02_00.csv'))
# len(results_hans)

In [None]:
# plot_hyperparam_experiment_results({
#     'mnli_hclass_lex_config.json': 'LexClass-MNLI [Lex. Overlap]',
#     'mnli_hclass_sub_config.json': 'LexClass-MNLI [Subsequence]'
# }, results_hans, metric='online_cdl', title='MDL vs. DFL focusing parameter $\gamma$')
# plot_hyperparam_experiment_results({
#     'snli_hclass_lex_config.json': 'LexClass-SNLI [Lex. Overlap]',
#     'snli_hclass_sub_config.json': 'LexClass-SNLI [Subsequence]'
# }, results_hans, metric='online_cdl', title='MDL vs. DFL focusing parameter $\gamma$')

In [None]:
# plot_hyperparam_experiment_results({
#     'mnli_hclass_lex_config.json': 'LexClass-MNLI [Lex. Overlap]',
#     'mnli_hclass_sub_config.json': 'LexClass-MNLI [Subsequence]'
# }, results_hans, 
#     metric='online_cdl', title=r'MDL vs. PoE weighing parameter $\alpha$',
#     key='alpha', x_label=r'$\alpha$')
# plot_hyperparam_experiment_results({
#     'snli_hclass_lex_config.json': 'LexClass-SNLI [Lex. Overlap]',
#     'snli_hclass_sub_config.json': 'LexClass-SNLI [Subsequence]'
# }, results_hans, 
#     metric='online_cdl', title=r'MDL vs. PoE weighing parameter $\alpha$',
#     key='alpha', x_label=r'$\alpha$')

In [None]:
# show_results(results_hans, tasks_hans)

In [None]:
# def fever_neg_words_results(task_name, **kwargs):
#     task_config_file = f'{task_name}.json'
#     header_mapping = {
# #         'online_cdl': '$L_{\mathrm{online}}$',
#         'compression': '$\mathcal{C}$',
#         'eval_accuracy_probing': 'Acc.',
# #         'eval_fever_symmetric.accuracy_0': 'Symmetric (Supports)',
# #         'eval_fever_symmetric.accuracy_1': 'Symmetric (Refutes)',
# #         'eval_fever_symmetric.accuracy': 'Symmetric',
#         'eval_fever_symmetric_hard.accuracy': 'Hard',
#         'eval_fever_symmetric_hard.accuracy_0': 'Hard (Supports)',
#         'eval_fever_symmetric_hard.accuracy_1': 'Hard (Refutes)',
#     }
#     return get_results_table(task_config_file, header_mapping, index_mapping_fever, 
#                              drop=['tiny'], **kwargs)

# def multi_nli_lex_class_results(task_keys: TaskKeys, **kwargs):
#     task_name = task_keys.task_name
#     task_config_file = f'{task_name}.json'
#     header_mapping = {
# #         'online_cdl': '$L_{\mathrm{online}}$',
#         'compression': '\multicolumn{1}{c}{$\mathcal{C}$}',
#         'eval_accuracy_probing': '\multicolumn{1}{c}{Acc.}',
#         task_keys.anti_bias_key: task_keys.anti_bias_label
#     }
#     results = get_results_table(task_config_file, header_mapping, index_mapping, 
#                              drop=['tiny', 'hans_poe_32b', 'hans_dfl_32b', 
#                                    'hans_confreg', 'hypo_confreg', 'confreg', 'hypo_only'], **kwargs)
#     return results


# def multi_nli_neg_words_results(task_keys: TaskKeys, **kwargs):
#     task_name = task_keys.task_name
#     task_config_file = f'{task_name}.json'
#     header_mapping = {
#         'compression': '\multicolumn{1}{c}{$\mathcal{C}$}',
#         'eval_accuracy_probing': 'Acc.',
#         task_keys.anti_bias_key: task_keys.anti_bias_label
#     }
#     results = get_results_table(task_config_file, header_mapping, index_mapping, 
#                              drop=['tiny', 'hans_poe_32b', 'hans_dfl_32b', 
#                                    'hans_confreg', 'hypo_confreg', 'confreg', 'hypo_only'], **kwargs)
#     return results

# def snli_lex_class_results(task_keys, **kwargs):
#     task_name = task_keys.task_name
#     task_config_file = f'{task_name}.json'
#     header_mapping = {
# #         'online_cdl': '$L_{\mathrm{online}}$',
#         'compression': '$\mathcal{C}$',
#         'eval_accuracy_probing': 'Acc.',
#         task_keys.anti_bias_key: task_keys.anti_bias_label
#     }
#     return get_results_table(task_config_file, header_mapping, index_mapping, 
#                              drop=['tiny', 'hans_confreg', 'confreg'], **kwargs)


# def snli_neg_words_results(task_name, **kwargs):
#     task_config_file = f'{task_name}.json'
#     header_mapping = {
#         'compression': '$\mathcal{C}$',
#         'eval_accuracy_probing': 'Acc.',
#         'eval_snli_hard.accuracy': 'Hard'
#     }
#     return get_results_table(task_config_file, header_mapping, index_mapping, 
#                              drop=['tiny', 'hans_confreg', 'confreg'], **kwargs)

## From JSON

In [None]:
def drop_columns_starting_with(df, s):
    return df.drop(df.columns[df.columns.str.startswith(s)], axis='columns')


def parse_json_str(s):
    if s is None or s == 'None':
        return dict()
    return json.loads(s.replace("'", '"').replace('None', 'null'))


def get_row_json_keys(df, name):
    return list(parse_json_str(df.iloc[df[name].first_valid_index()][name]).keys())


def normalize_columns(col):
        if 'accuracy' in col.name:
            return col * 100
        return col


def online_code_results():
    df_oc = pd.read_csv(join(project_config.RESULTS_DIR, 'results_online_code.csv'))\
    .drop(['Unnamed: 0', 'cache_only', 'config_dir', 'logging_dir', 'learning_rate',
           'task_type', 'batch_size',
           'mdl_fractions', 'early_stopping', 'max_seq_length', 'hypothesis_only',
          'new_split_ratio', 'overwrite_cache', 'checkpoint_steps', 'min_dataset_size',
          'num_train_epochs', 'train_batch_size', 'task_mapper_kwargs', 'wandb_project_name',
          'early_stopping_tolerance', 'task_heuristic', 'task_n', 'task_joint', 'task_negative_vocab',
          '_step', 'fraction', 'embedding_size', '_timestamp', '_runtime', 'name'], axis='columns', errors='ignore')\
    .rename({'name': 'run_name', 'name.1': 'name', 'model_name_or_path': 'model_name'}, axis='columns')

    df_oc['model_name'] = df_oc['model_name'].map(lambda x: re.sub('seed:[0-9]+/', '', x))
    df_oc['eval_accuracy'] = df_oc['eval_accuracy'] * 100
    df_oc = df_oc.drop('online_cdl', axis='columns')
    # df_oc['online_cdl'] = df_oc['online_cdl'] / 1000
    df_oc = df_oc.round(2)
    df_oc['eval_accuracy'] = df_oc['eval_accuracy'].round(1)
    return df_oc


def fine_tuning_results():
    df = pd.read_csv(join(project_config.RESULTS_DIR, 'results_debiasing.csv'))
    df = drop_columns_starting_with(df, 'gradients/')
    df = df.drop(['Unnamed: 0', 'fp16', 'debug', 'id2label', 
                  'label2id', 'top_k', 'top_p', 'adafactor', 'deepspeed',
                 'do_sample', 'num_beams', 'report_to', 'use_cache', 'adam_beta1', 'adam_beta2',
                 'do_predict', 'hidden_act', 'is_decoder', 'local_rank', 'max_length', 'min_length',
                 'model_type', 'past_index', 'save_steps', 'vocab_size', 'hidden_size', 'label_names',
                 'logging_dir', 'return_dict', 'sharded_ddp', 'temperature', 'torchscript', 'bos_token_id',
                 'disable_tqdm', 'eos_token_id', 'fp16_backend', 'pad_token_id', 'pruned_heads', 'sep_token_id',
                 'use_bfloat16', 'bad_words_ids', 'mp_parameters', 'output_scores', 'save_strategy',
                 'tpu_num_cores', 'early_stopping', 'fp16_full_eval', 'fp16_opt_level', 'layer_norm_eps',
                 'length_penalty', 'finetuning_task', 'group_by_length', 'num_beam_groups', 'overwrite_cache',
                 'tokenizer_class', 'type_vocab_size', 'eval_datasets'], axis='columns', errors='ignore')
    
    df = df.drop(df.columns[df.isnull().all()], axis='columns')

    columns = df.columns[df.columns.str.startswith('eval/')]
    for col in columns:
        if df[col].dtype == object:
            keys = get_row_json_keys(df, col)
            old_key = col.replace("eval_", "eval/")
            new_key = col.replace("eval/", "eval_")
            for key in keys:
                def dict_mapper(d):
                    if not isinstance(d, str):
                        return np.nan
                    res = parse_json_str(d)
                    if not key in res:
                        return np.nan
                    return res[key]
                if old_key in df.columns:
                    df[f'{new_key}.{key}'] = \
                    df[col]\
                    .map(dict_mapper)\
                    .combine_first(df[old_key].map(dict_mapper))
                else:
                    df[f'{new_key}.{key}'] = \
                    df[col]\
                    .map(dict_mapper)
            df = df.drop(col, axis='columns')
    
    # Legacy fix
    columns = df.columns[df.columns.str.startswith('eval_')]
    for col in columns:
        if df[col].dtype == object:
            keys = get_row_json_keys(df, col)
            for key in keys:
                def dict_mapper(d):
                    if not isinstance(d, str):
                        return np.nan
                    res = parse_json_str(d)
                    if not key in res:
                        return np.nan
                    return res[key]
                new_key = f'{col}.{key}'
                if new_key in df.columns:
                    df[f'{col}.{key}'] = df[f'{col}.{key}'].combine_first(df[col].map(dict_mapper))
                else:
                    df[f'{col}.{key}'] = df[col].map(dict_mapper)
            df = df.drop(col, axis='columns')
    
    df = df.drop(df.columns[df.isnull().all()], axis='columns')

    df = df[['tag', 'seed'] + list(df.columns[df.columns.str.startswith('eval_')])]\
        .drop(['eval_steps', 'eval_batch_size', 'eval_accumulation_steps', 'eval_datasets', 'eval_report'], axis='columns', errors='ignore')\
        .drop(df.columns[df.columns.str.contains('hard_mismatched|eval_report')], axis='columns', errors='ignore')\
        .drop(df.columns[df.columns.str.contains('f1|recall|precision')], axis='columns', errors='ignore')\
        .rename({'tag': 'model_name'}, axis='columns')\
        .apply(normalize_columns)

    return df


def merge_results(online_code_results, fine_tuning_results):
    return pd.merge(fine_tuning_results, 
                    online_code_results, 
                    how='outer',
                    on=['model_name', 'seed'],
                    suffixes=('_downstream', '_probing')
                   )
    
    
def results_for_task(df, task_config_file: str, agg=['mean', 'std']):
    temp = df\
        .groupby(['task_config_file', 'name'])\
        .agg(agg)\
        .loc[[task_config_file]]\
        .sort_values(('compression', 'mean'))\
        .dropna(how='all', axis='columns')\
        .round(2)
    temp = temp.fillna('-')
    pred = (temp == 0).all(axis=0)
    temp = temp.drop([i for i in pred.index if pred[i]], axis='columns')
    return temp


def to_latex(results):
    print(results.to_latex(
        escape=False,
        index=True,
        column_format='ll' + 'r' * len(results.columns),
        multicolumn_format='c'
    ))

### Finetuning

In [None]:
df_ft = fine_tuning_results()
df_ft

### Online Code

In [None]:
df_oc = online_code_results()
df_oc

In [None]:
df_merged = merge_results(df_oc, df_ft)
df_merged

### Camera Ready

In [None]:
from dataclasses import dataclass

annotation_dict = {
    'implicit_poe': 'Impl-PoE',
    'implicit_dfl': 'Impl-DFL',
    'implicit_confreg': 'ConfReg',
    'hans_poe': 'HANS-PoE',
    'hans_dfl': 'HANS-DFL',
    'implicit_poe_e2e': 'Impl-PoE-E2E',
    'implicit_dfl_e2e': 'Impl-DFL-E2E',
    'implicit_dfl_2k': 'Impl-DFL-Subset',
    'hypo_poe': 'Hypo-PoE',
    'hypo_dfl': 'Hypo-DFL'
    # 'implicit_poe_2k': 'Impl-PoE-Subset',
}

index_mapping = {
    'random': ('', 'Random'),
    'pretrained': ('', 'Pretrained'),
    'baseline': ('', 'Base'),
    'implicit_poe': ('Impl. [TinyBERT]', 'PoE'),
    'implicit_dfl': ('Impl. [TinyBERT]', 'DFL'),
    'implicit_poe_e2e': ('Impl. [TinyBERT]', 'PoE [E2E]'),
    'implicit_dfl_e2e': ('Impl. [TinyBERT]', 'DFL [E2E]'),
    'tiny': ('', 'TinyBERT'),
    'hans_poe': ('HANS', 'PoE [E2E]'),
    'hans_dfl': ('HANS', 'DFL [E2E]'),
    'hypo_poe': ('Hypothesis', 'PoE [E2E]'),
    'hypo_dfl': ('Hypothesis', 'DFL [E2E]'),
    'implicit_confreg': ('Impl. [Subset]', 'ConfReg'),
    'implicit_dfl_2k': ('Impl. [Subset]', 'DFL'),
    'implicit_poe_2k': ('Impl. [Subset]', 'PoE'),
}

index_mapping_fever = {
    **index_mapping,
    'hypo_poe': ('Claim', 'PoE [E2E]'),
    'hypo_dfl': ('Claim', 'DFL [E2E]'),
}

@dataclass
class TaskKeys:
    task_name: str
    anti_bias_key: str
    anti_bias_label: str
    scatter_title: str = None

In [None]:
# Camera Ready Tables
INDEX_NAME = 'Model'

def combine_mean_std(values):
    if values.str.contains('-').all():
        return '-'
    elif values.str.contains('-').any():
        return f'${values[0]}$'
    
    return '$' + r' \pm '.join(values.apply(lambda x: '{0:.2f}'.format(float(x)))) + '$'



def get_results_table(task_config_file, header_mapping, index_mapping, drop=[], rename=True, std=True):
    if len(header_mapping) == 0:
        raise ValueError('header_mapping must contain at least one key to map')
    keys = header_mapping.keys()
    
    results = results_for_task(df_merged, task_config_file)[keys].apply(normalize_columns)
    results.columns = results.columns.map('::'.join)
    # display(results)
    
    if std:
        results = results\
            .groupby(lambda x: x.split('::')[0], axis=1)\
            .apply(lambda x: x.astype(str).apply(combine_mean_std, 1))[keys]\
    
    results = results.drop([(task_config_file, x) for x in drop], errors='ignore')
        
    if rename:
        results = results\
        .rename(header_mapping, axis='columns')
    
    results = results.loc[task_config_file]
    if rename:
        results = results.rename({'name': INDEX_NAME}, axis='columns')
        results.index = results.index.map(lambda x: index_mapping[x])
        results.index.names = ['Bias', INDEX_NAME]
        results = results.sort_index()
    return results


def task_results(task_keys: TaskKeys, **kwargs):
    task_name = task_keys.task_name
    task_config_file = f'{task_name}.json'
    header_mapping = {
#         'online_cdl': '$L_{\mathrm{online}}$',
        'compression': '\multicolumn{1}{c}{$\mathcal{C}$}',
        'eval_accuracy_probing': '\multicolumn{1}{c}{Acc.}',
        task_keys.anti_bias_key: task_keys.anti_bias_label
    }
    results = get_results_table(task_config_file, header_mapping, index_mapping, 
                             drop=['tiny', 'hans_poe_32b', 'hans_dfl_32b', 
                                   'hans_confreg', 'hypo_confreg', 'confreg', 'hypo_only'], **kwargs)
    return results


# Polynomial Regression
def polyfit(x, y, degree):
    results = {}

    coeffs = np.polyfit(x, y, degree)

     # Polynomial Coefficients
    results['polynomial'] = coeffs.tolist()

    # r-squared
    p = np.poly1d(coeffs)
    # fit values, and mean
    yhat = p(x)                         # or [p(z) for z in x]
    ybar = np.sum(y)/len(y)          # or sum(y)/len(y)
    ssreg = np.sum((yhat-ybar)**2)   # or sum([ (yihat - ybar)**2 for yihat in yhat])
    sstot = np.sum((y - ybar)**2)    # or sum([ (yi - ybar)**2 for yi in y])
    results['determination'] = ssreg / sstot

    return results


def delta_graph(task_keys: TaskKeys, figure_name, return_points=False):
    task_name = task_keys.task_name
    x_key = task_keys.anti_bias_key + '::mean'
    y_key = 'compression::mean'
    
    temp = task_results(task_keys, rename=False, std=False)
    # display(temp)
    # temp = temp.set_index(INDEX_NAME)
    temp = temp.applymap(lambda x: x if x != '-' else np.nan)
    baseline_results = temp.loc['baseline']
    temp = temp - baseline_results
    temp = temp.drop(['baseline'], axis='index')
    
    temp = temp.reset_index()
    if temp.columns.nlevels > 1:
        temp = temp.droplevel(1, axis='columns')
        
    temp = temp[[c for c in temp.columns if '::std' not in c]]
    temp = temp.dropna()
    x = temp[x_key].dropna()
    y = temp[y_key].dropna()
    correlation = x.corr(y)
    if return_points:
        return (x, y), (baseline_results[x_key], baseline_results[y_key])
    
    print(f'Pearson = {correlation}\n')

    br = baseline_results
    ax = temp.plot.scatter(x=x_key, y=y_key, c='forestgreen', s=80, marker='^', figsize=(4.7747, 3.5))
    ax.scatter(0, 0, c='tab:red', s=80, marker='*')
    ax.text(0.1, 0.9, r'$\rho = {:.3f}$'.format(correlation),
            horizontalalignment='center',
            verticalalignment='center',
            transform = ax.transAxes,
            backgroundcolor='white')

    ax.set_xlabel('$\Delta\ \mathrm{Robustness}$')
    ax.set_ylabel('$\Delta\ \mathrm{Bias\ Extractability}$')
#     title = {
#         'fever_neg_words': 'FEVER / FEVER-Symmemtric',
#         'mnli_lex_class': 'MNLI / HANS',
#         'mnli_lex_class_sub': 'MNLI / HANS',
#         'mnli_neg_words': 'MNLI / Hard',
#         'snli_lex_class': 'SNLI / HANS',
#         'snli_lex_class_sub': 'SNLI / HANS',
#         'snli_neg_words': 'SNLI / Hard'
#     }[task_name]
    
    pf_results = polyfit(x, y, 1)
    print(f'Linear Regression:\n')
    pprint(pf_results)
    
    m ,b = pf_results['polynomial']
    # ax.plot(x, m * x + b, c='burlywood')
    
    if annotation_dict is not None and False:
        temp = temp.set_index('name')
        display(temp)
        for key, annotation in annotation_dict.items():
            if key not in temp.index:
                continue
            row = temp.loc[key]
            point = (row[x_key], row[y_key])
            ax.annotate(annotation,
                        xy=point,
                        xycoords='data',
                        xytext=(1.0, 0.0),
                        textcoords='offset points')

    # LaTeX
    try:
        plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.pgf'), dpi=300, bbox_inches='tight')
    except:
        print('Latex Error')
    # PNG
    plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.png'), dpi=300, bbox_inches='tight')

    
def delta_points(task_keys: TaskKeys):
    return delta_graph(task_keys, None, return_points=True)

#### FEVER

In [None]:
fever_neg_words = TaskKeys(
    task_name='fever_neg_words',
    anti_bias_key='eval_fever_symmetric_hard.accuracy',
    anti_bias_label='Symmetric'
)

results = task_results(fever_neg_words)
display(results)
to_latex(results)

#### MNLI

In [None]:
mnli_lex_class = TaskKeys(
    task_name='mnli_lex_class',
    anti_bias_key='eval_hans_lexical_overlap.accuracy_1',
    anti_bias_label='$\mathrm{HANS}^-$'
)

results_mnli_lex_class = task_results(mnli_lex_class)

In [None]:
mnli_lex_class_sub = TaskKeys(
    task_name='mnli_lex_class_sub',
    anti_bias_key='eval_hans_subsequence.accuracy_1',
    anti_bias_label='$\mathrm{HANS}^-$'
)

results_mnli_lex_class_sub = task_results(mnli_lex_class_sub)

In [None]:
results = pd.concat([results_mnli_lex_class, results_mnli_lex_class_sub], axis=1)
display(results)
to_latex(results)

In [None]:
mnli_neg_words = TaskKeys(
    task_name='mnli_neg_words',
    anti_bias_key='eval_multi_nli_hard_matched.accuracy',
    anti_bias_label='Hard'
)

results = task_results(mnli_neg_words)
display(results)
to_latex(results)

#### SNLI

In [None]:
snli_lex_class = TaskKeys(
    task_name='snli_lex_class',
    anti_bias_key='eval_hans_lexical_overlap.accuracy_1',
#     anti_bias_key='eval_hans.accuracy_1',
    anti_bias_label='$\mathrm{HANS}^-$'
)

results_snli_lex_class = task_results(snli_lex_class)

In [None]:
snli_lex_class_sub = TaskKeys(
    task_name='snli_lex_class_sub',
    anti_bias_key='eval_hans_subsequence.accuracy_1',
#     anti_bias_key='eval_hans.accuracy_1',
    anti_bias_label='$\mathrm{HANS}^-$'
)

results_snli_lex_class_sub = task_results(snli_lex_class_sub)

In [None]:
results = pd.concat([results_snli_lex_class, results_snli_lex_class_sub], axis=1)
display(results)
to_latex(results)

In [None]:
snli_neg_words = TaskKeys(
    task_name='snli_neg_words',
    anti_bias_key='eval_snli_hard.accuracy',
    anti_bias_label='Hard'
)

results = task_results(snli_neg_words)
display(results)
to_latex(results)

### Delta Graphs

In [None]:
delta_graph(mnli_lex_class, 'scatter_mnli_lex_class')

In [None]:
delta_graph(mnli_lex_class_sub, 'scatter_mnli_lex_class_sub')

In [None]:
delta_graph(mnli_neg_words, 'scatter_mnli_neg_words')

In [None]:
delta_graph(fever_neg_words, 'scatter_fever_neg_words')

In [None]:
delta_graph(snli_lex_class, 'scatter_snli_lex_class')

In [None]:
delta_graph(snli_lex_class_sub, 'scatter_snli_lex_class_sub')

In [None]:
delta_graph(snli_neg_words, 'scatter_snli_neg_words')

In [None]:
frames = []
for task in [
    mnli_lex_class,
    mnli_lex_class_sub,
    snli_lex_class,
    snli_lex_class_sub,
    fever_neg_words,
    mnli_neg_words,
    snli_neg_words
]:
    (x, y), (bx, by) = delta_points(task)
    nx = (x).div(bx).rename('x')
    ny = (y).div(by).rename('y')

    frames.append(pd.concat([nx, ny], axis=1))

df = pd.concat(frames, axis=0).reset_index()
# display(df)
ax = df.plot.scatter(x='x', y='y', c='royalblue', s=120, marker='.')
ax.scatter(0, 0, c='tab:red', marker='*', s=120)
ax.set_xlabel('$\Delta\ \mathrm{Robustness}\ [\%]$')
ax.set_ylabel('$\Delta\ \mathrm{Extractability}\ [\%]$')
figure_name = 'scatter_full'
print(f'Pearson = {df["x"].corr(df["y"])}')
# LaTeX
plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.pgf'), dpi=300, bbox_inches='tight')
# PNG
plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.png'), dpi=300, bbox_inches='tight')

## Retraining

In [None]:
def retraining_results(dataset_name, prefix, title=None):
    df_ft = fine_tuning_results()
    retrained_df = df_ft[df_ft.model_name.str.startswith(f'{prefix}/') & df_ft.model_name.str.contains(dataset_name)]\
        .groupby('model_name')\
        .mean()[['eval_hans.accuracy_0', 'eval_hans.accuracy_1']].rename({
            'eval_hans.accuracy_0': 'hans-ent-retrained',
            'eval_hans.accuracy_1': 'hans-non-ent-retrained'
        }, axis='columns')

    retrained_df.index = retrained_df.index.map(lambda s: s.replace(f'{prefix}/', ''))
    original_models = list(retrained_df.index.values)
    original_df = df_ft[df_ft.model_name.isin(original_models)]\
        .groupby('model_name')\
        .mean()[['eval_hans.accuracy_0', 'eval_hans.accuracy_1']].rename({
            'eval_hans.accuracy_0': 'hans-ent',
            'eval_hans.accuracy_1': 'hans-non-ent'
        }, axis='columns')

    df = pd.concat([original_df, retrained_df], axis=1)
    df.index = df.index.map(lambda s: s.replace(f'{dataset_name}_', '').replace('bert_', '').replace('/', '-').replace('_', '-')).rename('name')
    df = df.drop(df.index[df.index.str.contains('hypo')])
    ax = df.plot.bar()
    if title is not None:
        ax.set_title(title)
    ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45, ha='right')
    plt.legend(bbox_to_anchor=(1.03, 1), loc='upper left')
    y_baseline = df[df.index.str.contains('baseline')].iloc[0]['hans-non-ent-retrained']
    plt.axhline(y=y_baseline)

    figure_name = f'retrained_{dataset_name}_{prefix}'
    return plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.png'), dpi=300, bbox_inches='tight')


retraining_results('multi_nli', 'retrain2', 'MNLI')

In [None]:
retraining_results('multi_nli', 'retrain3', 'MNLI')

## Experiment: Varying the Debiasing Effect

### Setup

In [None]:
# def plot_hyperparam_experiment_results(task_dict, results: pd.DataFrame, logy=False,
#                                        probe_type='linear',
#                                        key='dfl_gamma',
#                                        title='',
#                                        x_label='$\gamma$',
#                                        y_label='MDL',
#                                        metric='eval_accuracy'):
#     fig0, ax0 = plt.subplots()
#     # ax1 = ax0.twinx()
    
#     for task_name, display_name in task_dict.items():
#         df = results_for_task_all_seeds(results, task_name, probe_type=probe_type).sort_index()
#         df = df[df.index.str.contains(key)]
#         df.index = df.index.map(lambda s: float(s.split('_')[-1]))

#         width = 0.2
#         df[(metric, 'mean')].plot(kind='line', marker='v',
#                                   logy=logy, stacked=True, ax=ax0, label=f'{display_name}', rot=0,
#                                   yerr=df[(metric, 'std')])
#     #     df[('eval_accuracy', 'mean')].plot(kind='line', color='tab:orange', marker='^', \
#     #                                        logy=logy, secondary_y=True, ax=ax1, label='Acc.', \
#     #                                        yerr=df[('eval_accuracy', 'std')])
    
#     plt.title(title)
#     ax0.set_xlabel(x_label)
#     ax0.set_ylabel(y_label)
#     # ax1.set_xlabel('$\gamma$')
#     plt.legend()
#     ax0.legend()
#     # ax0.grid(False)
#     # ax1.grid(False)
#     base_dir = join(figures_dir, key)
#     os.makedirs(base_dir, exist_ok=True)
#     filename = re.sub(r'_|(config.json)', ' ', '-'.join(task_dict.keys())) + f'-{key}-{metric}'
#     plt.savefig(join(base_dir, filename.replace(' ', '_') + '.png'), dpi=300)
#     plt.show()
#     plt.close()

In [None]:
name_to_title = {
    'mnli-lex-class': r'Overlap', 
    'mnli-lex-class-sub': r'Subsequence'
}

plt.figure(figsize=(4.7747, 3.5))
for task_name in name_to_title.keys():
    df = df_merged[df_merged.model_name.str.contains('dfl_gamma') & (df_merged.task_name == task_name)]\
        [['model_name', 'seed', 'compression', 'task_name']]
    df['gamma'] = df['model_name'].str.split('_').str[-1].astype(float)
    df['model_name'] = df['model_name'].str.split('_').apply(lambda x: '_'.join(x[:-1]))
    unstacked = df.groupby(['model_name', 'gamma']).agg(['mean', 'std'])
    display(unstacked)
    unstacked = unstacked.reset_index().unstack(level=0)

    x = unstacked['gamma']
    y = unstacked['compression']['mean']
    yerr = unstacked['compression']['std']

    # plt.plot(x, y)
    plt.errorbar(x=x, y=y, yerr=yerr, label=name_to_title[task_name])
    plt.xlabel('$\gamma$')
    plt.ylabel('Compression')
    plt.legend()

figure_name = 'exp_dfl_gamma'
# plt.show()
plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.pgf'), dpi=300, bbox_inches='tight')
plt.savefig(join(project_config.FIGURES_DIR, f'{figure_name}.png'), dpi=300, bbox_inches='tight')