# GRU Pareto Front:

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["JAX_PLATFORMS"]="cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import pickle
import collections
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", False)
import equinox as eqx

#from rhmag.runners.model_setup_jax import get_normalizer
from rhmag.data_management import MaterialSet, FINAL_MATERIALS, EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT
from rhmag.models.jiles_atherton import JAStatic, JAWithGRU
from rhmag.models.RNN import GRU
from rhmag.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from rhmag.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids, evaluate_cross_validation, load_parameterization
from rhmag.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL
from IPython.display import display, HTML
from rhmag.utils.pretest_evaluation import HOSTS_VALUES_DICT, evaluate_pretest_scenarios, create_multilevel_df, SCENARIO_LABELS

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

---

In [None]:
FINAL_MATERIALS

## Plotting

In [None]:
scenarios_per_material={"A":list(FINAL_SCENARIOS_PER_MATERIAL["A"].keys()),
                        "B":list(FINAL_SCENARIOS_PER_MATERIAL["B"].keys()),
                        "C":list(FINAL_SCENARIOS_PER_MATERIAL["C"].keys()),
                        "D":list(FINAL_SCENARIOS_PER_MATERIAL["D"].keys()),
                        "E":list(FINAL_SCENARIOS_PER_MATERIAL["E"].keys())}

df= pd.read_pickle('results_pareto_front.pkl')

In [None]:
df

In [None]:
subset_df = df.loc[df["material"]=="D"]
subset_df = subset_df.loc[df["model_type"] == "GRU6"]
subset_df = subset_df.loc[df["scenario"]=="50%known_50%unknown"]

print(subset_df.sre_avg)

plt.plot(subset_df.sre_avg)

In [None]:
subset_df = df.loc[df["material"]=="D"]
subset_df = subset_df.loc[df["model_type"] == "GRU8"]
subset_df = subset_df.loc[df["scenario"]=="50%known_50%unknown"]

print(subset_df.sre_avg)

plt.plot(subset_df.sre_avg)

In [None]:
subset_df = df.loc[df["material"]=="D"]
subset_df = subset_df.loc[df["model_type"] == "GRU10"]
subset_df = subset_df.loc[df["scenario"]=="50%known_50%unknown"]

print(subset_df.sre_avg)

plt.plot(subset_df.sre_avg)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D

def visualize_pareto_front_all_mats(df, scenarios_per_materials, metrics, scale_log=False, scale_log_x=False, sharex='col', sharey='row', chosen_param_size=-1):
    materials = sorted(df["material"].unique())
    num_rows = len(metrics) #len(scenarios_per_materials[materials[0]]) *
    num_cols = len(materials)
    colors = plt.rcParams["axes.prop_cycle"]()
    
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, sharex=sharex,sharey=sharey,
                            figsize=(7.167, 7.167 / 2), squeeze=False)

    c_avg = next(colors)["color"]
    c_95th = next(colors)["color"]

    for mat_idx, material_name in enumerate(materials):
        scenarios = scenarios_per_materials[material_name]
        mat_df = df[df["material"] == material_name]
        for m_idx, metric in enumerate(metrics):
            row_idx = m_idx
            ax = axs[row_idx, mat_idx]
            if row_idx == 0: ax.set_title(material_name, fontweight='bold',fontsize=10)
            visualize_single_pareto_front_sns(
                mat_df, ax, scenarios, metric, 
                colors=(c_avg, c_95th),
                show_x_label=(row_idx == num_rows - 1),
                show_y_label=(mat_idx == 0),
                scale_log=scale_log, scale_log_x=scale_log_x,
                chosen_param_size=chosen_param_size
            )

    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Avg', markerfacecolor=c_avg, markersize=9, linestyle='None'),
        Line2D([0], [0], marker='^', color='w', label='95-th percentile', markerfacecolor=c_95th, markersize=9, linestyle='None')
    ]
    fig.legend(
        handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.0), ncol=2, frameon=True, shadow=False, fontsize=9)
    fig.tight_layout(w_pad=0.5, h_pad=0.6)
    return fig, axs


def visualize_single_pareto_front_sns(df, ax, scenarios, metric, colors, show_x_label, show_y_label, scale_log, scale_log_x, chosen_param_size=-1):
    #df_plot = df[df["scenario"] == scenario].copy()
    df_plot = df[df["scenario"].isin(scenarios)].copy()
    df_agg = df_plot.groupby(["n_params", "material"])[[f"{metric}_avg", f"{metric}_95th"]].mean().reset_index()
    #df_agg = df_plot.groupby(["n_params", "exp_id_without_seed"])[[f"{metric}_avg", f"{metric}_95th"]].mean().reset_index()
    df_melted = df_agg.melt(
        id_vars=["n_params", "material"],
        value_vars=[f"{metric}_avg", f"{metric}_95th"],
        var_name="Type", value_name="Value"
    )
    for label, m, c in zip([f"{metric}_avg", f"{metric}_95th"], ["o", "^"], colors):
        
        sns.stripplot(
            data=df_melted[(df_melted["Type"] == label) & (df_melted["n_params"] != chosen_param_size)],
            x="n_params", y="Value",
            color=c, marker=m, alpha=1,
            size=5,
            jitter=0.05, 
            ax=ax, native_scale=True 
        )
        mask_highlight = (df_melted["Type"] == label) & (df_melted["n_params"] == chosen_param_size)
        highlight_df = df_melted[mask_highlight]
        if not highlight_df.empty:
            sns.stripplot(
            data=highlight_df,
            x="n_params", y="Value",
            edgecolor="lime",
            linewidth=1.5,
            color=c, marker=m, alpha=1,
            size=5,
            jitter=0.05, 
            ax=ax, native_scale=True 
        )
    
    if scale_log: 
        ax.set_yscale("log")
    if scale_log_x: 
        ax.set_xscale("log")
    unique_params = sorted(df_plot["n_params"].unique())
    ax.set_xticks(unique_params)
    ax.set_xticklabels(unique_params, rotation=90, fontsize=9) #ha='right')
    ax.tick_params(which='major', axis="y", direction='in')
    ax.tick_params(which='major', axis="x", direction='in')
    ax.minorticks_off()

    if show_y_label:
        # label_map = {"sre": "Sequence relative error", "nere": "Normalized energy relative error"}
        # ax.set_ylabel(label_map.get(metric, metric))
        label_map = {"sre": "SRE", "nere": "NERE"}
        ax.set_ylabel(label_map.get(metric, metric), fontsize=9)
    else:
        ax.set_ylabel("")

    if show_x_label:
        ax.set_xlabel("\# Model params.", fontsize=9)
    else:
        ax.set_xlabel("")

    ax.grid(True, alpha=0.3, which="major")

fig, axs=visualize_pareto_front_all_mats(
    df,
    scenarios_per_material,
    metrics=["sre","nere"],
    scale_log=True,
    scale_log_x=True,
    sharey='row',
    sharex=True,
    chosen_param_size=325,
)
plt.savefig("pareto_front_mc2_submission.pdf", bbox_inches="tight")