In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

# [WIP]
**Absolute basic RNN implementation.**

## TODOS:
Needs to be checked for errors, refined, and optimized.

## Load dataset

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

available_materials.remove("N49")
print(available_materials)
print(len(available_materials))


dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

## Preliminaries

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

# jax.config.update("jax_enable_x64", True)  # ?

In [None]:
# all_relevant_data = dataset.at_material("N27").at_frequency(50_000).filter_temperatures([25])
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])
all_relevant_data

In [None]:
training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, :],
    all_relevant_data.B[:200, :],
    all_relevant_data.T[:200],
)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, :],
    all_relevant_data.B[200:, :],
    all_relevant_data.T[200:],
)

In [None]:
training_data

In [None]:
testing_data

---

In [None]:
from mc2.models.RNN import BaseRNN
from mc2.training.optimization import make_step
from mc2.training.data_sampling import draw_data_uniformly

potential solution for sampling:
1. first sample a set of starting_points
2. then sample a set of sequences


Questions:
- Do we want different fequencies in a single dataset? Probably yes, at least at some point?
- Probably also yes for different temperatures?
- Also yes for different materials?
---

A batch would likely be its own Dataclass or just an array containing:

`n_batches` batches, each with a given `sequence_length` for H and B. Then each batch also has a frequency, temperature, and material.

Well, but we will not actually have frequency, temperature, and material knowledge in testing? Does this mean that we should not actuall present this information to the model, but instead estimate it from the data?

In [None]:
def test_on_evalset(evaluation_data, model):
    batched_H = evaluation_data.H[:, :][..., None]
    batched_B = evaluation_data.B[:, :][..., None]
    
    pred_H = jax.vmap(model)(batched_B)

    for i in range(min(batched_H.shape[0], 3)):
        fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
        axs[-1].plot(pred_H[i], label="pred")
        fig.legend()
        plt.show()

In [None]:
n_sequences, full_sequence_length = training_data.H.shape
training_batch_size = 64
sequence_length = 1000


key = jax.random.key(1)
key, model_key, loader_key = jax.random.split(key, 3)

model = BaseRNN(1, 1, 256, key=model_key)


lr = optax.schedules.exponential_decay(
    init_value=1e-3,
    transition_steps=50_000,
    transition_begin=2_000,
    decay_rate=0.995,
    end_value=1e-4
)

optim = optax.adam(lr)
opt_state = optim.init(model)

In [None]:
losses = []

for step in tqdm.tqdm(range(50_000)):

    # if step % 500 == 0 and step > 0 and sequence_length < 5000:
    #     sequence_length = sequence_length + 400
    #     print("Momentary sequence length:", sequence_length)
    
    batched_H, batched_B, loader_key = draw_data_uniformly(training_data, sequence_length, training_batch_size, loader_key)
    loss, model, opt_state = make_step(model, batched_B, batched_H, optim, opt_state)    
    losses.append(loss)

    if step % 1_000 == 0 and step > 0:
        plt.suptitle(f"Training loss over training steps at {step} steps")
        plt.plot(np.log(losses))
        plt.show()

        test_on_evalset(testing_data, model)

In [None]:
plt.suptitle("Final Training loss over over training steps")
plt.plot(np.log(losses))
plt.show()

## Consider Testing data

In [None]:
batched_H = testing_data.H[:, :][..., None]
batched_B = testing_data.B[:, :][..., None]

pred_H = jax.vmap(model)(batched_B)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i], label="pred")
    fig.legend()

In [None]:
start = 1000
end = 2000

pred_H = jax.vmap(model)(batched_B)

for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i, start:end], batched_H[i, start:end], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i, start:end], label="pred")
    fig.legend()

### Hysteresis plots:

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i], pred_H[i], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

- the inaccurate initial state seems to be resuling in a strong value drift?
- Initially, the magnetization is unknown, which seems to lead to parallel lines in the BH-plane?

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
start = 5_000
end = 6_000

for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, start:end], batched_H[i, start:end], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, start:end], pred_H[i, start:end], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

### Different material data:

- cross-validation on different material data:

In [None]:
different_material_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])

In [None]:
different_material_data

In [None]:
batched_H = different_material_data.H[:, :][..., None]
batched_B = different_material_data.B[:, :][..., None]

pred_H = jax.vmap(model)(batched_B)

for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i], label="pred")
    fig.legend()

In [None]:
start = 1000
end = 2000

batched_H = different_material_data.H[:, :][..., None]
batched_B = different_material_data.B[:, :][..., None]

pred_H = jax.vmap(model)(batched_B)

for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i, start:end], batched_H[i, start:end], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i, start:end], label="pred")
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i], pred_H[i], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

### Energy loss estimation

In [None]:
from mc2.metrics import get_energy_loss

In [None]:
energy_loss = jax.vmap(get_energy_loss)(
    b=testing_data.B,
    h=testing_data.H,
)

pred_H = jax.vmap(model)(testing_data.B[..., None])

est_energy_loss = get_energy_loss(
    b=testing_data.B,
    h=pred_H[..., 0],
)

In [None]:
plt.plot(energy_loss - est_energy_loss, label="value")
plt.plot(jnp.abs(energy_loss - est_energy_loss), label="abs value")

plt.xlabel("sequence idx")
plt.ylabel("error for energy loss in J")
plt.legend()
plt.grid()

plt.savefig("absolute errors.png", dpi=200)

In [None]:
plt.plot(jnp.abs(energy_loss - est_energy_loss) / energy_loss, label="abs value")

plt.xlabel("sequence idx")
plt.ylabel("relative error for energy loss in J")
plt.legend()
plt.grid()

plt.savefig("relative errors.png", dpi=200)

### Save model

In [None]:
# model_path = pathlib.Path("../../data/models") / "baseRNN_long_train.eqx"

In [None]:
# eqx.tree_serialise_leaves(model_path , model)

### Load model:

In [None]:
loaded_model = eqx.tree_deserialise_leaves(path_or_file=pathlib.Path("../../data/models") / "baseRNN_256_5000steps.eqx", like=model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

In [None]:
loaded_model