In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

## prepare data:

In [None]:
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

available_materials.remove("N49")
print(available_materials)
print(len(available_materials))


dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])
all_relevant_data

In [None]:
all_relevant_data

In [None]:
tau = 1 / (16)

t0 = 600
t1 = 2900
t2 = 7000

B = all_relevant_data.B[0, t0:t2]
H = all_relevant_data.H[0, t0:t1]
T = all_relevant_data.T[0]
t = np.linspace(0, (B.shape[0] -1) * tau, B.shape[0])

## Plot:

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3.5, 2.5), sharex=True)

colors = plt.rcParams["axes.prop_cycle"]()
c1 = next(colors)["color"]
c2 = next(colors)["color"]

axs[0].plot(t, B, color=c1)
axs[1].plot(t[:-(t2-t1)], H, color=c2)

axs[1].plot(t[-(t2-t1):], all_relevant_data.H[0, t1:t2], color=c2, alpha=0.95, linestyle="dashed")

axs[0].set_ylabel("$B \mathrm{\; in \; T}$")
axs[1].set_ylabel("$H \mathrm{\; in \; A/m}$")

for ax in axs:
    ax.grid(True, which="both", alpha=0.3)
    ax.tick_params(which='both', axis="y", direction='in')
    ax.tick_params(which='both', axis="x", direction='in') 
    ax.set_xlim([-15, 410])
    
axs[-1].set_xlabel("$t \mathrm{\; in \; \\upmu s}$")

for ax, perc in zip(axs, [0.06, 0.08]):
    ax.vlines(
        x=[0.0, (t1-t0) * tau, (t2-t0) * tau] ,
        ymin = ax.get_ylim()[0] - (perc * ax.get_ylim()[0]),
        ymax=ax.get_ylim()[1],
        color="k",
        linestyles="dashed",
        linewidth=1,
    )

axs[0].text(5, -0.1, '$t_0$', color='k', verticalalignment="top")
axs[0].text((t1-t0) * tau + 5, -0.1, '$t_1$', color='k', verticalalignment="top")
axs[0].text((t2-t0) * tau -20, -0.1, '$t_2$', color='k', verticalalignment="top")

axs[1].text(5, -1, '$t_0$', color='k', verticalalignment="top")
axs[1].text((t1-t0) * tau + 5, -1, '$t_1$', color='k', verticalalignment="top")
axs[1].text((t2-t0) * tau - 20, -1, '$t_2$', color='k', verticalalignment="top")

axs[1].text(((t2-t0) + (t1-t0)) * tau / 2 - 5, -70, '$\\textbf{?}$', fontsize=20, color=c2, verticalalignment="center")

fig.tight_layout(h_pad=0.2, w_pad=0)
fig.align_ylabels(axs)

fig.savefig("B_H_prediction_example.pdf", bbox_inches='tight')
fig.savefig("B_H_prediction_example.png", bbox_inches='tight')
fig.savefig("B_H_prediction_example.svg", bbox_inches='tight')
plt.show()

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3.5, 2.5), sharex=True)

colors = plt.rcParams["axes.prop_cycle"]()
c1 = next(colors)["color"]
c2 = next(colors)["color"]

axs[0].plot(t, B, color=c1)
axs[1].plot(t[:-(t2-t1)], H, color=c2)

axs[1].plot(t[-(t2-t1):], all_relevant_data.H[0, t1:t2], color=c2, alpha=0, linestyle="dashed")

axs[0].set_ylabel("$B \mathrm{\; in \; T}$")
axs[1].set_ylabel("$H \mathrm{\; in \; A/m}$")

for ax in axs:
    ax.grid(True, which="both", alpha=0.3)
    ax.tick_params(which='both', axis="y", direction='in')
    ax.tick_params(which='both', axis="x", direction='in') 
    ax.set_xlim([-15, 410])
    
axs[-1].set_xlabel("$t \mathrm{\; in \; \\upmu s}$")

