Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError when trying to use a trained PPO with LSTMs (defined in config) using Pytorch framework on same env as trained #13026

Closed
avacaondata opened this issue Dec 22, 2020 · 13 comments
Assignees
Labels
enhancement Request for new feature and/or capability P3 Issue moderate in impact or severity

Comments

@avacaondata
Copy link

What is the problem?

I've already trained an agent with PPO and a model with LSTMs with an environment, successfully. But when I try to test the trained agent loading it from the last checkpoint, it cannot take the first action, with a KeyError:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-34-94f7b1289e4c> in <module>
      1 while not done:
----> 2     action = agent.compute_action(observation=obs)
      3     obs, reward, done, info = env.step(action)
      4     episode_reward += reward
      5     print(info)

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/agents/trainer.py in compute_action(self, observation, state, prev_action, prev_reward, info, policy_id, full_fetch, explore)
    829             info,
    830             clip_actions=self.config["clip_actions"],
--> 831             explore=explore)
    832 
    833         if state or full_fetch:

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/policy.py in compute_single_action(self, obs, state, prev_action, prev_reward, info, episode, clip_actions, explore, timestep, **kwargs)
    205             episodes=episodes,
    206             explore=explore,
--> 207             timestep=timestep)
    208 
    209         # Some policies don't return a tuple, but always just a single action.

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py in compute_actions(self, obs_batch, state_batches, prev_action_batch, prev_reward_batch, info_batch, episodes, explore, timestep, **kwargs)
    167             actions, state_out, extra_fetches, logp = \
    168                 self._compute_action_helper(
--> 169                     input_dict, state_batches, seq_lens, explore, timestep)
    170 
    171             # Action-logp and action-prob.

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py in _compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep)
    247                 dist_class = self.dist_class
    248                 dist_inputs, state_out = self.model(input_dict, state_batches,
--> 249                                                     seq_lens)
    250 
    251             if not (isinstance(dist_class, functools.partial)

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/modelv2.py in __call__(self, input_dict, state, seq_lens)
    206             restored["obs_flat"] = input_dict["obs"]
    207         with self.context():
--> 208             res = self.forward(restored, state or [], seq_lens)
    209         if ((not isinstance(res, list) and not isinstance(res, tuple))
    210                 or len(res) != 2):

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py in forward(self, input_dict, state, seq_lens)
    162                 [
    163                     wrapped_out,
--> 164                     torch.reshape(input_dict[SampleBatch.PREV_ACTIONS].float(),
    165                                   [-1, self.action_dim]),
    166                     torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py in __getitem__(self, key)
     25     def __getitem__(self, key):
     26         self.accessed_keys.add(key)
---> 27         value = dict.__getitem__(self, key)
     28         if self.get_interceptor:
     29             if key not in self.intercepted_values:

KeyError: 'prev_actions'

Ray version and other system information (Python version, TensorFlow version, OS):
System information:

  • ray version: 1.0.1.post1
  • Python version: 3.7
  • Pytorch version: 1.7.1

Reproduction (REQUIRED)

For reproducing it a checkpoint should be created first, but I'll pass the configuration:

CONFIG = {
    'num_workers': 30,
    'num_envs_per_worker': 3,
    'create_env_on_driver': False,
    'rollout_fragment_length': 1024,
    'batch_mode': 'truncate_episodes',
    'num_gpus': 2,
    'train_batch_size': 30720,
    'model': {"fcnet_hiddens": [2048, 1024],
    "fcnet_activation": "relu",
    'conv_filters': None,
    'conv_activation': 'relu',
    'free_log_std': False,
    'no_final_linear': False,
    'vf_share_layers': True,
    "use_lstm": True,
    "max_seq_len": 256,
    "lstm_cell_size": 512,
    "lstm_use_prev_action_reward": True,
    '_time_major': False,
    'framestack': True,
    'dim': 84,
    'grayscale': False,
    'zero_mean': True,
    'custom_model': None,
    'custom_model_config': {},
    'custom_action_dist': None,
    'custom_preprocessor': None},
    'optimizer': {},
    'gamma': 0.99,
    'horizon': None,
    'soft_horizon': False,
    'no_done_at_end': False,
    'env_config': {},
    'env': "BasicEnv",
    'normalize_actions': False,
    'clip_rewards': True,
    'clip_actions': True,
    'preprocessor_pref': 'deepmind',
    'lr': 5e-05,
    'monitor': False,
    'log_level': 'WARN',
    'ignore_worker_failures': True,
    'log_sys_usage': True,
    'fake_sampler': False,
    'framework': 'torch',
    'eager_tracing': False,
    'explore': True,
    'exploration_config': {'type': 'StochasticSampling'},
    'evaluation_interval': None,
    'evaluation_num_episodes': 10,
    'in_evaluation': False,
    'evaluation_config': {},
    'evaluation_num_workers': 0,
    'custom_eval_function': None,
    'sample_async': False,
    '_use_trajectory_view_api': True,
    'observation_filter': 'MeanStdFilter',
    'synchronize_filters': True,
    'tf_session_args': {'intra_op_parallelism_threads': 2,
    'inter_op_parallelism_threads': 2,
    'gpu_options': {'allow_growth': True},
    'log_device_placement': False,
    'device_count': {'CPU': 1},
    'allow_soft_placement': True},
    'local_tf_session_args': {'intra_op_parallelism_threads': 8,
    'inter_op_parallelism_threads': 8},
    'compress_observations': False,
    'collect_metrics_timeout': 180,
    'metrics_smoothing_episodes': 5,
    'remote_worker_envs': False,
    'remote_env_batch_wait_ms': 0,
    'min_iter_time_s': 0,
    'timesteps_per_iteration': 0,
    'seed': None,
    'extra_python_environs_for_driver': {},
    'extra_python_environs_for_worker': {},
    'num_cpus_per_worker': 1,
    'num_gpus_per_worker': 0,
    'custom_resources_per_worker': {},
    'num_cpus_for_driver': 1,
    'memory': 0,
    'object_store_memory': 0,
    'memory_per_worker': 0,
    'object_store_memory_per_worker': 0,
    'input': 'sampler',
    'input_evaluation': ['is', 'wis'],
    'postprocess_inputs': False,
    'shuffle_buffer_size': 0,
    'output': None,
    'output_compress_columns': ['obs', 'new_obs'],
    'output_max_file_size': 67108864,
    'multiagent': {'policies': {},
    'policy_mapping_fn': None,
    'policies_to_train': None,
    'observation_fn': None,
    'replay_mode': 'independent'},
    'logger_config': None,
    'replay_sequence_length': 1,
    'use_critic': True,
    'use_gae': True,
    'lambda': 1.0,
    'kl_coeff': 0.2,
    'sgd_minibatch_size': 1024,
    'shuffle_sequences': True,
    'num_sgd_iter': 10,
    'lr_schedule': None,
    'vf_share_layers': False,
    'vf_loss_coeff': 1.0,
    'entropy_coeff': 0.0,
    'entropy_coeff_schedule': None,
    'clip_param': 0.3,
    'vf_clip_param': 30000,
    "grad_clip": 1.0,
    'kl_target': 0.01,
    'simple_optimizer': False,
    '_fake_gpus': False
}

checkpoint = "" # my checkpoint goes here

agent = ppo.PPOTrainer(config=config, env="BasicEnv")

agent.restore(checkpoint)

episode_reward = 0
done = False
obs = env.reset()
while not done:
    action = agent.compute_action(observation=obs, prev_action=np.array([0.0]*6), 
                                 prev_reward = [0.0], full_fetch=True)
    obs, reward, done, info = env.step(action)
    episode_reward += reward
    print(info)

If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".

  • [x ] I have verified my script runs in a clean environment and reproduces the issue.
  • [] I have verified the issue also occurs with the latest wheels.

@deanwampler @ericl @rshin @yaroslavvb

@avacaondata avacaondata added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Dec 22, 2020
@51616
Copy link

51616 commented Dec 22, 2020

I have the same issue when using custom RNN model. Seems like some keys in input_dict (e.g., prev_actions and prev_rewards) are not present while doing the first foward pass. This did not occur in previous release though (I have the same code ran in 0.8.5 and it worked fine.)

edit: Mine happened during training initialization but have the same error as above

Ray 1.0.1.post1
tensorflow 2.4.0
python 3.7.9

@avacaondata
Copy link
Author

@51616 What was the last ray version your code worked with?? For the moment that can be a solution, until this issue is solved.

@51616
Copy link

51616 commented Dec 22, 2020

@alexvaca0 I just upgraded directly to master from 0.8.5. I have no idea which one works apart from 0.8.5. But The APIs are changed a lot since then so you may have to fix a lot of your code just for it to work with the old ones. It could be a band-aid solution for now.

@avacaondata
Copy link
Author

@deanwampler @rshin @ericl Please someone can take a look at this?

@51616
Copy link

51616 commented Dec 22, 2020

@alexvaca0 Actually, I manage to run a code similar to your example without the error. The return values from compute_action is a 3-tuple (action, state, logit). You should use the action not the whole tuple in your example.
You should also try using prev_reward and prev_action as int and add state argument to the compute_action call without full_fetch.

I did not load the checkpoint though. It might cause a different behaviour. You should try it with your checkpoint.

@avacaondata
Copy link
Author

Could you post a code example of how you'd add those parameters to the function call in agent.compute_action()? @51616

@avacaondata
Copy link
Author

I'm trying this:

action=agent.compute_action(obs, prev_action=np.array([0.0]*6), prev_reward=0, state=np.zeros((1,1,256, 512)))

But gives the following error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-24-d643ff346dc1> in <module>
      2     #action = agent.compute_action(observation=np.array([[o for o in obs], [0.0]*obs.shape[0]]), prev_action=np.array([[0.0]*6]*256), prev_reward=np.array([0.0]*256),)
      3     #                             #prev_reward = [0.0], full_fetch=True)
----> 4     action=agent.compute_action(obs, prev_action=np.array([0.0]*6), prev_reward=0, state=np.zeros((1,1,256, 512)))
      5     obs, reward, done, info = env.step(action)
      6     episode_reward += reward

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/agents/trainer.py in compute_action(self, observation, state, prev_action, prev_reward, info, policy_id, full_fetch, explore)
    843             info,
    844             clip_actions=self.config["clip_actions"],
--> 845             explore=explore)
    846 
    847         if state or full_fetch:

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/policy.py in compute_single_action(self, obs, state, prev_action, prev_reward, info, episode, clip_actions, explore, timestep, **kwargs)
    215             episodes=episodes,
    216             explore=explore,
--> 217             timestep=timestep)
    218 
    219         # Some policies don't return a tuple, but always just a single action.

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py in compute_actions(self, obs_batch, state_batches, prev_action_batch, prev_reward_batch, info_batch, episodes, explore, timestep, **kwargs)
    168             ]
    169             return self._compute_action_helper(input_dict, state_batches,
--> 170                                                seq_lens, explore, timestep)
    171 
    172     @override(Policy)

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py in _compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep)
    230                 dist_class = self.dist_class
    231                 dist_inputs, state_out = self.model(input_dict, state_batches,
--> 232                                                     seq_lens)
    233 
    234             if not (isinstance(dist_class, functools.partial)

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/modelv2.py in __call__(self, input_dict, state, seq_lens)
    208             restored["obs_flat"] = input_dict["obs"]
    209         with self.context():
--> 210             res = self.forward(restored, state or [], seq_lens)
    211         if ((not isinstance(res, list) and not isinstance(res, tuple))
    212                 or len(res) != 2):

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py in forward(self, input_dict, state, seq_lens)
    192         # Then through our LSTM.
    193         input_dict["obs_flat"] = wrapped_out
--> 194         return super().forward(input_dict, state, seq_lens)
    195 
    196     @override(RecurrentNetwork)

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py in forward(self, input_dict, state, seq_lens)
     81             time_major=self.time_major,
     82         )
