# Model training: 

In [1]:
# optional setup
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax
#jax.config.update("jax_platform_name", "cpu")  # optionally run on cpu

In [2]:
from rhmag.runners.rnn_training_jax import train_model_jax
from rhmag.utils.model_evaluation import get_exp_ids, reconstruct_model_from_file

INFO:2026-01-27 18:06:15,927:jax._src.xla_bridge:834: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2026-01-27 18:06:15 | INFO : Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


## Training function:

More indepth descriptions of each parameter are given in the docstring for the function. Either look for the implementation in the source code (at `rhmag.runners.rnn_training_jax`) or press `shift+tab` while your caret is within the function name to show the docstring (at least in `jupyter-lab` and `jupyter-notebook`).

In [3]:
train_model_jax(
    material_name="A",  # perform training for material "A"
    model_types=["GRU4", "JA"],  # a training with all specified seeds is ran for all model types in the list
    seeds=[1, 2],  # a training is ran for each seed in the list
    exp_name="demonstration",
    loss_type="adapted_RMS",  # default loss function
    gpu_id=0,  # gpu at index 0
    epochs=10,  # very short for demonstration
    batch_size=512,
    tbptt_size=156,  # length of a sequence through which the loss is backpropagated
    past_size=28,  # how much of the sequence should be used for warmup
    time_shift=0,  # adds a feature, where the B trajectory is shifted by {time_shift} steps
    noise_on_data=0.0,  # adds noise to the B data in training
    tbptt_size_start=None,  # can be used to specify a pretraining phase, e.g., with shorter sequences
    dyn_avg_kernel_size=11,  # size of the kernel for the dynamic averaging (unused for the "reduced" feature set)
    disable_f64=True,  # if 'True' training is performed on f32 accuracy
    disable_features="reduce",  # One of (True, False, "reduce"), True uses no features, False uses all default features, "reduce" uses the dB/dt and d^2 B / dt^2 as features
    transform_H=False,  # flag whether a tanh transform of H should be utilized
    use_all_data=False,  # flag, whether all data should be used for training or if instead a train, eval, test split should be performed
)

2026-01-27 18:06:16 | INFO : Starting experiments for 2 model type(s) and 2 seeds: ['GRU4', 'JA'], [1, 2]
2026-01-27 18:06:16 | INFO : --- Starting experiments for Model Type: GRU4 ---


Loading data for A: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:04<00:00,  5.07it/s]


2026-01-27 18:06:28 | INFO : Training starting. Experiment ID is 'A_GRU4_demonstration_a7989e94_seed1'.
2026-01-27 18:06:28 | INFO : train size: 63, val size: 21, test size: 21
2026-01-27 18:06:32 | INFO : Test loss seed 1: 7.699741 A/m