for ax, perc in zip(axs, [0.06, 0.08]):
    ax.vlines(
        x=[0.0, (t1-t0) * tau, (t2-t0) * tau] ,
        ymin = ax.get_ylim()[0] - (perc * ax.get_ylim()[0]),
        ymax=ax.get_ylim()[1],
        color="k",
        linestyles="dashed",
        linewidth=1,
    )

axs[0].text(5, -0.1, '$t_0$', color='k', verticalalignment="top")
axs[0].text((t1-t0) * tau + 5, -0.1, '$t_1$', color='k', verticalalignment="top")
axs[0].text((t2-t0) * tau -20, -0.1, '$t_2$', color='k', verticalalignment="top")

axs[1].text(5, -1, '$t_0$', color='k', verticalalignment="top")
axs[1].text((t1-t0) * tau + 5, -1, '$t_1$', color='k', verticalalignment="top")
axs[1].text((t2-t0) * tau - 20, -1, '$t_2$', color='k', verticalalignment="top")

axs[1].text(((t2-t0) + (t1-t0)) * tau / 2 - 5, -70, '$\\textbf{?}$', fontsize=20, color=c2, verticalalignment="center")

fig.tight_layout(h_pad=0.2, w_pad=0)
fig.align_ylabels(axs)

fig.savefig("B_H_prediction_example_without_pred.pdf", bbox_inches='tight')
fig.savefig("B_H_prediction_example_without_pred.png", bbox_inches='tight')
fig.savefig("B_H_prediction_example_without_pred.svg", bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5))

colors = plt.rcParams["axes.prop_cycle"]()
c1 = next(colors)["color"]
c2 = next(colors)["color"]

ax.plot(all_relevant_data.H[0, t0:t1], all_relevant_data.B[0, t0:t1], color=c1)
ax.plot(all_relevant_data.H[0, t1:t2], all_relevant_data.B[0, t1:t2], color=c2, alpha=0, linestyle="dashed")

ax.set_ylabel("$B \mathrm{\; in \; T}$")
ax.set_xlabel("$H \mathrm{\; in \; A/m}$")

ax.grid(True, which="both", alpha=0.3)
ax.tick_params(which='both', axis="y", direction='in')
ax.tick_params(which='both', axis="x", direction='in') 

fig.savefig("B_H_hysteresis_example_without_pred.pdf", bbox_inches='tight')
fig.savefig("B_H_hysteresis_example_without_pred.png", bbox_inches='tight')
fig.savefig("B_H_hysteresis_example_without_pred.svg", bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5))

colors = plt.rcParams["axes.prop_cycle"]()
c1 = next(colors)["color"]
c2 = next(colors)["color"]

ax.plot(all_relevant_data.H[0, t0:t1], all_relevant_data.B[0, t0:t1], color=c1)
ax.plot(all_relevant_data.H[0, t1:t2], all_relevant_data.B[0, t1:t2], color=c2, alpha=0.95, linestyle="dashed")

ax.set_ylabel("$B \mathrm{\; in \; T}$")
ax.set_xlabel("$H \mathrm{\; in \; A/m}$")

ax.grid(True, which="both", alpha=0.3)
ax.tick_params(which='both', axis="y", direction='in')
ax.tick_params(which='both', axis="x", direction='in') 

fig.savefig("B_H_hysteresis_example.pdf", bbox_inches='tight')
fig.savefig("B_H_hysteresis_example.png", bbox_inches='tight')
fig.savefig("B_H_hysteresis_example.svg", bbox_inches='tight')
plt.show()

In [None]:
for i in range(10):
    fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5))
    
    colors = plt.rcParams["axes.prop_cycle"]()
    c1 = next(colors)["color"]
    c2 = next(colors)["color"]
    
    ax.plot(all_relevant_data.H[i], all_relevant_data.B[i], color=c1)
    
    ax.set_ylabel("$B \mathrm{\; in \; T}$")
    ax.set_xlabel("$H \mathrm{\; in \; A/m}$")
    
    ax.grid(True, which="both", alpha=0.3)
    ax.tick_params(which='both', axis="y", direction='in')
    ax.tick_params(which='both', axis="x", direction='in') 
    
    fig.savefig(f"B_H_hysteresis_example_{i}.svg", bbox_inches='tight')
    plt.show()

