Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Using (gym) discrete and box spaces inside dict observation space throws ValueError: Expected flattened obs shape ... #31525

Open
Naton1 opened this issue Jan 8, 2023 · 2 comments
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks repro-script-confirmed rllib RLlib related issues rllib-catalog issues related to RLlib's model catalog rllib-models An issue related to RLlib (default or custom) Models. rllib-tf The problems that only happen in tf on rllib

Comments

@Naton1
Copy link

Naton1 commented Jan 8, 2023

What happened + What you expected to happen

Using an observation space where the space is a gym Dict, with both a Box and Discrete space inside the Dict (see reproduction script below), throws the following error:

(RolloutWorker pid=1404)   File "python\ray\_raylet.pyx", line 830, in ray._raylet.execute_task
(RolloutWorker pid=1404)   File "python\ray\_raylet.pyx", line 834, in ray._raylet.execute_task
(RolloutWorker pid=1404)   File "python\ray\_raylet.pyx", line 780, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\_private\function_manager.py", line 674, in actor_method_executor
(RolloutWorker pid=1404)     return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span
(RolloutWorker pid=1404)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 712, in __init__
(RolloutWorker pid=1404)     self._build_policy_map(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span
(RolloutWorker pid=1404)     return method(self, *_args, **_kwargs)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1970, in _build_policy_map
(RolloutWorker pid=1404)     self.policy_map.create_policy(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\policy\policy_map.py", line 146, in create_policy
(RolloutWorker pid=1404)     policy = create_policy_for_framework(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\utils\policy.py", line 117, in create_policy_for_framework
(RolloutWorker pid=1404)     return policy_class(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\algorithms\ppo\ppo_tf_policy.py", line 83, in __init__
(RolloutWorker pid=1404)     base.__init__(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\policy\dynamic_tf_policy_v2.py", line 95, in __init__
(RolloutWorker pid=1404)     ) = self._init_action_fetches(timestep, explore)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\policy\dynamic_tf_policy_v2.py", line 616, in _init_action_fetches
(RolloutWorker pid=1404)     dist_inputs, self._state_out = self.model(self._input_dict)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\models\modelv2.py", line 247, in __call__
(RolloutWorker pid=1404)     restored["obs"] = restore_original_dimensions(
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\models\modelv2.py", line 411, in restore_original_dimensions
(RolloutWorker pid=1404)     return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(RolloutWorker pid=1404)   File "C:\Users\Nate\rl-ge\venv\lib\site-packages\ray\rllib\models\modelv2.py", line 445, in _unpack_obs
(RolloutWorker pid=1404)     raise ValueError(
(RolloutWorker pid=1404) ValueError: Expected flattened obs shape of [..., 82], got (?, 81)
(RolloutWorker pid=1404) 
(RolloutWorker pid=37048) 

Process finished with exit code 1

This also happens if replacing the Dict with a Tuple.

Please let me know if there's any other information needed, or if I missed anything. Thanks!

Versions / Dependencies

Windows 10 OS
Python 3.10

absl-py==1.3.0
aiohttp==3.8.3
aiosignal==1.3.1
ale-py==0.8.0
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
cachetools==5.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.3
cloudpickle==2.2.0
colorama==0.4.6
commonmark==0.9.1
contourpy==1.0.6
cycler==0.11.0
decorator==5.1.1
distlib==0.3.6
dm-tree==0.1.8
filelock==3.9.0
flatbuffers==23.1.4
fonttools==4.38.0
frozenlist==1.3.3
gast==0.4.0
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.51.1
gym==0.23.1
gym-notices==0.0.8
gymnasium==0.27.0
gymnasium-notices==0.0.1
h5py==3.7.0
idna==3.4
imageio==2.23.0
importlib-resources==5.10.2
jax-jumpy==0.2.0
jsonschema==4.17.3
keras==2.11.0
kiwisolver==1.4.4
libclang==14.0.6
lz4==4.3.2
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.6.2
msgpack==1.0.4
multidict==6.0.4
networkx==2.8.8
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.2
Pillow==9.4.0
platformdirs==2.6.2
protobuf==3.19.6
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.14.0
pyparsing==3.0.9
pyrsistent==0.19.3
python-dateutil==2.8.2
pytz==2022.7
PyWavelets==1.4.1
PyYAML==6.0
ray==2.2.0
requests==2.28.1
requests-oauthlib==1.3.1
rich==13.0.1
rsa==4.9
scikit-image==0.19.3
scipy==1.10.0
Shimmy==0.2.0
six==1.16.0
tabulate==0.9.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
tensorflow-estimator==2.11.0
tensorflow-intel==2.11.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-probability==0.19.0
termcolor==2.2.0
tifffile==2022.10.10
torch==1.13.1
typer==0.7.0
typing_extensions==4.4.0
urllib3==1.26.13
virtualenv==20.17.1
Werkzeug==2.2.2
wrapt==1.14.1
yarl==1.8.2

Reproduction script

import gym
import numpy as np
import ray
from ray.rllib.algorithms import ppo


class MyEnv(gym.Env):
    def __init__(self, env_config):
        self.observation_space = gym.spaces.Dict(
            test=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8, 10)),
            test_discrete=gym.spaces.Discrete(2)
        )
        self.action_space = gym.spaces.Discrete(3)

    def reset(self):
        return {'test': np.zeros((8, 10)), 'test_discrete': 1}

    def step(self, action):
        return {'test': np.zeros((8, 10)), 'test_discrete': 1}, 0, False, {}


ray.init()
algo = ppo.PPO(env=MyEnv)

while True:
    print(algo.train())

Issue Severity

None

@Naton1 Naton1 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 8, 2023
@sven1977 sven1977 self-assigned this Jan 9, 2023
@sven1977 sven1977 added P1 Issue that should be fixed within a few weeks repro-script-confirmed rllib RLlib related issues rllib-models An issue related to RLlib (default or custom) Models. rllib-tf The problems that only happen in tf on rllib rllib-catalog issues related to RLlib's model catalog and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 9, 2023
@LCarmi
Copy link

LCarmi commented Mar 18, 2023

Hey, I just want to add that such a bug is reproducible in RLlib 2.3.0 too, using torch as framework too, differently from what assumed in #31560

I attach a script to reproduce it:

import gymnasium as gym
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig


class MyEnv(gym.Env):
    def __init__(self, env_config):
        self.observation_space = gym.spaces.Dict(
            test=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8, 10)), test_discrete=gym.spaces.Discrete(5)
        )
        self.action_space = gym.spaces.Discrete(3)

    def reset(self, *, seed=None, options=None):
        return {"test": np.zeros((8, 10)), "test_discrete": 1}, {}

    def step(self, action):
        return {"test": np.zeros((8, 10)), "test_discrete": 1}, 0, False, False, {}


ray.init()
config = (
    PPOConfig()
    .environment(env=MyEnv)
    .rollouts(
        num_rollout_workers=0,  # local ray worker
        num_envs_per_worker=1,  # no vectorization
    )
    .framework("torch")
    .training(model={"fcnet_hiddens": [32, 16]}) 
)

algo = config.build()

while True:
    print(algo.train())

@LCarmi
Copy link

LCarmi commented Mar 18, 2023

I took some time to look into the code, and I think I was able to pin the source of the bug

The following code is the one responsible to create the dummy data and pack it into a SampleBatch

def _get_dummy_batch_from_view_requirements(
        self, batch_size: int = 1
    ) -> SampleBatch:
        """Creates a numpy dummy batch based on the Policy's view requirements.

        Args:
            batch_size: The size of the batch to create.

        Returns:
            Dict[str, TensorType]: The dummy batch containing all zero values.
        """
        ret = {}
        for view_col, view_req in self.view_requirements.items():
            data_col = view_req.data_col or view_col
            # Flattened dummy batch.
            if (isinstance(view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and (
                (
                    data_col == SampleBatch.OBS
                    and not self.config["_disable_preprocessor_api"]
                )
                or (
                    data_col == SampleBatch.ACTIONS
                    and not self.config.get("_disable_action_flattening")
                )
            ):
                _, shape = ModelCatalog.get_action_shape(
                    view_req.space, framework=self.config["framework"]
                )
                ret[view_col] = np.zeros((batch_size,) + shape[1:], np.float32)
            # Non-flattened dummy batch.
            else:
                # Range of indices on time-axis, e.g. "-50:-1".
                if isinstance(view_req.space, gym.spaces.Space):
                    time_size = (
                        len(view_req.shift_arr) if len(view_req.shift_arr) > 1 else None
                    )
                    ret[view_col] = get_dummy_batch_for_space(
                        view_req.space, batch_size=batch_size, time_size=time_size
                    )
                else:
                    ret[view_col] = [view_req.space for _ in range(batch_size)]

        # Due to different view requirements for the different columns,
        # columns in the resulting batch may not all have the same batch size.
        return SampleBatch(ret)

The problem

The code relies on the self.view_requirement dict to generate a SampleBatch having the format that the model expects.

However, both the OBS and ACTION fields get preprocessed to be stored in a SampleBatch as a flat numpy array (considering the case in which no experimental flags have been used).
In the normal RLlib flow each SampleBatch will be used to compute actions from a Policy, which will in turn rely on a ModelV2.__call__(input_dict: SampleBatch) that will restore the expected gym.Space Structure and then call ModelV2.forward()

The exception is raised in ModelV2.__call__ once a structured Tensor is being reconstructed from the flat observations mocked by _get_dummy_batch_from_view_requirements.

While _get_dummy_batch_from_view_requirements tries to mock the preprocessing for OBS and ACTIONS, in order to provide a SampleBatch undistinguishable from a real one, it actually does it in a wrong way. This is because it manipulates both the OBS and ACTION spaces (contained in their view requirement) by unflattening both of them as a action space, while action spaces and observation spaces have different flattening rules (in particular, discrete spaces are one-hot encoded for observations, while actions are not).
This is confirmed by the exception message (by doing some test and some computations, you can see that the mismatch is due to all discrete spaces being counted as 1 while they should be counted as their size in one-hot)


The solution

I don't know which option is the best for the API but I think that transforming ModelCatalog.get_action_shape(...) into a ModelCatalog.get_space_shape(..., one_hot=False) seems reasonable (and then differentiate the call done for SampleBatch.ACTION from the one for SampleBatch.OBS)

Otherwise, we may rely on ModelV2.processed_obs_space attribute and then get the shape of that space for SampleBatch.OBS


Extra

As a side note, I have a question about the NEXT_OBS field in view_requirement: given that the OBS field gets updated from the model's view_requirement, that by default use the original observation space, why does NEXT_OBS not follow the same path? At the moment its view_requirement corresponds to the preprocessed observation space. Can this be a bug as well?

@Rohan138 Rohan138 self-assigned this May 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks repro-script-confirmed rllib RLlib related issues rllib-catalog issues related to RLlib's model catalog rllib-models An issue related to RLlib (default or custom) Models. rllib-tf The problems that only happen in tf on rllib
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants