From ff1062d39dfbe12511a7859d221d4a7157df66eb Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 28 Jun 2021 17:48:07 +0900 Subject: [PATCH] Add --reset_learning_rate option to lstmtraining (#3470) When the --reset_learning_rate option is specified, it resets the learning rate stored in each layer of the network loaded with --continue_from to the value specified by the --learning_rate option. If checkpoint is available, it does nothing. --- src/lstm/lstmrecognizer.h | 20 ++++++++++++++++++++ src/lstm/plumbing.h | 8 ++++++++ src/training/lstmtraining.cpp | 6 ++++++ 3 files changed, 34 insertions(+) diff --git a/src/lstm/lstmrecognizer.h b/src/lstm/lstmrecognizer.h index 2892c17027..c1659502c2 100644 --- a/src/lstm/lstmrecognizer.h +++ b/src/lstm/lstmrecognizer.h @@ -157,6 +157,26 @@ class TESS_API LSTMRecognizer { series->ScaleLayerLearningRate(&id[1], factor); } + // Set the all the learning rate(s) to the given value. + void SetLearningRate(float learning_rate) + { + ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); + learning_rate_ = learning_rate; + if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { + for (auto &id : EnumerateLayers()) { + SetLayerLearningRate(id, learning_rate); + } + } + } + // Set the learning rate of the layer with id, by the given value. + void SetLayerLearningRate(const std::string &id, float learning_rate) + { + ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); + ASSERT_HOST(id.length() > 1 && id[0] == ':'); + auto *series = static_cast(network_); + series->SetLayerLearningRate(&id[1], learning_rate); + } + // Converts the network to int if not already. void ConvertToInt() { if ((training_flags_ & TF_INT_MODE) == 0) { diff --git a/src/lstm/plumbing.h b/src/lstm/plumbing.h index 1c65fe9f61..fe0f499b21 100644 --- a/src/lstm/plumbing.h +++ b/src/lstm/plumbing.h @@ -120,6 +120,14 @@ class Plumbing : public Network { ASSERT_HOST(lr_ptr != nullptr); *lr_ptr *= factor; } + + // Set the learning rate for a specific layer of the stack to the given value. + void SetLayerLearningRate(const char *id, float learning_rate) { + float *lr_ptr = LayerLearningRatePtr(id); + ASSERT_HOST(lr_ptr != nullptr); + *lr_ptr = learning_rate; + } + // Returns a pointer to the learning rate for the given layer id. TESS_API float *LayerLearningRatePtr(const char *id); diff --git a/src/training/lstmtraining.cpp b/src/training/lstmtraining.cpp index 297870d7dd..fd365423ef 100644 --- a/src/training/lstmtraining.cpp +++ b/src/training/lstmtraining.cpp @@ -36,6 +36,8 @@ static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples betwe static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent."); static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights."); static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas."); +static BOOL_PARAM_FLAG(reset_learning_rate, false, + "Resets all stored learning rates to the value specified by --learning_rate."); static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas."); static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas."); static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images."); @@ -157,6 +159,10 @@ int main(int argc, char **argv) { return EXIT_FAILURE; } tprintf("Continuing from %s\n", FLAGS_continue_from.c_str()); + if (FLAGS_reset_learning_rate) { + trainer.SetLearningRate(FLAGS_learning_rate); + tprintf("Set learning rate to %f\n", static_cast(FLAGS_learning_rate)); + } trainer.InitIterations(); } if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {