# GRU Pareto Front:

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["JAX_PLATFORMS"]="cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import pickle
import collections
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", False)
import equinox as eqx

#from mc2.runners.model_setup_jax import get_normalizer
from mc2.data_management import MaterialSet, FINAL_MATERIALS, EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT
from mc2.models.jiles_atherton import JAStatic, JAWithGRU
from mc2.models.RNN import GRU
from mc2.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids, evaluate_cross_validation, load_parameterization
from mc2.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL
from IPython.display import display, HTML
from mc2.utils.pretest_evaluation import HOSTS_VALUES_DICT, evaluate_pretest_scenarios, create_multilevel_df, SCENARIO_LABELS

---

In [None]:
FINAL_MATERIALS

In [None]:
exp_ids_no_seed={}
exp_name="final-comp-reduced"
for material_name in FINAL_MATERIALS:
    print("MATERIAL:", material_name)
    mat_ids=[element.rpartition('_')[0] for element in sorted(get_exp_ids(material_name=material_name, model_type=None)) if exp_name in element]
    mat_ids_unique= list(set(mat_ids))
    [print("    " + f"'{element}'") for element in mat_ids_unique]
    #print(len(get_exp_ids(material_name=material_name, model_type=None)))
    print()
    exp_ids_no_seed[material_name]=mat_ids_unique

#sorted(get_exp_ids(material_name="E", model_type=None))

## Compute metrics for trained models

In [None]:
from mc2.utils.final_data_evaluation import generate_metrics_from_exp_ids_without_seed

In [None]:
dfs={}
for material_name in FINAL_MATERIALS:
    if exp_ids_no_seed[material_name]:
        dfs[material_name], _ = generate_metrics_from_exp_ids_without_seed(
            exp_ids_without_seed=exp_ids_no_seed[material_name],
            material_name=material_name,
            loader_key=jax.random.PRNGKey(99),
        )
        # file_path_pickle = f'results__dump_material_reduced_{material_name}.pkl'
        # with open(file_path_pickle, 'wb') as f:
        #     pickle.dump(dfs[material_name], f)
    else:
        print(f"Skipping Material {material_name} due to missing exp_ids.")
df_all= pd.concat([dfs["A"],dfs["B"],dfs["C"],dfs["D"],dfs["E"]], ignore_index=True)
#df_all.to_pickle('results_pareto_front.pkl')

## Plotting

In [None]:
scenarios_per_material={"A":list(FINAL_SCENARIOS_PER_MATERIAL["A"].keys()),
                        "B":list(FINAL_SCENARIOS_PER_MATERIAL["B"].keys()),
                        "C":list(FINAL_SCENARIOS_PER_MATERIAL["C"].keys()),
                        "D":list(FINAL_SCENARIOS_PER_MATERIAL["D"].keys()),
                        "E":list(FINAL_SCENARIOS_PER_MATERIAL["E"].keys())}

df= pd.read_pickle('results_pareto_front.pkl')

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

def visualize_pareto_front_all_mats(df, scenarios_per_materials, metrics, scale_log=False, scale_log_x=False, sharex='col', sharey='row', chosen_param_size=-1):
    materials = sorted(df["material"].unique())
    num_rows = len(metrics) #len(scenarios_per_materials[materials[0]]) *
    num_cols = len(materials)
    sns.set_theme(style="whitegrid")
    
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, sharex=sharex,sharey=sharey,
                            figsize=(3.5 * num_cols, 4 * num_rows), squeeze=False)

    palette = sns.color_palette("deep")
    c_avg, c_95th = palette[0], palette[1]

    for mat_idx, material_name in enumerate(materials):
        scenarios = scenarios_per_materials[material_name]
        mat_df = df[df["material"] == material_name]
        for m_idx, metric in enumerate(metrics):
            row_idx = m_idx
            ax = axs[row_idx, mat_idx]
            if row_idx == 0: ax.set_title(material_name, fontweight='bold',fontsize=14)
            visualize_single_pareto_front_sns(
                mat_df, ax, scenarios, metric, 
                colors=(c_avg, c_95th),
                show_x_label=(row_idx == num_rows - 1),
                show_y_label=(mat_idx == 0),
                scale_log=scale_log, scale_log_x=scale_log_x,
                chosen_param_size=chosen_param_size
            )

    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Average', markerfacecolor=c_avg, markersize=10, linestyle='None'),
        Line2D([0], [0], marker='^', color='w', label='95th Quantile', markerfacecolor=c_95th, markersize=10, linestyle='None')
    ]
    fig.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, frameon=True, shadow=False,fontsize=14)
    fig.tight_layout()
    return fig, axs

def visualize_single_pareto_front_sns(df, ax, scenarios, metric, colors, show_x_label, show_y_label, scale_log, scale_log_x, chosen_param_size=-1):
    #df_plot = df[df["scenario"] == scenario].copy()
    df_plot = df[df["scenario"].isin(scenarios)].copy()
    df_agg = df_plot.groupby(["n_params", "exp_id_without_seed"])[[f"{metric}_avg", f"{metric}_95th"]].mean().reset_index()
    df_melted = df_agg.melt(
        id_vars=["n_params", "exp_id_without_seed"],
        value_vars=[f"{metric}_avg", f"{metric}_95th"],
        var_name="Type", value_name="Value"
    )
    for label, m, c in zip([f"{metric}_avg", f"{metric}_95th"], ["o", "^"], colors):
        
        sns.stripplot(
            data=df_melted[(df_melted["Type"] == label) & (df_melted["n_params"] != chosen_param_size)],
            x="n_params", y="Value",
            color=c, marker=m, alpha=1,
            size=8,
            jitter=0.05, 
            ax=ax, native_scale=True 
        )
        mask_highlight = (df_melted["Type"] == label) & (df_melted["n_params"] == chosen_param_size)
        highlight_df = df_melted[mask_highlight]
        if not highlight_df.empty:
            sns.stripplot(
            data=highlight_df,
            x="n_params", y="Value",
            edgecolor="lime",
            linewidth=1.5,
            color=c, marker=m, alpha=1,
            size=8,
            jitter=0.05, 
            ax=ax, native_scale=True 
        )
    
    if scale_log: 
        ax.set_yscale("log")
    if scale_log_x: 
        ax.set_xscale("log")
    unique_params = sorted(df_plot["n_params"].unique())
    ax.set_xticks(unique_params)
    ax.set_xticklabels(unique_params, rotation=45, ha='right', fontsize=12)
    if show_y_label:
        # label_map = {"sre": "Sequence relative error", "nere": "Normalized energy relative error"}
        # ax.set_ylabel(label_map.get(metric, metric),fontsize=14)
        label_map = {"sre": "Mean Seq. Rel. Error", "nere": "Mean Norm. Energy Rel. Error"}
        ax.set_ylabel(label_map.get(metric, metric), fontsize=14)
    else:
        ax.set_ylabel("",fontsize=14)

    if show_x_label:
        ax.set_xlabel("Model size",fontsize=14)
    else:
        ax.set_xlabel("")

In [None]:
fig,axs=visualize_pareto_front_all_mats(df,scenarios_per_material,metrics=["sre","nere"],scale_log=True,scale_log_x=True,sharey='row', sharex=True,chosen_param_size=325)