New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Using (gym) discrete and box spaces inside dict observation space throws ValueError: Expected flattened obs shape ... #31525
Comments
Hey, I just want to add that such a bug is reproducible in RLlib 2.3.0 too, using I attach a script to reproduce it: import gymnasium as gym
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig
class MyEnv(gym.Env):
def __init__(self, env_config):
self.observation_space = gym.spaces.Dict(
test=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8, 10)), test_discrete=gym.spaces.Discrete(5)
)
self.action_space = gym.spaces.Discrete(3)
def reset(self, *, seed=None, options=None):
return {"test": np.zeros((8, 10)), "test_discrete": 1}, {}
def step(self, action):
return {"test": np.zeros((8, 10)), "test_discrete": 1}, 0, False, False, {}
ray.init()
config = (
PPOConfig()
.environment(env=MyEnv)
.rollouts(
num_rollout_workers=0, # local ray worker
num_envs_per_worker=1, # no vectorization
)
.framework("torch")
.training(model={"fcnet_hiddens": [32, 16]})
)
algo = config.build()
while True:
print(algo.train()) |
I took some time to look into the code, and I think I was able to pin the source of the bug The following code is the one responsible to create the dummy data and pack it into a def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1
) -> SampleBatch:
"""Creates a numpy dummy batch based on the Policy's view requirements.
Args:
batch_size: The size of the batch to create.
Returns:
Dict[str, TensorType]: The dummy batch containing all zero values.
"""
ret = {}
for view_col, view_req in self.view_requirements.items():
data_col = view_req.data_col or view_col
# Flattened dummy batch.
if (isinstance(view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and (
(
data_col == SampleBatch.OBS
and not self.config["_disable_preprocessor_api"]
)
or (
data_col == SampleBatch.ACTIONS
and not self.config.get("_disable_action_flattening")
)
):
_, shape = ModelCatalog.get_action_shape(
view_req.space, framework=self.config["framework"]
)
ret[view_col] = np.zeros((batch_size,) + shape[1:], np.float32)
# Non-flattened dummy batch.
else:
# Range of indices on time-axis, e.g. "-50:-1".
if isinstance(view_req.space, gym.spaces.Space):
time_size = (
len(view_req.shift_arr) if len(view_req.shift_arr) > 1 else None
)
ret[view_col] = get_dummy_batch_for_space(
view_req.space, batch_size=batch_size, time_size=time_size
)
else:
ret[view_col] = [view_req.space for _ in range(batch_size)]
# Due to different view requirements for the different columns,
# columns in the resulting batch may not all have the same batch size.
return SampleBatch(ret) The problemThe code relies on the However, both the OBS and ACTION fields get preprocessed to be stored in a The exception is raised in While The solutionI don't know which option is the best for the API but I think that transforming Otherwise, we may rely on ExtraAs a side note, I have a question about the NEXT_OBS field in view_requirement: given that the OBS field gets updated from the model's view_requirement, that by default use the original observation space, why does NEXT_OBS not follow the same path? At the moment its view_requirement corresponds to the preprocessed observation space. Can this be a bug as well? |
What happened + What you expected to happen
Using an observation space where the space is a gym Dict, with both a Box and Discrete space inside the Dict (see reproduction script below), throws the following error:
This also happens if replacing the Dict with a Tuple.
Please let me know if there's any other information needed, or if I missed anything. Thanks!
Versions / Dependencies
Windows 10 OS
Python 3.10
absl-py==1.3.0
aiohttp==3.8.3
aiosignal==1.3.1
ale-py==0.8.0
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
cachetools==5.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.3
cloudpickle==2.2.0
colorama==0.4.6
commonmark==0.9.1
contourpy==1.0.6
cycler==0.11.0
decorator==5.1.1
distlib==0.3.6
dm-tree==0.1.8
filelock==3.9.0
flatbuffers==23.1.4
fonttools==4.38.0
frozenlist==1.3.3
gast==0.4.0
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.51.1
gym==0.23.1
gym-notices==0.0.8
gymnasium==0.27.0
gymnasium-notices==0.0.1
h5py==3.7.0
idna==3.4
imageio==2.23.0
importlib-resources==5.10.2
jax-jumpy==0.2.0
jsonschema==4.17.3
keras==2.11.0
kiwisolver==1.4.4
libclang==14.0.6
lz4==4.3.2
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.6.2
msgpack==1.0.4
multidict==6.0.4
networkx==2.8.8
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.2
Pillow==9.4.0
platformdirs==2.6.2
protobuf==3.19.6
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.14.0
pyparsing==3.0.9
pyrsistent==0.19.3
python-dateutil==2.8.2
pytz==2022.7
PyWavelets==1.4.1
PyYAML==6.0
ray==2.2.0
requests==2.28.1
requests-oauthlib==1.3.1
rich==13.0.1
rsa==4.9
scikit-image==0.19.3
scipy==1.10.0
Shimmy==0.2.0
six==1.16.0
tabulate==0.9.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
tensorflow-estimator==2.11.0
tensorflow-intel==2.11.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-probability==0.19.0
termcolor==2.2.0
tifffile==2022.10.10
torch==1.13.1
typer==0.7.0
typing_extensions==4.4.0
urllib3==1.26.13
virtualenv==20.17.1
Werkzeug==2.2.2
wrapt==1.14.1
yarl==1.8.2
Reproduction script
Issue Severity
None
The text was updated successfully, but these errors were encountered: