Skip to content

Commit

Permalink
[RLlib] Add support for custom MultiActionDistributions. (#11311)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Oct 12, 2020
1 parent 0c0f67c commit 1ebcdf2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
62 changes: 41 additions & 21 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import numpy as np
import tree
from typing import List
from typing import List, Optional, Type, Union

from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
RLLIB_ACTION_DIST, _global_registry
Expand Down Expand Up @@ -110,18 +110,20 @@ class ModelCatalog:

@staticmethod
@DeveloperAPI
def get_action_dist(action_space: gym.Space,
config: ModelConfigDict,
dist_type: str = None,
framework: str = "tf",
**kwargs) -> (type, int):
def get_action_dist(
action_space: gym.Space,
config: ModelConfigDict,
dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
framework: str = "tf",
**kwargs) -> (type, int):
"""Returns a distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
config (Optional[dict]): Optional model config.
dist_type (Optional[str]): Identifier of the action distribution
type (str) interpreted as a hint.
dist_type (Optional[Union[str, Type[ActionDistribution]]]):
Identifier of the action distribution (str) interpreted as a
hint or the actual ActionDistribution class to use.
framework (str): One of "tf", "tfe", or "torch".
kwargs (dict): Optional kwargs to pass on to the Distribution's
constructor.
Expand All @@ -143,6 +145,9 @@ def get_action_dist(action_space: gym.Space,
"Using custom action distribution {}".format(action_dist_name))
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
action_dist_name)
dist_cls = ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, {}, framework)

# Dist_type is given directly as a class.
elif type(dist_type) is type and \
issubclass(dist_type, ActionDistribution) and \
Expand Down Expand Up @@ -173,18 +178,10 @@ def get_action_dist(action_space: gym.Space,
elif dist_type in (MultiActionDistribution,
TorchMultiActionDistribution) or \
isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
flat_action_space = flatten_space(action_space)
child_dists_and_in_lens = tree.map_structure(
lambda s: ModelCatalog.get_action_dist(
s, config, framework=framework), flat_action_space)
child_dists = [e[0] for e in child_dists_and_in_lens]
input_lens = [int(e[1]) for e in child_dists_and_in_lens]
return partial(
(TorchMultiActionDistribution
if framework == "torch" else MultiActionDistribution),
action_space=action_space,
child_distributions=child_dists,
input_lens=input_lens), int(sum(input_lens))
return ModelCatalog._get_multi_action_distribution(
(MultiActionDistribution
if framework == "tf" else TorchMultiActionDistribution),
action_space, config, framework)
# Simplex -> Dirichlet.
elif isinstance(action_space, Simplex):
if framework == "torch":
Expand Down Expand Up @@ -422,7 +419,8 @@ def track_var_creation(next_creator, **kw):

@staticmethod
@DeveloperAPI
def get_preprocessor(env: gym.Env, options: dict = None) -> Preprocessor:
def get_preprocessor(env: gym.Env,
options: Optional[dict] = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
Expand Down Expand Up @@ -552,3 +550,25 @@ def _get_v2_model_class(input_space, model_config, framework="tf"):
# Default Conv2D net.
else:
return VisionNet

@staticmethod
def _get_multi_action_distribution(dist_class, action_space, config,
framework):
# In case the custom distribution is a child of MultiActionDistr.
# If users want to completely ignore the suggested child
# distributions, they should simply do so in their custom class'
# constructor.
if issubclass(dist_class,
(MultiActionDistribution, TorchMultiActionDistribution)):
flat_action_space = flatten_space(action_space)
child_dists_and_in_lens = tree.map_structure(
lambda s: ModelCatalog.get_action_dist(
s, config, framework=framework), flat_action_space)
child_dists = [e[0] for e in child_dists_and_in_lens]
input_lens = [int(e[1]) for e in child_dists_and_in_lens]
return partial(
dist_class,
action_space=action_space,
child_distributions=child_dists,
input_lens=input_lens), int(sum(input_lens))
return dist_class
2 changes: 1 addition & 1 deletion rllib/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def logp(self, x):
return tf.zeros(self.output_shape)


class ModelCatalogTest(unittest.TestCase):
class TestModelCatalog(unittest.TestCase):
def tearDown(self):
ray.shutdown()

Expand Down

0 comments on commit 1ebcdf2

Please sign in to comment.