In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
import equinox as eqx

from mc2.utils.model_evaluation import reconstruct_model_from_exp_id
from mc2.model_interfaces.model_interface import count_model_parameters
from mc2.utils.pretest_evaluation import produce_pretest_histograms, SCENARIO_LABELS, DETAILED_SCENARIO_LABELS, store_pretest_results_to_csv, load_hdf5_pretest_data

In [None]:
exp_ids = [
    '3C90_GRU_97c4047f-c2d8-48',
    'N87_GRU_8ba07f4f-c59a-42',
    '3C94_GRU_b7cf990c-33b5-49',
    '3E6_GRU_e45054a0-3df2-4c',
    '3F4_GRU_6d364e15-88db-46',
    '77_GRU_db53aa04-5f27-43',
    '78_GRU_69379b64-4b79-42',
    'N27_GRU_41070eb0-0850-4f',
    'N30_GRU_9eca1a85-eaec-4a',
    'N49_GRU_9d34af17-f5bc-46',
]
wrapped_models = {exp_id.split("_")[0]: reconstruct_model_from_exp_id(exp_id) for exp_id in exp_ids}

print("n_models:", len(wrapped_models.values()))
for model in wrapped_models.values():
    print("parameters_per_model:", model.n_params)

In [None]:
for exp_id, (material_name, wrapped_model) in zip(exp_ids, wrapped_models.items()):
    seed = 0

    assert exp_id.split("_")[0] == material_name

    # wrapped_model = reconstruct_model_from_exp_id(exp_id)
    print("Number of parameters:", wrapped_model.n_params)
    
    B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
    
    fig, axs = produce_pretest_histograms(
        material_name,
        wrapped_model,
        B,
        T,
        H_init,
        H_true,
        loss,
        msks_scenarios_N_tup,
        scenario_labels=SCENARIO_LABELS,
        adapted_scenario_labels=DETAILED_SCENARIO_LABELS,
        show_plots=False,
    );
    plt.savefig(f"histograms/{exp_id}_preevaluation_results.pdf", bbox_inches="tight")

    store_pretest_results_to_csv(
        pathlib.Path("csvs/"),
        exp_id,
        wrapped_model,
        B,
        T,
        H_init,
        H_true,
        loss,
        msks_scenarios_N_tup,
        scenario_labels=SCENARIO_LABELS,
    )
    
    

## Assert proper form of the csvs:

In [None]:
import jax.numpy as jnp

from mc2.metrics import sre, nere

In [None]:
for exp_id in exp_ids:
    predictions = pd.read_csv(f"csvs/{exp_id}_pred.csv", header=None)

    material_name = exp_id.split("_")[0]
    print(material_name)

    if material_name == "N49":
        # display(predictions)
        predictions = jnp.array(predictions)
        assert predictions.shape == (4500, 1000)
    
        B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
        
        assert jnp.all(predictions[:1500, :100] == H_init[:1500, :100])
        assert jnp.all(predictions[1500:3000, :500] == H_init[1500:3000, :500])
        assert jnp.all(predictions[3000:, :900] == H_init[3000:, :900])
    
        # reproduce metric values:
    
        print("90% unknown")
    
        preds = predictions[:1500, 100:]
        H_gt = H_true[:1500, 100:]
        B_scenario = B[:1500, :]
        warm_up_len = 100
        true_core_loss = jnp.squeeze(loss[:1500])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)
    
        print("50% unknown")
    
        preds = predictions[1500:3000, 500:]
        H_gt = H_true[1500:3000, 500:]
        B_scenario = B[1500:3000, :]
        warm_up_len = 500
        true_core_loss = jnp.squeeze(loss[1500:3000])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)
    
        print("10% unknown")
    
        preds = predictions[3000:, 900:]
        H_gt = H_true[3000:, 900:]
        B_scenario = B[3000:, :]
        warm_up_len = 900
        true_core_loss = jnp.squeeze(loss[3000:])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)
    else:
        # display(predictions)
        predictions = jnp.array(predictions)
        assert predictions.shape == (6300, 1000)
    
        B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
        
        assert jnp.all(predictions[:2100, :100] == H_init[:2100, :100])
        assert jnp.all(predictions[2100:4200, :500] == H_init[2100:4200, :500])
        assert jnp.all(predictions[4200:, :900] == H_init[4200:, :900])
    
        # reproduce metric values:
    
        print("90% unknown")
    
        preds = predictions[:2100, 100:]
        H_gt = H_true[:2100, 100:]
        B_scenario = B[:2100, :]
        warm_up_len = 100
        true_core_loss = jnp.squeeze(loss[:2100])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)
    
        print("50% unknown")
    
        preds = predictions[2100:4200, 500:]
        H_gt = H_true[2100:4200, 500:]
        B_scenario = B[2100:4200, :]
        warm_up_len = 500
        true_core_loss = jnp.squeeze(loss[2100:4200])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)
    
        print("10% unknown")
    
        preds = predictions[4200:, 900:]
        H_gt = H_true[4200:, 900:]
        B_scenario = B[4200:, :]
        warm_up_len = 900
        true_core_loss = jnp.squeeze(loss[4200:])
        
        wce_per_sequence = np.max(np.abs(preds - H_gt), axis=1)
        mse_per_sequence = np.mean((preds - H_gt) ** 2, axis=1)
        sre_per_sequence = eqx.filter_vmap(sre)(preds, H_gt)
    
        dbdt_full = np.gradient(B_scenario, axis=1)
        dbdt = dbdt_full[:, warm_up_len:]
        nere_per_sequence = eqx.filter_vmap(nere)(preds, H_gt, dbdt, np.abs(true_core_loss))
        sre_avg = np.mean(sre_per_sequence)
        sre_95th = np.percentile(sre_per_sequence, 95)
    
        nere_avg = np.mean(nere_per_sequence)
        nere_95th = np.percentile(nere_per_sequence, 95)
    
        print("sre_avg:", sre_avg)
        print("sre_95th:", sre_95th)
        print("nere_avg:", nere_avg)
        print("nere_95th:", nere_95th)

In [None]:
predictions = pd.read_csv(f"csvs/{exp_ids[0]}_pred.csv", header=None)
print(exp_id[0])
display(predictions)
predictions = jnp.array(predictions)

In [None]:
SCENARIO_LABELS

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)