Skip to content

Commit

Permalink
Move old EpisodeRunner and Trajectory to legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 7, 2019
1 parent 9211b8e commit 88eb550
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
2 changes: 0 additions & 2 deletions lagom/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .trajectory import Trajectory

from .returns import returns
from .returns import bootstrapped_returns

Expand Down
74 changes: 74 additions & 0 deletions legacy/episode_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
class EpisodeRunner(BaseRunner):
def __init__(self, reset_on_call=True):
self.reset_on_call = reset_on_call
self.observation = None

def __call__(self, agent, env, T, **kwargs):
assert isinstance(env, VecEnv) and isinstance(env, VecStepInfo) and len(env) == 1

D = [Trajectory()]
if self.reset_on_call:
observation, _ = env.reset()
else:
if self.observation is None:
self.observation, _ = env.reset()
observation = self.observation
D[-1].add_observation(observation)
for t in range(T):
out_agent = agent.choose_action(observation, **kwargs)
action = out_agent.pop('raw_action')
next_observation, [reward], [step_info] = env.step(action)
step_info.info = {**step_info.info, **out_agent}
if step_info.last:
D[-1].add_observation([step_info['last_observation']]) # add a batch dim
else:
D[-1].add_observation(next_observation)
D[-1].add_action(action)
D[-1].add_reward(reward)
D[-1].add_step_info(step_info)
if step_info.last:
assert D[-1].completed
D.append(Trajectory())
D[-1].add_observation(next_observation) # initial observation
observation = next_observation
if len(D[-1]) == 0:
D = D[:-1]
self.observation = observation
return D


@pytest.mark.parametrize('env_id', ['Sanity', 'CartPole-v1', 'Pendulum-v0', 'Pong-v0'])
@pytest.mark.parametrize('num_env', [1, 3])
@pytest.mark.parametrize('init_seed', [0, 10])
@pytest.mark.parametrize('T', [1, 5, 100])
def test_episode_runner(env_id, num_env, init_seed, T):
if env_id == 'Sanity':
make_env = lambda: TimeLimit(SanityEnv())
else:
make_env = lambda: gym.make(env_id)
env = make_vec_env(make_env, num_env, init_seed)
env = VecStepInfo(env)
agent = RandomAgent(None, env, None)
runner = EpisodeRunner()

if num_env > 1:
with pytest.raises(AssertionError):
D = runner(agent, env, T)
else:
with pytest.raises(AssertionError):
runner(agent, env.env, T) # must be VecStepInfo
D = runner(agent, env, T)
for traj in D:
assert isinstance(traj, Trajectory)
assert len(traj) <= env.spec.max_episode_steps
assert traj.numpy_observations.shape == (len(traj) + 1, *env.observation_space.shape)
if isinstance(env.action_space, gym.spaces.Discrete):
assert traj.numpy_actions.shape == (len(traj),)
else:
assert traj.numpy_actions.shape == (len(traj), *env.action_space.shape)
assert traj.numpy_rewards.shape == (len(traj),)
assert traj.numpy_dones.shape == (len(traj), )
assert traj.numpy_masks.shape == (len(traj), )
assert len(traj.step_infos) == len(traj)
if traj.completed:
assert np.allclose(traj.observations[-1], traj.step_infos[-1]['last_observation'])
52 changes: 52 additions & 0 deletions lagom/metric/trajectory.py → legacy/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,55 @@ def get_all_info(self, key):

def __repr__(self):
return f'Trajectory(T: {len(self)}, Completed: {self.completed}, Reach time limit: {self.reach_time_limit}, Reach terminal: {self.reach_terminal})'



@pytest.mark.parametrize('init_seed', [0, 10])
@pytest.mark.parametrize('T', [1, 5, 100])
def test_trajectory(init_seed, T):
make_env = lambda: TimeLimit(SanityEnv())
env = make_vec_env(make_env, 1, init_seed) # single environment
env = VecStepInfo(env)
D = Trajectory()
assert len(D) == 0
assert not D.completed

observation, _ = env.reset()
D.add_observation(observation)
for t in range(T):
action = [env.action_space.sample()]
next_observation, reward, [step_info] = env.step(action)
if step_info.last:
D.add_observation([step_info['last_observation']])
else:
D.add_observation(next_observation)
D.add_action(action)
D.add_reward(reward)
D.add_step_info(step_info)
observation = next_observation
if step_info.last:
with pytest.raises(AssertionError):
D.add_observation(observation)
break
assert len(D) > 0
assert len(D) <= T
assert len(D) + 1 == len(D.observations)
assert len(D) + 1 == len(D.numpy_observations)
assert len(D) == len(D.actions)
assert len(D) == len(D.numpy_actions)
assert len(D) == len(D.rewards)
assert len(D) == len(D.numpy_rewards)
assert len(D) == len(D.numpy_dones)
assert len(D) == len(D.numpy_masks)
assert np.allclose(np.logical_not(D.numpy_dones), D.numpy_masks)
assert len(D) == len(D.step_infos)
if len(D) < T:
assert step_info.last
assert D.completed
assert D.reach_terminal
assert not D.reach_time_limit
assert np.allclose(D.observations[-1], [step_info['last_observation']])
if not step_info.last:
assert not D.completed
assert not D.reach_terminal
assert not D.reach_time_limit

0 comments on commit 88eb550

Please sign in to comment.