diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 50add0b4aaa..8679d87d318 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -324,7 +324,6 @@ def get_observations( # TODO (simon): Users might want to receive only actions that have a # corresponding 'next observation' (i.e. no buffered actions). Take care of this. # Also in the `extra_model_outputs`. - def get_actions( self, indices: Union[int, List[int]] = -1, @@ -408,6 +407,7 @@ def get_actions( def get_extra_model_outputs( self, + key: str, indices: Union[int, List[int]] = -1, global_ts: bool = True, as_list: bool = False, @@ -418,6 +418,8 @@ def get_extra_model_outputs( during the given index range. Args: + key: A string determining the key in the extra model outputs + dictionary to return. This parameter is mandatory. indices: Either a single index or a list of indices. The indices can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). This defines the time indices for which the actions @@ -430,6 +432,7 @@ def get_extra_model_outputs( timestep, actions are returned (i.e. not all agent ids are necessarily in the keys). """ + assert key, "ERROR: When requesting extra model outputs a `key` is needed." buffered_outputs = {} if global_ts: @@ -462,8 +465,9 @@ def get_extra_model_outputs( # Then the buffer must be full and needs to be accessed. # Note, we do not want to empty the buffer, but only read it. buffered_outputs[agent_id] = [ - self.agent_buffers[agent_id]["extra_model_outputs"].queue[0] + self.agent_buffers[agent_id]["extra_model_outputs"].queue[0][key] ] + else: buffered_outputs[agent_id] = [] @@ -471,6 +475,7 @@ def get_extra_model_outputs( extra_model_outputs = self._getattr_by_index( "extra_model_outputs", indices=indices, + key=key, has_initial_value=True, global_ts=global_ts, global_ts_mapping=self.global_actions_t, @@ -1554,9 +1559,10 @@ def _generate_single_agent_episode( # Convert `extra_model_outputs` for this agent from list of dicts to dict # of lists. agent_extra_model_outputs = defaultdict(list) - for _model_out in _agent_extra_model_outputs: - for key, val in _model_out.items(): - agent_extra_model_outputs[key].append(val) + if _agent_extra_model_outputs: + for _model_out in _agent_extra_model_outputs: + for key, val in _model_out.items(): + agent_extra_model_outputs[key].append(val) agent_is_terminated = terminateds.get(agent_id, False) agent_is_truncated = truncateds.get(agent_id, False) @@ -1614,7 +1620,9 @@ def _generate_single_agent_episode( # TODO (simon): Check, if we need to use here also # `ts_carriage_return`. partial_agent_rewards_t.append(t + self.ts_carriage_return + 1) - if (t + 1) in self.global_t_to_local_t[agent_id][1:]: + if ( + t + self.ts_carriage_return + 1 + ) in self.global_t_to_local_t[agent_id][1:]: agent_rewards.append(agent_reward) agent_reward = 0.0 @@ -1645,6 +1653,7 @@ def _getattr_by_index( self, attr: str = "observations", indices: Union[int, List[int]] = -1, + key: Optional[str] = None, has_initial_value=False, global_ts: bool = True, global_ts_mapping: Optional[MultiAgentDict] = None, @@ -1687,14 +1696,12 @@ def _getattr_by_index( # If a list should be returned. if as_list: if buffered_values: - # Note, for addition we have to ensure that both elements are lists - # and terminated/truncated agents have numpy arrays. return [ { agent_id: ( - list(getattr(agent_eps, attr)) + getattr(agent_eps, attr).get(key) + buffered_values[agent_id] - )[global_ts_mapping[agent_id].find_indices([idx], shift)[0]] + )[global_ts_mapping[agent_id].find_indices([idx])[0]] for agent_id, agent_eps in self.agent_episodes.items() if global_ts_mapping[agent_id].find_indices([idx], shift) } @@ -1722,7 +1729,7 @@ def _getattr_by_index( agent_id: list( map( ( - list(getattr(agent_eps, attr)) + getattr(agent_eps, attr).get(key) + buffered_values[agent_id] ).__getitem__, global_ts_mapping[agent_id].find_indices( @@ -1755,12 +1762,15 @@ def _getattr_by_index( if not isinstance(indices, list): indices = [indices] + # If we have buffered values for the attribute we want to concatenate + # while searching for the indices. if buffered_values: return { agent_id: list( map( ( - getattr(agent_eps, attr) + buffered_values[agent_id] + getattr(agent_eps, attr).get(key) + + buffered_values[agent_id] ).__getitem__, set(indices).intersection( set( diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index b2c3f137701..8cc20bb79ec 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -436,7 +436,7 @@ def _sample_episodes( extra_model_outputs=extra_model_output, ) - done_episodes_to_return.append(episodes[i]) + done_episodes_to_return.append(episodes[i].finalize()) # Also early-out if we reach the number of episodes within this # for-loop. diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index a53f3ffa16e..a0bfac04e94 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -258,9 +258,7 @@ def test_init(self): self.assertEqual(len(episode.agent_episodes["agent_2"]), 1) self.assertEqual(len(episode.agent_episodes["agent_3"]), 1) self.assertEqual(len(episode.agent_episodes["agent_4"]), 1) - # Assert now that applying length on agent 5's episode raises an error. - with self.assertRaises(AssertionError): - len(episode.agent_episodes["agent_5"]) + self.assertEqual(len(episode.agent_episodes["agent_5"]), 0) # TODO (simon): Also test the other structs inside the MAE for agent 5 and # the other agents. @@ -790,7 +788,7 @@ def test_getters(self): # Test with initial observations only. episode_init_only = MultiAgentEpisode(agent_ids=agent_ids) episode_init_only.add_env_reset( - observation=observations[0], + observations=observations[0], infos=infos[0], ) # Get the last observation for agents and assert that its correct. @@ -846,94 +844,122 @@ def test_getters(self): self.assertEqual(last_actions["agent_3"][1], actions[-2]["agent_3"]) # --- extra_model_outputs --- - last_extra_model_outputs = episode.get_extra_model_outputs() - self.assertDictEqual( - last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"] + last_extra_model_outputs = episode.get_extra_model_outputs("extra") + self.assertEqual( + last_extra_model_outputs["agent_1"][0], + extra_model_outputs[-1]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"] + self.assertEqual( + last_extra_model_outputs["agent_3"][0], + extra_model_outputs[-1]["agent_3"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"] + self.assertEqual( + last_extra_model_outputs["agent_4"][0], + extra_model_outputs[-1]["agent_4"]["extra"], ) # Request the last two outputs. - last_extra_model_outputs = episode.get_extra_model_outputs(indices=[-1, -2]) - self.assertDictEqual( - last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"] + last_extra_model_outputs = episode.get_extra_model_outputs( + "extra", indices=[-1, -2] ) - self.assertDictEqual( - last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"] + self.assertEqual( + last_extra_model_outputs["agent_1"][0], + extra_model_outputs[-1]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"] + self.assertEqual( + last_extra_model_outputs["agent_3"][0], + extra_model_outputs[-1]["agent_3"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_1"][1], extra_model_outputs[-2]["agent_1"] + self.assertEqual( + last_extra_model_outputs["agent_4"][0], + extra_model_outputs[-1]["agent_4"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_2"][0], extra_model_outputs[-2]["agent_2"] + self.assertEqual( + last_extra_model_outputs["agent_1"][1], + extra_model_outputs[-2]["agent_1"]["extra"], + ) + self.assertEqual( + last_extra_model_outputs["agent_2"][0], + extra_model_outputs[-2]["agent_2"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_3"][1], extra_model_outputs[-2]["agent_3"] + self.assertEqual( + last_extra_model_outputs["agent_3"][1], + extra_model_outputs[-2]["agent_3"]["extra"], ) # Now request lists. - last_extra_model_outputs = episode.get_extra_model_outputs(as_list=True) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_1"], extra_model_outputs[-1]["agent_1"] + last_extra_model_outputs = episode.get_extra_model_outputs( + "extra", as_list=True ) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_3"], extra_model_outputs[-1]["agent_3"] + self.assertEqual( + last_extra_model_outputs[0]["agent_1"], + extra_model_outputs[-1]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_4"], extra_model_outputs[-1]["agent_4"] + self.assertEqual( + last_extra_model_outputs[0]["agent_3"], + extra_model_outputs[-1]["agent_3"]["extra"], + ) + self.assertEqual( + last_extra_model_outputs[0]["agent_4"], + extra_model_outputs[-1]["agent_4"]["extra"], ) # Request the last two extra model outputs and return as a list. last_extra_model_outputs = episode.get_extra_model_outputs( - [-1, -2], as_list=True + "extra", [-1, -2], as_list=True ) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_1"], extra_model_outputs[-1]["agent_1"] + self.assertEqual( + last_extra_model_outputs[0]["agent_1"], + extra_model_outputs[-1]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_3"], extra_model_outputs[-1]["agent_3"] + self.assertEqual( + last_extra_model_outputs[0]["agent_3"], + extra_model_outputs[-1]["agent_3"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[0]["agent_4"], extra_model_outputs[-1]["agent_4"] + self.assertEqual( + last_extra_model_outputs[0]["agent_4"], + extra_model_outputs[-1]["agent_4"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[1]["agent_1"], extra_model_outputs[-2]["agent_1"] + self.assertEqual( + last_extra_model_outputs[1]["agent_1"], + extra_model_outputs[-2]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[1]["agent_2"], extra_model_outputs[-2]["agent_2"] + self.assertEqual( + last_extra_model_outputs[1]["agent_2"], + extra_model_outputs[-2]["agent_2"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs[1]["agent_3"], extra_model_outputs[-2]["agent_3"] + self.assertEqual( + last_extra_model_outputs[1]["agent_3"], + extra_model_outputs[-2]["agent_3"]["extra"], ) # Now request the last extra model outputs at the local timesteps, i.e. # for each agent its last two actions. last_extra_model_outputs = episode.get_extra_model_outputs( - [-1, -2], global_ts=False + "extra", [-1, -2], global_ts=False ) - self.assertDictEqual( - last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"] + self.assertEqual( + last_extra_model_outputs["agent_1"][0], + extra_model_outputs[-1]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"] + self.assertEqual( + last_extra_model_outputs["agent_3"][0], + extra_model_outputs[-1]["agent_3"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"] + self.assertEqual( + last_extra_model_outputs["agent_4"][0], + extra_model_outputs[-1]["agent_4"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_1"][1], extra_model_outputs[-2]["agent_1"] + self.assertEqual( + last_extra_model_outputs["agent_1"][1], + extra_model_outputs[-2]["agent_1"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_2"][0], extra_model_outputs[-2]["agent_2"] + self.assertEqual( + last_extra_model_outputs["agent_2"][0], + extra_model_outputs[-2]["agent_2"]["extra"], ) - self.assertDictEqual( - last_extra_model_outputs["agent_3"][1], extra_model_outputs[-2]["agent_3"] + self.assertEqual( + last_extra_model_outputs["agent_3"][1], + extra_model_outputs[-2]["agent_3"]["extra"], ) # TODO (simon): Not tested with `global_ts=False`. @@ -1073,7 +1099,7 @@ def test_getters(self): self.assertListEqual(episode_1.global_t_to_local_t["agent_5"][-2:], [99, 100]) # Agent 5 has already died, so we need to convert back to list. self.assertListEqual( - episode_1.agent_episodes["agent_5"].rewards.tolist()[-2:], + episode_1.agent_episodes["agent_5"].rewards[-2:], last_rewards["agent_5"], ) self.assertIn("agent_2", last_rewards) @@ -1299,7 +1325,7 @@ def test_getters(self): self.assertIn("agent_5", last_rewards) # Agent 5 already died, so we need to convert to list first. self.assertListEqual( - episode_1.agent_episodes["agent_5"].rewards.tolist()[-1:-3:-1], + episode_1.agent_episodes["agent_5"].rewards[-1:-3:-1], last_rewards["agent_5"], ) self.assertIn("agent_8", last_rewards) diff --git a/rllib/env/utils.py b/rllib/env/utils.py index 30642462bbc..1bdd41e2c10 100644 --- a/rllib/env/utils.py +++ b/rllib/env/utils.py @@ -300,6 +300,33 @@ def get( return data + def __add__(self, other): + """Adds another buffer or list to the end of this one. + + Args: + other: Either `BufferWithInfiniteLookback` or `list`. + If a `BufferWithInfiniteLookback` the data gets + concatenated. If a `list` the list is concatenated to the + `self.data`. + + Returns: + A new `BufferWithInfiniteLookback` instance `self.data` cotnaining + concatenated data from `self.` and `other`. + """ + + if self.finalized: + raise RuntimeError(f"Cannot `add` to a finalized {type(self).__name__}.") + else: + if isinstance(other, BufferWithInfiniteLookback): + data = self.data + other.data + else: + data = self.data + other + return BufferWithInfiniteLookback( + data=data, + lookback=self.lookback, + space=self.space, + ) + def __getitem__(self, item): """Support squared bracket syntax, e.g. buffer[:5].""" return self.get(item) diff --git a/rllib/evaluation/postprocessing_v2.py b/rllib/evaluation/postprocessing_v2.py index a50da0718e5..98a5058330f 100644 --- a/rllib/evaluation/postprocessing_v2.py +++ b/rllib/evaluation/postprocessing_v2.py @@ -72,8 +72,8 @@ def compute_gae_for_episode( # them now in the training_step. episode = compute_bootstrap_value(episode, module) - vf_preds = episode.extra_model_outputs[SampleBatch.VF_PREDS] - rewards = episode.rewards + vf_preds = episode.get_extra_model_outputs(SampleBatch.VF_PREDS) + rewards = episode.get_rewards() # TODO (simon): In case of recurrent models sequeeze out time dimension. @@ -99,6 +99,8 @@ def compute_bootstrap_value( last_r = 0.0 else: # TODO (simon): This has to be made multi-agent ready. + # TODO (sven, simon): We have to change this as soon as the + # Connector API is ready. Episodes do not have states anymore. initial_states = module.get_initial_state() state = { k: initial_states[k] if episode.states is None else episode.states[k] @@ -163,9 +165,9 @@ def compute_advantages( last_r = convert_to_numpy(last_r) if rewards is None: - rewards = episode.rewards + rewards = episode.get_rewards() if vf_preds is None: - vf_preds = episode.extra_model_outs[SampleBatch.VF_PREDS] + vf_preds = episode.get_extra_model_outs(SampleBatch.VF_PREDS) if use_gae: vpred_t = np.concatenate([vf_preds, np.array([last_r])]) @@ -197,6 +199,7 @@ def compute_advantages( episode.extra_model_outputs[Postprocessing.ADVANTAGES] ) + # TODO (sven, simon): Maybe change to `BufferWithInfiniteLookback` episode.extra_model_outputs[ Postprocessing.ADVANTAGES ] = episode.extra_model_outputs[Postprocessing.ADVANTAGES].astype(np.float32)