Skip to content

Commit

Permalink
Merge pull request #167 from zuoxingdong/step_info_trajectory
Browse files Browse the repository at this point in the history
update runner: adapt to VecStepInfo
  • Loading branch information
zuoxingdong committed May 7, 2019
2 parents 0499bbd + 5ecb3b1 commit c35c50f
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 101 deletions.
8 changes: 7 additions & 1 deletion lagom/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,13 @@ def __init__(self, env):
assert isinstance(env, VecEnv)
self.env = env
self.metadata = env.metadata
super().__init__(list_make_env=env.list_make_env)

self.list_make_env = env.list_make_env
self.list_env = env.list_env
self.observation_space = env.observation_space
self.action_space = env.action_space
self.reward_range = env.reward_range
self.spec = env.spec

def step(self, actions):
return self.env.step(actions)
Expand Down
23 changes: 12 additions & 11 deletions lagom/runner/episode_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lagom.envs import VecEnv
from lagom.envs.wrappers import VecStepInfo

from .base_runner import BaseRunner
from .trajectory import Trajectory
Expand All @@ -11,32 +12,32 @@ def __init__(self, reset_on_call=True):

def __call__(self, agent, env, T, **kwargs):
assert isinstance(env, VecEnv)
assert isinstance(env, VecStepInfo)
assert len(env) == 1, 'for cleaner API, one should use single VecEnv'

D = [Trajectory()]
if self.reset_on_call:
observation = env.reset()
observation, _ = env.reset()
else:
if self.observation is None:
self.observation = env.reset()
self.observation, _ = env.reset()
observation = self.observation
D[-1].add_observation(observation)
for t in range(T):
out_agent = agent.choose_action(observation, **kwargs)
action = out_agent.pop('raw_action')
next_observation, reward, done, info = env.step(action)
# unbatched for [reward, done, info]
reward, done, info = map(lambda x: x[0], [reward, done, info])
info = {**info, **out_agent}
if done:
D[-1].add_observation([info['last_observation']]) # add a batch dim
next_observation, reward, step_info = env.step(action)
# unbatch for [reward, step_info]
reward, step_info = map(lambda x: x[0], [reward, step_info])
step_info.info = {**step_info.info, **out_agent}
if step_info.last:
D[-1].add_observation([step_info['last_observation']]) # add a batch dim
else:
D[-1].add_observation(next_observation)
D[-1].add_action(action)
D[-1].add_reward(reward)
D[-1].add_info(info)
D[-1].add_done(done)
if done:
D[-1].add_step_info(step_info)
if step_info.last:
assert D[-1].completed
D.append(Trajectory())
D[-1].add_observation(next_observation) # initial observation
Expand Down
38 changes: 17 additions & 21 deletions lagom/runner/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,40 @@ def __init__(self):
self.observations = []
self.actions = []
self.rewards = []
self.dones = []
self.infos = []
self.step_infos = []

self.completed = False
@property
def completed(self):
return len(self.step_infos) > 0 and self.step_infos[-1].last

def add_observation(self, observation):
assert not self.completed
self.observations.append(observation)

@property
def numpy_observations(self):
out = np.concatenate(np.asarray(self.observations), axis=0)
assert out.shape[0] == len(self) + 1 # plus initial observation
out = np.concatenate(self.observations, axis=0)
return out

@property
def last_observation(self):
return self.observations[-1]

@property
def reach_time_limit(self):
return self.step_infos[-1].time_limit

@property
def reach_terminal(self):
# TODO: handle TimeLimit seems not give performance boost
# return self.dones[-1] and 'TimeLimit.truncated' not in self.infos[-1]
return self.dones[-1]
return self.step_infos[-1].terminal

def add_action(self, action):
assert not self.completed
self.actions.append(action)

@property
def numpy_actions(self):
return np.concatenate(np.asarray(self.actions), axis=0)
return np.concatenate(self.actions, axis=0)

def add_reward(self, reward):
assert not self.completed
Expand All @@ -47,29 +49,23 @@ def add_reward(self, reward):
def numpy_rewards(self):
return np.asarray(self.rewards)

def add_done(self, done):
def add_step_info(self, step_info):
assert not self.completed
self.dones.append(done)
if done:
self.completed = True

self.step_infos.append(step_info)

@property
def numpy_dones(self):
return np.asarray(self.dones)
return np.asarray([step_info.done for step_info in self.step_infos])

@property
def numpy_masks(self):
return 1. - self.numpy_dones

def add_info(self, info):
assert not self.completed
self.infos.append(info)

def get_all_info(self, key):
return [info[key] for info in self.infos]
return [step_info[key] for step_info in self.step_infos]

def __len__(self):
return len(self.dones)
return len(self.step_infos)

def __repr__(self):
return f'Trajectory({len(self)})'

0 comments on commit c35c50f

Please sign in to comment.