In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

## Differentiable preisach model

**Warning**: This is the standard version with B = f(H) and not actually what we need in the end!

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn

import equinox as eqx
import optax

In [None]:
from mc2.models.preisach import hysteron_operator, DifferentiablePreisach, ArrayPreisach, estimate_B

In [None]:
def build_grid(dim, low, high, points_per_dim):
    """Build a uniform grid of points in the given dimension."""
    xs = [jnp.linspace(low, high, points_per_dim) for _ in range(dim)]

    x_g = jnp.meshgrid(*xs)
    x_g = jnp.stack([_x for _x in x_g], axis=-1)
    x_g = x_g.reshape(-1, dim)

    assert x_g.shape[0] == points_per_dim**dim
    return x_g

def filter_function(x):
    return jnn.relu(x[1] - x[0])
    
def filter_grid(x):
    valid_points = jax.vmap(filter_function)(x) == 0
    return x[jnp.where(valid_points == True)]

def build_alpha_beta_grid(points_per_dim):
    return filter_grid(build_grid(2, -1, 1, points_per_dim))

In [None]:
alpha_beta_grid = build_alpha_beta_grid(50)

plt.scatter(alpha_beta_grid[:, 0], alpha_beta_grid[:, 1])

In [None]:
# Investigate hysteron operator behavior:

outs = []

