In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx 

In [None]:
from rhmag.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set
)
from rhmag.utils.model_evaluation import reconstruct_model_from_file, get_exp_ids

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

---

## Gather data:

In [None]:
FINAL_MATERIALS

In [None]:
exp_ids_all_seeds={}
exp_name="pareto-front-f32"
for material_name in FINAL_MATERIALS:
    print("MATERIAL:", material_name)
    mat_ids=sorted(get_exp_ids(material_name=material_name, model_type=None, exp_name=exp_name))
    mat_ids_unique = list(set(mat_ids))

    [print("    " + f"'{element}'") for element in mat_ids_unique]
    print()

    exp_ids_all_seeds[material_name]=mat_ids_unique

In [None]:
for material_name, mat_ids_unique in exp_ids_all_seeds.items():
    exp_ids_all_seeds[material_name] = [mat_id for mat_id in mat_ids_unique if "GRU" in mat_id]

In [None]:
for material_name, mat_ids_unique in exp_ids_all_seeds.items():
    print(f"Material '{material_name}': {len(mat_ids_unique)} models found.")

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

In [None]:
from rhmag.utils.pretest_evaluation import create_multilevel_df
from rhmag.utils.final_data_evaluation import evaluate_test_scenarios

In [None]:
# metrics_per_material = {}
all_results = []
for material_name, exp_ids in exp_ids_all_seeds.items():
    test_set = test_data[material_name]
    for exp_id in exp_ids:
        model = reconstruct_model_from_file(exp_id)
        model_params = model.n_params
        seed = exp_id.split("seed")[-1]
        model_type = exp_id.split("_")[1]
        exp_name = exp_id.split("_")[2]
        num_id = exp_id.split("_")[-2]
        metrics_per_sequence = evaluate_test_scenarios(model, test_set)
        metrics = ['sre_avg', 'sre_95th', 'nere_avg', 'nere_95th']
        averages = {m: sum(d[m] for d in metrics_per_sequence.values()) / len(metrics_per_sequence) for m in metrics}
        all_results.append(
                {
                    "exp_id_without_seed": exp_id.rpartition('_')[0],
                    "exp_id": exp_id,
                    "exp_name": exp_name,
                    "num_id": num_id,
                    "material": material_name,
                    "model_type": model_type,
                    "seed": seed,
                    "n_params": model_params,
                    "sre_avg": averages["sre_avg"],
                    "sre_95th": averages["sre_95th"],
                    "nere_avg": averages["nere_avg"],
                    "nere_95th": averages["nere_95th"],
                }
            )
df_results = pd.DataFrame(all_results)

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    sorted_indices = np.argsort(df_results.n_params)
    sorted_df_results = df_results.iloc[sorted_indices]
    display(sorted_df_results)

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    sorted_indices = np.argsort(df_results.n_params)
    sorted_df_results = df_results.iloc[sorted_indices]

    A_sorted_df_results = sorted_df_results.loc[sorted_df_results["material"] == "A"]
    
    display(A_sorted_df_results)

In [None]:
# df_best_seeds = df_results.loc[df_results.groupby('exp_id_without_seed')['sre_95th'].idxmin()]

# df_best_seeds = df_best_seeds.reset_index(drop=True)
# df_model_comparison_best = df_best_seeds.groupby("model_type").mean(numeric_only=True).reset_index()

# df_results_spec=df_results[(df_results["n_params"] < 500) & (df_results["material"] != "A")]
# df_model_comparison_spec = df_results_spec.groupby("model_type").mean(numeric_only=True).reset_index()

In [None]:
df_model_comparison = df_results.groupby("model_type").mean(numeric_only=True).reset_index()

In [None]:
df_model_comparison

### Other Teams

In [None]:
from rhmag.utils.provided_final_results import EXTERNAL_TEAMS_AVG, EXTERNAL_TEAMS_PER_MATERIAL

In [None]:
df_external = pd.DataFrame(EXTERNAL_TEAMS_AVG)

In [None]:
df_combined = pd.concat([df_model_comparison, df_external], ignore_index=True)

In [None]:
df_combined

## Plot:

In [None]:
import seaborn as sns

