Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 21629: Video recorder env wrapper not working. Added test case. #21670

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions python/requirements/ml/requirements_rllib.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tensorflow_estimator==2.6.0
higher==0.2.1
# For auto-generating an env-rendering Window.
pyglet==1.5.15
imageio-ffmpeg==0.4.5
# For JSON reader/writer.
smart_open==5.0.0
# Ray Serve example
Expand Down
9 changes: 7 additions & 2 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class BaseEnv:

def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down Expand Up @@ -729,7 +729,12 @@ def convert_to_base_env(

# Given `env` is already a BaseEnv -> Return as is.
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
return env.to_base_env()
return env.to_base_env(
make_env=make_env,
num_envs=num_envs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol whoops my bad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this adds any speedup

remote_envs=remote_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
)
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
else:
# Sub-environments are ray.remote actors:
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/external_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _get(self, episode_id: str) -> "_ExternalEnvEpisode":

def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def with_agent_groups(
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
61 changes: 50 additions & 11 deletions rllib/env/tests/test_record_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,90 @@
from gym import wrappers
import tempfile
import glob
import gym
import numpy as np
import os
import shutil
import unittest

from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
from ray.rllib.examples.env.mock_env import MockEnv2
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
from ray.rllib.utils.test_utils import check


class TestRecordEnvWrapper(unittest.TestCase):
def test_wrap_gym_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")

if not os.path.exists(record_env_dir):
sys.exit(1)

num_steps_per_episode = 10
wrapped = record_env_wrapper(
env=MockEnv2(10),
record_env=tempfile.gettempdir(),
env=MockEnv2(num_steps_per_episode),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertFalse(isinstance(wrapped, VideoMonitor))

wrapped.reset()
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# 10 steps for a complete episode.
for i in range(10):
for i in range(num_steps_per_episode):
wrapped.step(0)
# Another episode.
wrapped.reset()
for i in range(num_steps_per_episode):
wrapped.step(0)
# Expect another video file to have been produced (2nd episode).
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 2)

# MockEnv2 returns a reward of 100.0 every step.
# So total reward is 1000.0.
self.assertEqual(wrapped.get_episode_rewards(), [1000.0])
# So total reward is 1000.0 per episode (10 steps).
check(
np.array([100.0, 100.0]) * num_steps_per_episode,
wrapped.get_episode_rewards())
avnishn marked this conversation as resolved.
Show resolved Hide resolved
# Erase all generated files and the temp path just in case,
# as to not disturb further CI-tests.
shutil.rmtree(record_env_dir)

def test_wrap_multi_agent_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")

if not os.path.exists(record_env_dir):
sys.exit(1)

wrapped = record_env_wrapper(
env=BasicMultiAgent(3),
record_env=tempfile.gettempdir(),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is VideoMonitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertTrue(isinstance(wrapped, VideoMonitor))

wrapped.reset()

# BasicMultiAgent is hardcoded to run 25-step episodes.
for i in range(25):
wrapped.step({0: 0, 1: 0, 2: 0})

# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)

# However VideoMonitor's _after_step is overwritten to not
# use stats_recorder. So nothing to verify here, except that
# it runs fine.
Expand Down
14 changes: 8 additions & 6 deletions rllib/env/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gym
from gym import wrappers
import os

Expand All @@ -7,7 +8,7 @@
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError


def gym_env_creator(env_context: EnvContext, env_descriptor: str):
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
"""Tries to create a gym env given an EnvContext object and descriptor.

Note: This function tries to construct the env from a string descriptor
Expand All @@ -17,20 +18,19 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
necessary imports and construction logic below.

Args:
env_context (EnvContext): The env context object to configure the env.
env_context: The env context object to configure the env.
Note that this is a config dict, plus the properties:
`worker_index`, `vector_index`, and `remote`.
env_descriptor (str): The env descriptor, e.g. CartPole-v0,
env_descriptor: The env descriptor, e.g. CartPole-v0,
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
CartPoleContinuousBulletEnv-v0.

Returns:
gym.Env: The actual gym environment object.
The actual gym environment object.

Raises:
gym.error.Error: If the env cannot be constructed.
"""
import gym
# Allow for PyBullet or VizdoomGym envs to be used as well
# (via string). This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0` or
Expand Down Expand Up @@ -85,7 +85,9 @@ def record_env_wrapper(env, record_env, log_dir, policy_config):
print(f"Setting the path for recording to {path_}")
wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
else wrappers.Monitor
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
if isinstance(env, MultiAgentEnv):
wrapper_cls = add_mixins(
wrapper_cls, [MultiAgentEnv], reversed=True)
env = wrapper_cls(
env,
path_,
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_unwrapped(self) -> List[EnvType]:
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
Expand Down
11 changes: 11 additions & 0 deletions rllib/examples/env/mock_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import numpy as np

from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -34,6 +35,10 @@ class MockEnv2(gym.Env):
configurable. Actions are ignored.
"""

metadata = {
avnishn marked this conversation as resolved.
Show resolved Hide resolved
"render.modes": ["rgb_array"],
}

def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
Expand All @@ -52,6 +57,12 @@ def step(self, action):
def seed(self, rng_seed):
self.rng_seed = rng_seed

def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
avnishn marked this conversation as resolved.
Show resolved Hide resolved
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)


class MockEnv3(gym.Env):
"""Mock environment for testing purposes.
Expand Down
11 changes: 11 additions & 0 deletions rllib/examples/env/multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import numpy as np
import random

from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
Expand All @@ -18,6 +19,10 @@ def make_multiagent(env_name_or_creator):
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""

metadata = {
"render.modes": ["rgb_array"],
}

def __init__(self, num):
super().__init__()
self.agents = [MockEnv(25) for _ in range(num)]
Expand All @@ -40,6 +45,12 @@ def step(self, action_dict):
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info

def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(200, 300, 3), dtype=np.uint8)


class EarlyDoneMultiAgent(MultiAgentEnv):
"""Env for testing when the env terminates (after agent 0 does)."""
Expand Down