H = jnp.linspace(-0.3, 0.3, 1000)[..., None]
H = jnp.concatenate([H, jnp.linspace(0.3, -0.3, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.3, -0.5, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.5, 0.15, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(0.15, -0.05, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.049, 0.1, 1000)[..., None]], axis=0)

# H = jnp.linspace(-0.3, 0.3, 1000)[..., None]
# H = jnp.concatenate([H, jnp.linspace(0.29, -0.1, 1000)[..., None]], axis=0)

positive_direction = True

initial_output = jnp.array([-1.])
initial_field = jnp.array([-10.1])
H_last = initial_field  # only used for sign change detection

for idx, H_in in enumerate(H):
    
    if positive_direction:
        if H_in < H_last:
            # print("Sign change detected")
            # print(H_in)
            # print(H_last)
            # print(output)
            initial_output = output
            initial_field = H_last
            positive_direction = False
    else:
        if H_in > H_last:
            # print("Sign change detected")
            # print(H_in)
            # print(H_last)
            # print(output)
            initial_output = output
            initial_field = H_last
            positive_direction = True


    # if idx >= 1000:
    #     print("H_in", H_in)
    #     print("H_last", H_last)
    #     print("positive_direction", positive_direction)

    #     print("initial_field", initial_field)
    #     print("initial_output", initial_output)

    #     print(" ")
        

    output = hysteron_operator(
        H_in,
        initial_field,
        initial_output,
        jnp.array([0.2, -0.15]),
        1e-3,
    )
    H_last = H_in

    outs.append(output)

plt.plot(H, outs)
plt.grid()
plt.show()

plt.plot(H)
plt.plot(outs)

In [None]:
def analyticalPreisachFunction2(A: float, Hc: float, sigma: float, beta: np.ndarray, alpha: np.ndarray) -> np.ndarray:
    """
    Function based on Paper 'Removing numerical instabilities in the Preisach model identification
    using genetic algorithms' by G. Consolo G. Finocchio, M. Carpentieri, B. Azzerboni.
    """
    nom1 = 1
    den1 = 1 + ((beta - Hc) * sigma / Hc) ** 2
    nom2 = 1
    den2 = 1 + ((alpha + Hc) * sigma / Hc) ** 2
    preisach = A * (nom1 / den1) * (nom2 / den2)
    # set lower right diagonal to zero
    for i in range(preisach.shape[0]):
        preisach[i, (-i - 1):] = 0
    return preisach

def preisachIntegration(w: float, Z: np.ndarray) -> np.ndarray:
    """
    Perform 2D- integration of the Preisach distribution function.
    """
    flipped = np.fliplr(np.flipud(w * Z))
    flipped_integral = np.cumsum(np.cumsum(flipped, axis=0), axis=1)
    return np.fliplr(np.flipud(flipped_integral))

In [None]:
dim = 2
low=-1
high=1
points_per_dim=100

xs = [jnp.linspace(low, high, points_per_dim) for _ in range(dim)]
alpha_grid, beta_grid = jnp.meshgrid(*xs)

alpha_beta_grid = jnp.concatenate([alpha_grid.flatten()[..., None], beta_grid.flatten()[..., None]], axis=-1)

preisach = analyticalPreisachFunction2(A=1, Hc=0.01, sigma=0.03, beta=np.array(beta_grid), alpha=np.array(alpha_grid))
preisach = preisachIntegration(w=2 * 1 / (points_per_dim - 1), Z=preisach)

preisach = preisach / np.max(preisach)


plt.imshow(preisach)
plt.show()

plt.imshow(jnp.fliplr(preisach))
plt.show()

preisach = jnp.fliplr(preisach)

preisach = preisach.flatten()


valid_points = jax.vmap(filter_function)(alpha_beta_grid) == 0
preisach = preisach[jnp.where(valid_points == True)][:, None]
preisach = jnp.array(preisach)

alpha_beta_grid = filter_grid(alpha_beta_grid)

In [None]:
model = ArrayPreisach(
    hysteron_density=preisach
    #hysteron_density=jnp.ones((alpha_beta_grid.shape[0], 1))
)

# alpha_beta_grid = build_alpha_beta_grid(50)
# model = DifferentiablePreisach(
#     width_size=128,
#     depth=3,
#     model_key=jax.random.PRNGKey(15)
# )

B_traj = []

# H = jnp.linspace(-5, 5, 1_000)[..., None]
# H = jnp.concatenate([H, jnp.linspace(5, -5, 1_000)[..., None]], axis=0)
# H = jnp.concatenate([H, jnp.linspace(-5, 5, 1_000)[..., None]], axis=0)
# H = jnp.concatenate([H, jnp.linspace(5, -5, 1_000)[..., None]], axis=0)

H = jnp.linspace(-0.3, 0.3, 10_000)[..., None]
H = jnp.concatenate([H, jnp.linspace(0.3, -0.3, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.3, -0.5, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.5, 0.15, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(0.15, -0.05, 1000)[..., None]], axis=0)
H = jnp.concatenate([H, jnp.linspace(-0.05, 0.1, 1000)[..., None]], axis=0)

H = H*2

# H = jnp.linspace(-0.3, 0.3, 1000)[..., None]
# H = jnp.concatenate([H, jnp.linspace(0.29, -0.1, 1000)[..., None]], axis=0)

positive_direction = True

initial_operator_values = - jnp.ones((alpha_beta_grid.shape[0], 1))
initial_field = jnp.array([-10.1])
H_last = initial_field  # only used for sign change detection

for idx, H_in in enumerate(H):
    
    if positive_direction:
        if H_in < H_last:
            initial_operator_values = operator_values
            initial_field = H_last
            positive_direction = False
    else:
        if H_in > H_last:
            initial_operator_values = operator_values
            initial_field = H_last
            positive_direction = True
       

    B, operator_values = model(
        H=H_in,
        initial_field=initial_field,
        initial_operator_values=initial_operator_values,
        alpha_beta_grid=alpha_beta_grid,
        T=1e-3
    )
    H_last = H_in

    B_traj.append(B)

plt.plot(H, B_traj)
plt.grid()
plt.show()

# plt.plot(H, B_traj)
# plt.grid()
# plt.xlim(-0.1, 0.1)
# plt.xlim(-0.4, -0.2)
# plt.show()

plt.plot(H)
plt.grid()
plt.plot(B_traj)

In [None]:
B_est = estimate_B(H, model, alpha_beta_grid, T=1e-3)

plt.plot(H)
plt.grid()
plt.show()

plt.plot(H, B_est)
plt.grid()
plt.show()

plt.plot(H)
plt.grid()
plt.plot(B_est)

In [None]:
model

## Real Data:

In [None]:
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "ten_mat_data.pickle")

In [None]:
B_trajectory = dataset[0][0].B[0, :][..., None] * 10 / 4
H_trajectory = dataset[0][0].H[0, :][..., None] / 100


B_trajectory = B_trajectory[::10]
H_trajectory = H_trajectory[::10]

In [None]:
plt.plot(H_trajectory)

In [None]:
signs = jnp.sign(jnp.roll(H_trajectory, shift=1) - H_trajectory)
signs

plt.plot(signs)

# sign_changes = jnp.diff(signs, n=1, axis=0)
# sign_changes

In [None]:
T_training = 1e-3

In [None]:
plt.plot(H_trajectory)
plt.show()

B_est = estimate_B(
    H_trajectory,
    model,
    alpha_beta_grid,
    T=T_training,
    initial_field=jnp.array([100.]),
    initial_operator_values=jnp.ones((alpha_beta_grid.shape[0], 1)),
)
plt.plot(B_est)
plt.plot(B_trajectory)

In [None]:
@eqx.filter_jit
@eqx.filter_value_and_grad
def compute_loss_and_grad(model, H_trajectory, B_trajectory, alpha_beta_grid):
    B_est = estimate_B(H_trajectory, model, alpha_beta_grid, T=T_training)

    return jnp.mean((B_est - B_trajectory)**2)

In [None]:
alpha_beta_grid

In [None]:
model = ArrayPreisach(
    hysteron_density=preisach
    # hysteron_density=jnp.ones((alpha_beta_grid.shape[0], 1)) / alpha_beta_grid.shape[0]
)
# alpha_beta_grid = build_alpha_beta_grid(200)
# model = DifferentiablePreisach(
#     width_size=128,
#     depth=3,
#     model_key=jax.random.PRNGKey(15)
# )

# jax.vmap(model.hysteron_density)(alpha_beta_grid).reshape(20,20)

# plt.show()

B_est = estimate_B(H_trajectory, model, alpha_beta_grid, T=T_training)

plt.plot(B_trajectory)
plt.plot(B_est)
plt.show()

plt.plot(H_trajectory, B_trajectory)
plt.show()


plt.plot(H_trajectory, B_est)
plt.show()

In [None]:
optim = optax.adam(learning_rate=1e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

for n in tqdm.tqdm(range(100_000)):

    loss, grads = compute_loss_and_grad(model, H_trajectory, B_trajectory, alpha_beta_grid)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    
    if n % 100 == 0:
        B_est = estimate_B(H_trajectory, model, alpha_beta_grid, T=T_training)
        plt.plot(B_trajectory, label="True traj")
        plt.plot(B_est, label="Est traj")
        plt.legend()
        plt.show()
        
        plt.plot(H_trajectory, B_trajectory)
        plt.show()
        
        
        plt.plot(H_trajectory, B_est)
        plt.show()

In [None]:
B_est = estimate_B(H_trajectory, model, alpha_beta_grid, T=T_training)

plt.plot(B_trajectory, label="True traj")
plt.plot(B_est, label="Est traj")
plt.legend()
plt.show()

plt.plot(H_trajectory, B_trajectory)
plt.show()


plt.plot(H_trajectory, B_est)
plt.show()

plt.plot(H_trajectory, B_trajectory)
plt.plot(H_trajectory, B_est)
plt.show()