In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
from mc2.data_management import load_data_into_pandas_df, MaterialSet

In [None]:
from mc2.training.data_sampling import draw_data_uniformly
from mc2.runners.model_setup_jax import setup_model
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

In [None]:
wrapped_model, optimizer, params, (train_set, eval_set, test_set) = setup_model(model_label="HNODE", material_name="3C90", model_key=jax.random.PRNGKey(0), n_epochs=300, tbptt_size=128, batch_size=512,)

In [None]:
from mc2.model_interfaces.model_interface import count_model_parameters

In [None]:
count_model_parameters(wrapped_model)

In [None]:
wrapped_model

In [None]:
from mc2.training.jax_routine import val_test
from mc2.runners.model_setup_jax import get_normalizer

from mc2.utils.model_evaluation import plot_first_predictions

In [None]:
_, (train_set, eval_set, test_set) = get_normalizer("3C90", wrapped_model.featurize, subsampling_freq=1, do_normalization=True)

In [None]:
test_loss, test_pred_l, test_gt_l = val_test(test_set, wrapped_model, past_size=1)

In [None]:
plot_first_predictions(test_gt_l[0], test_pred_l[0]);

In [None]:
def correlate_B_and_H(B, H):
    return jnp.correlate(
        B - jnp.mean(B),
        H - jnp.mean(H),
        mode="full",
    )

def best_correlation_point(B, H):
    seq_len = B.shape[0]
    correlation_values = correlate_B_and_H(B, H)

    x = jnp.arange(-seq_len+1, seq_len, 1)
    return x[jnp.argmax(correlation_values)]

In [None]:
best_matches = []

past_size = 999
print("past_size:", past_size)
for seed in tqdm(jnp.arange(1, 1000, 1)):
    loader_key = jax.random.PRNGKey(seed)
    H_list, B_list, T_list = [], [], []

    for freq_idx, frequency in enumerate(test_set.frequencies):
        test_set_at_frequency = test_set.at_frequency(frequency)
        H, B, T, _, loader_key = draw_data_uniformly(test_set_at_frequency, 2000, 1, loader_key)

        H_list.append(H[None, ...])
        B_list.append(B[None, ...])
        T_list.append(T[None, ...])

    H = jnp.concatenate(H_list, axis=0)
    B = jnp.concatenate(B_list, axis=0)
    T = jnp.concatenate(T_list, axis=0)

    best_matches_per_frequency = eqx.filter_vmap(best_correlation_point)(B, H)
    mean_best_match = jnp.mean(eqx.filter_vmap(best_correlation_point)(B, H))
    best_matches.append(mean_best_match)

In [None]:
plt.hist(best_matches)

In [None]:
from mc2.data_management import AVAILABLE_MATERIALS, load_data_into_pandas_df, MaterialSet

In [None]:
best_matches_per_material = dict()

for material_name in AVAILABLE_MATERIALS:
    mat_set = MaterialSet.from_pandas_dict(load_data_into_pandas_df(material_name))

    chunk_size = 400
    best_matches = {str(freq): [] for freq in mat_set.frequencies}
    for freq_set in mat_set:
        print(freq_set.frequency)
        for i in jnp.arange(0, freq_set.B.shape[0], chunk_size):
            best_match = eqx.filter_vmap(best_correlation_point)(freq_set.B[i: min(i+chunk_size, freq_set.B.shape[0]), :2000], freq_set.H[i: min(i+chunk_size, freq_set.B.shape[0]), :2000])
            best_matches[str(freq_set.frequency)].append(best_match)
    best_matches = {freq: jnp.concatenate(matches) for freq, matches in best_matches.items()}

    best_matches_per_material[material_name] = best_matches

    fig, axs = plt.subplots(1,7, figsize=(16, 4))

    for idx, (freq, matches) in enumerate(best_matches.items()):    
        axs[idx].hist(matches)
        axs[idx].grid(True, alpha=0.3)
        axs[idx].set_title(freq)
    fig.tight_layout()
    fig.suptitle(material_name)
    plt.plot()
    # plt.savefig(f"{material_name}_time_shift_over_frequency.png", dpi=300, bbox_inches="tight")

In [None]:
best_matches_per_material

In [None]:
data_dict = {material_name: {} for material_name in AVAILABLE_MATERIALS}

for material_name, mat_best_match in best_matches_per_material.items():
    for freq, freq_best_match in mat_best_match.items():
        data_dict[material_name][freq] = jnp.mean(freq_best_match).item()

In [None]:
data_df = pd.DataFrame(data_dict)
data_df.round(2)

In [None]:
plt.plot(data_df)
plt.grid(True, alpha=0.3)