In [None]:
def visualize_pareto_final(
    df,
    metrics,
    color_own="blue",
    color_others="gray",
    scale_log_metric=True,
    scale_log_size=True,
    highlighted_type="GRU",
    sharex="col",
    sharey="row",
    xlim=None,
    line_plot=False,
    show_median=False,
):
    df = df.copy()
    df["IsOwn"] = df["model_type"].apply(lambda x: highlighted_type in str(x))
    
    fig, axs = plt.subplots(nrows=1, ncols=len(metrics), sharex=sharex, sharey=sharey, figsize=(7.167, 7.167 / 2), squeeze=False)
    
    for i, metric in enumerate(metrics):
        ax = axs[0, i]
        target_col = f"{metric}_95th"

        if line_plot:
            for is_own, group_df in df.groupby("IsOwn"):
                current_color = color_own if is_own else color_others
                if not is_own:
                    current_color = color_own if is_own else color_others             
                    sns.scatterplot(
                        data=group_df,
                        x=target_col, 
                        y="n_params",
                        color=current_color,
                        marker="o" if is_own else "s", 
                        s=20,
                        alpha=0.8,
                        ax=ax,
                        zorder=3
                    )
                    for _, row in group_df.iterrows():
                        ax.text(
                            row[target_col] * 1.05, 
                            row["n_params"], 
                            str(row["model_type"]),
                            fontsize=10,
                            alpha=1,
                            va='center', 
                            color=current_color,
                            fontweight='normal'
                        )
                if is_own:
                    sorted_df = group_df.sort_values("n_params").reset_index()                    
                    sns.lineplot(
                        data=sorted_df,
                        x=target_col, 
                        y="n_params",
                        color = current_color,
                        alpha=0.8,
                        ax=ax,
                        sort=False,
                    )
        else:
            for is_own, group_df in df.groupby("IsOwn"):
                current_color = color_own if is_own else color_others             
                sns.scatterplot(
                    data=group_df,
                    x=target_col, 
                    y="n_params",
                    color=current_color,
                    marker="o" if is_own else "s", 
                    s=20,
                    alpha=0.8 if not (show_median and is_own) else 0.3,
                    ax=ax,
                    zorder=3,
                )

                if is_own and show_median:
                    med_df = group_df.groupby("model_type").median(numeric_only=True)
                    sorted_indices = np.argsort(med_df.n_params)
                    med_df = med_df.iloc[sorted_indices]
                

                    median = med_df[f"{metric}_95th"]
                    ax.plot(median, med_df["n_params"], c=current_color, alpha=0.8)
                
                if not is_own:
                    for _, row in group_df.iterrows():
                        ax.text(
                            row[target_col] * 1.05, 
                            row["n_params"], 
                            str(row["model_type"]),
                            fontsize=10,
                            alpha=1,
                            va='center', 
                            color=current_color,
                            fontweight='normal'
                        )

        if scale_log_metric:
            ax.set_xscale("log")
        if scale_log_size:
            ax.set_yscale("log")
            
        unique_params = sorted(df["n_params"].unique().astype(int))
        ax.set_yticks(unique_params)
        ax.yaxis.set_major_formatter(plt.ScalarFormatter())
        ax.xaxis.set_major_formatter(plt.ScalarFormatter())
        
        ax.set_yticklabels(unique_params, rotation=0, fontsize=9) #ha='right')
        ax.tick_params(which='major', axis="y", direction='in')
        ax.tick_params(which='both', axis="x", direction='in')
        ax.yaxis.minorticks_off()
        ax.xaxis.minorticks_off()

        if i == 0:
            #ax.set_xticks([0.2, 0.3, 0.4, 0.5, 0.6], minor=True)
            #ax.set_xticks([0.1], minor=False)
    
            ax.xaxis.set_minor_formatter(plt.ScalarFormatter())
            if xlim is not None:
                #ax.set_xlim(0.08[0][0], 0.75)
                ax.set_xlim(xlim[0][0], xlim[0][1])

        if i == 1:
            #ax.set_xticks([0.02, 0.05, 0.2], minor=True)
            #ax.set_xticks([0.1], minor=False)
    
            ax.xaxis.set_minor_formatter(plt.ScalarFormatter())
            if xlim is not None:
                #ax.set_xlim(0.013, 0.22)
                ax.set_xlim(xlim[1][0], xlim[1][1])

        if i == 0:
            ax.set_ylabel("\# Model params.", fontsize=9)
        
        label_map = {"sre": "95-th percentile SRE", "nere": "95-th percentile NERE"}
        ax.set_xlabel(label_map.get(metric, metric), fontsize=9)
        
        ax.grid(True, which="both", ls="--", alpha=0.3)

    plt.tight_layout()
    return fig, axs

