Skip to content

Commit

Permalink
[RLlib] Torch LR schedule not working. Fix and added test case. (#12396)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Nov 26, 2020
1 parent d521574 commit 6475297
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
30 changes: 29 additions & 1 deletion rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf
Expand Down Expand Up @@ -36,6 +37,30 @@
}


class MyCallbacks(DefaultCallbacks):
@staticmethod
def _check_lr_torch(policy, policy_id):
for j, opt in enumerate(policy._optimizers):
for p in opt.param_groups:
assert p["lr"] == policy.cur_lr, "LR scheduling error!"

@staticmethod
def _check_lr_tf(policy, policy_id):
lr = policy.cur_lr
sess = policy.get_session()
if sess:
lr = sess.run(lr)
optim_lr = sess.run(policy._optimizer._lr)
else:
lr = lr.numpy()
optim_lr = policy._optimizer.lr.numpy()
assert lr == optim_lr, "LR scheduling error!"

def on_train_result(self, *, trainer, result: dict, **kwargs):
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
"framework"] == "torch" else self._check_lr_tf)


class TestPPO(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -45,9 +70,12 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_ppo_compilation(self):
def test_ppo_compilation_and_lr_schedule(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
# for checking lr-schedule correctness
config["callbacks"] = MyCallbacks

config["num_workers"] = 1
config["num_sgd_iter"] = 2
# Settings in case we use an LSTM.
Expand Down
6 changes: 1 addition & 5 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,15 +612,11 @@ def __init__(self, lr, lr_schedule):

@override(Policy)
def on_global_var_update(self, global_vars):
super(LearningRateSchedule, self).on_global_var_update(global_vars)
super().on_global_var_update(global_vars)
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])

@override(TorchPolicy)
def optimizer(self):
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
return self._optimizers


@DeveloperAPI
Expand Down

0 comments on commit 6475297

Please sign in to comment.