In [None]:
np.unique(all_temp_data.T)

In [None]:
all_temp_data = dataset.at_material("N30").at_frequency(80_000)
fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5))
for i in range(all_temp_data.B.shape[0]):
   
    colors = plt.rcParams["axes.prop_cycle"]()
    c1 = next(colors)["color"]
    c2 = next(colors)["color"]
    
    ax.plot(all_temp_data.H[i], all_temp_data.B[i], color=c1)
    
ax.set_ylabel("$B \mathrm{\; in \; T}$")
ax.set_xlabel("$H \mathrm{\; in \; A/m}$")

ax.grid(True, which="both", alpha=0.3)
ax.tick_params(which='both', axis="y", direction='in')
ax.tick_params(which='both', axis="x", direction='in') 
    
fig.savefig(f"B_H_hysteresis_example_all_data.svg", bbox_inches='tight')
plt.show()

## Model plots

In [None]:
test_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:, ::5],
    all_relevant_data.B[:, ::5],
    all_relevant_data.T[:],
)
tau_ = 1

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
from mc2.models.NODE import HiddenStateNeuralEulerODE
from mc2.features.features_jax import add_fe as add_features

In [None]:
model = HiddenStateNeuralEulerODE(obs_dim=1, state_dim=10, action_dim=5, width_size=32, depth=2, obs_func=lambda x: x[0], key=jax.random.key(0))
model = eqx.tree_deserialise_leaves(path_or_file=pathlib.Path("../../data/models") / "NODE.eqx", like=model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

In [None]:
model

In [None]:
def evaluate_on_test_data(testing_data, model):
    batched_H = testing_data.H[:, :][..., None]
    batched_B = testing_data.B[:, :][..., None]
    _, pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], add_features(batched_B[:, 1:, 0], n_s=10), tau_)

    for i in range(min(batched_H.shape[0], 20)):
        fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
        axs[-1].plot(pred_H[i], label="pred")
        fig.legend()
        plt.show()

In [None]:
evaluate_on_test_data(test_data, model)

In [None]:
start = 500
sequence_length = 1000

In [None]:
batched_H = test_data.H[:, start:start+sequence_length][..., None]
batched_B = test_data.B[:, start:start+sequence_length][..., None]
_, pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], add_features(batched_B[:, 1:, 0], n_s=10), tau_)
pred_H.shape

In [None]:
pred_H[0]

In [None]:
t = np.linspace(0, (batched_B.shape[1] -1) * tau * 5, batched_B.shape[1])

