Skip to content

Latest commit

 

History

History
188 lines (131 loc) · 6.14 KB

training.rst

File metadata and controls

188 lines (131 loc) · 6.14 KB

Training

Apart from active learning, small-text includes several helpers for classification.

Overview


Early Stopping

Early stopping is a mechanism which tries to avoid overfitting when training a model. For this purpose, an early stopping mechanism monitors certain metrics during the training process ---usually after each epoch---in order to check if early stopping should be triggered. If the early stopping handler deems an early stop to be necessary according to the given contraints then it returns True when check_early_stop() is called. This response has to be subsequently handled in the respective classifier.

Interface

../../small_text/training/early_stopping.py

Example Usage

  1. Monitoring validation loss (lower is better):

from small_text.training.early_stopping import EarlyStopping from small_text.training.metrics import Metric

early_stopping = EarlyStopping(Metric('val_loss'), patience=2)

print(early_stopping.check_early_stop(1, {'val_loss': 0.060})) print(early_stopping.check_early_stop(2, {'val_loss': 0.061})) # no improvement, don't stop print(early_stopping.check_early_stop(3, {'val_loss': 0.060})) # no improvement, don't stop print(early_stopping.check_early_stop(3, {'val_loss': 0.060})) # no improvement, stop

Output:

False False False True

  1. Monitoring training accuracy (higher is better) with `patience=1`:

from small_text.training.early_stopping import EarlyStopping from small_text.training.metrics import Metric

early_stopping = EarlyStopping(Metric('val_acc', lower_is_better=False), patience=1)

print(early_stopping.check_early_stop(1, {'val_acc': 0.80})) print(early_stopping.check_early_stop(3, {'val_acc': 0.79})) # no improvement, don't stop print(early_stopping.check_early_stop(2, {'val_acc': 0.81})) # improvement print(early_stopping.check_early_stop(3, {'val_acc': 0.81})) # no improvement, don't stop print(early_stopping.check_early_stop(3, {'val_acc': 0.80})) # no improvement, stop

Output:

False False False False True

Combining Early Stopping Conditions

What if we want to early stop based on either one of two conditions? For example, if validation loss does not change during the last 3 checks or training accuracy crosses 0.99? This can be easily done by using :pyEarlyStoppingOrCondition which sequentially applies a list of early stopping handlers.

from small_text.training.early_stopping import EarlyStopping, EarlyStoppingOrCondition from small_text.training.metrics import Metric

early_stopping = EarlyStoppingOrCondition([

EarlyStopping(Metric('val_loss'), patience=3), EarlyStopping(Metric('train_acc', lower_is_better=False), threshold=0.99)

])

:pyEarlyStoppingOrCondition returns True, i.e. triggers an early stop, iff at least one of the early stopping handlers within the given list returns True. Similarly, we have :pyEarlyStoppingAndCondition which stops only when all of the early stopping handlers return True.

Implementations

EarlyStopping

Note

Currently, supported metrics are validation accuracy (val_acc), validation loss (val_loss), training accuracy (train_acc), and training loss (train_loss). For the accuracy metric, a higher value is better, i.e. patience triggers only when the respective metric has not exceeded the previous best value, and for loss metrics when the respective metric has not fallen below the previous best value respectively.

EarlyStoppingAndCondition

EarlyStoppingOrCondition

NoopEarlyStopping


Model Selection

Given a set of models that have been trained on the same data, model selection chooses the model that is considered best according to some criterion. In the context of neural networks, a typical use case for this is the training process, where the set of models is given by the respetive model after each epoch, or hyperparameter search, where one model for each hyperparameter configuration is trained.

Interface

../../small_text/training/model_selection.py

Example Usage

from small_text.training.model_selection import ModelSelection

model_selection = ModelSelection()

measured_values = {'val_acc': 0.87, 'train_acc': 0.89, 'val_loss': 0.123} model_selection.add_model('model_id_1', 'model_1.bin', measured_values) measured_values = {'val_acc': 0.88, 'train_acc': 0.91, 'val_loss': 0.091} model_selection.add_model('model_id_2', 'model_2.bin', measured_values) measured_values = {'val_acc': 0.87, 'train_acc': 0.92, 'val_loss': 0.101} model_selection.add_model('model_id_3', 'model_3.bin', measured_values)

print(model_selection.select(select_by='val_acc')) print(model_selection.select(select_by='train_acc')) print(model_selection.select(select_by=['val_acc', 'train_acc']))

Output:

ModelSelectionResult('model_id_2', 'model_2.bin', {'val_loss': 0.091, 'val_acc': 0.88, 'train_loss': nan, 'train_acc': 0.91}, {'early_stop': False}) ModelSelectionResult('model_id_3', 'model_3.bin', {'val_loss': 0.101, 'val_acc': 0.87, 'train_loss': nan, 'train_acc': 0.92}, {'early_stop': False}) ModelSelectionResult('model_id_2', 'model_2.bin', {'val_loss': 0.091, 'val_acc': 0.88, 'train_loss': nan, 'train_acc': 0.91}, {'early_stop': False})

Implementations

ModelSelection

NoopModelSelection