From d352b1d70784b50ee16ef2fb28cc4ae8b1308844 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 22 Mar 2024 20:38:20 +0100 Subject: [PATCH] [RLlib] `MultiAgentEpisode` add `module_for` API. (#44241) --- rllib/env/multi_agent_episode.py | 36 +++++++++++++++++++++++++++---- rllib/env/single_agent_episode.py | 31 ++++++++++++++------------ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 08541a36248632..648c6a0f223b90 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -177,6 +177,11 @@ def __init__( AlgorithmConfig.DEFAULT_AGENT_TO_MODULE_MAPPING_FN ) self.agent_to_module_mapping_fn = agent_to_module_mapping_fn + # In case a user - e.g. via callbacks - already forces a mapping to happen + # via the `module_for()` API even before the agent has entered the episode + # (and has its SingleAgentEpisode created), we store all aldeary done mappings + # in this dict here. + self._agent_to_module_mapping: Dict[AgentID, ModuleID] = {} # Lookback buffer length is not provided. Interpret all provided data as # lookback buffer. @@ -305,7 +310,7 @@ def add_env_reset( if agent_id not in self.agent_episodes: self.agent_episodes[agent_id] = SingleAgentEpisode( agent_id=agent_id, - module_id=self.agent_to_module_mapping_fn(agent_id, self), + module_id=self.module_for(agent_id), multi_agent_episode_id=self.id_, observation_space=self.observation_space.get(agent_id), action_space=self.action_space.get(agent_id), @@ -408,7 +413,7 @@ def add_env_step( if agent_id not in self.agent_episodes: self.agent_episodes[agent_id] = SingleAgentEpisode( agent_id=agent_id, - module_id=self.agent_to_module_mapping_fn(agent_id, self), + module_id=self.module_for(agent_id), multi_agent_episode_id=self.id_, observation_space=self.observation_space.get(agent_id), action_space=self.action_space.get(agent_id), @@ -663,7 +668,7 @@ def finalize( Note that Columns.INFOS are NEVER numpy'ized and will remain a list (normally, a list of the original, env-returned dicts). This is due to the - herterogenous nature of INFOS returned by envs, which would make it unwieldy to + heterogeneous nature of INFOS returned by envs, which would make it unwieldy to convert this information to numpy arrays. After calling this method, no further data may be added to this episode via @@ -865,6 +870,28 @@ def agent_episode_ids(self) -> MultiAgentDict: for agent_id, agent_eps in self.agent_episodes.items() } + def module_for(self, agent_id: AgentID) -> Optional[ModuleID]: + """Returns the ModuleID for a given AgentID. + + Forces the agent-to-module mapping to be performed (via + `self.agent_to_module_mapping_fn`), if this has not been done yet. + Note that all such mappings are stored in the `self._agent_to_module_mapping` + property. + + Args: + agent_id: The AgentID to get a mapped ModuleID for. + + Returns: + The ModuleID mapped to from the given `agent_id`. + """ + if agent_id not in self._agent_to_module_mapping: + module_id = self._agent_to_module_mapping[ + agent_id + ] = self.agent_to_module_mapping_fn(agent_id, self) + return module_id + else: + return self._agent_to_module_mapping[agent_id] + def get_observations( self, indices: Optional[Union[int, List[int], slice]] = None, @@ -1595,7 +1622,8 @@ def _init_single_agent_episodes( ) ) # Try to figure out the module ID for this agent. - # If not provided explicitly, try the mapping function (if provided). + # If not provided explicitly by the user that initializes this episode + # object, try our mapping function. module_id = agent_module_ids.get( agent_id, self.agent_to_module_mapping_fn(agent_id, self) ) diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py index 6157b38d9635e9..cce1d0934c3709 100644 --- a/rllib/env/single_agent_episode.py +++ b/rllib/env/single_agent_episode.py @@ -756,7 +756,10 @@ def get_observations( .. testcode:: + import gymnasium as gym + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check episode = SingleAgentEpisode( # Discrete(4) observations (ints between 0 and 4 (excl.)) @@ -766,29 +769,29 @@ def get_observations( len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Plain usage (`indices` arg only). - episode.get_observations(-1) # 3 - episode.get_observations(0) # 0 - episode.get_observations([0, 2]) # [0, 2] - episode.get_observations([-1, 0]) # [3, 0] - episode.get_observations(slice(None, 2)) # [0, 1] - episode.get_observations(slice(-2, None)) # [2, 3] + check(episode.get_observations(-1), 3) + check(episode.get_observations(0), 0) + check(episode.get_observations([0, 2]), [0, 2]) + check(episode.get_observations([-1, 0]), [3, 0]) + check(episode.get_observations(slice(None, 2)), [0, 1]) + check(episode.get_observations(slice(-2, None)), [2, 3]) # Using `fill=...` (requesting slices beyond the boundaries). - episode.get_observations(slice(-6, -2), fill=-9) # [-9, -9, 0, 1] - episode.get_observations(slice(2, 5), fill=-7) # [2, 3, -7] + check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1]) + check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7]) # Using `one_hot_discrete=True`. - episode.get_observations(2, one_hot_discrete=True) # [0 0 1 0] - episode.get_observations(3, one_hot_discrete=True) # [0 0 0 1] - episode.get_observations( + check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0]) + check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1]) + check(episode.get_observations( slice(0, 3), one_hot_discrete=True, - ) # [[1 0 0 0], [0 1 0 0], [0 0 1 0]] + ), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]) # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. - episode.get_observations( + check(episode.get_observations( -1, neg_indices_left_of_zero=True, # -1 means one left of ts=0 fill=0.0, one_hot_discrete=True, - ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!) + ), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!) Returns: The collected observations.