Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fixed errors in 'MultiAgentEPisode' tests. Due to the use of 'BufferW… #41631

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 22 additions & 12 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def get_observations(
# TODO (simon): Users might want to receive only actions that have a
# corresponding 'next observation' (i.e. no buffered actions). Take care of this.
# Also in the `extra_model_outputs`.

def get_actions(
self,
indices: Union[int, List[int]] = -1,
Expand Down Expand Up @@ -408,6 +407,7 @@ def get_actions(

def get_extra_model_outputs(
self,
key: str,
indices: Union[int, List[int]] = -1,
global_ts: bool = True,
as_list: bool = False,
Expand All @@ -418,6 +418,8 @@ def get_extra_model_outputs(
during the given index range.

Args:
key: A string determining the key in the extra model outputs
dictionary to return. This parameter is mandatory.
indices: Either a single index or a list of indices. The indices
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
This defines the time indices for which the actions
Expand All @@ -430,6 +432,7 @@ def get_extra_model_outputs(
timestep, actions are returned (i.e. not all agent ids are
necessarily in the keys).
"""
assert key, "ERROR: When requesting extra model outputs a `key` is needed."
buffered_outputs = {}

if global_ts:
Expand Down Expand Up @@ -462,15 +465,17 @@ def get_extra_model_outputs(
# Then the buffer must be full and needs to be accessed.
# Note, we do not want to empty the buffer, but only read it.
buffered_outputs[agent_id] = [
self.agent_buffers[agent_id]["extra_model_outputs"].queue[0]
self.agent_buffers[agent_id]["extra_model_outputs"].queue[0][key]
]

else:
buffered_outputs[agent_id] = []

# Now, get the actions.
extra_model_outputs = self._getattr_by_index(
"extra_model_outputs",
indices=indices,
key=key,
has_initial_value=True,
global_ts=global_ts,
global_ts_mapping=self.global_actions_t,
Expand Down Expand Up @@ -1554,9 +1559,10 @@ def _generate_single_agent_episode(
# Convert `extra_model_outputs` for this agent from list of dicts to dict
# of lists.
agent_extra_model_outputs = defaultdict(list)
for _model_out in _agent_extra_model_outputs:
for key, val in _model_out.items():
agent_extra_model_outputs[key].append(val)
if _agent_extra_model_outputs:
for _model_out in _agent_extra_model_outputs:
for key, val in _model_out.items():
agent_extra_model_outputs[key].append(val)

agent_is_terminated = terminateds.get(agent_id, False)
agent_is_truncated = truncateds.get(agent_id, False)
Expand Down Expand Up @@ -1614,7 +1620,9 @@ def _generate_single_agent_episode(
# TODO (simon): Check, if we need to use here also
# `ts_carriage_return`.
partial_agent_rewards_t.append(t + self.ts_carriage_return + 1)
if (t + 1) in self.global_t_to_local_t[agent_id][1:]:
if (
t + self.ts_carriage_return + 1
) in self.global_t_to_local_t[agent_id][1:]:
agent_rewards.append(agent_reward)
agent_reward = 0.0

Expand Down Expand Up @@ -1645,6 +1653,7 @@ def _getattr_by_index(
self,
attr: str = "observations",
indices: Union[int, List[int]] = -1,
key: Optional[str] = None,
has_initial_value=False,
global_ts: bool = True,
global_ts_mapping: Optional[MultiAgentDict] = None,
Expand Down Expand Up @@ -1687,14 +1696,12 @@ def _getattr_by_index(
# If a list should be returned.
if as_list:
if buffered_values:
# Note, for addition we have to ensure that both elements are lists
# and terminated/truncated agents have numpy arrays.
return [
{
agent_id: (
list(getattr(agent_eps, attr))
getattr(agent_eps, attr).get(key)
+ buffered_values[agent_id]
)[global_ts_mapping[agent_id].find_indices([idx], shift)[0]]
)[global_ts_mapping[agent_id].find_indices([idx])[0]]
for agent_id, agent_eps in self.agent_episodes.items()
if global_ts_mapping[agent_id].find_indices([idx], shift)
}
Expand Down Expand Up @@ -1722,7 +1729,7 @@ def _getattr_by_index(
agent_id: list(
map(
(
list(getattr(agent_eps, attr))
getattr(agent_eps, attr).get(key)
+ buffered_values[agent_id]
).__getitem__,
global_ts_mapping[agent_id].find_indices(
Expand Down Expand Up @@ -1755,12 +1762,15 @@ def _getattr_by_index(
if not isinstance(indices, list):
indices = [indices]

# If we have buffered values for the attribute we want to concatenate
# while searching for the indices.
if buffered_values:
return {
agent_id: list(
map(
(
getattr(agent_eps, attr) + buffered_values[agent_id]
getattr(agent_eps, attr).get(key)
+ buffered_values[agent_id]
).__getitem__,
set(indices).intersection(
set(
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _sample_episodes(
extra_model_outputs=extra_model_output,
)

done_episodes_to_return.append(episodes[i])
done_episodes_to_return.append(episodes[i].finalize())

# Also early-out if we reach the number of episodes within this
# for-loop.
Expand Down
144 changes: 85 additions & 59 deletions rllib/env/tests/test_multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ def test_init(self):
self.assertEqual(len(episode.agent_episodes["agent_2"]), 1)
self.assertEqual(len(episode.agent_episodes["agent_3"]), 1)
self.assertEqual(len(episode.agent_episodes["agent_4"]), 1)
# Assert now that applying length on agent 5's episode raises an error.
with self.assertRaises(AssertionError):
len(episode.agent_episodes["agent_5"])
self.assertEqual(len(episode.agent_episodes["agent_5"]), 0)

# TODO (simon): Also test the other structs inside the MAE for agent 5 and
# the other agents.
Expand Down Expand Up @@ -790,7 +788,7 @@ def test_getters(self):
# Test with initial observations only.
episode_init_only = MultiAgentEpisode(agent_ids=agent_ids)
episode_init_only.add_env_reset(
observation=observations[0],
observations=observations[0],
infos=infos[0],
)
# Get the last observation for agents and assert that its correct.
Expand Down Expand Up @@ -846,94 +844,122 @@ def test_getters(self):
self.assertEqual(last_actions["agent_3"][1], actions[-2]["agent_3"])

# --- extra_model_outputs ---
last_extra_model_outputs = episode.get_extra_model_outputs()
self.assertDictEqual(
last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"]
last_extra_model_outputs = episode.get_extra_model_outputs("extra")
self.assertEqual(
last_extra_model_outputs["agent_1"][0],
extra_model_outputs[-1]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"]
self.assertEqual(
last_extra_model_outputs["agent_3"][0],
extra_model_outputs[-1]["agent_3"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"]
self.assertEqual(
last_extra_model_outputs["agent_4"][0],
extra_model_outputs[-1]["agent_4"]["extra"],
)

# Request the last two outputs.
last_extra_model_outputs = episode.get_extra_model_outputs(indices=[-1, -2])
self.assertDictEqual(
last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"]
last_extra_model_outputs = episode.get_extra_model_outputs(
"extra", indices=[-1, -2]
)
self.assertDictEqual(
last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"]
self.assertEqual(
last_extra_model_outputs["agent_1"][0],
extra_model_outputs[-1]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"]
self.assertEqual(
last_extra_model_outputs["agent_3"][0],
extra_model_outputs[-1]["agent_3"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_1"][1], extra_model_outputs[-2]["agent_1"]
self.assertEqual(
last_extra_model_outputs["agent_4"][0],
extra_model_outputs[-1]["agent_4"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_2"][0], extra_model_outputs[-2]["agent_2"]
self.assertEqual(
last_extra_model_outputs["agent_1"][1],
extra_model_outputs[-2]["agent_1"]["extra"],
)
self.assertEqual(
last_extra_model_outputs["agent_2"][0],
extra_model_outputs[-2]["agent_2"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_3"][1], extra_model_outputs[-2]["agent_3"]
self.assertEqual(
last_extra_model_outputs["agent_3"][1],
extra_model_outputs[-2]["agent_3"]["extra"],
)

# Now request lists.
last_extra_model_outputs = episode.get_extra_model_outputs(as_list=True)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_1"], extra_model_outputs[-1]["agent_1"]
last_extra_model_outputs = episode.get_extra_model_outputs(
"extra", as_list=True
)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_3"], extra_model_outputs[-1]["agent_3"]
self.assertEqual(
last_extra_model_outputs[0]["agent_1"],
extra_model_outputs[-1]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_4"], extra_model_outputs[-1]["agent_4"]
self.assertEqual(
last_extra_model_outputs[0]["agent_3"],
extra_model_outputs[-1]["agent_3"]["extra"],
)
self.assertEqual(
last_extra_model_outputs[0]["agent_4"],
extra_model_outputs[-1]["agent_4"]["extra"],
)
# Request the last two extra model outputs and return as a list.
last_extra_model_outputs = episode.get_extra_model_outputs(
[-1, -2], as_list=True
"extra", [-1, -2], as_list=True
)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_1"], extra_model_outputs[-1]["agent_1"]
self.assertEqual(
last_extra_model_outputs[0]["agent_1"],
extra_model_outputs[-1]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_3"], extra_model_outputs[-1]["agent_3"]
self.assertEqual(
last_extra_model_outputs[0]["agent_3"],
extra_model_outputs[-1]["agent_3"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[0]["agent_4"], extra_model_outputs[-1]["agent_4"]
self.assertEqual(
last_extra_model_outputs[0]["agent_4"],
extra_model_outputs[-1]["agent_4"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[1]["agent_1"], extra_model_outputs[-2]["agent_1"]
self.assertEqual(
last_extra_model_outputs[1]["agent_1"],
extra_model_outputs[-2]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[1]["agent_2"], extra_model_outputs[-2]["agent_2"]
self.assertEqual(
last_extra_model_outputs[1]["agent_2"],
extra_model_outputs[-2]["agent_2"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs[1]["agent_3"], extra_model_outputs[-2]["agent_3"]
self.assertEqual(
last_extra_model_outputs[1]["agent_3"],
extra_model_outputs[-2]["agent_3"]["extra"],
)

# Now request the last extra model outputs at the local timesteps, i.e.
# for each agent its last two actions.
last_extra_model_outputs = episode.get_extra_model_outputs(
[-1, -2], global_ts=False
"extra", [-1, -2], global_ts=False
)
self.assertDictEqual(
last_extra_model_outputs["agent_1"][0], extra_model_outputs[-1]["agent_1"]
self.assertEqual(
last_extra_model_outputs["agent_1"][0],
extra_model_outputs[-1]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_3"][0], extra_model_outputs[-1]["agent_3"]
self.assertEqual(
last_extra_model_outputs["agent_3"][0],
extra_model_outputs[-1]["agent_3"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_4"][0], extra_model_outputs[-1]["agent_4"]
self.assertEqual(
last_extra_model_outputs["agent_4"][0],
extra_model_outputs[-1]["agent_4"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_1"][1], extra_model_outputs[-2]["agent_1"]
self.assertEqual(
last_extra_model_outputs["agent_1"][1],
extra_model_outputs[-2]["agent_1"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_2"][0], extra_model_outputs[-2]["agent_2"]
self.assertEqual(
last_extra_model_outputs["agent_2"][0],
extra_model_outputs[-2]["agent_2"]["extra"],
)
self.assertDictEqual(
last_extra_model_outputs["agent_3"][1], extra_model_outputs[-2]["agent_3"]
self.assertEqual(
last_extra_model_outputs["agent_3"][1],
extra_model_outputs[-2]["agent_3"]["extra"],
)

# TODO (simon): Not tested with `global_ts=False`.
Expand Down Expand Up @@ -1073,7 +1099,7 @@ def test_getters(self):
self.assertListEqual(episode_1.global_t_to_local_t["agent_5"][-2:], [99, 100])
# Agent 5 has already died, so we need to convert back to list.
self.assertListEqual(
episode_1.agent_episodes["agent_5"].rewards.tolist()[-2:],
episode_1.agent_episodes["agent_5"].rewards[-2:],
last_rewards["agent_5"],
)
self.assertIn("agent_2", last_rewards)
Expand Down Expand Up @@ -1299,7 +1325,7 @@ def test_getters(self):
self.assertIn("agent_5", last_rewards)
# Agent 5 already died, so we need to convert to list first.
self.assertListEqual(
episode_1.agent_episodes["agent_5"].rewards.tolist()[-1:-3:-1],
episode_1.agent_episodes["agent_5"].rewards[-1:-3:-1],
last_rewards["agent_5"],
)
self.assertIn("agent_8", last_rewards)
Expand Down
27 changes: 27 additions & 0 deletions rllib/env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,33 @@ def get(

return data

def __add__(self, other):
"""Adds another buffer or list to the end of this one.

Args:
other: Either `BufferWithInfiniteLookback` or `list`.
If a `BufferWithInfiniteLookback` the data gets
concatenated. If a `list` the list is concatenated to the
`self.data`.

Returns:
A new `BufferWithInfiniteLookback` instance `self.data` cotnaining
concatenated data from `self.` and `other`.
"""

if self.finalized:
raise RuntimeError(f"Cannot `add` to a finalized {type(self).__name__}.")
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
else:
if isinstance(other, BufferWithInfiniteLookback):
data = self.data + other.data
else:
data = self.data + other
return BufferWithInfiniteLookback(
data=data,
lookback=self.lookback,
space=self.space,
)

def __getitem__(self, item):
"""Support squared bracket syntax, e.g. buffer[:5]."""
return self.get(item)
Expand Down