In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax
jax.config.update("jax_platform_name", "cpu")

In [None]:
from rhmag.runners.rnn_training_jax import train_model_jax
from rhmag.utils.model_evaluation import get_exp_ids, reconstruct_model_from_file

In [None]:
import optax
import equinox as eqx
from rhmag.losses import adapted_RMS_loss
from rhmag.model_setup import setup_dataset
from rhmag.runners.rnn_training_jax import train_model

from rhmag.data_management import DataSet, FINAL_MATERIALS
from rhmag.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set, evaluate_test_scenarios
)

In [None]:
exp_ids = {
    "A": 'A_GRU8_reduced-features-f32_2a1473b6_seed12',
    "B": 'B_GRU8_reduced-features-f32_c785b2c3_seed12',
    "C": 'C_GRU8_reduced-features-f32_348e220c_seed12',
    "D": 'D_GRU8_reduced-features-f32_b6ac55b5_seed12',
    "E": 'E_GRU8_reduced-features-f32_e88a2583_seed12',
}
models = {material_name: reconstruct_model_from_file(exp_id) for material_name, exp_id in exp_ids.items()}

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

---

- automatically draw the pretrained model in `setup_model`

In [None]:
get_exp_ids(
    material_name="X",
    model_type="GRU",
    exp_name="pretraining",
)

In [None]:
cross_material_model = reconstruct_model_from_file('X_GRU8_pretraining_11ff4081_seed42')

In [None]:
from rhmag.model_interfaces.rnn_interfaces import RNNwInterface

In [None]:
frankenstein_model = RNNwInterface(
    model=cross_material_model.model,
    normalizer=models["B"].normalizer,
    featurize=cross_material_model.featurize
)
metrics = evaluate_test_scenarios(frankenstein_model, test_data["B"])
print(metrics)

metrics = evaluate_test_scenarios(cross_material_model, test_data["B"])
print(metrics)

metrics = evaluate_test_scenarios(models["B"], test_data["B"])
print(metrics)

In [None]:
seed = 44
key = jax.random.PRNGKey(seed)
key, training_key, model_key = jax.random.split(key, 3)

material = "B"

lr_params = dict(
    init_value=1e-3,
    transition_steps=1_000_000,
    transition_begin=2_000,
    decay_rate=0.1,
    end_value=1e-4,
)
lr_schedule = optax.schedules.exponential_decay(**lr_params)
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr_schedule)

loss_function = eqx.filter_value_and_grad(adapted_RMS_loss)
data_tuple = setup_dataset(
    material_name=material,
    subsampling_freq=1,
    use_all_data=True,
)

In [None]:
pretrained_id = get_exp_ids(
    material_name="X",
    model_type="GRU8",
    exp_name="pretraining",
)[0]

In [None]:
pretrained_model = reconstruct_model_from_file(pretrained_id)

In [None]:
# run training
logs, model = train_model(
    model=frankenstein_model,
    loss_function=loss_function,
    optimizer=optimizer,
    material_name=material,
    data_tuple=data_tuple,
    key=training_key,
    seed=seed,
    n_steps=0,
    n_epochs=20_000,  # for 100 it showed improvements
    val_every=1000,
    tbptt_size=128,
    past_size=28,
    batch_size=512,
    time_shift=0,
    noise_on_data=0.0,
)

In [None]:
model = reconstruct_model_from_file('B_GRU8_posttraining-f32_6c918b11_seed42')


metrics = evaluate_test_scenarios(model, test_data[material])
metrics

In [None]:
metrics = evaluate_test_scenarios(models[material], test_data[material])
metrics