Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 31525: Observation Space Dict w/ Box+Discrete elements errors out. #31560

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,13 @@ py_test(
srcs = ["utils/exploration/tests/test_random_encoder.py"]
)

py_test(
name = "utils/tests/test_tf_utils",
tags = ["team:rllib", "utils"],
size = "small",
srcs = ["utils/tests/test_tf_utils.py"]
)

py_test(
name = "utils/tests/test_torch_utils",
tags = ["team:rllib", "utils"],
Expand Down
16 changes: 11 additions & 5 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,21 @@ def get_action_dist(
@staticmethod
@DeveloperAPI
def get_action_shape(
action_space: gym.Space, framework: str = "tf"
action_space: gym.Space, framework: str = "tf", one_hot: bool = False
) -> (np.dtype, List[int]):
"""Returns action tensor dtype and shape for the action space.

Args:
action_space: Action space of the target gym env.
framework: The framework identifier. One of "tf" or "torch".
one_hot: Whether to use one-hot encoding for discrete (sub-)spaces.

Returns:
(dtype, shape): Dtype and shape of the actions tensor.
"""
dl_lib = torch if framework == "torch" else tf
if isinstance(action_space, Discrete):
return action_space.dtype, (None,)
return action_space.dtype, (None, action_space.n) if one_hot else (None,)
elif isinstance(action_space, (Box, Simplex)):
if np.issubdtype(action_space.dtype, np.floating):
return dl_lib.float32, (None,) + action_space.shape
Expand All @@ -369,7 +370,7 @@ def get_action_shape(
all_discrete = True
for i in range(len(flat_action_space)):
if isinstance(flat_action_space[i], Discrete):
size += 1
size += flat_action_space[i].n if one_hot else 1
else:
all_discrete = False
size += np.product(flat_action_space[i].shape)
Expand All @@ -383,19 +384,24 @@ def get_action_shape(
@staticmethod
@DeveloperAPI
def get_action_placeholder(
action_space: gym.Space, name: str = "action"
action_space: gym.Space, name: str = "action", one_hot: bool = False
) -> TensorType:
"""Returns an action placeholder consistent with the action space

Args:
action_space: Action space of the target gym env.
name: An optional string to name the placeholder by.
Default: "action".
one_hot: Whether to use one-hot encoding for discrete (sub-)spaces.

Returns:
action_placeholder: A placeholder for the actions
"""
dtype, shape = ModelCatalog.get_action_shape(action_space, framework="tf")
dtype, shape = ModelCatalog.get_action_shape(
action_space,
framework="tf",
one_hot=one_hot,
)

return tf1.placeholder(dtype, shape=shape, name=name)

Expand Down
17 changes: 12 additions & 5 deletions rllib/policy/dynamic_tf_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,25 +530,32 @@ def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs)
# Create a +time-axis placeholder if the shift is not an
# int (range or list of ints).
# Do not flatten actions if action flattening disabled.
if self.config.get("_disable_action_flattening") and view_col in [
SampleBatch.ACTIONS,
SampleBatch.PREV_ACTIONS,
]:
flatten = False
if view_col in [SampleBatch.ACTIONS, SampleBatch.PREV_ACTIONS]:
if self.config.get("_disable_action_flattening"):
flatten = False
one_hot = False
else:
flatten = True
one_hot = False
# Do not flatten observations if no preprocessor API used.
elif (
view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
and self.config["_disable_preprocessor_api"]
):
flatten = False
one_hot = False
# Flatten everything else.
else:
flatten = True
one_hot = True

input_dict[view_col] = get_placeholder(
space=view_req.space,
name=view_col,
time_axis=time_axis,
flatten=flatten,
one_hot=one_hot,

)
dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)

Expand Down
77 changes: 77 additions & 0 deletions rllib/utils/tests/test_tf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest

import gymnasium as gym
import numpy as np

import ray
from ray.rllib.utils.tf_utils import get_placeholder


class TestTfUtils(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_get_placeholder_w_discrete(self):
"""Tests whether `get_placeholder` works as expected on Box spaces."""
space = gym.spaces.Discrete(2)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=False, one_hot=False
)
self.assertTrue(placeholder.shape.as_list() == [None])
self.assertTrue(placeholder.dtype.name == "int64")

space = gym.spaces.Discrete(3)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=True, one_hot=False
)
self.assertTrue(placeholder.shape.as_list() == [None])
self.assertTrue(placeholder.dtype.name == "int64")

space = gym.spaces.Discrete(4)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=False, one_hot=True
)
self.assertTrue(placeholder.shape.as_list() == [None, 4])
self.assertTrue(placeholder.dtype.name == "float32")

