# Training a neural posterior estimator of simulation parameters

__Author:__ Sebastian Wagner-Carena

__Goals: Understand how to use the tools built into manada __ 

__If you have not already done so, you will have to install `tensorflow` to run the Analysis module tools in this notebook.__

1. Understand how to train a model using an analysis configuration file.
2. Understand the components of the analysis configuration file.

### Table of Contents

1. [Training a Model](#train_model) 
2. [Building an Analysis Config](#analysis_config)

## Training a Model <a class="anchor" id="train_model"></a>

`paltas` provides a few tools to make it easy to train a neural posterior estimator on the datasets. At the highest level, this works the same way as generating a dataset. To train a model run:

```
python Analysis/train_model.py path/to/analysis/config/file.py
```

All the hard work is done by the configuration file you pass in and `train_model.py`. You can also pass a path to a tensorboard directory if you want the training and validation loss logged granularly:

```
python Analysis/train_model.py path/to/analysis/config/file.py --tensorboard_dir path/to/save/tensorboard/outputs
```

One final note before we go into the analysis config, the train_model.py will pull the data from a TFRecord file. If the file doesn't exist in the training directory, it will generate it. If multiple training directories are passed in, this can be a high one-time cost before training. To avoid this, you can specify that a TFRecord file be generated when each training set is generated:

```
python generate.py path/to/config/file path/to/output/folder --n 100 --tf_record
```

## Building an Analysis Config <a class="anchor" id="analysis_config"></a>

Here we'll reproduce the analysis config you can find in `Analysis/AnalysisConfig/train_config_examp.py`:

In [None]:
import os

# The batch size for each training step
batch_size = 256

# The number of epochs to train for
n_epochs = 200

# The size of the images in the training set
img_size = (170,170,1)

# A random seed to use
random_seed = 2

# The list of learning parameters to pull
learning_params = ['main_deflector_parameters_theta_E',
    'main_deflector_parameters_gamma1','main_deflector_parameters_gamma2',
    'main_deflector_parameters_gamma','main_deflector_parameters_e1',
    'main_deflector_parameters_e2','main_deflector_parameters_center_x',
    'main_deflector_parameters_center_y',
    'subhalo_parameters_sigma_sub']

# The path to the folder containing the npy images for training
npy_folders_train = ['list','of','folder','paths']
# The path to the tf_record for the training images
tfr_train_paths = [
    os.path.join(path,'data.tfrecord') for path in npy_folders_train]
# The path to the training metadata
metadata_paths_train = [
    os.path.join(path,'metadata.csv') for path in npy_folders_train]
    
# The path to the folder containing the npy images for validation
npy_folder_val = ('validation_folder_path')
# The path to the tf_record for the validation images
tfr_val_path = os.path.join(npy_folder_val,'data.tfrecord')
# The path to the validation metadata
metadata_path_val = os.path.join(npy_folder_val,'metadata.csv')

# The path to the csv file to read from / write to for normalization
# of learning parameters.
input_norm_path = npy_folders_train[0] + 'norms.csv'

# The detector kwargs to use for on-the-fly noise generation
kwargs_detector = None

# Whether or not to normalize the images by the standard deviation
norm_images = True

# A string with which loss function to use.
loss_function = 'full'

# A string specifying which model to use
model_type = 'xresnet34'

# A string specifying which optimizer to use
optimizer = 'Adam'
# The learning rate for the model
learning_rate = 5e-3

# Where to save the model weights
model_weights = ('path_to_model_weights_{epoch:02d}-{val_loss:.2f}.h5')
model_weights_init = ('path_to_initial_weights.h5')

# Whether or not to use random rotation of the input images
random_rotation = True

# Only train the head
train_only_head = False

A number of the parameters above are self-explanatory, but we'll dig a bit deeper into a few here:

1. `image_size`: The first two values are set by the dimension of the images you've generated, but the final axis must always have dimension 1. Note that the `.npy` files that are saved by paltas are only two dimensional for now, but the `train_model.py` pipeline will add the third dimension to be compatible with the model architecture.
2. `learning_params`: This is a list of the parameters in `metadata.csv` that you want the model to predict. The name of the parameter is defined by the object it is associated to and the name of that parameter for that object. For example, if we wanted to predict `center_x` for the source we would add `source_center_x` to the list.
3. `npy_folders_train`: This should just be a list of the paths to all the training folders from which training examples will be pulled. If you have a very large training set (>10k images) it's best to divide the images among multiple folders. If there is only one training folder than this can be a list of length 1. 
4. `kwargs_detector`: The training pipeline allows for noise generation on the fly for noiseless images. This functionality is a little limited, but it allows the user to specify the same `kwargs_detector` in the config file passed to `generate.py`. Note this cannot conduct drizzling or PSF convolution, that must be done at the time of dataset generation.
5. `loss_function`: Currently this accepts three options -- `mse` which will result in a network trained to return the maximum likelihood estimate, `diag` which will result in a network trained to predict a diagonal covariance matrix, and `full` which will result in a network trained to predict a full covariance matrix.
6. `model_type`: The model to use. Only option implemented by default is `xresnet34`, but you can implement your own model in `conv_models` and add the option to `train_model.py`.
7. `model_weights`: The `{epoch:02d}-{val_loss:.2f}` string in the name will be dynamically populated with the epoch and the validation loss. This will cause a new model to be saved at the end of each epoch. If a fixed string is given then the newest model will always overwrite the previous model.