Skip to content

Commit

Permalink
[RLlib] Issues 16287 and 16200: RLlib not rendering custom multi-agen…
Browse files Browse the repository at this point in the history
…t Envs. (#16428)
  • Loading branch information
sven1977 authored Jun 19, 2021
1 parent 853caea commit 79a9d6d
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 46 deletions.
10 changes: 4 additions & 6 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,10 @@ def get_unwrapped(self) -> List[EnvType]:

@override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None:
if env_id is not None:
assert isinstance(env_id, int)
self.envs[env_id].render()
else:
for e in self.envs:
e.render()
if env_id is None:
env_id = 0
assert isinstance(env_id, int)
return self.envs[env_id].render()


class _MultiAgentEnvState:
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,6 @@ def step(self, action_dict):

@override(MultiAgentEnv)
def render(self, mode=None):
return [a.render(mode) for a in self.agents]
return self.agents[0].render(mode)

return MultiEnv
47 changes: 47 additions & 0 deletions rllib/env/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from gym import wrappers
import os
import re

from ray.rllib.env.env_context import EnvContext


Expand Down Expand Up @@ -64,3 +68,46 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
`ray.rllib.examples.env.repeat_after_me_env.RepeatAfterMeEnv`
"""
raise gym.error.Error(error_msg)


class VideoMonitor(wrappers.Monitor):
# Same as original method, but doesn't use the StatsRecorder as it will
# try to add up multi-agent rewards dicts, which throws errors.
def _after_step(self, observation, reward, done, info):
if not self.enabled:
return done

# Use done["__all__"] b/c this is a multi-agent dict.
if done["__all__"] and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation
# will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()

# Record video
self.video_recorder.capture_frame()

return done


def record_env_wrapper(env, record_env, log_dir, policy_config):
if record_env:
path_ = record_env if isinstance(record_env, str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not re.search("[/\\\]", path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
from ray.rllib.env.multi_agent_env import MultiAgentEnv
wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
else wrappers.Monitor
env = wrapper_cls(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation"
if policy_config["in_evaluation"] else "training")
return env
71 changes: 33 additions & 38 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pickle
import platform
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
TYPE_CHECKING, Union

Expand All @@ -15,6 +14,7 @@
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.utils import record_env_wrapper
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
Expand Down Expand Up @@ -382,14 +382,38 @@ def gen_rollouts():
# Create an env for this worker.
else:
self.env = _validate_env(env_creator(env_context))
# Validate environment, if validation function provided.
if validate_env is not None:
validate_env(self.env, self.env_context)

if isinstance(self.env, (BaseEnv, MultiAgentEnv)):
# MultiAgentEnv (a gym.Env) -> Wrap and make
# the wrapped Env yet another MultiAgentEnv.
if isinstance(self.env, MultiAgentEnv):

def wrap(env):
return env # we can't auto-wrap these env types
cls = env.__class__
# Add gym.Env as mixin parent to the env's class
# (so it can be wrapped).
env.__class__ = \
type(env.__class__.__name__, (type(env), gym.Env), {})
# Wrap the (now gym.Env) env with our (multi-agent capable)
# recording wrapper.
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
# Make sure, we make the wrapped object a member of the
# original MultiAgentEnv sub-class again.
if type(env) is not cls:
env.__class__ = \
type(cls.__name__, (type(env), cls), {})
return env

# We can't auto-wrap a BaseEnv.
elif isinstance(self.env, BaseEnv):

def wrap(env):
return env

# Atari type env and "deepmind" preprocessor pref.
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
Expand Down Expand Up @@ -422,45 +446,16 @@ def wrap(env):
dim=model_config.get("dim"),
framestack=framestack,
framestack_via_traj_view_api=framestack_traj_view)
if record_env:
from gym import wrappers
path_ = record_env if isinstance(record_env,
str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not re.search("[/\\\]", path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
env = wrappers.Monitor(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation"
if policy_config["in_evaluation"] else "training")
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
return env

# gym.Env -> Wrap with gym Monitor.
else:

def wrap(env):
if record_env:
from gym import wrappers
path_ = record_env if isinstance(record_env,
str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not re.search("[/\\\]", path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
env = wrappers.Monitor(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation"
if policy_config["in_evaluation"] else "training")
return env
return record_env_wrapper(env, record_env, log_dir,
policy_config)

self.env: EnvType = wrap(self.env)

Expand Down
6 changes: 6 additions & 0 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,12 @@ def new_episode(env_id):
"rendering! Try `pip install gym[all]`.")
if simple_image_viewer:
simple_image_viewer.imshow(rendered)
elif rendered not in [True, False, None]:
raise ValueError(
"The env's ({base_env}) `try_render()` method returned an"
" unsupported value! Make sure you either return a "
"uint8/w x h x 3 (RGB) image or handle rendering in a "
"window and then return `True`.")
perf_stats.env_render_time += time.time() - t5


Expand Down
9 changes: 8 additions & 1 deletion rllib/examples/env_rendering_and_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
import numpy as np
import ray
from gym.spaces import Box, Discrete

from ray import tune
from ray.rllib.env.multi_agent_env import make_multi_agent

parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--multi-agent", action="store_true")
parser.add_argument("--stop-iters", type=int, default=10)
parser.add_argument("--stop-timesteps", type=int, default=10000)
parser.add_argument("--stop-reward", type=float, default=9.0)
Expand Down Expand Up @@ -86,6 +89,9 @@ def render(self, mode="rgb"):
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)


MultiAgentCustomRenderedEnv = make_multi_agent(
lambda config: CustomRenderedEnv(config))

if __name__ == "__main__":
# Note: Recording and rendering in this example
# should work for both local_mode=True|False.
Expand All @@ -95,7 +101,8 @@ def render(self, mode="rgb"):
# Example config causing
config = {
# Also try common gym envs like: "CartPole-v0" or "Pendulum-v0".
"env": CustomRenderedEnv,
"env": (MultiAgentCustomRenderedEnv
if args.multi_agent else CustomRenderedEnv),
"env_config": {
"corridor_length": 10,
"max_steps": 100,
Expand Down

0 comments on commit 79a9d6d

Please sign in to comment.