In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from mc2.data_management import EXPERIMENT_LOGS_ROOT

In [None]:
import jax.numpy as jnp
import jax

jax.devices('cpu')

In [None]:
def load_preds_n_gt(folder_name):
    fp = EXPERIMENT_LOGS_ROOT / folder_name
    fp_pred = list(fp.glob("seed_*_seq_0_preds.parquet"))[0]
    fp_gt = list(fp.glob("seed_*_seq_0_gt.parquet"))[0]
    
    preds_MS = pd.read_parquet(fp_pred).to_numpy()
    gt_MS = pd.read_parquet(fp_gt).to_numpy()
    return gt_MS, preds_MS


In [None]:

def plot_worst_predictions(folder_name,):
    gt_MS, preds_MS = load_preds_n_gt(folder_name)
    material = folder_name.split("_")[0]
    fig, axes = plt.subplots(5, 1, sharex=True, sharey="col", figsize=(10, 15))
    mae_M = np.mean(np.abs(gt_MS - preds_MS), axis=-1)
    mse_M = np.mean((gt_MS - preds_MS)**2, axis=-1)
    wce_M = np.max(np.abs(gt_MS - preds_MS), axis=-1)
    idx_argmax = np.argpartition(wce_M, -5)[-5:]  # worse case trajectories
    print(f"MAE {mae_M.mean():.1f} A/m | MSE {mse_M.mean():.1f} (A/m)² | WCE {wce_M.max():.1f} A/m")
    for tst_i in range(axes.shape[0]):
        tst_idx = idx_argmax[tst_i]
        ax = axes[tst_i]
        ax.plot(gt_MS[tst_idx], label='gt')
        ax.plot(preds_MS[tst_idx], label='pred', ls='dashed')
        ax.annotate(f"MAE {mae_M[tst_idx]:.1f} A/m | "
                    f"MSE {mse_M[tst_idx]:.1f} (A/m)² | "
                    f"WCE {wce_M[tst_idx]:.1f} A/m",
                    (0.3, 0.1), xycoords=ax.transAxes)

    axes[0].set_title(f"Worst-case predictions for {material}")
    axes.flatten()[0].legend()
    for ax in axes.flatten():
        ax.grid(alpha=0.3)
    for ax in axes:
        ax.set_ylabel("H in A/m")

    for ax in [axes[-1]]:
        ax.set_xlabel("Sequence step")
    fig.tight_layout()


In [None]:
def show_trends(exp_id):
    fp = EXPERIMENT_LOGS_ROOT 
    trend_fp = list(fp.glob(f"*{exp_id}/seed_*_loss_trends.parquet"))[0]  # get the first file in the directory
    trends_df = pd.read_parquet(trend_fp)
    material, exp_id = trend_fp.parent.name.split("_")
    fig, ax = plt.subplots(figsize=(8, 3))
    #ax.plot(trends_df["train"], label="train")
    ax.plot(trends_df["val"], label="val")
    ax.set_title(f"Material {material} {exp_id}")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss in (A/m)²")
    ax.set_yscale("log")
    ax.legend()
    ax.grid(alpha=0.3)
    fig.tight_layout()
    plt.show()

In [None]:
# 3 hidden units, init states with 0
for e in ["1b95979d", "3f47f52b", "3385f2b4", "497a5aa4", "497a5aa4", "ac76da17",
          "65e4d4fa", "e04e1d4f", "b7274ece", "c0f35141"]:
    show_trends(e)

In [None]:
# worst-case predictions
relevant_exps_l = ["3C90_c0f35141", "3E6_e04e1d4f"]
for e in relevant_exps_l:
    plot_worst_predictions(e)

In [None]:
# 3 hidden units, init states with first H ground truth
for e in sorted(["3C90_db54d1df", "3C94_e3ea619b", "3E6_e8cdb81c", "3F4_c0aa52f3", "77_bf033c05",
           "78_44b5e1be", "N27_70f4bad3", "N30_55ca3feb", "N87_f1c24305"]):
    show_trends(e)

In [None]:
# worst-case predictions (init with first H ground truth)
relevant_exps_l = ["3C90_db54d1df", "3E6_e8cdb81c"]
for e in relevant_exps_l:
    plot_worst_predictions(e)

In [None]:
# model with max and min of BH curve silhouette 25°C, 8 hidden units
relevant_exps_l = ["3C90_17a1f985", # standard model 1k epochs
                   #"3C90_a4cd8d8f", # with silhouette
                   "3C90_74bd22c0", # standard model 2k epochs
                   ] 
for e in relevant_exps_l:
    show_trends(e)
for e in relevant_exps_l:
    plot_worst_predictions(e)


# Further investigations