In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import seaborn as sns
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx 

In [None]:
from rhmag.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set
)
from rhmag.utils.model_evaluation import reconstruct_model_from_file, get_exp_ids

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

---

## Gather data:

In [None]:
FINAL_MATERIALS

In [None]:
exp_ids_all_seeds={}
exp_name="pareto-front-f32"
for material_name in FINAL_MATERIALS:
    print("MATERIAL:", material_name)
    mat_ids=sorted(get_exp_ids(material_name=material_name, model_type=None, exp_name=exp_name))
    mat_ids_ja = sorted(get_exp_ids(material_name=material_name, model_type="JA", exp_name="pareto-front-f64"))
    mat_ids = mat_ids_ja + mat_ids
    
    mat_ids_unique = list(set(mat_ids))

    [print("    " + f"'{element}'") for element in mat_ids_unique]
    print()

    exp_ids_all_seeds[material_name]=mat_ids_unique

In [None]:
for material_name, mat_ids_unique in exp_ids_all_seeds.items():
    print(f"Material '{material_name}': {len(mat_ids_unique)} models found.")

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

## Update df:

In [None]:
from rhmag.utils.final_data_evaluation import update_pareto_df

In [None]:
df_results = update_pareto_df("pareto_results.parquet", exp_ids_all_seeds, test_data)
df_results.to_parquet("pareto_results.parquet")

In [None]:
df_results

## Other teams:

In [None]:
from rhmag.utils.provided_final_results import EXTERNAL_TEAMS_AVG, EXTERNAL_TEAMS_PER_MATERIAL

In [None]:
df_external = pd.DataFrame(EXTERNAL_TEAMS_AVG)
df_external

## Visualize pareto from across models:

In [None]:
from rhmag.utils.pareto_plot_functions import visualize_pareto_cross_model

In [None]:
dfs_per_material = [x for _, x in df_results.groupby("material")]
for material_name, df in zip(FINAL_MATERIALS, dfs_per_material):
    assert (df.material == material_name).all()
dfs_per_material = {material_name: df for material_name, df in zip(FINAL_MATERIALS, dfs_per_material)}

external_df = pd.DataFrame(EXTERNAL_TEAMS_PER_MATERIAL)
external_df_per_material = [x for _, x in external_df.groupby("material")]
external_df_per_material = {material_name: df for material_name, df in zip(FINAL_MATERIALS, external_df_per_material)}

In [None]:
xlim_per_material = {
    "A": ((0.1, 0.80), (0.007, 0.22)),
    "B": ((0.05, 0.4), (0.008, 0.22)),
    "C": ((0.11, 1.5), (0.03, 0.35)),
    "D": ((0.05, 0.65), (0.015, 0.36)),
    "E": ((0.08, 0.45), (0.005, 0.12)),
}

for material_name, df_material in dfs_per_material.items():

    
    fig, axs = visualize_pareto_cross_model(
        df_material,
        external_df_per_material[material_name],
        metrics=["sre","nere"],
        colors={
            "GRU": "tab:blue",
            "LSTM": "tab:orange",
            "JA": "tab:purple",
        },
        color_others="black",
        sharex="col",
        sharey="row",
        xlim=None,
        show_median=True,
    )
    fig.suptitle(f"Pareto investigation for material '{material_name}'")
    plt.tight_layout()
    plt.savefig(f"pareto_investigation_material_each_trial_{material_name}.png", dpi=250, facecolor='white', transparent=False)

In [None]:
fig, axs = visualize_pareto_cross_model(
    df_results,
    df_external,
    metrics=["sre","nere"],
    colors={
        "GRU": "tab:blue",
        "LSTM": "tab:orange",
        "JA": "tab:purple",
    },
    color_others="black",
    sharex="col",
    sharey="row",
    xlim=None,
    show_median=True,
)