Skip to content

Commit

Permalink
[RLlib] Fix Single/MultiAgentEnvRunner missing env-to-module connecto…
Browse files Browse the repository at this point in the history
…r call in `_sample_episodes()`. (#45517)
  • Loading branch information
sven1977 committed May 23, 2024
1 parent 2bd35d7 commit 4ebbcbf
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 22 deletions.
22 changes: 17 additions & 5 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,24 @@ def on_episode_end(
The exact time of the call of this callback is after `env.step([action])` and
also after the results of this step (observation, reward, terminated, truncated,
infos) have been logged to the given `episode` object, where either terminated
or truncated were True.
or truncated were True:
Note that on the new API stack, this callback is always preceeded by an
`on_episode_step` call, which comes before the call to this method, but is
provided with the non-finalized episode object (meaning the data has NOT
been converted to numpy arrays yet).
- The env is stepped: `final_obs, rewards, ... = env.step([action])`
- The step results are logged `episode.add_env_step(final_obs, rewards)`
- Callback `on_episode_step` is fired.
- Another env-to-module connector call is made (even though we won't need any
RLModule forward pass anymore). We make this additional call to ensure that in
case users use the connector pipeline to process observations (and write them
back into the episode), the episode object has all observations - even the
terminal one - properly processed.
- ---> This callback `on_episode_end()` is fired. <---
- The episode is finalized (i.e. lists of obs/rewards/actions/etc.. are
converted into numpy arrays).
Args:
episode: The terminated/truncated SingleAgent- or MultiAgentEpisode object
Expand Down
48 changes: 34 additions & 14 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def _sample_timesteps(
extra_model_outputs=extra_model_outputs,
)

# Make the `on_episode_step` callback (before finalizing the episode
# object).
self._make_on_episode_callback("on_episode_step")

# Episode is done for all agents. Wrap up the old one and create a new
# one (and reset it) to continue.
if self._episode.is_done:
Expand All @@ -329,17 +333,20 @@ def _sample_timesteps(
# a call and in case the structure of the observations change
# sufficiently, the following `finalize()` call on the episode will
# fail.
self._env_to_module(
episodes=[self._episode],
explore=explore,
rl_module=self.module,
shared_data=self._shared_data,
)
if self.module is not None:
self._env_to_module(
episodes=[self._episode],
explore=explore,
rl_module=self.module,
shared_data=self._shared_data,
)

# Make the `on_episode_step` and `on_episode_end` callbacks (before
# finalizing the episode object).
self._make_on_episode_callback("on_episode_step")
# Make the `on_episode_end` callback (before finalizing the episode,
# but after(!) the last env-to-module connector call has been made.
# -> All obs (even the terminal one) should have been processed now (by
# the connector, if applicable).
self._make_on_episode_callback("on_episode_end")

# Finalize (numpy'ize) the episode.
self._episode.finalize(drop_zero_len_single_agent_episodes=True)
done_episodes_to_return.append(self._episode)
Expand All @@ -356,10 +363,6 @@ def _sample_timesteps(
# Make the `on_episode_start` callback.
self._make_on_episode_callback("on_episode_start")

else:
# Make the `on_episode_step` callback.
self._make_on_episode_callback("on_episode_step")

# Already perform env-to-module connector call for next call to
# `_sample_timesteps()`. See comment in c'tor for `self._cached_to_module`.
if self.module is not None:
Expand Down Expand Up @@ -531,7 +534,24 @@ def _sample_episodes(
# Increase episode count.
eps += 1

# Make `on_episode_end` callback before finalizing the episode.
# We have to perform an extra env-to-module pass here, just in case
# the user's connector pipeline performs (permanent) transforms
# on each observation (including this final one here). Without such
# a call and in case the structure of the observations change
# sufficiently, the following `finalize()` call on the episode will
# fail.
if self.module is not None:
self._env_to_module(
episodes=[_episode],
explore=explore,
rl_module=self.module,
shared_data=_shared_data,
)

# Make the `on_episode_end` callback (before finalizing the episode,
# but after(!) the last env-to-module connector call has been made.
# -> All obs (even the terminal one) should have been processed now (by
# the connector, if applicable).
self._make_on_episode_callback("on_episode_end")

# Finish the episode.
Expand Down
27 changes: 24 additions & 3 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def _sample_timesteps(
truncated=truncateds[env_index],
extra_model_outputs=extra_model_output,
)
# Make the `on_episode_step` and `on_episode_end` callbacks (before
# finalizing the episode object).
self._make_on_episode_callback("on_episode_step", env_index)

# We have to perform an extra env-to-module pass here, just in case
# the user's connector pipeline performs (permanent) transforms
# on each observation (including this final one here). Without such
Expand All @@ -347,9 +351,7 @@ def _sample_timesteps(
rl_module=self.module,
shared_data=self._shared_data,
)
# Make the `on_episode_step` and `on_episode_end` callbacks (before
# finalizing the episode object).
self._make_on_episode_callback("on_episode_step", env_index)

self._make_on_episode_callback("on_episode_end", env_index)

# Then finalize (numpy'ize) the episode.
Expand Down Expand Up @@ -524,6 +526,25 @@ def _sample_episodes(
self._make_on_episode_callback(
"on_episode_step", env_index, episodes
)

# We have to perform an extra env-to-module pass here, just in case
# the user's connector pipeline performs (permanent) transforms
# on each observation (including this final one here). Without such
# a call and in case the structure of the observations change
# sufficiently, the following `finalize()` call on the episode will
# fail.
if self.module is not None:
self._env_to_module(
episodes=[episodes[env_index]],
explore=explore,
rl_module=self.module,
shared_data=_shared_data,
)

# Make the `on_episode_end` callback (before finalizing the episode,
# but after(!) the last env-to-module connector call has been made.
# -> All obs (even the terminal one) should have been processed now
# (by the connector, if applicable).
self._make_on_episode_callback(
"on_episode_end", env_index, episodes
)
Expand Down

0 comments on commit 4ebbcbf

Please sign in to comment.