In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import seaborn as sns
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx 

In [None]:
from rhmag.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set
)
from rhmag.utils.model_evaluation import reconstruct_model_from_file, get_exp_ids

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

## Gather data:

In [None]:
FINAL_MATERIALS

In [None]:
exp_ids_all_seeds={}
exp_name="pareto-front-f32"
for material_name in FINAL_MATERIALS:
    print("MATERIAL:", material_name)
    mat_ids=sorted(get_exp_ids(material_name=material_name, model_type=None, exp_name=exp_name))
    mat_ids_ja = sorted(get_exp_ids(material_name=material_name, model_type="JA", exp_name="pareto-front-f64"))
    mat_ids = mat_ids_ja + mat_ids
    
    mat_ids_unique = list(set(mat_ids))

    [print("    " + f"'{element}'") for element in mat_ids_unique]
    print()

    exp_ids_all_seeds[material_name]=mat_ids_unique

In [None]:
for material_name, mat_ids_unique in exp_ids_all_seeds.items():
    print(f"Material '{material_name}': {len(mat_ids_unique)} models found.")

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

## Checkout models:
What exactly is going on in the test data?

Why are the models so much worse on 'B' and 'D'?
Are they simply bad at extrapolation?

In [None]:
material_name = "D"
model_type = "JA"
exp_name="pareto-front-f64"

get_exp_ids(material_name=material_name, model_type=model_type, exp_name=exp_name)

In [None]:
# exp_id = 'E_GRUwLinearModel_demonstration_0a8ae335_seed1'
exp_id = 'D_JA_pareto-front-f64_838e8e74_seed1'
model = reconstruct_model_from_file(exp_id)
material_name = exp_id.split("_")[0]
model_type = exp_id.split("_")[1]

In [None]:
test_set = test_data[material_name]
test_set

In [None]:
from rhmag.utils.final_data_evaluation import evaluate_test_scenarios

In [None]:
metrics = evaluate_test_scenarios(
    model, 
    test_set,
    reduce=False,
)

In [None]:
# for scenario_key, scenario_metrics in metrics.items():
#     print(scenario_key)
#     plt.plot(scenario_metrics["sre"])
#     plt.show()

In [None]:
for scenario in test_set.scenarios:
    H_pred = model(
        B_past=scenario.B_past,
        H_past=scenario.H_past,
        B_future=scenario.B_future,
        T=jnp.squeeze(scenario.T),
    )

    B_future = scenario.B_future
    H_future = scenario.H_future
    B_past = scenario.B_past

    start_idx = 0
    n_plots = 5

    for start_idx in np.arange(0, H_pred.shape[0], n_plots):
    
        fig, axs = plt.subplots(3, n_plots, figsize=(12,7))
        for idx in range(n_plots):
            axs[0, idx].plot(B_future[start_idx+idx])
            axs[1, idx].plot(H_future[start_idx+idx])
            axs[1, idx].plot(H_pred[start_idx+idx])
            axs[1, idx].plot(H_future[start_idx+idx] - H_pred[start_idx+idx], color="tab:red", linestyle="--")
        
            axs[2, idx].plot(B_future[start_idx+idx], H_future[start_idx+idx])
            axs[2, idx].plot(B_future[start_idx+idx], H_pred[start_idx+idx])
        
            axs[0, idx].grid(True, alpha=0.3)
            axs[1, idx].grid(True, alpha=0.3)
            axs[2, idx].grid(True, alpha=0.3)
        
            axs[0, idx].set_ylabel("B")
            axs[0, idx].set_xlabel("k")
            axs[1, idx].set_ylabel("H")
            axs[1, idx].set_xlabel("k")
            axs[2, idx].set_ylabel("H")
            axs[2, idx].set_xlabel("B")
        
        fig.tight_layout(pad=-0.2)
        plt.show()