Skip to content

Commit

Permalink
[RLlib] APEX_DDPG (PyTorch) test case and docs. (#8288)
Browse files Browse the repository at this point in the history
APEX_DDPG (PyTorch) test case and docs.
  • Loading branch information
sven1977 committed May 4, 2020
1 parent 5f351a0 commit b95e28f
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 23 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
`ARS`_ tf + torch **Yes** **Yes** No
`ES`_ tf + torch **Yes** **Yes** No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
`APEX-DDPG`_ tf No **Yes** **Yes**
`APEX-DDPG`_ tf + torch No **Yes** **Yes**
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
Expand Down
24 changes: 16 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ py_test(
srcs = ["agents/a3c/tests/test_a2c.py"]
)

# APEXTrainer (DQN)
py_test(
name = "test_apex_dqn",
tags = ["agents_dir"],
size = "large",
srcs = ["agents/dqn/tests/test_apex_dqn.py"]
)

# APEXDDPGTrainer
py_test(
name = "test_apex_ddpg",
tags = ["agents_dir"],
size = "small",
srcs = ["agents/ddpg/tests/test_apex_ddpg.py"]
)

# DDPGTrainer
py_test(
name = "test_ddpg",
Expand All @@ -121,14 +137,6 @@ py_test(
srcs = ["agents/dqn/tests/test_simple_q.py"]
)

# APEXTrainer
py_test(
name = "test_apex",
tags = ["agents_dir"],
size = "large",
srcs = ["agents/dqn/tests/test_apex.py"]
)

# IMPALA
py_test(
name = "test_vtrace",
Expand Down
56 changes: 56 additions & 0 deletions rllib/agents/ddpg/tests/test_apex_ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import unittest

import ray
import ray.rllib.agents.ddpg.apex as apex_ddpg
from ray.rllib.utils.test_utils import check, framework_iterator


class TestApexDDPG(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)

def tearDown(self):
ray.shutdown()

def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self):
"""Test whether an APEX-DDPGTrainer can be built on all frameworks."""
config = apex_ddpg.APEX_DDPG_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100
config["min_iter_time_s"] = 1
config["learning_starts"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
num_iterations = 1
for _ in framework_iterator(config, ("torch", "tf")):
plain_config = config.copy()
trainer = apex_ddpg.ApexDDPGTrainer(
config=plain_config, env="Pendulum-v0")

# Test per-worker scale distribution.
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
scale = [i["cur_scale"] for i in infos]
expected = [
0.4**(1 + (i + 1) / float(config["num_workers"] - 1) * 7)
for i in range(config["num_workers"])
]
check(scale, [0.0] + expected)

for _ in range(num_iterations):
print(trainer.train())

# Test again per-worker scale distribution
# (should not have changed).
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
scale = [i["cur_scale"] for i in infos]
check(scale, [0.0] + expected)

trainer.stop()


if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
2 changes: 1 addition & 1 deletion rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
def defer_make_workers(trainer, env_creator, policy, config):
# Hack to workaround https://github.com/ray-project/ray/issues/2541
# The workers will be created later, after the optimizer is created
return trainer._make_workers(env_creator, policy, config, 0)
return trainer._make_workers(env_creator, policy, config, num_workers=0)


def make_async_optimizer(workers, config):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import numpy as np
import pytest
import unittest

import ray
import ray.rllib.agents.dqn.apex as apex
from ray.rllib.utils.test_utils import framework_iterator
from ray.rllib.utils.test_utils import check, framework_iterator


class TestApex(unittest.TestCase):
class TestApexDQN(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)

def tearDown(self):
ray.shutdown()

def test_apex_compilation_and_per_worker_epsilon_values(self):
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
Expand All @@ -30,14 +29,20 @@ def test_apex_compilation_and_per_worker_epsilon_values(self):
# Test per-worker epsilon distribution.
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
eps = [i["cur_epsilon"] for i in infos]
assert np.allclose(eps, [0.0, 0.4, 0.016190862, 0.00065536])
expected = [0.4, 0.016190862, 0.00065536]
check([i["cur_epsilon"] for i in infos], [0.0] + expected)

# TODO(ekl) fix iterator metrics bugs w/multiple trainers.
# for i in range(1):
# results = trainer.train()
# print(results)

# Test again per-worker epsilon distribution
# (should not have changed).
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
check([i["cur_epsilon"] for i in infos], [0.0] + expected)

trainer.stop()


Expand Down
6 changes: 2 additions & 4 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self,
self._env_creator = env_creator
self._policy = policy
self._remote_config = trainer_config
self._num_workers = num_workers
self._logdir = logdir

if _setup:
Expand All @@ -62,7 +61,7 @@ def __init__(self,

# Create a number of remote workers
self._remote_workers = []
self.add_workers(self._num_workers)
self.add_workers(num_workers)

def local_worker(self):
"""Return the local rollout worker."""
Expand All @@ -86,7 +85,6 @@ def add_workers(self, num_workers):
num_workers (int): The number of remote Workers to add to this
WorkerSet.
"""
self._num_workers = num_workers
remote_args = {
"num_cpus": self._remote_config["num_cpus_per_worker"],
"num_gpus": self._remote_config["num_gpus_per_worker"],
Expand Down Expand Up @@ -266,7 +264,7 @@ def session_creator():
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=self._num_workers,
num_workers=config["num_workers"],
monitor_path=self._logdir if config["monitor"] else None,
log_dir=self._logdir,
log_level=config["log_level"],
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/exploration/per_worker_epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
alpha, eps, i = 7, 0.4, worker_index - 1
epsilon_schedule = ConstantSchedule(
eps**(1 + i / (num_workers - 1) * alpha),
eps**(1 + i / float(num_workers - 1) * alpha),
framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/exploration/per_worker_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
scale_schedule = None
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index >= 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, action_space, *, framework, num_workers, worker_index,
scale_schedule = None
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index >= 0:
if worker_index > 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
Expand Down

0 comments on commit b95e28f

Please sign in to comment.