Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Vectorized Environment Autoreset Incompatible with openai/baselines' API #194

Open
3 tasks done
vwxyzjn opened this issue Sep 14, 2022 · 16 comments
Open
3 tasks done
Assignees

Comments

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Sep 14, 2022

Describe the bug

Related to #33.

When an environment is "done", the autoreset feature in openai/gym' API will reset this environment and return the initial observation from the next episode. Here is a simple demonstration of how it works with gym==0.23.1:

import gym

class TestEnv(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Discrete(10)
        
    def reset(self):
        self.obs = 0
        return self.obs

    def step(self, action):
        self.obs += 1
        return self.obs, 0, False, {}

def thunk():
    env = TestEnv()
    env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
    return env

env = gym.vector.SyncVectorEnv([thunk])
env.reset()
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
(array([1]), array([0.]), array([False]), [{}])
(array([2]), array([0.]), array([False]), [{}])
(array([3]), array([0.]), array([False]), [{}])
(array([0]), array([0.]), array([ True]), [{'TimeLimit.truncated': True, 'terminal_observation': 4}])

Note that done=True and obs=0 is returned in this example, and the truncated observation is put to the info dict.

However, envpool does not implement this behavior and will only return the initial observation of the next episode after an additional step. See reproduction below.

To Reproduce

import envpool
import numpy as np
import matplotlib.pyplot as plt

# make gym env
env = envpool.make(
    "Breakout-v5",
    env_type="gym",
    num_envs=1,
    max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)
    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info['TimeLimit.truncated'])
plt.savefig(f"static/envpool.png")

With stable_baselines3==1.2.0.

import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3.common.atari_wrappers import (  # isort:skip
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
import gym

def thunk():
    env = gym.make("BreakoutNoFrameskip-v4")
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    # env = EpisodicLifeEnv(env) # have to comment this out due to how timelimit works
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = ClipRewardEnv(env)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)
    env = gym.wrappers.FrameStack(env, 4)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
    return env
env = gym.vector.SyncVectorEnv([thunk])
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/gym-reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)
    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info)

plt.savefig(f"static/gym.png")

image

It can be observed from the picture that envpool returns the initial observation of the new episode at step 4, whereas gym's vec env returns it at step 3, the same time when done=True happens.

Expected behavior

In the screenshot above, envpool should return the initial observation of the new episode at step 3.

This is highly relevant to return calculation as it causes an off by 1 error. Consider the following return calculation:

import numpy as np

# assume the game is terminated and resulted in the terminated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
gamma = 1.0
num_steps = 8
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)): 
    if t == num_steps - 1: 
        nextnonterminal = 1.0 - next_done 
        next_return = next_value 
    else: 
        nextnonterminal = 1.0 - dones[t + 1] 
        next_return = returns[t + 1] 
    returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))


# assume the game is truncated and resulted in the truncated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
v_obs3 = 0.008
rewards[2] += v_obs3
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)): 
    if t == num_steps - 1: 
        nextnonterminal = 1.0 - next_done 
        next_return = next_value 
    else: 
        nextnonterminal = 1.0 - dones[t + 1] 
        next_return = returns[t + 1] 
    returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))
