In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax

In [None]:
from rhmag.runners.rnn_training_jax import train_model_jax
from rhmag.utils.model_evaluation import get_exp_ids, reconstruct_model_from_file

In [None]:
from rhmag.data_management import DataSet, FINAL_MATERIALS
from rhmag.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set, evaluate_test_scenarios
)

---

In [None]:
data_set = DataSet.from_material_names(FINAL_MATERIALS)

In [None]:
data_set

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

In [None]:
exp_ids = {
    "A": 'A_GRU8_reduced-features-f32_2a1473b6_seed12',
    "B": 'B_GRU8_reduced-features-f32_c785b2c3_seed12',
    "C": 'C_GRU8_reduced-features-f32_348e220c_seed12',
    "D": 'D_GRU8_reduced-features-f32_b6ac55b5_seed12',
    "E": 'E_GRU8_reduced-features-f32_e88a2583_seed12',
}
models = {material_name: reconstruct_model_from_file(exp_id) for material_name, exp_id in exp_ids.items()}

In [None]:
metrics = evaluate_test_scenarios(models["B"], test_data["B"])
metrics

In [None]:
metrics = evaluate_test_scenarios(models["C"], test_data["B"])
metrics

In [None]:
all_metrics = {}
for material_name_model, model in models.items():
    for material_name_data, material_data in test_data.items():
        metrics = evaluate_test_scenarios(model, material_data)
        all_metrics[f"model_{material_name_model}_data_{material_name_data}"] = metrics

In [None]:
all_metrics

## Train C model

In [None]:
train_model_jax(
    material_names=["C"],
    model_types=["GRU8",],
    seeds=[12],
    exp_name="crossmaterial",
    loss_type="adapted_RMS",
    gpu_id=0,
    epochs=1000,
    batch_size=512,
    tbptt_size=156, 
    past_size=28,
    time_shift=0, 
    noise_on_data=0.0,
    tbptt_size_start=None,
    dyn_avg_kernel_size=11,
    disable_f64=True, 
    disable_features="reduce",
    transform_H=False,
    use_all_data=False,
)

In [None]:
# load model:
C_model = reconstruct_model_from_file('C_GRU8_crossmaterial_bed6f10f_seed12')
metrics = evaluate_test_scenarios(C_model, test_data["B"])
metrics

In [None]:
from rhmag.model_interfaces.rnn_interfaces import RNNwInterface

In [None]:
frankenstein_model = RNNwInterface(
    model=C_model.model,
    normalizer=models["B"].normalizer,
    featurize=C_model.featurize
)
metrics = evaluate_test_scenarios(frankenstein_model, test_data["B"])
metrics

## post training on B-data:

In [None]:
C_model

In [None]:
frankenstein_model

In [None]:
import optax
import equinox as eqx
from rhmag.losses import adapted_RMS_loss
from rhmag.model_setup import setup_dataset
from rhmag.runners.rnn_training_jax import train_model

In [None]:
seed = 44
key = jax.random.PRNGKey(seed)
key, training_key, model_key = jax.random.split(key, 3)

material = "B"

lr_params = dict(
    init_value=1e-3,
    transition_steps=1_000_000,
    transition_begin=2_000,
    decay_rate=0.1,
    end_value=1e-4,
)
lr_schedule = optax.schedules.exponential_decay(**lr_params)
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr_schedule)

loss_function = eqx.filter_value_and_grad(adapted_RMS_loss)
data_tuple = setup_dataset(
    material_name=material,
    subsampling_freq=1,
    use_all_data=True,
)

In [None]:
# run training
logs, model = train_model(
    model=frankenstein_model,
    loss_function=loss_function,
    optimizer=optimizer,
    material_name=material,
    data_tuple=data_tuple,
    key=training_key,
    seed=seed,
    n_steps=0,
    n_epochs=1000,  # for 100 it showed improvements
    val_every=1000,
    tbptt_size=128,
    past_size=28,
    batch_size=512,
    time_shift=0,
    noise_on_data=0.0,
)

In [None]:
metrics = evaluate_test_scenarios(model, test_data["B"])
metrics

In [None]:
metrics = evaluate_test_scenarios(models["B"], test_data["B"])
metrics

## drafting cross material model

In [None]:
cross_material_model = reconstruct_model_from_file('X_GRU8_pretraining_1a23d676_seed42')

In [None]:
cross_material_model

In [None]:
all_metrics = {}
for material_name_data, material_data in test_data.items():
    metrics = evaluate_test_scenarios(cross_material_model, material_data)
    all_metrics[f"model_X_data_{material_name_data}"] = metrics

In [None]:
all_metrics