def test_get_placeholder_w_box(self):
"""Tests whether `get_placeholder` works as expected on Box spaces."""
space = gym.spaces.Box(-1.0, 1.0, (2,), dtype=np.float32)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=False, one_hot=False
)
self.assertTrue(placeholder.shape.as_list() == [None, 2])
self.assertTrue(placeholder.dtype.name == "float32")

space = gym.spaces.Box(-1.0, 1.0, (2, 3), dtype=np.float32)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=False, one_hot=False
)
self.assertTrue(placeholder.shape.as_list() == [None, 2, 3])
self.assertTrue(placeholder.dtype.name == "float32")

space = gym.spaces.Box(-1.0, 1.0, (2, 3), dtype=np.float32)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=True, one_hot=False
)
self.assertTrue(placeholder.shape.as_list() == [None, 2, 3])
self.assertTrue(placeholder.dtype.name == "float32")

space = gym.spaces.Box(-1.0, 1.0, (2, 3), dtype=np.float32)
placeholder = get_placeholder(
space=space, time_axis=False, flatten=True, one_hot=True
)
self.assertTrue(placeholder.shape.as_list() == [None, 2, 3])
self.assertTrue(placeholder.dtype.name == "float32")


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
13 changes: 9 additions & 4 deletions rllib/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def get_placeholder(
value: Optional[Any] = None,
name: Optional[str] = None,
time_axis: bool = False,
flatten: bool = True
flatten: bool = True,
one_hot: bool = False,
) -> "tf1.placeholder":
"""Returns a tf1.placeholder object given optional hints, such as a space.

Expand All @@ -193,6 +194,8 @@ def get_placeholder(
dimension (None).
flatten: Whether to flatten the given space into a plain Box space
and then create the placeholder from the resulting space.
one_hot: Whether to one-hot discrete (sub-)spaces. For example, Discrete(2)
would yield a placeholder of Box((2,)).

Returns:
The tf1 placeholder.
Expand All @@ -202,18 +205,20 @@ def get_placeholder(
if space is not None:
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
if flatten:
return ModelCatalog.get_action_placeholder(space, None)
return ModelCatalog.get_action_placeholder(space, None, one_hot=one_hot)
else:
return tree.map_structure_with_path(
lambda path, component: get_placeholder(
space=component,
name=name + "." + ".".join([str(p) for p in path]),
one_hot=one_hot,
),
get_base_struct_from_space(space),
)
_, shape = ModelCatalog.get_action_shape(space, framework="tf", one_hot=one_hot)
return tf1.placeholder(
shape=(None,) + ((None,) if time_axis else ()) + space.shape,
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
shape=(None,) + ((None,) if time_axis else ()) + shape[1:],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we skipping that first index of that shape returned from the action?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question. We are now getting the shape from the call to _, shape = ModelCatalog.get_action_shape(space, framework="tf", one_hot=one_hot) instead of from the space directly. That's why we have to cut the batch dim (which was not part of the space shape before, but it is part of the shape returned by get_action_shape).

dtype=tf.float32 if (space.dtype == np.float64 or one_hot) else space.dtype,
name=name,
)
else:
Expand Down