---> 83         output, new_state = self.forward_rnn(inputs, state, seq_lens)
     84         output = torch.reshape(output, [-1, self.num_outputs])
     85         return output, new_state

~/miniconda/envs/env_trade/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py in forward_rnn(self, inputs, state, seq_lens)
    209             inputs,
    210             [torch.unsqueeze(state[0], 0),
--> 211              torch.unsqueeze(state[1], 0)])
    212         # Re-apply paddings.
    213         # if time_major and max_seq_len > 1:

IndexError: list index out of range

@51616
Copy link

51616 commented Dec 22, 2020

agent = ppo.PPOTrainer(config=config, env="env_name")
state = agent.get_policy().model.get_initial_state()
env = create_env()
obs = env.reset()
done = False
while not done:
            action, state, logit = agent.compute_action(observation=o, prev_action=1.0, 
                                         prev_reward = 0.0, state = state)
            obs, reward, dones, info = env.step(action)

@avacaondata
Copy link
Author

Thanks you so much!! :) @51616

@51616
Copy link

51616 commented Dec 22, 2020

I found a solution to my custom model too. Here what's needed for the model to use prev_rewards and prev_actions in forward call.

self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
            ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
                            shift=-1)
        self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
            ViewRequirement(SampleBatch.REWARDS, shift=-1)

