Skip to content

[RLlib] agent index is missing from policy input dict on environment reset #37521

Closed
@cassidylaidlaw

Description

@cassidylaidlaw

What happened + What you expected to happen

If I write a custom policy class that overrides compute_actions_from_input_dict(input_dict, ...), then in older versions of RLlib this input_dict always had the SampleBatch.AGENT_INDEX set correctly. With the latest version, it is set correctly except for on the first timestep after an environment reset. This is because EnvRunnerV2.__process_resetted_obs_for_eval does not add an agent_index key to the input dict, unlike EnvRunnerV2._process_observations which does add this key. It should be very simple to fix this bug: just change

    acd_list: List[AgentConnectorDataType] = [
            AgentConnectorDataType(
                env_id,
                agent_id,
                {
                    SampleBatch.NEXT_OBS: obs,
                    SampleBatch.INFOS: infos,
                    SampleBatch.T: episode.length,
                },
            )
            for agent_id, obs in agents_obs
        ]

to

    acd_list: List[AgentConnectorDataType] = [
            AgentConnectorDataType(
                env_id,
                agent_id,
                {
                    SampleBatch.NEXT_OBS: obs,
                    SampleBatch.INFOS: infos,
                    SampleBatch.T: episode.length,
                    SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
                },
            )
            for agent_id, obs in agents_obs
        ]

in EnvRunnerV2.__process_resetted_obs_for_eval.

I would submit a pull request myself but in the past my PRs have been ignored.

Versions / Dependencies

Ray: 2.5.1
Python: 3.9.16
OS: Ubuntu 22.04.2 LTS

Reproduction script

This gives an error with Ray 2.5.1 but is fixed if I change the code snippet I mentioned above in EnvRunnerV2:

from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.policy.sample_batch import SampleBatch


class MyCustomPolicy(PPOTorchPolicy):
    def compute_actions_from_input_dict(
        self, input_dict, explore=None, timestep=None, **kwargs
    ):
        assert all(
            agent_index > -1 for agent_index in input_dict[SampleBatch.AGENT_INDEX]
        )
        return super().compute_actions_from_input_dict(
            input_dict, explore, timestep, **kwargs
        )


class MyCustomAlgorithm(PPO):
    @classmethod
    def get_default_policy_class(cls, config):
        return MyCustomPolicy


algo = MyCustomAlgorithm(
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
)

algo.train()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

Metadata

Metadata

Labels

P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn'trllibRLlib related issuesrllib-envrunnersIssues around the sampling backend of RLlib

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions