-
Notifications
You must be signed in to change notification settings - Fork 5.8k
/
utils.py
101 lines (84 loc) · 3.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gym
from gym import wrappers
import os
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils import add_mixins
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
"""Tries to create a gym env given an EnvContext object and descriptor.
Note: This function tries to construct the env from a string descriptor
only using possibly installed RL env packages (such as gym, pybullet_envs,
vizdoomgym, etc..). These packages are no installation requirements for
RLlib. In case you would like to support more such env packages, add the
necessary imports and construction logic below.
Args:
env_context: The env context object to configure the env.
Note that this is a config dict, plus the properties:
`worker_index`, `vector_index`, and `remote`.
env_descriptor: The env descriptor, e.g. CartPole-v0,
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
CartPoleContinuousBulletEnv-v0.
Returns:
The actual gym environment object.
Raises:
gym.error.Error: If the env cannot be constructed.
"""
# Allow for PyBullet or VizdoomGym envs to be used as well
# (via string). This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0` or
# `env=VizdoomBasic-v0`.
try:
import pybullet_envs
pybullet_envs.getList()
except (ModuleNotFoundError, ImportError):
pass
try:
import vizdoomgym
vizdoomgym.__name__ # trick LINTer.
except (ModuleNotFoundError, ImportError):
pass
# Try creating a gym env. If this fails we can output a
# decent error message.
try:
return gym.make(env_descriptor, **env_context)
except gym.error.Error:
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
class VideoMonitor(wrappers.Monitor):
# Same as original method, but doesn't use the StatsRecorder as it will
# try to add up multi-agent rewards dicts, which throws errors.
def _after_step(self, observation, reward, done, info):
if not self.enabled:
return done
# Use done["__all__"] b/c this is a multi-agent dict.
if done["__all__"] and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation
# will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record video
self.video_recorder.capture_frame()
return done
def record_env_wrapper(env, record_env, log_dir, policy_config):
if record_env:
path_ = record_env if isinstance(record_env, str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not os.path.isabs(path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
wrapper_cls = (
VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
)
if isinstance(env, MultiAgentEnv):
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
env = wrapper_cls(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation" if policy_config["in_evaluation"] else "training",
)
return env