In [None]:
fig, axs=visualize_pareto_final(
    df_combined,
    ["sre","nere"],
    color_own="orange",
    color_others="black",
    scale_log_metric=True,
    scale_log_size=True,
    highlighted_type="LSTM", # "GRU",
    sharex="col",
    sharey="row",
    xlim=None,#((0.08, 0.75), (0.013, 0.22)),
    line_plot=False,
    show_median=True
)
# plt.savefig("pareto_comparison.pdf", bbox_inches="tight")
plt.savefig("median.png", bbox_inches="tight", dpi=250)

### plot per material:

In [None]:
dfs_per_material = [x for _, x in df_results.groupby("material")]
for material_name, df in zip(FINAL_MATERIALS, dfs_per_material):
    assert (df.material == material_name).all()
dfs_per_material = {material_name: df for material_name, df in zip(FINAL_MATERIALS, dfs_per_material)}

In [None]:
external_df = pd.DataFrame(EXTERNAL_TEAMS_PER_MATERIAL)
external_df_per_material = [x for _, x in external_df.groupby("material")]
external_df_per_material = {material_name: df for material_name, df in zip(FINAL_MATERIALS, external_df_per_material)}

In [None]:
# dfs_per_material = [x for _, x in df_results.groupby("material")]
# dfs_per_material = {material_name: df for material_name, df in zip(["A", "B", "E"], dfs_per_material)}

In [None]:
xlim_per_material = {
    "A": ((0.1, 0.80), (0.007, 0.22)),
    "B": ((0.05, 0.4), (0.008, 0.22)),
    "C": ((0.11, 1.5), (0.03, 0.35)),
    "D": ((0.05, 0.65), (0.015, 0.36)),
    "E": ((0.08, 0.45), (0.005, 0.12)),
}

for material_name, df in dfs_per_material.items():

    avg_per_model_type = df.groupby("model_type").mean(numeric_only=True).reset_index()
    avg_per_model_type=df
    
    fig, axs = visualize_pareto_final(
        pd.concat([avg_per_model_type, external_df_per_material[material_name]], ignore_index=True),
        ["sre","nere"],
        color_own="orange",
        color_others="black",
        scale_log_metric=True,
        scale_log_size=True,
        highlighted_type="GRU", #"GRU",
        sharex="col",
        sharey="row",
        xlim=xlim_per_material[material_name],
        line_plot=False,
        show_median=True,
    )
    fig.suptitle(f"Pareto investigation for material '{material_name}'")
    plt.tight_layout()
    plt.savefig(f"pareto_investigation_material_each_trial_{material_name}.png", dpi=250, facecolor='white', transparent=False)

## mean +- var plots:

In [None]:
def visualize_mean_var_pareto(
    df,
    metrics,
    color_own="blue",
    color_others="gray",
    scale_log_metric=True,
    scale_log_size=True,
    highlighted_type="GRU",
    sharex="col",
    sharey="row",
    xlim=None,
    line_plot=False,
):
    df = df.copy()
    df["IsOwn"] = df["model_type"].apply(lambda x: highlighted_type in str(x))
    
    fig, axs = plt.subplots(nrows=1, ncols=len(metrics), sharex=sharex, sharey=sharey, figsize=(7.167, 7.167 / 2), squeeze=False)
    
    for i, metric in enumerate(metrics):
        ax = axs[0, i]
        target_col = f"{metric}_95th"

        for is_own, group_df in df.groupby("IsOwn"):
            current_color = color_own if is_own else color_others
            if not is_own:
                current_color = color_own if is_own else color_others             
                sns.scatterplot(
                    data=group_df,
                    x=target_col, 
                    y="n_params",
                    color=current_color,
                    marker="o" if is_own else "s", 
                    s=20,
                    alpha=0.8,
                    ax=ax,
                    zorder=3
                )
                for _, row in group_df.iterrows():
                    ax.text(
                        row[target_col] * 1.05, 
                        row["n_params"], 
                        str(row["model_type"]),
                        fontsize=10,
                        alpha=1,
                        va='center', 
                        color=current_color,
                        fontweight='normal'
                    )
            if is_own:
                
                sorted_df = group_df.sort_values("n_params").reset_index()
                
                avg_df = sorted_df.groupby("model_type").mean(numeric_only=True)
                std_df = sorted_df.groupby("model_type").std(numeric_only=True)

                sorted_indices = np.argsort(avg_df.n_params)

                avg_df = avg_df.iloc[sorted_indices]
                std_df = std_df.iloc[sorted_indices]

                mean = avg_df[f"{metric}_95th"]
                std = std_df[f"{metric}_95th"]
                                
                ax.plot(mean, avg_df["n_params"], c=current_color, alpha=0.8)
                ax.fill_betweenx(avg_df["n_params"], mean - std, mean + std, color=current_color, alpha=0.1)
                
                                    
                # sns.lineplot(
                #     data=sorted_df,
                #     x=target_col, 
                #     y="n_params",
                #     color = current_color,
                #     alpha=0.8,
                #     ax=ax,
                #     sort=False,
                # )

        if scale_log_metric:
            ax.set_xscale("log")
        if scale_log_size:
            ax.set_yscale("log")
            
        unique_params = sorted(df["n_params"].unique().astype(int))
        ax.set_yticks(unique_params)
        ax.yaxis.set_major_formatter(plt.ScalarFormatter())
        ax.xaxis.set_major_formatter(plt.ScalarFormatter())
        
        ax.set_yticklabels(unique_params, rotation=0, fontsize=9) #ha='right')
        ax.tick_params(which='major', axis="y", direction='in')
        ax.tick_params(which='both', axis="x", direction='in')
        ax.yaxis.minorticks_off()
        ax.xaxis.minorticks_off()

        if i == 0:
            #ax.set_xticks([0.2, 0.3, 0.4, 0.5, 0.6], minor=True)
            #ax.set_xticks([0.1], minor=False)
    
            ax.xaxis.set_minor_formatter(plt.ScalarFormatter())
            if xlim is not None:
                #ax.set_xlim(0.08[0][0], 0.75)
                ax.set_xlim(xlim[0][0], xlim[0][1])

        if i == 1:
            #ax.set_xticks([0.02, 0.05, 0.2], minor=True)
            #ax.set_xticks([0.1], minor=False)
    
            ax.xaxis.set_minor_formatter(plt.ScalarFormatter())
            if xlim is not None:
                #ax.set_xlim(0.013, 0.22)
                ax.set_xlim(xlim[1][0], xlim[1][1])

        if i == 0:
            ax.set_ylabel("\# Model params.", fontsize=9)
        
        label_map = {"sre": "95-th percentile SRE", "nere": "95-th percentile NERE"}
        ax.set_xlabel(label_map.get(metric, metric), fontsize=9)
        
        ax.grid(True, which="both", ls="--", alpha=0.3)

    plt.tight_layout()
    return fig, axs

In [None]:
xlim_per_material = {
    "A": ((0.1, 0.80), (0.007, 0.22)),
    "B": ((0.05, 0.4), (0.008, 0.22)),
    "C": ((0.11, 1.5), (0.03, 0.35)),
    "D": ((0.05, 0.65), (0.015, 0.36)),
    "E": ((0.08, 0.45), (0.005, 0.12)),
}

for material_name, df in dfs_per_material.items():

    # avg_per_model_type = df.groupby("model_type").mean(numeric_only=True).reset_index()
    
    fig, axs = visualize_mean_var_pareto(
        pd.concat([df, external_df_per_material[material_name]], ignore_index=True),
        ["sre","nere"],
        color_own="orange",
        color_others="black",
        scale_log_metric=True,
        scale_log_size=True,
        highlighted_type="GRU",
        sharex="col",
        sharey="row",
        xlim=xlim_per_material[material_name],
        line_plot=True,
    )
    fig.suptitle(f"Pareto investigation for material '{material_name}'")
    plt.tight_layout()
    plt.savefig(f"pareto_investigation_material_{material_name}.png", dpi=250, facecolor='white', transparent=False)

In [None]:
fig, axs = visualize_mean_var_pareto(
    pd.concat([df_results, df_external], ignore_index=True),
    ["sre","nere"],
    color_own="orange",
    color_others="black",
    scale_log_metric=True,
    scale_log_size=True,
    highlighted_type="GRU",
    sharex="col",
    sharey="row",
    xlim=((0.05, 1.2),(0.005,0.36)),
    line_plot=True,
)
fig.suptitle(f"Pareto investigation averaged for all materials")
plt.tight_layout()
plt.savefig("pareto_investigation_material_average.png", dpi=250, facecolor='white', transparent=False)