# Training a reflectorch model

First, we import the necessary methods from the reflectorch package, as well as othar basic Python packages:

In [1]:
from reflectorch import get_trainer_by_name, SAVED_MODELS_DIR, StepLR, SaveBestModel
from reflectorch.extensions.jupyter import JPlotLoss

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

:::{tip}
:class: dropdown

Alternatively, we can import everything from reflectorch with
from reflectorch import *
:::

## The training loop

The *Trainer* class is the central object for the training process, which encapsulates the simulation of the training data as well as the neural network architecture. We can initialize the trainer according to the specifications defined in a YAML configuration file using the *get_trainer_by_name* method which takes as input the name of the configuration file. The `load_weights` parameter should be set to `False` since we want the neural network weights to be randomly initialized for a fresh training.

In [2]:
config_name = 'time_val_sim_L5_q256_d300_r60_s25_bs4_budist_noise-poisson02_LONGER'
trainer = get_trainer_by_name(config_name, load_weights=False)

Model time_val_sim_L5_q256_d300_r60_s25_bs4_budist_noise-poisson02_LONGER loaded. Number of parameters: 13.85 M


The trainer contains several important attributes which we can access:

  1. The Pytorch optimizer. We can observe that the optimizer specified in the configuration is AdamW:

In [3]:
trainer.optim

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.01
)

:::{note}
The learning rate can be easily changed using trainer.set_lr(new_lr)
:::

  2. The batch size

In [4]:
trainer.batch_size

4096

  3. The neural network architecture. We see that the model belongs to the class *SubPriorConvFCEncoder_V2*, which contains a 1D CNN embedding network and a multilayer perceptron (MLP) with residual connection, batch normalization layers and GELU activations.

In [5]:
trainer.model

SubPriorConvFCEncoder_V2(
  (conv): ConvEncoder(
    (core): Sequential(
      (0): Sequential(
        (0): Conv1d(1, 32, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate='none')
      )
      (1): Sequential(
        (0): Conv1d(32, 64, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate='none')
      )
      (2): Sequential(
        (0): Conv1d(64, 128, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate='none')
      )
      (3): Sequential(
        (0): Conv1d(128, 256, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate

We can inquire the the location of the directory where the models are saved using the *SAVED_MODELS_DIR* global variable:

In [47]:
SAVED_MODELS_DIR

PosixPath('/home/vmunteanu/Jupyter Notebooks/Reflectivity/reflectorch/saved_models')

Next, we should create a name used for saving the weights of our model during training, which together with the directory defines the *save_path*.

In [51]:
save_model_name = 'model_' + config_name + '.pt'
save_path = str(SAVED_MODELS_DIR / save_model_name)

We can add functionality to the training process using several callbacks, which are grouped together in a Python touple.

In [55]:
callbacks = (
    JPlotLoss(frequency=10, ), 
    StepLR(step_size=5000, gamma=0.1, last_epoch=-1), 
    SaveBestModel(path=save_path, freq=100, )
)

The callbacks defined above provide the following functionality:

  1. JPlotLoss - allows the interactive visualization of the loss curve when training inside a Jupyter Notebook, the *frequency* parameter being the refresh rate of the interactive widget

  2. StepLR - implements a learning rate scheduler which decreases the learning rate in steps. After a number of iterations defined by *step_size* the learning rate is multiplied with the factor *gamma*. Other types of learning rate schedulers can alternatively be used. 

  3. SaveBestModel - it enables the periodic saving of the weights of the neural network during training. After a number of iterations defined by the *freq* parameter, the weights of the neural network are saved at the specified *path* if the current value of the loss is lower than the loss for the previous save. 

The training process is started by calling the *train_epoch* method of the trainer. This method takes as parameters the previously define tuple of callbacks as well as the number of batches / iterations. Notably a new batch of data is generated at each iteration, the training taking place in a "one-epoch regime".

    trainer.train_epoch(num_batches=15000, callbacks=callbacks)

![](training_curve_reflectorch_example.png)

## Customizing the YAML configuration for training

In the following we show how the YAML configuration can be customized.

```{dropdown}

general:
  name: val_sim_L2_q256_d300_r60_s25_bs4_budist_noise-poisson02
  root_dir: null
  
dset:
  prior_sampler:
    cls: SubpriorParametricSampler
    kwargs:
      param_ranges:
        thicknesses: [1., 300.]
        roughnesses: [0., 60.]
        slds: [0., 25.]
      bound_width_ranges:
        thicknesses: [1.0e-2, 300.]
        roughnesses: [1.0e-2, 60.]
        slds: [ 1.0e-2, 4.]
      model_name: standard_model
      max_num_layers: 2
      constrained_roughness: true
      max_thickness_share: 0.5
      logdist: false
      
  q_generator:
    cls: ConstantQ
    kwargs:
      q: [0.02, 0.3, 256]
      remove_zero: false
      fixed_zero: true
      

  intensity_noise: 
    cls: BasicExpIntensityNoise
    kwargs:
      relative_errors: [0.0, 0.2]
      abs_errors: 0.0
      consistent_rel_err: false
      logdist: false
      apply_shift: false
      shift_range: [-0.001, 0.001]
      apply_scaling: false
      scale_range: [-0.001, 0.001]

  q_noise:
    cls: BasicQNoiseGenerator
    kwargs:
      shift_std: 1.0e-7
      noise_std: [0., 1.0e-6]
      
  curves_scaler:
    cls: LogAffineCurvesScaler
    kwargs:
      weight: 0.2 #0.2
      bias: 1.0 #1.0
      eps: 1.0e-10

model:
  encoder:
    cls: SubPriorConvFCEncoder_V2
    pretrained_name: null
    kwargs:
       hidden_dims: [32, 64, 128, 256, 512]
       latent_dim: 8
       conv_latent_dim: 128
       avpool: 8
       use_batch_norm: true
       in_features: 256
       prior_in_features: 16
       hidden_features: 1024
       num_blocks: 6  #3
       fc_activation: 'gelu'
       conv_activation: 'gelu' #'lrelu'
       pass_bounds: false
       pretrained_conv: null
training:
  train_with_q_input: False
  num_iterations: 2000
  batch_size: 4096
  lr: 1.0e-4
  update_tqdm_freq: 1
  grad_accumulation_steps: 1
  optimizer: AdamW

  callbacks:
    save_best_model:
      enable: true
      freq: 500
```

The `general` key, contains the following lower level keys:

- `name` - which should be set to the name of the particular model / configuration file
- `root` - for providing the path to the root directory, default is the current directory (`null`)

```yaml
general:
  name: val_sim_L2_q256_d300_r60_s25_bs4_budist_noise-poisson02
  root_dir: null
```

```yaml

model:
  encoder:
    cls: SubPriorConvFCEncoder_V2
    pretrained_name: null
    kwargs:
       hidden_dims: [32, 64, 128, 256, 512]
       latent_dim: 8
       conv_latent_dim: 128
       avpool: 8
       use_batch_norm: true
       in_features: 256
       prior_in_features: 16
       hidden_features: 1024
       num_blocks: 6  #3
       fc_activation: 'gelu'
       conv_activation: 'gelu'
       pass_bounds: false
       pretrained_conv: null
    
```

The `training` key can be used to customize the training settings:

- `num_iterations` - the total number of iterations the network is trained for
- `batch_size` - the batch size (number of curves generated at each iteration)
- `optimizer` - the used [Pytorch optimizer](https://pytorch.org/docs/stable/optim). Default is `AdamW`
- `lr` - the initial learning rate
- `grad_accumulation_steps` - if larger than 1, training is performed with gradient accumulation with the chosen number of steps

- `update_tqdm_freq` - the frequency for updating the [tqdm progress bar](https://tqdm.github.io/)
- `train_with_q_input` - must be set to `True` if the q-values are used as input (i.e. when the )
- `callbacks` - (optional) the callback classes together with their arguments. Can also be defined directly as in the previous subsection.

```yaml

training:
  train_with_q_input: False
  num_iterations: 2000
  batch_size: 4096
  lr: 1.0e-4
  update_tqdm_freq: 1
  grad_accumulation_steps: 1
  optimizer: AdamW
  callbacks:
    save_best_model:
      enable: true
      freq: 500
        
```