## Plotting

In [16]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
from omegaconf import OmegaConf
sys.path.insert(0, "..")
from floral.utils.plotting import (
    OUTPUT_DIR, PLOTS_DIR,
    load_runs,
    histories_to_df,
    setup_experiment_plotting_and_variables,
    variables_metrics_to_csv,
)

HISTORIES = load_runs(output_dir=os.path.join("..", OUTPUT_DIR))

In [17]:
# ========== CHOOSE EXPERIMENT ========== #
EXPERIMENTS = [
    # "run_methods_synthetic_linear",
    # "run_methods_synthetic_mlp",
    "run_methods_mnist_rotate",
    "run_methods_mnist_label_shift",
    "run_methods_cifar10_rotate",
    "run_methods_cifar10_label_shift",
    "run_methods_cifar100",
    "run_methods_mnist_rotate_reduced",
    "run_methods_mnist_label_shift_reduced",
    "run_methods_cifar10_rotate_reduced",
    "run_methods_cifar10_label_shift_reduced",
    "run_methods_cifar100_reduced",
    # "run_methods_shakespeare",    # XXX
    # "run_methods_emnist",  # XXX
    # "run_methods_stackoverflow",  # XXX
]

In [18]:
summary_dfs = {}
for experiment in EXPERIMENTS:
    filter_values = f"""
    experiment: [{experiment}]
    """
    ignore_values = """
    """
    print("\n" + experiment + "...")
    history_df = histories_to_df(
        HISTORIES,
        filter_values=OmegaConf.create(filter_values),
        ignore_values=OmegaConf.create(ignore_values),
        #  downsampled_len=500,
        hide_na=True,
    )
    results_dir = os.path.join("..", PLOTS_DIR, f"{experiment}")
    os.makedirs(results_dir, exist_ok=True)
    if len(history_df) == 0:
        print("Failed to find valid runs")
        continue
    history_df, plot_opts, variables = setup_experiment_plotting_and_variables(history_df, experiment)
    df = variables_metrics_to_csv(history_df, variables, results_dir)
    summary_dfs[experiment] = df
    print("Ok")


run_methods_mnist_rotate...
Ok

run_methods_mnist_label_shift...
Ok

run_methods_cifar10_rotate...
Ok

run_methods_cifar10_label_shift...
Ok

run_methods_cifar100...
Ok

run_methods_mnist_rotate_reduced...
Ok

run_methods_mnist_label_shift_reduced...
Ok

run_methods_cifar10_rotate_reduced...
Ok

run_methods_cifar10_label_shift_reduced...
Ok

run_methods_cifar100_reduced...
Ok


In [19]:
COL_SEP = " & "
DECIMALS = 1
METHOD_COLNAME = "Method"
OPTIMAL_ROUTER_COLNAME = "Optimal $\\pi$"
FIELD_SIZE = 25
EMPTY_FIELD = f"{'-':{FIELD_SIZE}s}"

# TODO: mark best and second best metrics

SORTED_METHODS = [
    ("FedAvg", False),
    ("Local Adaptor", False),
    ("Ensemble", False),
    ("Ensemble", True),
    ("FLoRAL(1%)", False),
    ("FLoRAL(1%)", True),
    ("FLoRAL(10%)", False),
    ("FLoRAL(10%)", True),
]
METHOD_FIELD_SIZE = max(len(m) for m, _ in SORTED_METHODS)

SORTED_DATASETS = [
    "mnist_rotate",
    "mnist_label_shift",
    "mnist_rotate_reduced",
    "mnist_label_shift_reduced",
    "cifar10_rotate",
    "cifar10_label_shift",
    "cifar10_rotate_reduced",
    "cifar10_label_shift_reduced",
    "cifar100",
    "cifar100_reduced",
    "shakespeare_top1",
    "shakespeare_top5",
]


def metric_to_latex(mean, std, decimals=DECIMALS, field_size=FIELD_SIZE, marker=None):
    acc_str =  f"{mean:.{decimals}f}" + " {\\tiny " + f"{std:.{decimals}f}" + "}"
    if marker == 1:
        acc_str = "{\\bf " + acc_str + "}"
    elif marker == 2:
        acc_str = "{\\it " + acc_str + "}"
    return f"{acc_str:{field_size}s}"


# Get best and second best methods
def get_topk_methods(df_means, k=2):
    assert k >= 0
    df_means_by_method = df_means.set_index([METHOD_COLNAME, OPTIMAL_ROUTER_COLNAME])
    top_methods = {}
    for i in range(1, k+1):
        top_methods[i] = df_means_by_method.idxmax()
        for col in df_means_by_method.columns:
            df_means_by_method.loc[top_methods[i][col], col] -= df_means_by_method.loc[top_methods[i][col], col]
    return top_methods


def get_marker(method, optimal_router, topk_methods, metric):
    marker = None
    for k, best_method in topk_methods.items():
        if best_method[metric] == (method, optimal_router):
            marker = k
    return marker


