Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] When using gym.spaces.Dict as observation_space the method export_model() breaks #26782

Open
Stefan-1313 opened this issue Jul 20, 2022 · 5 comments
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-env rllib env related issues rllib-torch The problems that only happen in torch on rllib

Comments

@Stefan-1313
Copy link

Stefan-1313 commented Jul 20, 2022

What happened + What you expected to happen

When I create a custom Environment which has a gym.spaces.Dict as observation_space, I cannot export the model to TorchScript.

For this I use:

agent = SimpleQTrainer(config=config_simple, env=select_env)
agent.get_policy().export_model(export_dir)

See below for full code to reproduce the issue.
When running it, I get the following error:

2022-07-20 16:19:23,831	WARNING env.py:135 -- Your env doesn't have a .spec.max_episode_steps attribute. This is fine if you have set 'horizon' in your config dictionary, or `soft_horizon`. However, if you haven't, 'horizon' will default to infinity, and your environment will not be reset.
Traceback (most recent call last):
  File "C:\Users\Stefan\Documents\Dev\Ray RLlib\minimal_ray_working_example (with export error)\main_MinWorkingExample.py", line 79, in <module>
    agent.get_policy().export_model(export_dir)
  File "C:\ProgramData\Miniconda3\envs\cenv39rl\lib\site-packages\ray\rllib\policy\torch_policy.py", line 912, in export_model
    traced = torch.jit.trace(self.model, (dummy_inputs, state_ins, seq_lens))
  File "C:\ProgramData\Miniconda3\envs\cenv39rl\lib\site-packages\torch\jit\_trace.py", line 741, in trace
    return trace_module(
  File "C:\ProgramData\Miniconda3\envs\cenv39rl\lib\site-packages\torch\jit\_trace.py", line 958, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of ({'obs': {'obs_1': tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])}, 'new_obs': {'obs_1': tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])}, 'actions': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0]), 'prev_actions': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0]), 'rewards': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'prev_rewards': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'dones': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'infos': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'eps_id': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'unroll_id': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'agent_index': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 't': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'q_values': tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]), 'action_dist_inputs': tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]), 'action_prob': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'action_logp': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), 'state_in_0': tensor([1.]), 'seq_lens': tensor([1.])}, [tensor([1.])], tensor([1.]))
:Dictionary inputs to traced functions must have consistent type. Found Dict[str, Tensor] and Tensor

Process finished with exit code 1

I also included (commented) code lines to use a gym.spaces.Box as observation_space. In the comments is which lines you have to comment+uncomment to switch. Then, you can see the issue no longer occurs.

Versions / Dependencies

Python v3.9.12
pip install Ray==1.13.0
pip install Ray[tune]==1.13.0
pip install Ray[rllib]==1.13.0

Reproduction script

from typing import Dict, List
import torch.nn as nn
import gym
import numpy as np
from ray.rllib.agents.dqn.simple_q import DEFAULT_CONFIG, SimpleQTrainer
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import TensorType, ModelConfigDict
from ray.tune.registry import register_env


class SimpleEnv(gym.Env):
    def __init__(self):
        self.shape = 4
        self.action_space = gym.spaces.Discrete(self.shape)
        self.observation_space = gym.spaces.Dict({"obs_1": gym.spaces.Box(low=np.float32(0),
                                                                          high=np.float32(1),
                                                                          shape=(self.shape,),
                                                                          dtype=np.float32)})
        # self.observation_space = gym.spaces.Box(low=np.float32(0),
        #                                         high=np.float32(1),
        #                                         shape=(self.shape,),
        #                                         dtype=np.float32)

    def reset(self):
        return {"obs_1": np.zeros(self.shape, dtype=np.float32)}  # Comment out if only using Box gym space.
        # return np.zeros(self.shape, dtype=np.float32)  # Uncomment out if only using Box gym space.

    def step(self, action):
        state = {"obs_1": np.zeros(self.shape, dtype=np.float32)}  # Comment out if only using Box gym space.
        # state = np.zeros(self.shape, dtype=np.float32)  # Uncomment out if only using Box gym space.
        return state, 1, False, {}


class SimpleNetwork(TorchModelV2, nn.Module):
    def __init__(self,
                 obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 num_outputs: int,
                 model_config: ModelConfigDict,
                 name: str,
                 **customized_model_kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        self.layers = nn.Sequential(nn.Linear(in_features=4, out_features=4, bias=True),
                                    nn.ReLU(),
                                    nn.Linear(in_features=4, out_features=num_outputs, bias=True))

    def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]):
        obs = input_dict["obs"]
        obs_1 = obs["obs_1"]  # comment out if only using Box gym space.
        # obs_1 = obs  # Uncomment out if only using Box gym space.

        q_values = self.layers(obs_1)
        return q_values, state

# Define names.
env_name = "simple_env"
model_name = "simple_network"

# Create experiment config.
config_simple = DEFAULT_CONFIG.copy()
config_simple["model"] = {"custom_model": model_name}
config_simple["framework"] = "torch"
config_simple["env"] = env_name
config_simple["_disable_preprocessor_api"] = True
config_simple["num_gpus"] = 0

# Register custom model and environment.
register_env("simple_env", lambda config: SimpleEnv())
ModelCatalog.register_custom_model(model_name, SimpleNetwork)

# Create agent.
agent = SimpleQTrainer(config=config_simple, env=env_name)

# Train network.
for n in range(2):
    result = agent.train()
agent.stop()

