In [None]:
# Set the working directory to the Toto module
%cd ../../boom

In [6]:
import os
import numpy as np
import pandas as pd
import json

from utils.leaderboard import (
    NON_ZERO_METRICS,
    ZERO_METRICS,
    load_model_results,
    separate_zero_inflated_data,
    process_benchmark_model_results,
    get_separate_zero_inflated_leaderboard,
    shifted_gmean,
)

from utils.breakdown import (
    METRIC_NAMES,
    add_agg_columns,
    get_breakdown_table,
    expand_complex_column,
)

BOOMLET_BENCHMARK = False


Full benchmark leaderboard

In [None]:
dfs, dfs_names = load_model_results('results/', BOOMLET_BENCHMARK)
non_zero_dfs, zero_dfs = separate_zero_inflated_data(dfs)

non_zero_dfs = process_benchmark_model_results(
    is_scale_by_naive=True,
    dfs=non_zero_dfs,
    metrics=NON_ZERO_METRICS,
)

zero_dfs = process_benchmark_model_results(
    is_scale_by_naive=False, 
    dfs=zero_dfs,
    metrics=ZERO_METRICS,
)

os.makedirs('leaderboards/', exist_ok=True)
leaderboard = get_separate_zero_inflated_leaderboard(
        non_zero_dfs=non_zero_dfs,
        zero_dfs=zero_dfs,
        dfs_names=dfs_names,
        agg_func=shifted_gmean,
        non_zero_metrics=NON_ZERO_METRICS,
        zero_metrics=ZERO_METRICS,
    )

leaderboard.to_csv(f'leaderboards/boom{"let_" if BOOMLET_BENCHMARK else "_"}leaderboard.csv')


Breakdown tables

In [None]:
agg_method = 'shifted_gmean'
out_dir = 'breakdown_tabels'
agg_columns = ['real_term', 'type', 'domain']
boom_properties = json.load(open('boomlet_properties.json' if BOOMLET_BENCHMARK else 'boom_properties.json', "r"))
dfs = add_agg_columns(non_zero_dfs, agg_columns, boom_properties)

os.makedirs(out_dir, exist_ok=True)

def save_breakdown(dfs_input, agg_column):
    tables = get_breakdown_table(dfs_input, dfs_names, agg_column, NON_ZERO_METRICS, agg_method)
    for metric_key, table in tables.items():
        print(agg_column, metric_key, agg_method)
        table.round(3).to_csv(
            f"{out_dir}/{agg_column}_{METRIC_NAMES.get(metric_key, metric_key.replace('/', '_'))}_{agg_method}.csv"
        )

for col in ['full_benchmark', 'real_term','type']:
    save_breakdown(dfs, col)

for col in ['domain']:
    save_breakdown(expand_complex_column(dfs, col), col)
