Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rename LearningRateModulation to LearningRateModulator and reformat d…
…ocstring (#1225) * Rename LearningRateModulation to LearningRateModulator and reformat docstring Signed-off-by: terrytangyuan <terrytangyuan@gmail.com> * Fix typos Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>
- Loading branch information
1 parent
a25baf2
commit a224932
Showing
4 changed files
with
62 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import threading | ||
|
||
|
||
class LearningRateModulator: | ||
"""Modulates the learning rate with a multiplier. | ||
Note: | ||
This class supports concurrent usage by using | ||
thread local storage. | ||
""" | ||
|
||
def __init__(self, learning_rate): | ||
"""Constructs a `LearningRateModulator` instance. | ||
Args: | ||
learning_rate: The learning rate to be modulated. | ||
This can be either a numeric value or a callable. | ||
""" | ||
self._learning_rate = learning_rate | ||
self._tls = threading.local() | ||
self._tls.multiplier = 1 | ||
|
||
def set_multiplier(self, multiplier): | ||
"""Sets the multiplier. | ||
Args: | ||
multiplier: The multiplier used to modulate the learning rate. | ||
""" | ||
self._tls.multiplier = multiplier | ||
|
||
def get_learning_rate(self): | ||
"""Gets the modulated learning rate. | ||
Returns: | ||
The learning rate modulated by the multiplier. | ||
""" | ||
lr = self._learning_rate | ||
if callable(lr): | ||
lr = lr() | ||
lr *= self._tls.multiplier | ||
return lr | ||
|
||
|
||
def add_lr_modulation_to_optimizer(optimizer): | ||
"""Adds learning rate modulation to the given optimizer. | ||
Args: | ||
optimizer: The optimizer to add learning rate modulation to. | ||
Returns: | ||
A `LearningRateModulator` instance. | ||
""" | ||
# Get learning rate from optimizer | ||
learning_rate = optimizer._hyper["learning_rate"] | ||
|
||
# Replace the learning rate in optimizer with a callable | ||
lr_modulation = LearningRateModulator(learning_rate) | ||
optimizer.learning_rate = lr_modulation.get_learning_rate | ||
|
||
return lr_modulation |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters