Description
What happened + What you expected to happen
If I write a custom policy class that overrides compute_actions_from_input_dict(input_dict, ...)
, then in older versions of RLlib this input_dict always had the SampleBatch.AGENT_INDEX
set correctly. With the latest version, it is set correctly except for on the first timestep after an environment reset. This is because EnvRunnerV2.__process_resetted_obs_for_eval
does not add an agent_index key to the input dict, unlike EnvRunnerV2._process_observations
which does add this key. It should be very simple to fix this bug: just change
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.INFOS: infos,
SampleBatch.T: episode.length,
},
)
for agent_id, obs in agents_obs
]
to
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.INFOS: infos,
SampleBatch.T: episode.length,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
},
)
for agent_id, obs in agents_obs
]
in EnvRunnerV2.__process_resetted_obs_for_eval
.
I would submit a pull request myself but in the past my PRs have been ignored.
Versions / Dependencies
Ray: 2.5.1
Python: 3.9.16
OS: Ubuntu 22.04.2 LTS
Reproduction script
This gives an error with Ray 2.5.1 but is fixed if I change the code snippet I mentioned above in EnvRunnerV2
:
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.policy.sample_batch import SampleBatch
class MyCustomPolicy(PPOTorchPolicy):
def compute_actions_from_input_dict(
self, input_dict, explore=None, timestep=None, **kwargs
):
assert all(
agent_index > -1 for agent_index in input_dict[SampleBatch.AGENT_INDEX]
)
return super().compute_actions_from_input_dict(
input_dict, explore, timestep, **kwargs
)
class MyCustomAlgorithm(PPO):
@classmethod
def get_default_policy_class(cls, config):
return MyCustomPolicy
algo = MyCustomAlgorithm(
PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
)
algo.train()
Issue Severity
Medium: It is a significant difficulty but I can work around it.