In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
import equinox as eqx

from mc2.utils.model_evaluation import reconstruct_model_from_exp_id
from mc2.model_interfaces.model_interface import count_model_parameters
from mc2.utils.pretest_evaluation import produce_pretest_histograms, SCENARIO_LABELS, DETAILED_SCENARIO_LABELS, store_pretest_results_to_csv, load_hdf5_pretest_data


from mc2.utils.final_data_evaluation import FINAL_MATERIALS

# visualize cross validation trajectories
from mc2.utils.model_evaluation import plot_model_frequency_sweep, get_mixed_frequency_arrays
from mc2.runners.model_setup_jax import get_normalizer

In [None]:
exp_ids = {
    "A": 'A_GRU8_reduced-features-f32_2a1473b6_seed12',
    "B": 'B_GRU8_reduced-features-f32_c785b2c3_seed12',
    "C": 'C_GRU8_reduced-features-f32_348e220c_seed12',
    "D": 'D_GRU8_reduced-features-f32_b6ac55b5_seed12',
    "E": 'E_GRU8_reduced-features-f32_e88a2583_seed12',
}

models = {material_name: reconstruct_model_from_exp_id(exp_id)[0] for material_name, exp_id in exp_ids.items()}
# models

In [None]:
from mc2.utils.model_evaluation import (
    load_gt_and_pred, plot_worst_predictions, plot_first_predictions, plot_loss_trends
)

In [None]:
material = "C"
exp_id = exp_ids[material]
seed=exp_id.split("seed")[-1]
gt, pred = load_gt_and_pred(
    exp_id=exp_id,
    seed=seed,
    freq_idx=0
)
fig, axs = plot_loss_trends(exp_id, seed, plot_together=True, figsize=(252*0.0138889, 3));
plt.savefig("fig/exemplary_losses.pdf", bbox_inches="tight")

In [None]:
for material_name in FINAL_MATERIALS:
    exp_id = exp_ids[material_name]
    wrapped_model = models[material_name]

    _, (train_set, eval_set, test_set) = get_normalizer(material_name, wrapped_model.featurize, subsampling_freq=1, do_normalization=True, transform_H=False)

    loader_key = jax.random.key(12)
    plot_model_frequency_sweep(wrapped_model, test_set, loader_key, past_size=1)

In [None]:
material_name = "E"

exp_id = exp_ids[material_name]
wrapped_model = models[material_name]

_, (train_set, eval_set, test_set) = get_normalizer(material_name, wrapped_model.featurize, subsampling_freq=1, do_normalization=True, transform_H=False)


loader_key = jax.random.key(13)
past_size = 100
figsize = (7 * 252*0.0138889, 3 * 4)

H, B, T = get_mixed_frequency_arrays(test_set, sequence_length=1000, batch_size=1, key=loader_key)

H_past = H[:, :past_size]
B_past = B[:, :past_size]

B_future = B[:, past_size:]
H_future = H[:, past_size:]

H_pred = wrapped_model(B_past, H_past, B_future, T)

# plot
fig, axs = plt.subplots(3, 7, figsize=figsize)
for freq_idx in range(len(test_set.frequencies)):
    axs[0, freq_idx].plot(B_future[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx])
    axs[1, freq_idx].plot(H_pred[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx] - H_pred[freq_idx], color="tab:red", linestyle="--")

    axs[2, freq_idx].plot(B_future[freq_idx], H_future[freq_idx])
    axs[2, freq_idx].plot(B_future[freq_idx], H_pred[freq_idx])

    axs[0, freq_idx].grid(True, alpha=0.3)
    axs[1, freq_idx].grid(True, alpha=0.3)
    axs[2, freq_idx].grid(True, alpha=0.3)

    axs[0, freq_idx].set_ylabel("B")
    axs[0, freq_idx].set_xlabel("k")
    axs[1, freq_idx].set_ylabel("H")
    axs[1, freq_idx].set_xlabel("k")
    axs[2, freq_idx].set_ylabel("H")
    axs[2, freq_idx].set_xlabel("B")

fig.tight_layout(pad=-0.2)
plt.show()

# ======================================================================================================
freq_idx = 3

fig, axs = plt.subplots(2,1, figsize=(252*0.0138889, 6))

ax = axs[0]

length = H.shape[-1]
k = jnp.linspace(0, length-1, length)