Taken from the default LSTM model in recurrent_net.py.

@avacaondata
Copy link
Author

Yeah, I took a look at it too, but the thing is that internally, there should be some place where if PREV_ACTIONS or PREV_REWARDS is None, it's filled with 0s (the same with state). Actually, there might be somewhere in the trainer or somewhere else, where this is done for training, as training in my case worked properly.

@sven1977 sven1977 self-assigned this Dec 30, 2020
@sven1977 sven1977 added P1 Issue that should be fixed within a few weeks rllib and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Dec 30, 2020
@ericl ericl added this to the RLlib Bugs milestone Mar 11, 2021
@ericl ericl removed the rllib label Mar 11, 2021
@sven1977 sven1977 added enhancement Request for new feature and/or capability P3 Issue moderate in impact or severity and removed P1 Issue that should be fixed within a few weeks bug Something that is supposed to be working; but isn't labels Mar 22, 2021
@sven1977
Copy link
Contributor

sven1977 commented Mar 22, 2021

Downgrading this to P3.

  1. Models have to declare that they require PREV_ACTIONS/PREV_REWARDS via their self.view_requirements dict.
    See the examples here:
    rllib/models/tf/recurrent_net.py::LSTMWrapper (most previous action/reward)
    rllib/models/torch/recurrent_net.py::LSTMWrapper
    rllib/models/tf/attention_net.py::AttentionWrapper (range of previous n actions/rewards)
    rllib/models/tf/attention_net.py::AttentionWrapper
    rllib/examples/models/trajectory_view_utilizing_models.py
  2. When calling Trainer.compute_action() and using an RNN model (or one that requires recent actions/rewards) one has to provide internal state information, previous actions or rewards. I.e. The above reproduction script works fine when being passed the required information on each Trainer.compute_action call:
