Skip to content

Commit

Permalink
[RLlib] Issue 21629: Video recorder env wrapper not working. Added te…
Browse files Browse the repository at this point in the history
…st case. (#21670)
  • Loading branch information
sven1977 committed Jan 24, 2022
1 parent 2010f13 commit c288b97
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 22 deletions.
1 change: 1 addition & 0 deletions python/requirements/ml/requirements_rllib.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tensorflow_estimator==2.6.0
higher==0.2.1
# For auto-generating an env-rendering Window.
pyglet==1.5.15
imageio-ffmpeg==0.4.5
# For JSON reader/writer.
smart_open==5.0.0
# Ray Serve example
Expand Down
9 changes: 7 additions & 2 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class BaseEnv:

def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down Expand Up @@ -729,7 +729,12 @@ def convert_to_base_env(

# Given `env` is already a BaseEnv -> Return as is.
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
return env.to_base_env()
return env.to_base_env(
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
)
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
else:
# Sub-environments are ray.remote actors:
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/external_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _get(self, episode_id: str) -> "_ExternalEnvEpisode":

def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def with_agent_groups(
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
61 changes: 50 additions & 11 deletions rllib/env/tests/test_record_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,90 @@
from gym import wrappers
import tempfile
import glob
import gym
import numpy as np
import os
import shutil
import unittest

from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
from ray.rllib.examples.env.mock_env import MockEnv2
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
from ray.rllib.utils.test_utils import check


class TestRecordEnvWrapper(unittest.TestCase):
def test_wrap_gym_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")

if not os.path.exists(record_env_dir):
sys.exit(1)

num_steps_per_episode = 10
wrapped = record_env_wrapper(
env=MockEnv2(10),
record_env=tempfile.gettempdir(),
env=MockEnv2(num_steps_per_episode),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertFalse(isinstance(wrapped, VideoMonitor))

wrapped.reset()
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# 10 steps for a complete episode.
for i in range(10):
for i in range(num_steps_per_episode):
wrapped.step(0)
# Another episode.
wrapped.reset()
for i in range(num_steps_per_episode):
wrapped.step(0)
# Expect another video file to have been produced (2nd episode).
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 2)

# MockEnv2 returns a reward of 100.0 every step.
# So total reward is 1000.0.
self.assertEqual(wrapped.get_episode_rewards(), [1000.0])
# So total reward is 1000.0 per episode (10 steps).
check(
np.array([100.0, 100.0]) * num_steps_per_episode,
wrapped.get_episode_rewards())
# Erase all generated files and the temp path just in case,
# as to not disturb further CI-tests.
shutil.rmtree(record_env_dir)

def test_wrap_multi_agent_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")

if not os.path.exists(record_env_dir):
sys.exit(1)

wrapped = record_env_wrapper(
env=BasicMultiAgent(3),
record_env=tempfile.gettempdir(),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is VideoMonitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertTrue(isinstance(wrapped, VideoMonitor))

wrapped.reset()

# BasicMultiAgent is hardcoded to run 25-step episodes.
for i in range(25):
wrapped.step({0: 0, 1: 0, 2: 0})

# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)

# However VideoMonitor's _after_step is overwritten to not
# use stats_recorder. So nothing to verify here, except that
# it runs fine.
Expand Down
14 changes: 8 additions & 6 deletions rllib/env/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gym
from gym import wrappers
import os

Expand All @@ -7,7 +8,7 @@
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError


def gym_env_creator(env_context: EnvContext, env_descriptor: str):
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
"""Tries to create a gym env given an EnvContext object and descriptor.
Note: This function tries to construct the env from a string descriptor
Expand All @@ -17,20 +18,19 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
necessary imports and construction logic below.
Args:
env_context (EnvContext): The env context object to configure the env.
env_context: The env context object to configure the env.
Note that this is a config dict, plus the properties:
`worker_index`, `vector_index`, and `remote`.
env_descriptor (str): The env descriptor, e.g. CartPole-v0,
env_descriptor: The env descriptor, e.g. CartPole-v0,
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
CartPoleContinuousBulletEnv-v0.
Returns:
gym.Env: The actual gym environment object.
The actual gym environment object.
Raises:
gym.error.Error: If the env cannot be constructed.
"""
import gym
# Allow for PyBullet or VizdoomGym envs to be used as well
# (via string). This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0` or
Expand Down Expand Up @@ -85,7 +85,9 @@ def record_env_wrapper(env, record_env, log_dir, policy_config):
print(f"Setting the path for recording to {path_}")
wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
else wrappers.Monitor
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
if isinstance(env, MultiAgentEnv):
wrapper_cls = add_mixins(
wrapper_cls, [MultiAgentEnv], reversed=True)
env = wrapper_cls(
env,
path_,
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_unwrapped(self) -> List[EnvType]:
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
11 changes: 11 additions & 0 deletions rllib/examples/env/mock_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import numpy as np

from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -34,6 +35,10 @@ class MockEnv2(gym.Env):
configurable. Actions are ignored.
"""

metadata = {
"render.modes": ["rgb_array"],
}

def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
Expand All @@ -52,6 +57,12 @@ def step(self, action):
def seed(self, rng_seed):
self.rng_seed = rng_seed

def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)


class MockEnv3(gym.Env):
"""Mock environment for testing purposes.
Expand Down
11 changes: 11 additions & 0 deletions rllib/examples/env/multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import numpy as np
import random

from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
Expand All @@ -18,6 +19,10 @@ def make_multiagent(env_name_or_creator):
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""

metadata = {
"render.modes": ["rgb_array"],
}

def __init__(self, num):
super().__init__()
self.agents = [MockEnv(25) for _ in range(num)]
Expand All @@ -40,6 +45,12 @@ def step(self, action_dict):
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info

def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(200, 300, 3), dtype=np.uint8)


class EarlyDoneMultiAgent(MultiAgentEnv):
"""Env for testing when the env terminates (after agent 0 does)."""
Expand Down

0 comments on commit c288b97

Please sign in to comment.