from collections import defaultdict
results = defaultdict(dict)
available_datasets = []
for experiment, df in summary_dfs.items():
    dataset = experiment[len("run_methods_"):]
    available_datasets.append(dataset)
    if OPTIMAL_ROUTER_COLNAME not in df.columns:
        df[OPTIMAL_ROUTER_COLNAME] = False
    df_groupedby_seed = df.groupby([METHOD_COLNAME, OPTIMAL_ROUTER_COLNAME])
    df_means = df_groupedby_seed.mean().reset_index()
    df_stds = df_groupedby_seed.std().reset_index()
    topk_methods = get_topk_methods(df_means, k=2)
    for method, optimal_router in SORTED_METHODS:
        mean_row = df_means[(df_means[METHOD_COLNAME] == method) & (df_means[OPTIMAL_ROUTER_COLNAME] == optimal_router)]
        std_row = df_stds[(df_stds[METHOD_COLNAME] == method) & (df_stds[OPTIMAL_ROUTER_COLNAME] == optimal_router)]
        if len(mean_row) == 0 or len(std_row) == 0:
            continue
        mean_row = mean_row.iloc[0]
        std_row = std_row.iloc[0]
        if dataset == "shakespeare":
            # Top-1 accuracy
            metric = "accuracy_top1_distributed"
            mean, std = mean_row[metric], std_row[metric]
            marker = get_marker(method, optimal_router, topk_methods, metric)
            results[(method, optimal_router)]["shakespeare_top1"] = metric_to_latex(mean, std, marker=marker)
            results[(method, True)]["shakespeare_top1"] = EMPTY_FIELD
            # Top-5 accuracy
            metric = "accuracy_top5_distributed"
            mean, std = mean_row[metric], std_row[metric]
            marker = get_marker(method, optimal_router, topk_methods, metric)
            results[(method, optimal_router)]["shakespeare_top5"] = metric_to_latex(mean, std, marker=marker)
            results[(method, True)]["shakespeare_top5"] = EMPTY_FIELD
        elif "synthetic" in dataset:
            metric = "loss_distributed"
            mean, std = mean_row[metric], std_row[metric]
            marker = get_marker(method, optimal_router, topk_methods, metric)
            results[(method, optimal_router)][dataset] = metric_to_latex(mean, std, marker=marker)
        else:
            metric = "acc_distributed"
            mean, std = mean_row[metric], std_row[metric]
            marker = get_marker(method, optimal_router, topk_methods, metric)
            results[(method, optimal_router)][dataset] = metric_to_latex(mean, std, marker=marker)

for method, optimal_router in SORTED_METHODS:
    method_latex = method.replace('%', '\\%')  # because they're treated as comments in latex
    row = [f"{method_latex:{METHOD_FIELD_SIZE}s}", "\\cmark" if optimal_router else "\\xmark"]
    for dataset in SORTED_DATASETS:
        if dataset not in available_datasets:
            continue
        if (method, optimal_router) not in results:
            row.append(metric_to_latex(float('nan'), float('nan')))
        elif dataset not in results[(method, optimal_router)]:
            row.append(metric_to_latex(float('nan'), float('nan')))
        else:
            row.append(results[(method, optimal_router)][dataset])
    print(COL_SEP.join(row) + "\\\\")

FedAvg        & \xmark & 91.5 {\tiny 0.6}          & 25.8 {\tiny 2.4}          & 78.2 {\tiny 0.6}          & 23.2 {\tiny 0.9}          & 64.4 {\tiny 0.3}          & 21.9 {\tiny 0.4}          & 45.6 {\tiny 0.3}          & 18.7 {\tiny 0.4}          & 29.2 {\tiny 1.8}          & 20.7 {\tiny 1.4}         \\
Local Adaptor & \xmark & 86.6 {\tiny 0.3}          & 84.5 {\tiny 1.8}          & 47.4 {\tiny 5.4}          & 32.0 {\tiny 2.3}          & 66.3 {\tiny 0.5}          & 68.8 {\tiny 0.5}          & 33.5 {\tiny 0.5}          & 30.8 {\tiny 0.8}          & 85.1 {\tiny 0.8}          & 39.5 {\tiny 2.8}         \\
Ensemble      & \xmark & 92.0 {\tiny 0.1}          & 93.8 {\tiny 0.5}          & 66.7 {\tiny 5.3}          & 86.4 {\tiny 0.4}          & {\it 71.0 {\tiny 2.8}}    & 46.4 {\tiny 9.2}          & 42.4 {\tiny 0.9}          & 41.7 {\tiny 4.6}          & 86.2 {\tiny 0.0}          & 43.7 {\tiny 3.2}         \\
Ensemble      & \cmark & {\bf 95.8 {\tiny 0.3}}    & {\bf 95.6 {\tiny 0.3}}    & {\bf