ax.plot(k, H[freq_idx], color="tab:blue", label="gt", linewidth=1)
ax.plot(k[:past_size], H_past[freq_idx], color="tab:orange", linestyle="-", label="warmup", linewidth=1)
ax.plot(k[past_size:], H_pred[freq_idx], color="tab:orange", linestyle="--", label=r"pred", linewidth=1)
ax.set_ylabel("$H$")
ax.set_xlabel("$k$")
ax.legend()
ax.grid(True, alpha=0.3)


ax = axs[1]

ax.plot(B[freq_idx], H[freq_idx], color="tab:blue", label="gt", linewidth=1)
ax.plot(B_past[freq_idx], H_past[freq_idx], color="tab:orange", linestyle="-", label="warmup", linewidth=1)
ax.plot(B_future[freq_idx], H_pred[freq_idx], color="tab:orange", linestyle="--", label=r"pred", linewidth=1)
ax.set_ylabel("$H$")
ax.set_xlabel("$B$")
ax.legend()
ax.grid(True, alpha=0.3)


plt.savefig("fig/exemplary_trajectory_matE.pdf", bbox_inches="tight")

In [None]:
material_name = "B"

exp_id = exp_ids[material_name]
wrapped_model = models[material_name]

_, (train_set, eval_set, test_set) = get_normalizer(material_name, wrapped_model.featurize, subsampling_freq=1, do_normalization=True, transform_H=False)

In [None]:
loader_key = jax.random.key(13)
past_size = 500
figsize = (7 * 252*0.0138889, 3 * 4)

H, B, T = test_set[1].H[:, :2000], test_set[1].B[:, :2000], test_set[1].T[:]

H_past = H[:, :past_size]
B_past = B[:, :past_size]

B_future = B[:, past_size:]
H_future = H[:, past_size:]

H_pred = wrapped_model(B_past, H_past, B_future, T)

# plot
fig, axs = plt.subplots(3, 7, figsize=figsize)
for freq_idx in range(len(test_set.frequencies)):
    axs[0, freq_idx].plot(B_future[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx])
    axs[1, freq_idx].plot(H_pred[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx] - H_pred[freq_idx], color="tab:red", linestyle="--")

    axs[2, freq_idx].plot(B_future[freq_idx], H_future[freq_idx])
    axs[2, freq_idx].plot(B_future[freq_idx], H_pred[freq_idx])

    axs[0, freq_idx].grid(True, alpha=0.3)
    axs[1, freq_idx].grid(True, alpha=0.3)
    axs[2, freq_idx].grid(True, alpha=0.3)

    axs[0, freq_idx].set_ylabel("B")
    axs[0, freq_idx].set_xlabel("k")
    axs[1, freq_idx].set_ylabel("H")
    axs[1, freq_idx].set_xlabel("k")
    axs[2, freq_idx].set_ylabel("H")
    axs[2, freq_idx].set_xlabel("B")

fig.tight_layout(pad=-0.2)

In [None]:
freq_idx = 0

fig, axs = plt.subplots(2,1, figsize=(252*0.0138889, 6))

ax = axs[0]

length = H.shape[-1]
k = jnp.linspace(0, length-1, length)


ax.plot(k, H[freq_idx], color="tab:blue", label="gt", linewidth=1)
ax.plot(k[:past_size], H_past[freq_idx], color="tab:orange", linestyle="-", label="warmup", linewidth=1)
ax.plot(k[past_size:], H_pred[freq_idx], color="tab:orange", linestyle="--", label=r"pred", linewidth=1)
ax.set_ylabel("$H$")
ax.set_xlabel("$k$")
ax.legend()
ax.grid(True, alpha=0.3)


ax = axs[1]

ax.plot(B[freq_idx], H[freq_idx], color="tab:blue", label="gt", linewidth=1)
ax.plot(B_past[freq_idx], H_past[freq_idx], color="tab:orange", linestyle="-", label="warmup", linewidth=1)
ax.plot(B_future[freq_idx], H_pred[freq_idx], color="tab:orange", linestyle="--", label=r"pred", linewidth=1)
ax.set_ylabel("$H$")
ax.set_xlabel("$B$")
ax.legend()
ax.grid(True, alpha=0.3)

plt.savefig("fig/exemplary_trajectory_matB.pdf", bbox_inches="tight")