...

agent = ppo.PPOTrainer(config=config)
env = StatelessCartPole()
agent.restore(checkpoint)

episode_reward = 0
done = False
obs = env.reset()
state = agent.get_policy().get_initial_state()
reward = 0.0
action = 0
while not done:
    action, state, _ = agent.compute_action(observation=obs, state=state, prev_action=action,
                                         prev_reward=reward, full_fetch=True)
    obs, reward, done, info = env.step(action)
    print(info)
    episode_reward += reward
    if done:
        obs = env.reset()
        state = agent.get_policy().get_initial_state()
        reward = 0.0
        action = 0

@amanarora28
Copy link

amanarora28 commented Feb 19, 2024

facing same issue but key error is coming for '1' , compute_single_action does not work ( version 2.9.2 ray)

Running this code

done = False
obs = env.reset()
state = algo.get_policy().model.get_initial_state()

while not done:
    action = algo.compute_single_action(obs, prev_action=0, prev_reward=0.0,state=state)
    obs, reward, done, info = env.step(action)

getting this error

KeyError                                  Traceback (most recent call last)
Cell In[11], line 64
     61 obs = env.reset()
     63 while not done:
---> 64     action = algo.compute_single_action(obs)
     65     obs, reward, done, info = env.step(action)
     66     episode_reward += reward

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/util/tracing/tracing_helper.py:467, in _inject_tracing_into_class.<locals>.span_wrapper.<locals>._resume_span(self, _ray_trace_ctx, *_args, **_kwargs)
    465 # If tracing feature flag is not on, perform a no-op
    466 if not _is_tracing_enabled() or _ray_trace_ctx is None:
