In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
from mc2.training.data_sampling import draw_data_uniformly
from mc2.runners.model_setup_jax import setup_model
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

In [None]:
wrapped_model, optimizer, params, (train_set, eval_set, test_set) = setup_model(model_label="GRU", material_name="3C90", model_key=jax.random.PRNGKey(0), n_epochs=300, tbptt_size=128, batch_size=512,)

In [None]:
def correlate_B_and_H(B, H):
    return jnp.correlate(
        B - jnp.mean(B),
        H - jnp.mean(H),
        mode="full",
    )

def best_correlation_point(B, H):
    seq_len = B.shape[0]
    correlation_values = correlate_B_and_H(B, H)

    x = jnp.arange(-seq_len+1, seq_len, 1)
    return x[jnp.argmax(correlation_values)]

In [None]:
best_matches = []

past_size = 999
print("past_size:", past_size)
for seed in tqdm(jnp.arange(1, 1000, 1)):
    loader_key = jax.random.PRNGKey(seed)
    H_list, B_list, T_list = [], [], []

    for freq_idx, frequency in enumerate(test_set.frequencies):
        test_set_at_frequency = test_set.at_frequency(frequency)
        H, B, T, _, loader_key = draw_data_uniformly(test_set_at_frequency, 2000, 1, loader_key)

        H_list.append(H[None, ...])
        B_list.append(B[None, ...])
        T_list.append(T[None, ...])

    H = jnp.concatenate(H_list, axis=0)
    B = jnp.concatenate(B_list, axis=0)
    T = jnp.concatenate(T_list, axis=0)

    best_matches_per_frequency = eqx.filter_vmap(best_correlation_point)(B, H)
    mean_best_match = jnp.mean(eqx.filter_vmap(best_correlation_point)(B, H))
    best_matches.append(mean_best_match)

In [None]:
plt.hist(best_matches)

In [None]:
best_matches = {str(freq): [] for freq in train_set.frequencies}

for freq_set in train_set:
    print(freq_set.frequency)
    mean_best_match = jnp.mean(eqx.filter_vmap(best_correlation_point)(freq_set.B[:100], freq_set.H[:100]))
    best_matches[str(freq_set.frequency)].append(mean_best_match)

In [None]:
best_matches