-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Vectorized Environment Autoreset Incompatible with openai/baselines' API #194
Comments
I noticed this API mismatch when reading the autoreset documentation yesterday. |
Hi @pseudo-rnd-thoughts , could you elaborate on why you think the envpool implementation is better? As I believe quite strongly that this will introduce off by one errors and overhead in almost all Deep RL frameworks. |
I agree with @edbeeching , I'm actually trying to put together a PR to include the final observation on reset, just trying to track down where the auto reset is actually happening |
At least to me, I find it quite weird to have the final observation in the info not obs. This is just quite unnatural to me. Currently, the vector implementations between EnvPool and Gym are different. Therefore, we should probably look to using a single API. @edbeeching or @DavidSlayback What are the advantages of the Gym version with last obs in info? |
Hey @edbeeching and @DavidSlayback and @pseudo-rnd-thoughts, can I play the devil's advocate here? The current envpool API might be better than the openai/baselines' design (I still need to think about the details). The main problem with # Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done_ in enumerate(dones):
if (
done_
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value However, such an implementation might not be even possible for JAX-based implementations (including using the XLA interface) — we cannot really use jnp.where(info["truncated"], rewards[idx] + self.gamma * terminal_value, rewards[idx] which is incredibly inefficient because we are doing an extra forward pass per step. An alternative is to use the current envpool's design — we can save all of the rewards[truncations] = jnp.where(truncations, rewards + values, rewards) Some indexing might be off in this example, but the spirit is to reuse the stored values which were already calculated for both truncated and normal states. Regarding the off-by-one errors, we can simply mask out the truncated observation somehow. Maybe it's still worth having the |
I agree with @vwxyzjn and think that the EnvPool API > Gym Vector API |
Hmmm...to be honest, I'm mostly defending the final observation for correctness, but within that honesty, I'm also not sure how much the one state matters, let alone how often people actually use it. Basically:
Considering the possibility of JAX/XLA implementations with limited control flow, the inefficiency might be too much. What if @vwxyzjn Where do the values in |
What do you mean? The final observation is present in EnvPool's current API design.
I might be understanding this wrong, but maybe it's the same as my proposal (see draft below).
Here is a preliminary draft: https://www.diffchecker.com/U5VW3TTH (which does not work but demonstrates my point)
@araffin have you had situations where correctly handling truncation limit has made a different in the performance? |
What I find so unusual about the envpool auto reset API is that the environment is not actually automatically reset, you have to call an additional step() with dummy actions for certain indices in order to actually reset environments. If this is happening often, say when there is a large number of parallel environments, it adds overhead.
Perhaps I am missing something, so please feel free to rewrite the above snippet based on the current envpool API. For me all this stems from the truncated vs terminated issue and algorithms that require the next_obs in order to perform a value estimation when an episode is truncated. I considered trucated episodes to be the exception, not the rule. While it is good to account for truncated episodes, the envpool API is not the right way to do this (in my opinion of course). |
Just run SAC/PPO on pybullet with/without time feature/handling of timeout, you will see the difference ;). Also related: hill-a/stable-baselines#120 (comment) |
@DavidSlayback I had a chance to further look into this and prototyped a snippet https://www.diffchecker.com/U5VW3TTH. You can run it with
|
Before Gymnasium makes any changes to the vector environment to use the EnvPool's API not the Gym API |
I did a series of benchmarks and found no significant performance difference when properly handling truncation. So it appears to me the current envpool design is actually quite good because everything has the same shapes. https://github.com/vwxyzjn/ppo-atari-metrics#generated-results |
In short, auto-reset in gym:
in envpool
the auto-reset solution in envpool will result in two controlversial transitions when someone use env to collect transitions
these two transitions, specially for the second, are NOT IN LOGIC(specially when model the state transition in model-based rl case) and should not be involed in ReplayBuffer. any suggestions for address these issues when some Collector follows the gym.vec.env auto-reset protocol? |
Coming back to this after quite a long time. Can you confirm there still exists this API mismatch? I think the Gym API > EnvPool API and I realized that when playing with the PPO+LSTM from CleanRL: You can see in the following code how when done=True, the LSTM expects to combine the hidden corresponding to the first observation of the episode with the newly initialized LSTM hidden full of zeros. Otherwise, according to the EnvPool API, this code would be combining the new LSTM hidden with the (old) observation from the previous episode (introducing a bug):
So I think this should be the natural way to think about observations when done=True for clear implementations. |
Describe the bug
Related to #33.
When an environment is "done", the autoreset feature in openai/gym' API will reset this environment and return the initial observation from the next episode. Here is a simple demonstration of how it works with
gym==0.23.1
:Note that
done=True
andobs=0
is returned in this example, and the truncated observation is put to theinfo
dict.However, envpool does not implement this behavior and will only return the initial observation of the next episode after an additional step. See reproduction below.
To Reproduce
With
stable_baselines3==1.2.0
.It can be observed from the picture that envpool returns the initial observation of the new episode at step 4, whereas gym's vec env returns it at step 3, the same time when
done=True
happens.Expected behavior
In the screenshot above, envpool should return the initial observation of the new episode at step 3.
This is highly relevant to return calculation as it causes an off by 1 error. Consider the following return calculation:
which calculates the returns for two trajectories correctly.
If the
dones
are off by 1 likenp.array([0, 0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1)
, the results will be quite different.System info
Describe the characteristic of your environment:
Additional context
There are ways to manually trigger reset by doing
env.reset(done_env_ids)
as follows, but this is not supported in both XLA and async API.Reason and Possible fixes
A possible solution is to add a
last_observation
key like in the gym's API.Checklist
The text was updated successfully, but these errors were encountered: