diff --git a/elasticdl/python/master/learning_rate_modulator.py b/elasticdl/python/master/learning_rate_modulator.py new file mode 100644 index 000000000..202c0f2b9 --- /dev/null +++ b/elasticdl/python/master/learning_rate_modulator.py @@ -0,0 +1,60 @@ +import threading + + +class LearningRateModulator: + """Modulates the learning rate with a multiplier. + + Note: + This class supports concurrent usage by using + thread local storage. + """ + + def __init__(self, learning_rate): + """Constructs a `LearningRateModulator` instance. + + Args: + learning_rate: The learning rate to be modulated. + This can be either a numeric value or a callable. + """ + self._learning_rate = learning_rate + self._tls = threading.local() + self._tls.multiplier = 1 + + def set_multiplier(self, multiplier): + """Sets the multiplier. + + Args: + multiplier: The multiplier used to modulate the learning rate. + """ + self._tls.multiplier = multiplier + + def get_learning_rate(self): + """Gets the modulated learning rate. + + Returns: + The learning rate modulated by the multiplier. + """ + lr = self._learning_rate + if callable(lr): + lr = lr() + lr *= self._tls.multiplier + return lr + + +def add_lr_modulation_to_optimizer(optimizer): + """Adds learning rate modulation to the given optimizer. + + Args: + optimizer: The optimizer to add learning rate modulation to. + + Returns: + A `LearningRateModulator` instance. + """ + # Get learning rate from optimizer + learning_rate = optimizer._hyper["learning_rate"] + + # Replace the learning rate in optimizer with a callable + lr_modulation = LearningRateModulator(learning_rate) + optimizer.learning_rate = lr_modulation.get_learning_rate + + return lr_modulation diff --git a/elasticdl/python/master/lr_modulation.py b/elasticdl/python/master/lr_modulation.py deleted file mode 100644 index d8f235880..000000000 --- a/elasticdl/python/master/lr_modulation.py +++ /dev/null @@ -1,43 +0,0 @@ -import threading - - -class LearningRateModulation: - """ - Modify learning rate with a multiplier. - Support concurrent usage by using thread local storage. - Arguments - learning rate: can be a value or a callable. - """ - - def __init__(self, learning_rate): - self._learning_rate = learning_rate - self._tls = threading.local() - self._tls.multiplier = 1 - - def set_multiplier(self, multiplier): - self._tls.multiplier = multiplier - - def get_learning_rate(self): - lr = self._learning_rate - if callable(lr): - lr = lr() - lr *= self._tls.multiplier - return lr - - -def add_lr_modulation_to_optimizer(optimizer): - """ - Add lr modulation feature in optimizer - Argument: - optimizer: the optimizer to add lr modulation feature - Return: - LearningRateModulation instance - """ - # Get learning rate from optimizer - learning_rate = optimizer._hyper["learning_rate"] - - # Replace the learning rate in optimizer with a calllable - lr_modulation = LearningRateModulation(learning_rate) - optimizer.learning_rate = lr_modulation.get_learning_rate - - return lr_modulation diff --git a/elasticdl/python/master/servicer.py b/elasticdl/python/master/servicer.py index fe2132606..464b97974 100644 --- a/elasticdl/python/master/servicer.py +++ b/elasticdl/python/master/servicer.py @@ -16,7 +16,7 @@ from elasticdl.python.elasticdl.layers.embedding import Embedding from elasticdl.python.master.checkpoint_service import CheckpointService from elasticdl.python.master.embedding_service import EmbeddingService -from elasticdl.python.master.lr_modulation import ( +from elasticdl.python.master.learning_rate_modulator import ( add_lr_modulation_to_optimizer, ) from elasticdl.python.master.optimizer_wrapper import OptimizerWrapper diff --git a/elasticdl/python/tests/staleness_aware_test.py b/elasticdl/python/tests/staleness_aware_test.py index fe7943631..8582ddc9f 100644 --- a/elasticdl/python/tests/staleness_aware_test.py +++ b/elasticdl/python/tests/staleness_aware_test.py @@ -4,7 +4,7 @@ import tensorflow as tf -from elasticdl.python.master.lr_modulation import ( +from elasticdl.python.master.learning_rate_modulator import ( add_lr_modulation_to_optimizer, )