Skip to content

Commit

Permalink
[neural-networks] extended documentation of training functions
Browse files Browse the repository at this point in the history
  • Loading branch information
HenKlei committed Apr 9, 2021
1 parent 5a3fb25 commit d653805
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions src/pymor/reductors/neural_network.py
Expand Up @@ -436,22 +436,46 @@ def train_neural_network(training_data, validation_data, neural_network,
Parameters
----------
training_data
Data to use during the training phase.
Data to use during the training phase. Has to be a list of tuples,
where each tuple consists of two PyTorch-tensors (`torch.DoubleTensor`).
The first tensor contains the input data, the second tensor contains
the target values.
validation_data
Data to use during the validation phase.
Data to use during the validation phase. Has to be a list of tuples,
where each tuple consists of two PyTorch-tensors (`torch.DoubleTensor`).
The first tensor contains the input data, the second tensor contains
the target values.
neural_network
The neural network to train (can also be a pre-trained model).
Has to be a PyTorch-Module.
training_parameters
Dictionary with additional parameters for the training routine like
the type of the optimizer, the batch size, the learning rate or the
loss function to use.
the type of the optimizer, the (maximum) number of epochs, the batch
size, the learning rate or the loss function to use.
Possible keys are `'optimizer'` (an optimizer from the PyTorch `optim`
package; if not provided, the LBFGS-optimizer is taken as default),
`'epochs'` (an integer that determines the number of epochs to use
for training the neural network (if training is not interrupted
prematurely due to early stopping); if not provided, 1000 is taken as
default value), `'batch_size'` (an integer that determines the number
of samples to pass to the optimizer at once; if not provided, 20 is
taken as default value; not used in the case of the LBFGS-optimizer
since LBFGS does not support mini-batching), `'learning_rate'` (a
positive real number used as the (initial) step size of the optimizer;
if not provided, 1 is taken as default value; thus far, no learning
rate schedulers are supported in this implementation), and
`'loss_function'` (a loss function from PyTorch; if not provided, the
MSE loss is taken as default).
Returns
-------
best_neural_network
The best trained neural network with respect to validation loss.
losses
The corresponding losses.
The corresponding losses as a dictionary with keys `'full'` (for the
full loss containing the training and the validation average loss),
`'train'` (for the average loss on the training set), and `'val'`
(for the average loss on the validation set).
"""
assert isinstance(neural_network, nn.Module)
for data in training_data, validation_data:
Expand Down Expand Up @@ -553,6 +577,8 @@ def multiple_restarts_training(training_data, validation_data, neural_network,
the best trained network or tries to reach a given target loss and
stops training when the target loss is reached.
See :func:`train_neural_network` for more information on the parameters.
Parameters
----------
training_data
Expand Down

0 comments on commit d653805

Please sign in to comment.