Skip to content

Commit

Permalink
Rename LearningRateModulation to LearningRateModulator and reformat d…
Browse files Browse the repository at this point in the history
…ocstring (#1225)

* Rename LearningRateModulation to LearningRateModulator and reformat docstring

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>

* Fix typos

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>
  • Loading branch information
terrytangyuan committed Sep 20, 2019
1 parent a25baf2 commit a224932
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 45 deletions.
60 changes: 60 additions & 0 deletions elasticdl/python/master/learning_rate_modulator.py
@@ -0,0 +1,60 @@
import threading


class LearningRateModulator:
"""Modulates the learning rate with a multiplier.
Note:
This class supports concurrent usage by using
thread local storage.
"""

def __init__(self, learning_rate):
"""Constructs a `LearningRateModulator` instance.
Args:
learning_rate: The learning rate to be modulated.
This can be either a numeric value or a callable.
"""
self._learning_rate = learning_rate
self._tls = threading.local()
self._tls.multiplier = 1

def set_multiplier(self, multiplier):
"""Sets the multiplier.
Args:
multiplier: The multiplier used to modulate the learning rate.
"""
self._tls.multiplier = multiplier

def get_learning_rate(self):
"""Gets the modulated learning rate.
Returns:
The learning rate modulated by the multiplier.
"""
lr = self._learning_rate
if callable(lr):
lr = lr()
lr *= self._tls.multiplier
return lr


def add_lr_modulation_to_optimizer(optimizer):
"""Adds learning rate modulation to the given optimizer.
Args:
optimizer: The optimizer to add learning rate modulation to.
Returns:
A `LearningRateModulator` instance.
"""
# Get learning rate from optimizer
learning_rate = optimizer._hyper["learning_rate"]

# Replace the learning rate in optimizer with a callable
lr_modulation = LearningRateModulator(learning_rate)
optimizer.learning_rate = lr_modulation.get_learning_rate

return lr_modulation
43 changes: 0 additions & 43 deletions elasticdl/python/master/lr_modulation.py

This file was deleted.

2 changes: 1 addition & 1 deletion elasticdl/python/master/servicer.py
Expand Up @@ -16,7 +16,7 @@
from elasticdl.python.elasticdl.layers.embedding import Embedding
from elasticdl.python.master.checkpoint_service import CheckpointService
from elasticdl.python.master.embedding_service import EmbeddingService
from elasticdl.python.master.lr_modulation import (
from elasticdl.python.master.learning_rate_modulator import (
add_lr_modulation_to_optimizer,
)
from elasticdl.python.master.optimizer_wrapper import OptimizerWrapper
Expand Down
2 changes: 1 addition & 1 deletion elasticdl/python/tests/staleness_aware_test.py
Expand Up @@ -4,7 +4,7 @@

import tensorflow as tf

from elasticdl.python.master.lr_modulation import (
from elasticdl.python.master.learning_rate_modulator import (
add_lr_modulation_to_optimizer,
)

Expand Down

0 comments on commit a224932

Please sign in to comment.