[array([1.11]), array([0.11]), array([0.01]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([1.118]), array([0.118]), array([0.018]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]

which calculates the returns for two trajectories correctly.

If the dones are off by 1 like np.array([0, 0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1), the results will be quite different.

[array([3.11]), array([2.11]), array([2.01]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([3.118]), array([2.118]), array([2.018]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
0.6.4 1.21.6 3.7.8 (default, Mar 30 2022, 09:38:46) 
[GCC 11.2.0] linux

Additional context

There are ways to manually trigger reset by doing env.reset(done_env_ids) as follows, but this is not supported in both XLA and async API.

import envpool
import numpy as np
import matplotlib.pyplot as plt

# make gym env
env = envpool.make(
    "Breakout-v5",
    env_type="gym",
    num_envs=1,
    max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)

    # proper auto-reset
    auto_reset_ids = np.where((info['TimeLimit.truncated'] or info['terminated']) == 1)[0]
    obs[auto_reset_ids] = env.reset(auto_reset_ids)

    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info['TimeLimit.truncated'], info['terminated'])
plt.savefig(f"static/envpool_mannual.png")

Reason and Possible fixes

A possible solution is to add a last_observation key like in the gym's API.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@edbeeching
Copy link

I noticed this API mismatch when reading the autoreset documentation yesterday.
Do you have any idea when this will be fixed?

@pseudo-rnd-thoughts
Copy link

Hi from the Gymnasium team, we are planning on reimplementing the vector API to include envpool and therefore change the vector function calls to match EnvPool as it seems better than Gym / Gymnasium implementation.

@vwxyzjn Do you agree?

@edbeeching
Copy link

Hi @pseudo-rnd-thoughts , could you elaborate on why you think the envpool implementation is better? As I believe quite strongly that this will introduce off by one errors and overhead in almost all Deep RL frameworks.

@DavidSlayback
Copy link

I agree with @edbeeching , I'm actually trying to put together a PR to include the final observation on reset, just trying to track down where the auto reset is actually happening

@pseudo-rnd-thoughts
Copy link

At least to me, I find it quite weird to have the final observation in the info not obs. This is just quite unnatural to me.
However, my research is not in algorithm design so you all are probably right about the off by one issues.

Currently, the vector implementations between EnvPool and Gym are different. Therefore, we should probably look to using a single API.
From an elegant side and that info only contains the actual environment info, I prefer the env pool style.

@edbeeching or @DavidSlayback What are the advantages of the Gym version with last obs in info?

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Oct 27, 2022

Hey @edbeeching and @DavidSlayback and @pseudo-rnd-thoughts, can I play the devil's advocate here? The current envpool API might be better than the openai/baselines' design (I still need to think about the details).

The main problem with openai/baselines' design is that the final observation is put into the info, which is very hacky. And the learning library needs to do additional forward passes to calculate the values of truncated states (e.g., see SB3's implementation here), which looks like

# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done_ in enumerate(dones):
    if (
        done_
        and infos[idx].get("terminal_observation") is not None
        and infos[idx].get("TimeLimit.truncated", False)
    ):
        terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
        with th.no_grad():
            terminal_value = self.policy.predict_values(terminal_obs)[0]
        rewards[idx] += self.gamma * terminal_value

However, such an implementation might not be even possible for JAX-based implementations (including using the XLA interface) — we cannot really use if statements like that in a JAX loop. As a result, if we were to do this in JAX, we might have to always do a forward pass on the terminal_observation regardless of whether they exist, and do something like

jnp.where(info["truncated"], rewards[idx] + self.gamma * terminal_value, rewards[idx]

which is incredibly inefficient because we are doing an extra forward pass per step.

An alternative is to use the current envpool's design — we can save all of the info["truncated"] and do something like

rewards[truncations] = jnp.where(truncations, rewards + values, rewards)

Some indexing might be off in this example, but the spirit is to reuse the stored values which were already calculated for both truncated and normal states. Regarding the off-by-one errors, we can simply mask out the truncated observation somehow.

Maybe it's still worth having the openai/baselines-style vectorized environment as a togglable option?

@pseudo-rnd-thoughts
Copy link

I agree with @vwxyzjn and think that the EnvPool API > Gym Vector API
My primary problem is backward compatibility as the two API can look almost identical to each other.

@DavidSlayback
Copy link

Hmmm...to be honest, I'm mostly defending the final observation for correctness, but within that honesty, I'm also not sure how much the one state matters, let alone how often people actually use it. Basically:

  • We autoreset when terminated or truncated, returning the first observation for the next episode
  • Algorithms that look at truncation want to bootstrap value using the last observation instead of the first observation (i.e., imagining future rewards)
  • Algorithms that look at termination don't really care (bootstrap with 0, assume last state is absorbing, has no reward) except maybe some exploration-based ones
  • Regardless of algorithm, we need the first observation

Considering the possibility of JAX/XLA implementations with limited control flow, the inefficiency might be too much. What if terminal_observation always exists, but if we're not actually terminating, it's just a reference to the observation we already get? Then you could always do your bootstrap value passes based on that single observation, but otherwise rely on the normal one

@vwxyzjn Where do the values in rewards[truncations] = jnp.where(truncations, rewards + values, rewards) come from? We would have them for everything except the final state. Also, you've done a ton of experiments within the CleanRL repo. I've seen in PRs that this distinction doesn't seem to matter for stuff like PPO; are there examples of not handling this correctly being a major problem?

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Oct 27, 2022

I'm mostly defending the final observation for correctness

What do you mean? The final observation is present in EnvPool's current API design.

What if terminal_observation always exists, but if we're not actually terminating, it's just a reference to the observation we already get? Then you could always do your bootstrap value passes based on that single observation, but otherwise rely on the normal one

I might be understanding this wrong, but maybe it's the same as my proposal (see draft below).

Where do the values in rewards[truncations] = jnp.where(truncations, rewards + values, rewards) come from?

Here is a preliminary draft: https://www.diffchecker.com/U5VW3TTH (which does not work but demonstrates my point)

Also, you've done a ton of experiments within the CleanRL repo. I've seen in PRs that this distinction doesn't seem to matter for stuff like PPO; are there examples of not handling this correctly being a major problem?

@araffin have you had situations where correctly handling truncation limit has made a different in the performance?

@edbeeching
Copy link

What I find so unusual about the envpool auto reset API is that the environment is not actually automatically reset, you have to call an additional step() with dummy actions for certain indices in order to actually reset environments. If this is happening often, say when there is a large number of parallel environments, it adds overhead.
Imagine that I want to, for example, evaluate a trained policy in a batched setting. I would have to create many workarounds that create dummy actions etc. When it should just be:

obs, info = env.reset()

while some_condition:
    actions = policy.get_action(obs)
    obs, reward, term, trunc, info = env.step(actions)
    # logging of rewards, etc

Perhaps I am missing something, so please feel free to rewrite the above snippet based on the current envpool API.

For me all this stems from the truncated vs terminated issue and algorithms that require the next_obs in order to perform a value estimation when an episode is truncated. I considered trucated episodes to be the exception, not the rule. While it is good to account for truncated episodes, the envpool API is not the right way to do this (in my opinion of course).

@araffin
Copy link
Contributor

araffin commented Oct 27, 2022

@araffin have you had situations where correctly handling truncation limit has made a different in the performance?

Just run SAC/PPO on pybullet with/without time feature/handling of timeout, you will see the difference ;).
I have some results in the appendix A.8 of https://arxiv.org/abs/2005.05719

Also related: hill-a/stable-baselines#120 (comment)

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Nov 1, 2022

@DavidSlayback I had a chance to further look into this and prototyped a snippet https://www.diffchecker.com/U5VW3TTH. You can run it with python -i ppo_atari_envpool_xla_jax_truncation.py --env-id Breakout-v5 --num-envs 1 --num-steps 8 --num-minibatches 1 --update-epochs 1, which generates the following output.

storage.dones.flatten():
 [0. 0. 0. 0. 0. 0. 1. 0.]
storage.truncations.flatten():
 [0. 0. 0. 0. 0. 0. 1. 0.]
storage.rewards.flatten():
 [1. 0. 0. 0. 0. 0. 0. 0.]
storage.values.flatten():
 [0.26382822 0.16091006 0.18179433 0.17108421 0.25206307 0.5066296
 0.33542693 0.1653153 ]
NOTE: bootstrap value as below:
jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards).flatten():
 [1.         0.         0.         0.         0.         0.
 0.33542693 0.        ]

@pseudo-rnd-thoughts
Copy link

pseudo-rnd-thoughts commented Nov 10, 2022

Before Gymnasium makes any changes to the vector environment to use the EnvPool's API not the Gym API
@DavidSlayback @edbeeching @araffin @vwxyzjn Does anyone think this is a bad idea? Or think the gym API is better etc.

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Nov 16, 2022

I did a series of benchmarks and found no significant performance difference when properly handling truncation. So it appears to me the current envpool design is actually quite good because everything has the same shapes.

https://github.com/vwxyzjn/ppo-atari-metrics#generated-results

@walkacross
Copy link

In short, auto-reset in gym:

reset_obs = env.reset()
action_result_in_done = poilicy(reset_obs)
a_reset_obs, reward, done, info = env.step(action_result_in_done)

in envpool

reset_obs = envpool_env.reset()
action_result_in_done = policy(reset_obs)
not_a_reset_obs, reward1, done, info = envpool_env.step(action_result_in_done)

action = policy(not_a_reset_obs)
a_reset_obs, reward2, done, info = envpool_env.step(action)

the auto-reset solution in envpool will result in two controlversial transitions when someone use env to collect transitions

reset_obs, action_result_in_done, reward1, not_a_reset_obs

not_a_reset_obs, action, reward2, a_reset_obs

these two transitions, specially for the second, are NOT IN LOGIC(specially when model the state transition in model-based rl case) and should not be involed in ReplayBuffer.

any suggestions for address these issues when some Collector follows the gym.vec.env auto-reset protocol?

@roger-creus
Copy link

roger-creus commented Aug 29, 2023

Coming back to this after quite a long time. Can you confirm there still exists this API mismatch?

I think the Gym API > EnvPool API and I realized that when playing with the PPO+LSTM from CleanRL:

You can see in the following code how when done=True, the LSTM expects to combine the hidden corresponding to the first observation of the episode with the newly initialized LSTM hidden full of zeros. Otherwise, according to the EnvPool API, this code would be combining the new LSTM hidden with the (old) observation from the previous episode (introducing a bug):

def get_states(self, x, lstm_state, done):
        hidden = self.network(x / 255.0)

        # LSTM logic
        batch_size = lstm_state[0].shape[1]
        hidden = hidden.reshape((-1, batch_size, self.lstm.input_size))
        done = done.reshape((-1, batch_size))
        new_hidden = []
        for h, d in zip(hidden, done):
            h, lstm_state = self.lstm(
                h.unsqueeze(0),
                (
                    (1.0 - d).view(1, -1, 1) * lstm_state[0],
                    (1.0 - d).view(1, -1, 1) * lstm_state[1],
                ),
            )
            new_hidden += [h]
        new_hidden = torch.flatten(torch.cat(new_hidden), 0, 1)
        return new_hidden, lstm_state

So I think this should be the natural way to think about observations when done=True for clear implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants