Skip to content

Commit

Permalink
Update render interface
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jul 21, 2023
1 parent c27f0a4 commit a7207cd
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 140 deletions.
27 changes: 9 additions & 18 deletions d3rlpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import click
import gym
import numpy as np
from gym.wrappers import RecordVideo

from ._version import __version__
from .algos import (
Expand All @@ -17,7 +18,6 @@
TransformerAlgoBase,
)
from .base import load_learnable
from .envs import Monitor
from .metrics.utility import (
evaluate_qlearning_with_environment,
evaluate_transformer_with_environment,
Expand Down Expand Up @@ -249,8 +249,6 @@ def _exec_to_create_env(code: str) -> gym.Env[Any, Any]:
@click.option(
"--n-episodes", default=3, help="the number of episodes to record."
)
@click.option("--frame-rate", default=60, help="video frame rate.")
@click.option("--record-rate", default=1, help="record frame rate.")
@click.option(
"--target-return",
default=None,
Expand All @@ -262,8 +260,6 @@ def record(
env_header: Optional[str],
out: str,
n_episodes: int,
frame_rate: float,
record_rate: int,
target_return: Optional[float],
) -> None:
# load saved model
Expand All @@ -273,32 +269,27 @@ def record(
# wrap environment with Monitor
env: gym.Env[Any, Any]
if env_id is not None:
env = gym.make(env_id)
env = gym.make(env_id, render_mode="rgb_array")
elif env_header is not None:
env = _exec_to_create_env(env_header)
else:
raise ValueError("env_id or env_header must be provided.")

wrapped_env = Monitor(
wrapped_env = RecordVideo(
env,
out,
video_callable=lambda ep: ep % 1 == 0,
frame_rate=float(frame_rate),
record_rate=int(record_rate),
episode_trigger=lambda ep: True,
)

# run episodes
if isinstance(algo, QLearningAlgoBase):
evaluate_qlearning_with_environment(
algo, wrapped_env, n_episodes, render=True
)
evaluate_qlearning_with_environment(algo, wrapped_env, n_episodes)
elif isinstance(algo, TransformerAlgoBase):
assert target_return is not None, "--target-return must be specified."
evaluate_transformer_with_environment(
StatefulTransformerWrapper(algo, float(target_return)),
wrapped_env,
n_episodes,
render=True,
)
else:
raise ValueError("invalid algo type.")
Expand Down Expand Up @@ -330,22 +321,22 @@ def play(
# wrap environment with Monitor
env: gym.Env[Any, Any]
if env_id is not None:
env = gym.make(env_id)
env = gym.make(env_id, render_mode="human")
elif env_header is not None:
env = _exec_to_create_env(env_header)
else:
raise ValueError("env_id or env_header must be provided.")

# run episodes
if isinstance(algo, QLearningAlgoBase):
evaluate_qlearning_with_environment(algo, env, n_episodes, render=True)
score = evaluate_qlearning_with_environment(algo, env, n_episodes)
elif isinstance(algo, TransformerAlgoBase):
assert target_return is not None, "--target-return must be specified."
evaluate_transformer_with_environment(
score = evaluate_transformer_with_environment(
StatefulTransformerWrapper(algo, float(target_return)),
env,
n_episodes,
render=True,
)
else:
raise ValueError("invalid algo type.")
print(f"Score: {score}")
102 changes: 1 addition & 101 deletions d3rlpy/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import os
from collections import deque
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Deque, Dict, Optional, Tuple, TypeVar

import gym
import numpy as np
Expand All @@ -19,7 +17,6 @@
"FrameStack",
"AtariPreprocessing",
"Atari",
"Monitor",
]

_ObsType = TypeVar("_ObsType")
Expand Down Expand Up @@ -338,100 +335,3 @@ def __init__(
else:
env = ChannelFirst(env)
super().__init__(env)


class Monitor(gym.Wrapper[_ObsType, _ActType]):
"""gym.wrappers.Monitor-style Monitor wrapper.
Args:
env (gym.Env): gym environment.
directory (str): directory to save.
video_callable (callable): callable function that takes episode counter
to control record frequency.
force (bool): flag to allow existing directory.
frame_rate (float): video frame rate.
record_rate (int): images are record every ``record_rate`` frames.
"""

_directory: str
_video_callable: Callable[[int], bool]
_frame_rate: float
_record_rate: int
_episode: int
_episode_return: float
_episode_step: int
_buffer: List[np.ndarray]

def __init__(
self,
env: gym.Env[_ObsType, _ActType],
directory: str,
video_callable: Optional[Callable[[int], bool]] = None,
force: bool = False,
frame_rate: float = 30.0,
record_rate: int = 1,
):
super().__init__(env)
# prepare directory
if os.path.exists(directory) and not force:
raise ValueError(f"{directory} already exists.")
os.makedirs(directory, exist_ok=True)
self._directory = directory

if video_callable:
self._video_callable = video_callable
else:
self._video_callable = lambda ep: ep % 10 == 0

self._frame_rate = frame_rate
self._record_rate = record_rate

self._episode = 0
self._episode_return = 0.0
self._episode_step = 0
self._buffer = []

def step(
self, action: _ActType
) -> Tuple[_ObsType, float, bool, bool, Dict[str, Any]]:
obs, reward, done, truncated, info = super().step(action)

if self._video_callable(self._episode):
# store rendering
frame = cv2.cvtColor(super().render("rgb_array"), cv2.COLOR_BGR2RGB)
self._buffer.append(frame)
self._episode_step += 1
self._episode_return += reward
if done:
self._save_video()
self._save_stats()

return obs, reward, done, truncated, info

def reset(self, **kwargs: Any) -> Tuple[_ObsType, Dict[str, Any]]:
self._episode += 1
self._episode_return = 0.0
self._episode_step = 0
self._buffer = []
return super().reset(**kwargs)

def _save_video(self) -> None:
height, width = self._buffer[0].shape[:2]
path = os.path.join(self._directory, f"video{self._episode}.avi")
fmt = cv2.VideoWriter_fourcc(*"MJPG")
writer = cv2.VideoWriter(path, fmt, self._frame_rate, (width, height))
print(f"Saving a recorded video to {path}...")
for i, frame in enumerate(self._buffer):
if i % self._record_rate == 0:
writer.write(frame)
writer.release()

def _save_stats(self) -> None:
path = os.path.join(self._directory, f"stats{self._episode}.json")
stats = {
"episode_step": self._episode_step,
"return": self._episode_return,
}
with open(path, "w") as f:
json_str = json.dumps(stats, indent=2)
f.write(json_str)
5 changes: 0 additions & 5 deletions d3rlpy/metrics/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,24 +516,20 @@ class EnvironmentEvaluator(EvaluatorProtocol):
env: Gym environment.
n_trials: Number of episodes to evaluate.
epsilon: Probability of random action.
render: Flag to turn on rendering.
"""
_env: gym.Env[Any, Any]
_n_trials: int
_epsilon: float
_render: bool

def __init__(
self,
env: gym.Env[Any, Any],
n_trials: int = 10,
epsilon: float = 0.0,
render: bool = False,
):
self._env = env
self._n_trials = n_trials
self._epsilon = epsilon
self._render = render

def __call__(
self, algo: QLearningAlgoProtocol, dataset: ReplayBuffer
Expand All @@ -543,5 +539,4 @@ def __call__(
env=self._env,
n_trials=self._n_trials,
epsilon=self._epsilon,
render=self._render,
)
10 changes: 0 additions & 10 deletions d3rlpy/metrics/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def evaluate_qlearning_with_environment(
env: gym.Env[Any, Any],
n_trials: int = 10,
epsilon: float = 0.0,
render: bool = False,
) -> float:
"""Returns average environment score.
Expand All @@ -39,7 +38,6 @@ def evaluate_qlearning_with_environment(
env: gym-styled environment.
n_trials: the number of trials.
epsilon: noise factor for epsilon-greedy policy.
render: flag to render environment.
Returns:
average score.
Expand All @@ -59,9 +57,6 @@ def evaluate_qlearning_with_environment(
observation, reward, done, truncated, _ = env.step(action)
episode_reward += reward

if render:
env.render()

if done or truncated:
break
episode_rewards.append(episode_reward)
Expand All @@ -72,7 +67,6 @@ def evaluate_transformer_with_environment(
algo: StatefulTransformerAlgoProtocol,
env: gym.Env[Any, Any],
n_trials: int = 10,
render: bool = False,
) -> float:
"""Returns average environment score.
Expand All @@ -94,7 +88,6 @@ def evaluate_transformer_with_environment(
alg: algorithm object.
env: gym-styled environment.
n_trials: the number of trials.
render: flag to render environment.
Returns:
average score.
Expand All @@ -112,9 +105,6 @@ def evaluate_transformer_with_environment(
observation, reward, done, truncated, _ = env.step(action)
episode_reward += reward

if render:
env.render()

if done or truncated:
break
episode_rewards.append(episode_reward)
Expand Down
8 changes: 2 additions & 6 deletions docs/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ Record evaluation episodes as videos with the saved model::
- Output directory.
* - ``--n-episodes``
- The number of episodes to record.
* - ``--frame-rate``
- Video frame rate.
* - ``--record-rate``
- Images are recored every ``record-rate`` frames.
* - ``--epsilon``
- :math:`\epsilon`-greedy evaluation.
* - ``--target-return``
Expand All @@ -99,7 +95,7 @@ example::

# record wrapped environment
$ d3rlpy record d3rlpy_logs/Discrete_CQL_20201224224314/model_100.d3 \
--env-header 'import gym; from d3rlpy.envs import Atari; env = Atari(gym.make("BreakoutNoFrameskip-v4"), is_eval=True)'
--env-header 'import gym; from d3rlpy.envs import Atari; env = Atari(gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array"), is_eval=True)'

play
----
Expand Down Expand Up @@ -129,4 +125,4 @@ example::

# record wrapped environment
$ d3rlpy play d3rlpy_logs/Discrete_CQL_20201224224314/model_100.d3 \
--env-header 'import gym; from d3rlpy.envs import Atari; env = Atari(gym.make("BreakoutNoFrameskip-v4"), is_eval=True)'
--env-header 'import gym; from d3rlpy.envs import Atari; env = Atari(gym.make("BreakoutNoFrameskip-v4", render_mode="human"), is_eval=True)'

0 comments on commit a7207cd

Please sign in to comment.