In [None]:
for i in range(20):
    fig, axs = plt.subplots(2, 1, figsize=(3.5, 2.5), sharex=True)
    
    colors = plt.rcParams["axes.prop_cycle"]()
    c1 = next(colors)["color"]
    c2 = next(colors)["color"]
    c3 = next(colors)["color"]
    c4 = next(colors)["color"]
    
    axs[0].plot(t, batched_B[i], color=c1)
    axs[1].plot(t, batched_H[i], color=c2, alpha=0.7, label=r"$H_{\mathrm{meas}}$")
    axs[1].plot(t, pred_H[i], color=c4, linestyle="dashed", label=r"$\hat{H}$")
    
    axs[0].set_ylabel("$B \mathrm{\; in \; T}$")
    axs[1].set_ylabel("$H \mathrm{\; in \; A/m}$")
    
    for ax in axs:
        ax.grid(True, which="both", alpha=0.3)
        ax.tick_params(which='both', axis="y", direction='in')
        ax.tick_params(which='both', axis="x", direction='in') 
        # ax.set_xlim([-15, 410])
        
    axs[-1].set_xlabel("$t \mathrm{\; in \; \\upmu s}$")
    
    # for ax, perc in zip(axs, [0.06, 0.08]):
    #     ax.vlines(
    #         x=[0.0, (t1-t0) * tau, (t2-t0) * tau] ,
    #         ymin = ax.get_ylim()[0] - (perc * ax.get_ylim()[0]),
    #         ymax=ax.get_ylim()[1],
    #         color="k",
    #         linestyles="dashed",
    #         linewidth=1,
    #     )
    
    # axs[0].text(5, -0.1, '$t_0$', color='k', verticalalignment="top")
    # axs[0].text((t1-t0) * tau + 5, -0.1, '$t_1$', color='k', verticalalignment="top")
    # axs[0].text((t2-t0) * tau -20, -0.1, '$t_2$', color='k', verticalalignment="top")
    
    # axs[1].text(5, -1, '$t_0$', color='k', verticalalignment="top")
    # axs[1].text((t1-t0) * tau + 5, -1, '$t_1$', color='k', verticalalignment="top")
    # axs[1].text((t2-t0) * tau - 20, -1, '$t_2$', color='k', verticalalignment="top")
    
    # axs[1].text(((t2-t0) + (t1-t0)) * tau / 2 - 5, -70, '$\\textbf{?}$', fontsize=20, color=c2, verticalalignment="center")
    
    fig.tight_layout(h_pad=0.2, w_pad=0)
    fig.align_ylabels(axs)

    axs[-1].legend()
    # fig.savefig("model_performance_time_series.pdf", bbox_inches='tight')
    # fig.savefig("model_performance_time_series.png", bbox_inches='tight')
    fig.savefig(f"model_performance_time_series_{i}.svg", bbox_inches='tight')
    plt.show()

In [None]:
for i in range(10):
    fig, axs = plt.subplots(2, 1, figsize=(3.5, 2.5), sharex=True)
    
    colors = plt.rcParams["axes.prop_cycle"]()
    c1 = next(colors)["color"]
    c2 = next(colors)["color"]
    c3 = next(colors)["color"]
    c4 = next(colors)["color"]
    
    axs[0].plot(t, batched_B[i], color=c1)
    axs[1].plot(t, batched_H[i], color=c2, alpha=0.7, label=r"$H_{\mathrm{meas}}$")
    axs[1].plot(t, pred_H[i], color=c4, linestyle="dashed", label=r"$\hat{H}$")
    
    axs[0].set_ylabel("$B \mathrm{\; in \; T}$")
    axs[1].set_ylabel("$H \mathrm{\; in \; A/m}$")
    
    for ax in axs:
        ax.grid(True, which="both", alpha=0.3)
        ax.tick_params(which='both', axis="y", direction='in')
        ax.tick_params(which='both', axis="x", direction='in') 
        # ax.set_xlim([-15, 410])
        
    axs[-1].set_xlabel("$t \mathrm{\; in \; \\upmu s}$")
    
    fig.tight_layout(h_pad=0.2, w_pad=0)
    fig.align_ylabels(axs)

    axs[-1].legend()
    # fig.savefig("model_performance_time_series.pdf", bbox_inches='tight')
    # fig.savefig("model_performance_time_series.png", bbox_inches='tight')
    fig.savefig(f"model_performance_time_series_{i}.svg", bbox_inches='tight')
    plt.show()

In [None]:
for i in range(5):
    fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5))

    colors = plt.rcParams["axes.prop_cycle"]()
    c1 = next(colors)["color"]
    c2 = next(colors)["color"]
    c3 = next(colors)["color"]
    c4 = next(colors)["color"]

    ax.plot(batched_H[i], batched_B[i], color=c1, alpha=0.7, label=r"${\mathrm{meas}}$")
    ax.plot(pred_H[i], batched_B[i], color=c4, linestyle="dashed", alpha=1, label=r"${\mathrm{est}}$")

    ax.set_ylabel("$B \mathrm{\; in \; T}$")
    ax.set_xlabel("$H \mathrm{\; in \; A/m}$")
    
    ax.grid(True, which="both", alpha=0.3)
    ax.tick_params(which='both', axis="y", direction='in')
    ax.tick_params(which='both', axis="x", direction='in') 
    ax.legend()
    fig.savefig(f"model_performance_hysteresis{i}.svg", bbox_inches='tight')
    plt.show()