Skip to content

Commit

Permalink
Fix/deterministic action space sampling in SubprocVectorEnv (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxhuettenrauch committed Apr 18, 2024
1 parent 6935a11 commit a043711
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
57 changes: 57 additions & 0 deletions test/base/test_action_space_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import gymnasium as gym

from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv


def test_gym_env_action_space() -> None:
env = gym.make("Pendulum-v1")
env.action_space.seed(0)
action1 = env.action_space.sample()

env.action_space.seed(0)
action2 = env.action_space.sample()

assert action1 == action2


def test_dummy_vec_env_action_space() -> None:
num_envs = 4
envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]

envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]

assert action1 == action2


def test_subproc_vec_env_action_space() -> None:
num_envs = 4
envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]

envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]

assert action1 == action2


def test_shmem_vec_env_action_space() -> None:
num_envs = 4
envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]

envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]

assert action1 == action2


if __name__ == "__main__":
test_gym_env_action_space()
test_dummy_vec_env_action_space()
test_subproc_vec_env_action_space()
test_shmem_vec_env_action_space()
1 change: 1 addition & 0 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _encode_obs(
if hasattr(env, "seed"):
p.send(env.seed(data))
else:
env.action_space.seed(seed=data)
env.reset(seed=data)
p.send(None)
elif cmd == "getattr":
Expand Down

0 comments on commit a043711

Please sign in to comment.