# Export network.
export_dir = r'D:\Default_Folders\Documents\_TEMP\U-SKU-test-runs\TESTER'
agent.get_policy().export_model(export_dir)

Issue Severity

High: It blocks me from completing my task.

@Stefan-1313 Stefan-1313 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jul 20, 2022
@Stefan-1313
Copy link
Author

Stefan-1313 commented Jul 20, 2022

I know about this issue: #26593 of a co-worker. It seems totally different, but the issue there also (only) occurs when using a gym.spaces.Dict as observation_space. The reproduction code is practically identical.

Because of that issue I tried setting config_simple["_disable_preprocessor_api"] = False.

Then, for both cases (gym.spaces.Dict and gym.spaces.Box as observation_space) the export works as expected.
I get this output in case it is relevant:

2022-07-20 16:41:05,577	WARNING env.py:135 -- Your env doesn't have a .spec.max_episode_steps attribute. This is fine if you have set 'horizon' in your config dictionary, or `soft_horizon`. However, if you haven't, 'horizon' will default to infinity, and your environment will not be reset.
C:\ProgramData\Miniconda3\envs\cenv39rl\lib\site-packages\ray\rllib\models\modelv2.py:444: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
C:\ProgramData\Miniconda3\envs\cenv39rl\lib\site-packages\torch\jit\_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  module._c._create_method_from_trace(

Process finished with exit code 0

Why this does not solve my problem

It is good to mention, I need to set config_simple["_disable_preprocessor_api"] = True because I have a custom Gym space which otherwise cannot be used.

@Stefan-1313 Stefan-1313 changed the title [RLlib] When using gym.spaces.Dict as observation_space export_model breaks [RLlib] When using gym.spaces.Dict as observation_space the method export_model() breaks Jul 20, 2022
@Stefan-1313
Copy link
Author

Stefan-1313 commented Jul 20, 2022

To summarize, the problem occurs when using both:

  • gym.spaces.Dict as observation_space
  • setting config_simple["_disable_preprocessor_api"] = True

@kouroshHakha kouroshHakha self-assigned this Jul 28, 2022
@kouroshHakha kouroshHakha added rllib RLlib related issues P0 Issue that must be fixed in short order and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jul 28, 2022
@kouroshHakha
Copy link
Contributor

I have looked into the issue and it is rooted in torch.jit.trace not allowing to specify dicts with different value types as input to nn.Modules that are being traced. You can look for more info here: pytorch/pytorch#16847 (comment)
The reason that it gets solved with "_disable_preprocessor_api" = True is that under the hood the input sample_batch will only use the flattened tensors and you won't get nested dictionaries with tensor leaves. So I cannot think of any other solutions that would be more elegant here. @Stefan-1313 Do you have any suggestions? I don't see why trace does not support this common use-case by default.

@kouroshHakha kouroshHakha added P1 Issue that should be fixed within a few weeks and removed P0 Issue that must be fixed in short order labels Aug 4, 2022
@kouroshHakha kouroshHakha removed their assignment Aug 5, 2022
@Stefan-1313
Copy link
Author

Stefan-1313 commented Aug 10, 2022

Hi @kouroshHakha ,
Thanks for you looking into it.

Why we cannot use "_disable_preprocessor_api" = False:
We use a CustomGymSpace so that is why "_disable_preprocessor_api" = True has to be. For us setting "_disable_preprocessor_api" = False is thus not feasible.

Best solution?
The best solution is for torch.jit.trace to support this. I believe that under the hood this is all implemented in C++. You can also import the exported modules in C++ (using LibTorch) (which we actually do for production scenarios). So I can imagine it is difficult (impossible?) for C++ to handle list/dict/vector types which hold elements of different types.
Maybe that is why torch.jit.trace does not support this? Otherwise, this would be the best solution!

Our temporary solution:
For us, we only use the "obs" (observation from the environment) key-values in the input_dict. We don't use all other stuff in there, so maybe it can simply be left out from the input_dict? I don't know if this holds in general so maybe it's not a feasible solution.

I have a work around, which is working but I don't think it is great.
I manually call torch.jit.trace on my custom_model. Then I provide as inputs an input_dict with only the stuff I need (so only consistent datatypes).

I'm afraid this is going to cause a lot of maintaining issues and may break with future versions of Ray and different models/architectures. It may be hard to generalize this solution. It may provide a hint however for you?

A feasible general solution?
The forward method of a custom_module gets 3 inputs:

  • input_dict: Dict[str, Any] ==> But we want this to be input_dict: Dict[str, Dict[str, TensorType] so we have consistent types in the dictionary which solves the problem.
  • state: List[TensorType]
  • seq_lens: TensorType

From my observation the input_dict holds not only the observation which has consistent datatypes (of Dict[str, TensorType]) but ALSO another copy of state and seq_lens! Both have a type other then Dict[str, TensorType] which causes the problem. I think both don't have to be inside the input_dict as both are already supplied to the forward() method as separate inputs.

So maybe the solution is simple?
Off course there is more I don't know then I do know, so it is very possible I overlook something.

@Missourl
Copy link

Hello,

Iam also using my obs space as dict, however I dont understand how can I custom my model . Knowing that Iam using DDPG.

Would anybody please give me some hint?

Thank you

@Rohan138 Rohan138 added rllib-torch The problems that only happen in torch on rllib rllib-env rllib env related issues labels Apr 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-env rllib env related issues rllib-torch The problems that only happen in torch on rllib
Projects
None yet
Development

No branches or pull requests

4 participants