Seed 1:   0%|                                                                                                                                                                                                                                                                         | 0/10 [00:00<?, ?epoch/s][A
Seed 1:   0%|                                                                                                                                                                                                                                       | 0/10 [00:09<?, ?epoch/s, Loss 3.54e-01| val loss 1.11e+01][A
Seed 1:  10%|██████████████████████▎                                                                                                                                                                                                        | 1/10 [00:09<01:25,  9.51s/epoch, Loss 3.54e-01| val loss 1.11e+01][A
Seed 1:  10%|██████████████████████▎                                       

2026-01-27 18:06:47 | INFO : Test loss seed 1: 3.197392 A/m
2026-01-27 18:06:47 | INFO : Training done. Proceeding with evaluation..
2026-01-27 18:07:55 | INFO : Evaluation done. Proceeding with storing experiment data..
RNNwInterface(
  model=GRU(
    hidden_size=4,
    cell=GRUCell(
      weight_ih=f32[12,4],
      weight_hh=f32[12,4],
      bias=f32[12],
      bias_n=f32[4],
      input_size=4,
      hidden_size=4,
      use_bias=True
    )
  ),
  normalizer=Normalizer(
    B_max=0.468600869178772,
    H_max=261.848388671875,
    T_max=70,
    norm_fe_max=[0.00948895514011383, 0.005385361611843109],
    H_transform=<function FrequencySet.normalize.<locals>.<lambda>>,
    H_inverse_transform=<function FrequencySet.normalize.<locals>.<lambda>>
  ),
  featurize=partial(<function setup_featurize.<locals>.featurize>, time_shift=0)
)
2026-01-27 18:08:04 | INFO : Experiment with id 'A_GRU4_demonstration_a7989e94_seed1' finished successfully. Parameters, logs, evaluation metrics, and the mo

Loading data for A: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:04<00:00,  5.17it/s]


2026-01-27 18:08:16 | INFO : Training starting. Experiment ID is 'A_GRU4_demonstration_a7989e94_seed2'.
2026-01-27 18:08:16 | INFO : train size: 63, val size: 21, test size: 21
2026-01-27 18:08:18 | INFO : Test loss seed 2: 3.331386 A/m




Seed 2:   0%|                                                                                                                                                                                                                                                                         | 0/10 [00:00<?, ?epoch/s][A[A

Seed 2:   0%|                                                                                                                                                                                                                                       | 0/10 [00:07<?, ?epoch/s, Loss 1.17e-01| val loss 3.70e+00][A[A

Seed 2:  10%|██████████████████████▎                                                                                                                                                                                                        | 1/10 [00:07<01:04,  7.14s/epoch, Loss 1.17e-01| val loss 3.70e+00][A[A

Seed 2:  10%|██████████████████████▎                          

2026-01-27 18:08:31 | INFO : Test loss seed 2: 2.068132 A/m
2026-01-27 18:08:31 | INFO : Training done. Proceeding with evaluation..
2026-01-27 18:09:27 | INFO : Evaluation done. Proceeding with storing experiment data..
RNNwInterface(
  model=GRU(
    hidden_size=4,
    cell=GRUCell(
      weight_ih=f32[12,4],
      weight_hh=f32[12,4],
      bias=f32[12],
      bias_n=f32[4],
      input_size=4,
      hidden_size=4,
      use_bias=True
    )
  ),
  normalizer=Normalizer(
    B_max=0.468600869178772,
    H_max=261.848388671875,
    T_max=70,
    norm_fe_max=[0.00948895514011383, 0.005385361611843109],
    H_transform=<function FrequencySet.normalize.<locals>.<lambda>>,
    H_inverse_transform=<function FrequencySet.normalize.<locals>.<lambda>>
  ),
  featurize=partial(<function setup_featurize.<locals>.featurize>, time_shift=0)
)
2026-01-27 18:09:37 | INFO : Experiment with id 'A_GRU4_demonstration_a7989e94_seed2' finished successfully. Parameters, logs, evaluation metrics, and the mo

Loading data for A: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:04<00:00,  4.93it/s]


2026-01-27 18:09:48 | INFO : Training starting. Experiment ID is 'A_JA_demonstration_49516c69_seed1'.
2026-01-27 18:09:48 | INFO : train size: 63, val size: 21, test size: 21
2026-01-27 18:09:50 | INFO : Test loss seed 1: 0.447387 A/m



Seed 1:   0%|                                                                                                                                                                                                                                                                         | 0/10 [00:00<?, ?epoch/s][A
Seed 1:   0%|                                                                                                                                                                                                                                       | 0/10 [00:07<?, ?epoch/s, Loss 1.08e-02| val loss 5.64e-01][A
Seed 1:  10%|██████████████████████▎                                                                                                                                                                                                        | 1/10 [00:07<01:04,  7.21s/epoch, Loss 1.08e-02| val loss 5.64e-01][A
Seed 1:  10%|██████████████████████▎                                       

2026-01-27 18:10:03 | INFO : Test loss seed 1: 0.454425 A/m
2026-01-27 18:10:03 | INFO : Training done. Proceeding with evaluation..
2026-01-27 18:10:55 | INFO : Evaluation done. Proceeding with storing experiment data..
JAwInterface(
  model=JAStatic(
    Ms_param=f32[],
    a_param=f32[],
    alpha_param=f32[],
    k_param=f32[],
    c_param=f32[]
  ),
  normalizer=Normalizer(
    B_max=0.468600869178772,
    H_max=261.848388671875,
    T_max=70,
    norm_fe_max=[0.00948895514011383, 0.005385361611843109],
    H_transform=<function FrequencySet.normalize.<locals>.<lambda>>,
    H_inverse_transform=<function FrequencySet.normalize.<locals>.<lambda>>
  ),
  featurize=partial(<function setup_featurize.<locals>.featurize>, time_shift=0)
)
2026-01-27 18:11:05 | INFO : Experiment with id 'A_JA_demonstration_49516c69_seed1' finished successfully. Parameters, logs, evaluation metrics, and the model have been stored successfully.


Loading data for A: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:04<00:00,  4.53it/s]


2026-01-27 18:11:17 | INFO : Training starting. Experiment ID is 'A_JA_demonstration_49516c69_seed2'.
2026-01-27 18:11:17 | INFO : train size: 63, val size: 21, test size: 21
2026-01-27 18:11:19 | INFO : Test loss seed 2: 0.448741 A/m




Seed 2:   0%|                                                                                                                                                                                                                                                                         | 0/10 [00:00<?, ?epoch/s][A[A

Seed 2:   0%|                                                                                                                                                                                                                                       | 0/10 [00:07<?, ?epoch/s, Loss 1.08e-02| val loss 5.71e-01][A[A

Seed 2:  10%|██████████████████████▎                                                                                                                                                                                                        | 1/10 [00:07<01:06,  7.41s/epoch, Loss 1.08e-02| val loss 5.71e-01][A[A

Seed 2:  10%|██████████████████████▎                          

2026-01-27 18:11:32 | INFO : Test loss seed 2: 0.437880 A/m
2026-01-27 18:11:32 | INFO : Training done. Proceeding with evaluation..
2026-01-27 18:12:20 | INFO : Evaluation done. Proceeding with storing experiment data..
JAwInterface(
  model=JAStatic(
    Ms_param=f32[],
    a_param=f32[],
    alpha_param=f32[],
    k_param=f32[],
    c_param=f32[]
  ),
  normalizer=Normalizer(
    B_max=0.468600869178772,
    H_max=261.848388671875,
    T_max=70,
    norm_fe_max=[0.00948895514011383, 0.005385361611843109],
    H_transform=<function FrequencySet.normalize.<locals>.<lambda>>,
    H_inverse_transform=<function FrequencySet.normalize.<locals>.<lambda>>
  ),
  featurize=partial(<function setup_featurize.<locals>.featurize>, time_shift=0)
)
2026-01-27 18:12:29 | INFO : Experiment with id 'A_JA_demonstration_49516c69_seed2' finished successfully. Parameters, logs, evaluation metrics, and the model have been stored successfully.
2026-01-27 18:12:30 | INFO : All scheduled experiments complete

this can also be run from the command line:
- `python rhmag/runners/rnn_training_jax.py -h` shows all options
- exemplary call: `python rhmag/runners/rnn_training_jax.py --material "A" --model_types "GRU4" "JA" --seeds 1 2 3 --loss_type "adapted_RMS"  --exp_name "demonstration" --gpu_id 0 -e 10 -b 512 -t 156 -p 28 --disable_f64`
- Due to the amount of options it is generally easier to handle if you create a script where all parameters are already set (i.e., copy the cell above into a python script)

In [4]:
exp_ids = get_exp_ids(exp_name="demonstration")
exp_ids

['A_JA_demonstration_c48c1f4f_seed2',
 'A_JA_demonstration_c48c1f4f_seed1',
 'A_JA_demonstration_49516c69_seed2',
 'A_GRU4_demonstration_a7989e94_seed2',
 'A_JA_demonstration_c48c1f4f_seed3',
 'A_JA_demonstration_d2f87cb8_seed1',
 'A_JA_demonstration_49516c69_seed1',
 'A_GRU4_demonstration_a7989e94_seed1',
 'A_GRU4_demonstration_a29a66a2_seed1',
 'A_GRU4_demonstration_dad948ee_seed1',
 'A_GRU4_demonstration_a29a66a2_seed2',
 'A_GRU4_demonstration_dad948ee_seed2',
 'A_GRU4_demonstration_dad948ee_seed3']

In [5]:
model = reconstruct_model_from_file(exp_ids[1])
model

Found model file at '/home/hvater/projects/RHINO-MAG/data/models/A_JA_demonstration_c48c1f4f_seed1.eqx'. Loading model..


JAwInterface(
  model=JAStatic(
    Ms_param=f32[],
    a_param=f32[],
    alpha_param=f32[],
    k_param=f32[],
    c_param=f32[]
  ),
  normalizer=Normalizer(
    B_max=0.468600869178772,
    H_max=261.848388671875,
    T_max=70,
    norm_fe_max=[0.00948895514011383, 0.005385361611843109],
    H_transform=<function Normalizer.from_dict.<locals>.<lambda>>,
    H_inverse_transform=<function Normalizer.from_dict.<locals>.<lambda>>
  ),
  featurize=partial(<function setup_featurize.<locals>.featurize>, time_shift=0)
)