In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

# [WIP]
**Absolute basic RNN implementation.**

## TODOS:
Needs to be checked for errors, refined, and optimized.

## Load dataset

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

available_materials.remove("N49")
print(available_materials)
print(len(available_materials))


dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

## Preliminaries

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])
all_relevant_data

In [None]:
training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, :],
    all_relevant_data.B[:200, :],
    all_relevant_data.T[:200],
)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, :],
    all_relevant_data.B[200:, :],
    all_relevant_data.T[200:],
)

In [None]:
training_data

In [None]:
testing_data

---

In [None]:
from mc2.models.RNN import BaseRNN

In [None]:
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((pred_y - y)**2)

@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
@eqx.filter_jit
def precompute_starting_points(n_train_steps, k, sequence_length, training_batch_size, loader_key):
    index_normalized = jax.random.uniform(loader_key, shape=(n_train_steps, training_batch_size)) * (
        k + 1 - sequence_length
    )
    starting_points = index_normalized.astype(jnp.int32)
    (loader_key,) = jax.random.split(loader_key, 1)

    return starting_points, loader_key


@eqx.filter_jit
def load_single_batch(dataset, starting_points, sequence_length):

    slice = jnp.linspace(
        start=starting_points, stop=starting_points + sequence_length, num=sequence_length, dtype=int
    ).T

    batched_H = dataset.H[:, slice]
    batched_B = dataset.B[:, slice]

    batched_H = batched_H[:, :, :]
    batched_B = batched_B[:, :, :]
    return batched_H, batched_B


@eqx.filter_jit
def get_data(dataset, sequence_length, training_batch_size, loader_key):

    n_sequences, full_sequence_length = dataset.H.shape
    starting_points, loader_key = precompute_starting_points(1, full_sequence_length, sequence_length, training_batch_size, loader_key)
    batched_H, batched_B = load_single_batch(training_data, starting_points, sequence_length)

    batched_H = jnp.squeeze(batched_H)[..., None]
    batched_B = jnp.squeeze(batched_B)[..., None]

    # return a batched dataset ?

    return batched_H, batched_B, loader_key

In [None]:
n_sequences, full_sequence_length = training_data.H.shape
training_batch_size = 1
sequence_length = 5000

learning_rate = 1e-3


key = jax.random.key(1)
key, model_key, loader_key = jax.random.split(key, 3)

model = BaseRNN(1, 1, 256, key=model_key)
optim = optax.adam(learning_rate)
opt_state = optim.init(model)

In [None]:
losses = []

for step in tqdm.tqdm(range(5_000)):
    
    batched_H, batched_B, loader_key = get_data(training_data, sequence_length, training_batch_size, loader_key)
    loss, model, opt_state = make_step(model, batched_B, batched_H, opt_state)    
    losses.append(loss)

    if step % 100 == 0 and step > 0:
        plt.suptitle(f"Training loss over training steps at {step} steps")
        plt.plot(losses)
        plt.show()

plt.suptitle("Final Training loss over over training steps")
plt.plot(losses)
plt.show()

## Consider Testing data

In [None]:
batched_H = testing_data.H[:, :][..., None]
batched_B = testing_data.B[:, :][..., None]

pred_H = jax.vmap(model)(batched_B)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i], label="pred")
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i], pred_H[i], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
from mc2.metrics import get_energy_loss

In [None]:
energy_loss = jax.vmap(get_energy_loss)(
    b=testing_data.B,
    h=testing_data.H,
)

pred_H = jax.vmap(model)(testing_data.B[..., None])

est_energy_loss = get_energy_loss(
    b=testing_data.B,
    h=pred_H[..., 0],
)

In [None]:
plt.plot(energy_loss - est_energy_loss, label="value")
plt.plot(jnp.abs(energy_loss - est_energy_loss), label="abs value")

plt.xlabel("sequence idx")
plt.ylabel("error for energy loss in J")
plt.legend()
plt.grid()

plt.savefig("absolute errors.png", dpi=200)

In [None]:
plt.plot(jnp.abs(energy_loss - est_energy_loss) / energy_loss, label="abs value")

plt.xlabel("sequence idx")
plt.ylabel("relative error for energy loss in J")
plt.legend()
plt.grid()

plt.savefig("relative errors.png", dpi=200)

### Save model

In [None]:
model_path = pathlib.Path("../data/models") / "baseRNN.eqx"

In [None]:
eqx.tree_serialise_leaves(model_path , model)

### Load model:

In [None]:
model = eqx.tree_deserialise_leaves(model_path, model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

In [None]:
model