Skip to content

Commit

Permalink
Add wrapper RecordEpisodeStatistics
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 2, 2019
1 parent 238d3e5 commit 91de829
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/source/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ lagom.envs
.. automodule:: lagom.envs
.. currentmodule:: lagom.envs

.. autoclass:: RecordEpisodeStatistics
:members:

.. autoclass:: VecEnv
:members:

Expand Down
2 changes: 2 additions & 0 deletions lagom/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .vec_env import VecEnv
from .vec_env import VecEnvWrapper
from .make_vec_env import make_vec_env

from .record_episode_statistics import RecordEpisodeStatistics
34 changes: 34 additions & 0 deletions lagom/envs/record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import time
from collections import deque

import gym


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.t0 = time.perf_counter()
self.episode_return = 0.0
self.episode_horizon = 0
self.return_queue = deque(maxlen=deque_size)
self.horizon_queue = deque(maxlen=deque_size)

def reset(self, **kwargs):
observation = super().reset(**kwargs)
self.episode_return = 0.0
self.episode_horizon = 0
return observation

def step(self, action):
observation, reward, done, info = super().step(action)
self.episode_return += reward
self.episode_horizon += 1
if done:
info['episode'] = {'return': self.episode_return,
'horizon': self.episode_horizon,
'time': round(time.perf_counter() - self.t0, 4)}
self.return_queue.append(self.episode_return)
self.horizon_queue.append(self.episode_horizon)
self.episode_return = 0.0
self.episode_horizon = 0
return observation, reward, done, info
21 changes: 21 additions & 0 deletions test/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gym.spaces import Dict
from gym.wrappers import ClipReward

from lagom.envs import RecordEpisodeStatistics
from lagom.envs import make_vec_env
from lagom.envs import VecEnv
from lagom.envs.wrappers import get_wrapper
Expand Down Expand Up @@ -54,6 +55,26 @@ def make_env():
assert isinstance(infos, list) and len(infos) == num_env
env.close()
assert env.closed


@pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v0'])
@pytest.mark.parametrize('deque_size', [2, 5])
def test_record_episode_statistics(env_id, deque_size):
env = gym.make(env_id)
env = RecordEpisodeStatistics(env, deque_size)

for n in range(5):
env.reset()
assert env.episode_return == 0.0
assert env.episode_horizon == 0
for t in range(env.spec.max_episode_steps):
_, _, done, info = env.step(env.action_space.sample())
if done:
assert 'episode' in info
assert all([item in info['episode'] for item in ['return', 'horizon', 'time']])
break
assert len(env.return_queue) == deque_size
assert len(env.horizon_queue) == deque_size


@pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v0'])
Expand Down

0 comments on commit 91de829

Please sign in to comment.