Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] agent index is missing from policy input dict on environment reset #37521

Closed
cassidylaidlaw opened this issue Jul 18, 2023 · 2 comments
Closed
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-samplingbackend Issues around the sampling backend of RLlib

Comments

@cassidylaidlaw
Copy link
Contributor

What happened + What you expected to happen

If I write a custom policy class that overrides compute_actions_from_input_dict(input_dict, ...), then in older versions of RLlib this input_dict always had the SampleBatch.AGENT_INDEX set correctly. With the latest version, it is set correctly except for on the first timestep after an environment reset. This is because EnvRunnerV2.__process_resetted_obs_for_eval does not add an agent_index key to the input dict, unlike EnvRunnerV2._process_observations which does add this key. It should be very simple to fix this bug: just change

    acd_list: List[AgentConnectorDataType] = [
            AgentConnectorDataType(
                env_id,
                agent_id,
                {
                    SampleBatch.NEXT_OBS: obs,
                    SampleBatch.INFOS: infos,
                    SampleBatch.T: episode.length,
                },
            )
            for agent_id, obs in agents_obs
        ]

to

    acd_list: List[AgentConnectorDataType] = [
            AgentConnectorDataType(
                env_id,
                agent_id,
                {
                    SampleBatch.NEXT_OBS: obs,
                    SampleBatch.INFOS: infos,
                    SampleBatch.T: episode.length,
                    SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
                },
            )
            for agent_id, obs in agents_obs
        ]

in EnvRunnerV2.__process_resetted_obs_for_eval.

I would submit a pull request myself but in the past my PRs have been ignored.

Versions / Dependencies

Ray: 2.5.1
Python: 3.9.16
OS: Ubuntu 22.04.2 LTS

Reproduction script

This gives an error with Ray 2.5.1 but is fixed if I change the code snippet I mentioned above in EnvRunnerV2:

from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.policy.sample_batch import SampleBatch


class MyCustomPolicy(PPOTorchPolicy):
    def compute_actions_from_input_dict(
        self, input_dict, explore=None, timestep=None, **kwargs
    ):
        assert all(
            agent_index > -1 for agent_index in input_dict[SampleBatch.AGENT_INDEX]
        )
        return super().compute_actions_from_input_dict(
            input_dict, explore, timestep, **kwargs
        )


class MyCustomAlgorithm(PPO):
    @classmethod
    def get_default_policy_class(cls, config):
        return MyCustomPolicy


algo = MyCustomAlgorithm(
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
)

algo.train()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

@cassidylaidlaw cassidylaidlaw added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jul 18, 2023
@ArturNiederfahrenhorst
Copy link
Contributor

@cassidylaidlaw Thank you very much for opening this issue and providing a concise and actionable text.
I've opened a PR.

@ArturNiederfahrenhorst ArturNiederfahrenhorst added P1 Issue that should be fixed within a few weeks rllib RLlib related issues and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jul 19, 2023
@ArturNiederfahrenhorst ArturNiederfahrenhorst added the rllib-samplingbackend Issues around the sampling backend of RLlib label Jul 28, 2023
@ArturNiederfahrenhorst
Copy link
Contributor

@cassidylaidlaw Thanks again and please reopen this if there is an issue with the current solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-samplingbackend Issues around the sampling backend of RLlib
Projects
None yet
Development

No branches or pull requests

2 participants