Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples now support gym 0.26 #215

Merged
merged 7 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/acme_examples/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
from acme.jax.types import PRNGKey
from acme.utils import loggers
from acme.utils.loggers import aggregators, base, filters, terminal
from packaging import version
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

import envpool
from envpool.python.protocol import EnvPool

logging.getLogger().setLevel(logging.INFO)
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class TimeStep(dm_env.TimeStep):
Expand Down Expand Up @@ -276,7 +278,10 @@ def __init__(

def reset(self) -> TimeStep:
self._reset_next_step = False
observation = self._environment.reset()
if is_legacy_gym:
observation = self._environment.reset()
else:
observation, _ = self._environment.reset()
ts = TimeStep(
step_type=np.full(self._num_envs, dm_env.StepType.FIRST, dtype="int32"),
reward=np.zeros(self._num_envs, dtype="float32"),
Expand All @@ -289,7 +294,11 @@ def step(self, action: types.NestedArray) -> TimeStep:
if self._reset_next_step:
return self.reset()
if self._use_env_pool:
observation, reward, done, _ = self._environment.step(action)
if is_legacy_gym:
observation, reward, done, _ = self._environment.step(action)
else:
observation, reward, term, trunc, _ = self._environment.step(action)
done = term + trunc
else:
self._environment.step_async(action)
observation, reward, done, _ = self._environment.step_wait()
Expand Down
19 changes: 16 additions & 3 deletions examples/cleanrl_examples/ppo_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import torch
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


def parse_args():
# fmt: off
Expand Down Expand Up @@ -221,7 +224,10 @@ def __init__(self, env, deque_size=100):
print("env has lives")

def reset(self, **kwargs):
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
if is_legacy_gym:
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
else:
observations, _ = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
Expand All @@ -230,8 +236,15 @@ def reset(self, **kwargs):
return observations

def step(self, action):
observations, rewards, dones, infos = super(RecordEpisodeStatistics,
self).step(action)
if is_legacy_gym:
observations, rewards, dones, infos = super(
RecordEpisodeStatistics, self
).step(action)
else:
observations, rewards, term, trunc, infos = super(
RecordEpisodeStatistics, self
).step(action)
dones = term + trunc
self.episode_returns += infos["reward"]
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
Expand Down
10 changes: 9 additions & 1 deletion examples/ppo_atari/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import argparse
from typing import Any, Dict, Tuple, Type

import gym
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from gae import compute_gae
from packaging import version
from torch import nn
from torch.utils.tensorboard import SummaryWriter

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class CnnActorCritic(nn.Module):

Expand Down Expand Up @@ -228,7 +232,11 @@ def run(self) -> None:
while t.n < self.config.step_per_epoch:
# collect
for _ in range(self.config.step_per_collect // self.config.waitnum):
obs, rew, done, info = self.train_envs.recv()
if is_legacy_gym:
obs, rew, done, info = self.train_envs.recv()
else:
obs, rew, term, trunc, info = self.train_envs.recv()
done = term + trunc
env_id = info["env_id"]
obs = torch.tensor(obs, device="cuda")
self.obs_batch.append(obs)
Expand Down
40 changes: 35 additions & 5 deletions examples/sb3_examples/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import gym
import numpy as np
import torch as th
from packaging import version
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
Expand All @@ -38,6 +39,7 @@
seed = 0
use_env_pool = True # whether to use EnvPool or Gym for training
render = False # whether to render final policy using Gym
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class VecAdapter(VecEnvWrapper):
Expand All @@ -56,14 +58,21 @@ def step_async(self, actions: np.ndarray) -> None:
self.actions = actions

def reset(self) -> VecEnvObs:
return self.venv.reset()
if is_legacy_gym:
return self.venv.reset()
else:
return self.venv.reset()[0]

def seed(self, seed: Optional[int] = None) -> None:
# You can only seed EnvPool env by calling envpool.make()
pass

def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
if is_legacy_gym:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
else:
obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions)
dones = terms + truncs
infos = []
# Convert dict to list of dict
# and add terminal observation
Expand All @@ -77,8 +86,10 @@ def step_wait(self) -> VecEnvStepReturn:
)
if dones[i]:
infos[i]["terminal_observation"] = obs[i]
obs[i] = self.venv.reset(np.array([i]))

if is_legacy_gym:
obs[i] = self.venv.reset(np.array([i]))
else:
obs[i] = self.venv.reset(np.array([i]))[0]
return obs, rewards, dones, infos


Expand Down Expand Up @@ -115,7 +126,26 @@ def step_wait(self) -> VecEnvStepReturn:
pass

# Agent trained on envpool version should also perform well on regular Gym env
test_env = gym.make(env_id)
if not is_legacy_gym:

def legacy_wrap(env):
env.reset_fn = env.reset
env.step_fn = env.step

def legacy_reset():
return env.reset_fn()[0]

def legacy_step(action):
obs, rew, term, trunc, info = env.step_fn(action)
return obs, rew, term + trunc, info

env.reset = legacy_reset
env.step = legacy_step
return env

test_env = legacy_wrap(gym.make(env_id))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is it possible to work out-of-box if we don't add legacy_wrap?

Copy link
Contributor Author

@51616 51616 Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the wrap, it wouldn't work with gym=0.26 using the latest version of sb3 on pip.

else:
test_env = gym.make(env_id)

# Test with EnvPool
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
Expand Down
14 changes: 12 additions & 2 deletions examples/xla_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
See https://envpool.readthedocs.io/en/latest/content/xla_interface.html
"""

import gym
import jax.numpy as jnp
from jax import jit, lax
from packaging import version

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


def policy(states: jnp.ndarray) -> jnp.ndarray:
return jnp.zeros(states.shape[0], dtype=jnp.int32)
Expand All @@ -35,14 +39,20 @@ def gym_sync_step() -> None:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states)
handle1, (new_states, rew, done, info) = step(handle0, action)
if is_legacy_gym:
handle1, (new_states, rew, done, info) = step(handle0, action)
else:
handle1, (new_states, rew, term, trunc, info) = step(handle0, action)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tested this for new gym version?

Copy link
Contributor Author

@51616 51616 Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It works with both gym=0.26 and gym=0.21.

return (handle1, new_states)

@jit
def run_actor_loop(num_steps, init_var):
return lax.fori_loop(0, num_steps, actor_step, init_var)

states = env.reset()
if is_legacy_gym:
states = env.reset()
else:
states, _ = env.reset()
run_actor_loop(100, (handle, states))


Expand Down