In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

## Differentiable preisach model

**Warning**: This is the standard version with B = f(H) and not actually what we need in the end!

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn

import equinox as eqx
import optax

In [None]:
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet
from mc2.training.data_sampling import draw_data_uniformly

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "ten_mat_data.pickle")
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])

training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, :][:, ::10] / 100,
    all_relevant_data.B[:200, :][:, ::10],
    all_relevant_data.T[:200],
)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, :][:, ::10] / 100,
    all_relevant_data.B[200:, :][:, ::10],
    all_relevant_data.T[200:],
)

In [None]:
def test_on_evalset(evaluation_data, model, alpha_beta_grid):
    batched_H = evaluation_data.H[:, :][..., None]
    batched_B = evaluation_data.B[:, :][..., None]
    
    pred_B = jax.vmap(model, in_axes=(0, None))(batched_H, alpha_beta_grid)

    for i in range(min(batched_H.shape[0], 3)):
        fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
        axs[0].plot(pred_B[i], label="pred")
        fig.legend()
        plt.show()

In [None]:
from mc2.models.preisach_utils import build_alpha_beta_grid
from mc2.models.preisach import ArrayPreisach

In [None]:
@eqx.filter_jit
@eqx.filter_value_and_grad
def compute_loss_and_grad(model, H_trajectory, B_trajectory, alpha_beta_grid):
    B_est = jax.vmap(model, in_axes=(0, None))(H_trajectory, alpha_beta_grid)
    return jnp.mean((B_est - B_trajectory)**2)

In [None]:
model, alpha_beta_grid = ArrayPreisach.from_parameters(
    points_per_dim=150,
)

sequence_length = 1000
training_batch_size = 64

key = jax.random.key(1)
key, _, loader_key = jax.random.split(key, 3)

optim = optax.adam(learning_rate=1e-4)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
for n in tqdm.tqdm(range(20_000)):

    # sample data batches
    batched_H, batched_B, loader_key = draw_data_uniformly(training_data, sequence_length, training_batch_size, loader_key)
    loss, grads = compute_loss_and_grad(model, batched_H, batched_B, alpha_beta_grid)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)


    if n % 10_000 == 0 and n > 0:
        test_on_evalset(testing_data, model, alpha_beta_grid)

In [None]:
batched_B = testing_data.B[:, :][..., None]
batched_H = testing_data.H[:, :][..., None]

pred_B = jax.vmap(model, in_axes=(0, None))(batched_H, alpha_beta_grid)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
    axs[0].plot(pred_B[i], label="pred")
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(pred_B[i], batched_H[i], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()