--> 467     return method(self, *_args, **_kwargs)
    469 tracer: _opentelemetry.trace.Tracer = _opentelemetry.trace.get_tracer(
    470     __name__
    471 )
    473 # Retrieves the context from the _ray_trace_ctx dictionary we
    474 # injected.

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:1836, in Algorithm.compute_single_action(self, observation, state, prev_action, prev_reward, info, input_dict, policy_id, full_fetch, explore, timestep, episode, unsquash_action, clip_action, **kwargs)
   1828     action, state, extra = policy.compute_single_action(
   1829         input_dict=input_dict,
   1830         explore=explore,
   1831         timestep=timestep,
   1832         episode=episode,
   1833     )
   1834 # Individual args.
   1835 else:
-> 1836     action, state, extra = policy.compute_single_action(
   1837         obs=observation,
   1838         state=state,
   1839         prev_action=prev_action,
   1840         prev_reward=prev_reward,
   1841         info=info,
   1842         explore=explore,
   1843         timestep=timestep,
   1844         episode=episode,
   1845     )
   1847 # If we work in normalized action space (normalize_actions=True),
   1848 # we re-translate here into the env's action space.
   1849 if unsquash_action:

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py:559, in Policy.compute_single_action(self, obs, state, prev_action, prev_reward, info, input_dict, episode, explore, timestep, **kwargs)
    556 if episode is not None:
    557     episodes = [episode]
--> 559 out = self.compute_actions_from_input_dict(
    560     input_dict=SampleBatch(input_dict),
    561     episodes=episodes,
    562     explore=explore,
    563     timestep=timestep,
    564 )
    566 # Some policies don't return a tuple, but always just a single action.
    567 # E.g. ES and ARS.
    568 if not isinstance(out, tuple):

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py:572, in TorchPolicyV2.compute_actions_from_input_dict(self, input_dict, explore, timestep, **kwargs)
    565 if state_batches:
    566     seq_lens = torch.tensor(
    567         [1] * len(state_batches[0]),
    568         dtype=torch.long,
    569         device=state_batches[0].device,
    570     )
--> 572 return self._compute_action_helper(
    573     input_dict, state_batches, seq_lens, explore, timestep
    574 )

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/utils/threading.py:24, in with_lock.<locals>.wrapper(self, *a, **k)
     22 try:
     23     with self._lock:
---> 24         return func(self, *a, **k)
     25 except AttributeError as e:
     26     if "has no attribute '_lock'" in e.args[0]:

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py:1293, in TorchPolicyV2._compute_action_helper(self, input_dict, state_batches, seq_lens, explore, timestep)
   1291 else:
   1292     dist_class = self.dist_class
-> 1293     dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
   1295 if not (
   1296     isinstance(dist_class, functools.partial)
   1297     or issubclass(dist_class, TorchDistributionWrapper)
   1298 ):
   1299     raise ValueError(
   1300         "`dist_class` ({}) not a TorchDistributionWrapper "
   1301         "subclass! Make sure your `action_distribution_fn` or "
   1302         "`make_model_and_action_dist` return a correct "
   1303         "distribution class.".format(dist_class.__name__)
   1304     )

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/models/modelv2.py:263, in ModelV2.__call__(self, input_dict, state, seq_lens)
    260         restored["obs_flat"] = input_dict["obs"]
    262 with self.context():
--> 263     res = self.forward(restored, state or [], seq_lens)
    265 if isinstance(input_dict, SampleBatch):
    266     input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/ray/rllib/models/torch/complex_input_net.py:211, in ComplexInputNetwork.forward(self, input_dict, state, seq_lens)
    209         outs.append(one_hot_out)
    210     else:
--> 211         nn_out, _ = self.flatten[i](
    212             SampleBatch(
    213                 {
    214                     SampleBatch.OBS: torch.reshape(
    215                         component, [-1, self.flatten_dims[i]]
    216                     )
    217                 }
    218             )
    219         )
    220         outs.append(nn_out)
    222 # Concat all outputs and the non-image inputs.

File ~/Projects/ving/code/pynotebook/ving_venv/lib/python3.11/site-packages/torch/nn/modules/container.py:461, in ModuleDict.__getitem__(self, key)
    459 @_copy_to_script_wrapper
    460 def __getitem__(self, key: str) -> Module:
--> 461     return self._modules[key]

KeyError: '1'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for new feature and/or capability P3 Issue moderate in impact or severity
Projects
None yet
Development

No branches or pull requests

5 participants