# Model training: 

In [None]:
# optional setup
%load_ext autoreload
%autoreload 2

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax
#jax.config.update("jax_platform_name", "cpu")  # optionally run on cpu

In [None]:
from mc2.runners.rnn_training_jax import train_model_jax
from mc2.utils.model_evaluation import get_exp_ids, reconstruct_model_from_file

## Training function:

More indepth descriptions of each parameter are given in the docstring for the function. Either look for the implementation in the source code (at `mc2.runners.rnn_training_jax`) or press `shift+tab` while your caret is within the function name to show the docstring (at least in `jupyter-lab` and `jupyter-notebook`).

In [None]:
train_model_jax(
    material="A",  # perform training for material "A"
    model_type=["GRU4", "JA"],  # a training with all specified seeds is ran for all model types in the list
    seeds=[1, 2, 3],  # a training is ran for each seed in the list
    exp_name="demonstration",
    loss_type="adapted_RMS",  # default loss function
    gpu_id=0,  # gpu at index 0
    epochs=10,  # very short for demonstration
    batch_size=512,
    tbptt_size=156,  # length of a sequence through which the loss is backpropagated
    past_size=28,  # how much of the sequence should be used for warmup
    time_shift=0,  # adds a feature, where the B trajectory is shifted by {time_shift} steps
    noise_on_data=0.0,  # adds noise to the B data in training
    tbptt_size_start=None,  # can be used to specify a pretraining phase, e.g., with shorter sequences
    dyn_avg_kernel_size=11,  # size of the kernel for the dynamic averaging (unused for the "reduced" feature set)
    disable_f64=True,  # if 'True' training is performed on f32 accuracy
    disable_features="reduce",  # One of (True, False, "reduce"), True uses no features, False uses all default features, "reduce" uses the dB/dt and d^2 B / dt^2 as features
    transform_H=False,  # flag whether a tanh transform of H should be utilized
    use_all_data=False,  # flag, whether all data should be used for training or if instead a train, eval, test split should be performed
)

this can also be run from the command line:
- `python mc2/runners/rnn_training_jax.py -h` shows all options
- exemplary call: `python mc2/runners/rnn_training_jax.py --material "A" --model_type "GRU4" "JA" --seeds 1 2 3 --loss_type "adapted_RMS"  --exp_name "demonstration" --gpu_id 0 -e 10 -b 512 -t 156 -p 28 --disable_f64`
- Due to the amount of options it is generally easier to handle if you create a script where all parameters are already set (i.e., copy the cell above into a python script)

In [None]:
exp_ids = get_exp_ids(exp_name="demonstration")
exp_ids

In [None]:
model = reconstruct_model_from_file(exp_ids[1])
model