Skip to content

Commit

Permalink
Update TrajectoryRunner with its docs and test
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Sep 6, 2018
1 parent 78afd12 commit 6884940
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 133 deletions.
4 changes: 4 additions & 0 deletions lagom/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def choose_action(self, obs):
Returns:
output (dict): a dictionary of action selection output.
NOTE: everything should be batched, e.g. scalar loss -> [loss]
Possible keys: ['action', 'action_logprob', 'state_value', 'Q_value']
"""
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions lagom/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .trajectory import Trajectory
from .segment import Segment

from .base_runner import BaseRunner
from .trajectory_runner import TrajectoryRunner
from .segment_runner import SegmentRunner
37 changes: 37 additions & 0 deletions lagom/runner/base_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from lagom.envs.vec_env import VecEnv


class BaseRunner(object):
r"""Base class for all runners.
Any runner should subclass this class.
A runner is a data collection interface between the agent and the environment.
For each calling of the runner, the agent will take actions and receive observation
in and from an environment for a certain number of trajectories/segments and a certain
number of time steps.
.. note::
By default, the agent handles batched data returned from :class:`VecEnv` type of environment.
And the collected data should use either :class:`Trajectory` or :class:`Segment`.
"""
def __init__(self, agent, env, gamma):
r"""Initialize the runner.
Args:
agent (BaseAgent): agent
env (VecEnv): VecEnv type of environment
gamma (float): discount factor
"""
assert isinstance(env, VecEnv), f'expected VecEnv, got {type(env)}'

self.agent = agent
self.env = env
self.gamma = gamma

def __call__(self, N, T):
r"""Run the agent in the environment and collect all necessary interaction data as a batch. """
raise NotImplementedError

82 changes: 35 additions & 47 deletions lagom/runner/segment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,61 @@
from lagom.runner import Transition
from lagom.runner import Segment

from lagom.envs.vec_env import VecEnv
from .base_runner import BaseRunner


class SegmentRunner(object):
"""
Batched data collection for an agent in one or multiple environments for a certain time steps.
It includes successive transitions (observation, action, reward, next observation, done) and
additional data useful for training the agent such as the action log-probabilities, policy entropies,
Q values etc.
class SegmentRunner(BaseRunner):
r"""Define a data collection interface by running the agent in an environment and collecting a batch
of segments for a certain time steps.
.. note::
By default, the agent handles batched data returned from :class:`VecEnv` type of environment.
And the collected data is a list of :class:`Segment`.
The collected data in each environment will be wrapped in an individual Segment object. Each call
of the runner will return a list of Segment objects each with length same as the number of time steps.
Each :class:`Segment` should store successive transitions i.e. :math:`(s_t, a_t, r_t, s_{t+1}, \text{done})`
and all corresponding useful information such as log-probabilities of actions, state values, Q values etc.
Note that we allow transitions coming from multiple episodes successively.
For example, for a Segment with length 4, it can have either of following cases:
The collected transitions in each :class:`Segment` come from one or multiple episodes. The first state is not
restricted to be initial observation from an episode, this allows a rolling segments of episodic transitions.
And all segments have the same number of transitions (time steps). For detailed description, see the docstring
in :class:`Segment`.
Let s_t^i be state at time step t in episode i and s_T^i be terminal state in episode i.
Be aware that if the transitions come from more than one episode, the succeeding transitions for the next episode
start from initial observation.
1. Part of single episode from initial observation:
s_0^1 -> s_1^1 -> s_2^1 -> s_3^1
2. A complete episode:
s_0^1 -> s_1^1 -> s_2^1 -> s_T^1
3. Intermediate part of a single episode:
s_5^1 -> s_6^1 -> s_7^1 -> s_8^1
4. Two complete episodes:
s_0^1 -> s_T^1 -> s_0^2 -> s_T^2
5. Parts from two episodes:
s_3^1 -> s_T^1 -> s_0^2 -> s_1^2
.. note::
Be aware that if the transitions coming from more than one episode, then the episodes should be
in a successive order, and all preceding episodes but the last one should reach a terminal state
before starting the new episode and each succeeding episode starts from initial observation.
For collecting batch of trajectories, one should use :class:`TrajectoryRunner` instead.
Example::
In order to make such data collection possible, the environment must be of type VecEnv to support
batched data. VecEnv will continuously collect data in all environment, for each occurrence of
`done=True`, the environment will be automatically reset and continue. So if we want to collect data
from initial observation in all environments, the method `reset` should be called.
The SegmentRunner is very general, for runner that only collects transitions from a single
episode (start from initial observation) one can use TrajectoryRunner instead.
"""
def __init__(self, agent, env, gamma):
self.agent = agent
self.env = env
assert isinstance(self.env, VecEnv), 'The environment must be of type VecEnv. '
self.gamma = gamma
super().__init__(agent=agent, env=env, gamma=gamma)

# Buffer for observation (continuous with next call)
self.obs_buffer = None

def __call__(self, T, reset=False):
"""
Run the agent in the batched environments and collect all necessary data for given number of
time steps for each Segment (one Segment for each environment).
Note that we do not reset all environments for each call as it does in TrajectoryRunner.
An option `reset` is provided to decide if reset all environment before data collection.
r"""Run the agent in the vectorized environment (one or multiple environments) and collect
a number of segments each with exactly T time steps.
This is because for SegmentRunner, we often need to continuously collect a batched data
with small time steps, so each `__call__` will continuously collect data until `reset=True`.
.. note::
One can continuously call this method to collect a rolling of segments for episodic transitions
until :attr:`reset` set to be ``True``.
Args:
T (int): Number of time steps
reset (bool): Whether to reset all environments (in VecEnv).
T (int): number of time steps to collect
reset (bool, optional): If ``True``, then reset all internal environments in VecEnv.
Default: ``False``
Returns:
D (list of Segment): list of collected segments.
Returns
-------
D : list
a list of collected :class:`Segment`
"""
# Initialize all Segment for each environment
D = [Segment(gamma=self.gamma) for _ in range(self.env.num_env)]
Expand Down
133 changes: 69 additions & 64 deletions lagom/runner/trajectory_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,67 @@
from .transition import Transition
from .trajectory import Trajectory

from lagom.envs.vec_env import VecEnv
from lagom.envs.spaces import Discrete
from .base_runner import BaseRunner


class TrajectoryRunner(object):
"""
Batched data collection for an agent in one environment for a number of trajectories and a certain time steps.
It includes successive transitions (observation, action, reward, next observation, done) and
additional data useful for training the agent such as the action log-probabilities, policy entropies,
Q values etc.
class TrajectoryRunner(BaseRunner):
r"""Define a data collection interface by running the agent in an environment and collecting a batch of
trajectories for a maximally allowed time steps.
The collected data in each trajectory will be wrapped in an individual Trajectory object. Each call
of the runner will return a list of Trajectory objects.
.. note::
By default, the agent handles batched data returned from :class:`VecEnv` type of environment.
And the collected data is a list of :class:`Trajectory`.
Note that the transitions in a Trajectory should come from a single episode and started from initial observation.
The length of the trajectory can maximally be the allowed time steps or can be the time steps until it reaches
terminal state.
Each :class:`Trajectory` should store successive transitions i.e. :math:`(s_t, a_t, r_t, s_{t+1}, \text{done})`
and all corresponding useful information such as log-probabilities of actions, state values, Q values etc.
For example, for a Trajectory with length 4, it can have either of following cases:
The collected transitions in each :class:`Trajectory` come from a single episode starting from initial observation
until reaching maximally allowed time steps or reaching terminal states. For detailed description, see the docstring
in :class:`Trajectory`.
Let s_t be state at time step t and s_T be terminal state.
.. note::
For collecting batch of segments, one should use :class:`SegmentRunner` instead.
1. Part of single episode from initial observation:
s_0 -> s_1 -> s_2 -> s_3
2. A complete episode:
s_0 -> s_1 -> s_2 -> s_T
Example::
For runner that collects transitions from multiple episodes, one can use SegmentRunner instead.
>>> list_make_env = make_envs(make_env=make_gym_env,
env_id='CartPole-v1',
num_env=1,
init_seed=0)
>>> env = SerialVecEnv(list_make_env=list_make_env)
>>> env_spec = EnvSpec(env)
>>> agent = RandomAgent(env_spec=env_spec)
>>> runner = TrajectoryRunner(agent=agent, env=env, gamma=1.0)
>>> runner(N=2, T=3)
[Trajectory:
Transition: (s=[-0.04002427 0.00464987 -0.01704236 -0.03673052], a=1, r=1.0, s_next=[-0.03993127 0.20001201 -0.01777697 -0.33474139], done=False)
Transition: (s=[-0.03993127 0.20001201 -0.01777697 -0.33474139], a=1, r=1.0, s_next=[-0.03593103 0.39538239 -0.0244718 -0.63297681], done=False)
Transition: (s=[-0.03593103 0.39538239 -0.0244718 -0.63297681], a=1, r=1.0, s_next=[-0.02802339 0.59083704 -0.03713133 -0.93326499], done=False),
Trajectory:
Transition: (s=[-0.04892357 0.02011271 0.02775732 -0.04547827], a=1, r=1.0, s_next=[-0.04852131 0.21482587 0.02684775 -0.3292759 ], done=False)
Transition: (s=[-0.04852131 0.21482587 0.02684775 -0.3292759 ], a=0, r=1.0, s_next=[-0.0442248 0.01933221 0.02026223 -0.0282488 ], done=False)
Transition: (s=[-0.0442248 0.01933221 0.02026223 -0.0282488 ], a=1, r=1.0, s_next=[-0.04383815 0.21415782 0.01969726 -0.31447053], done=False)]
"""
def __init__(self, agent, env, gamma):
"""
Args:
agent (BaseAgent): agent
env (Env): environment
gamma (float): discount factor
"""
self.agent = agent
self.env = env
assert isinstance(self.env, VecEnv), 'The environment must be of type VecEnv. '
msg = f'expected only one environment for TrajectoryRunner, got {self.env.num_env}'
assert self.env.num_env == 1, msg
self.gamma = gamma
super().__init__(agent=agent, env=env, gamma=gamma)
assert self.env.num_env == 1, f'expected a single environment, got {self.env.num_env}'

def __call__(self, N, T):
"""
Run the agent in the environment and collect all necessary data for given number of trajectories
and horizon (time steps) for each trajectory.
r"""Run the agent in the environment and collect N trajectories each with maximally T time steps.
Args:
N (int): Number of trajectories
T (int): Number of time steps
N (int): number of trajectories to collect.
T (int): maximally allowed time steps.
Returns:
D (list of Trajectory): list of collected trajectories.
Returns
-------
D : list
a list of collected :class:`Trajectory`
"""
D = []

Expand All @@ -68,39 +75,38 @@ def __call__(self, N, T):
obs = self.env.reset()

for t in range(T): # Iterate over the number of time steps
# Action selection by the agent
# Not using numpy because we don't know exact dtype, all Agent should handle batched data
output_agent = self.agent.choose_action(obs)
# Action selection by the agent (handles batched data)
out_agent = self.agent.choose_action(obs)

# Unpack action from output.
# We record Tensor dtype for backprop (propagate via Transitions)
action = output_agent.pop('action') # pop-out
state_value = output_agent.pop('state_value', None)

# Obtain raw action from Tensor for environment to execute
# Unpack action
action = out_agent.pop('action') # pop-out
# Get raw action if Tensor dtype for feeding the environment
if torch.is_tensor(action):
raw_action = action.detach().cpu().numpy()
raw_action = list(raw_action)
raw_action = list(action.detach().cpu().numpy())
else: # Non Tensor action, e.g. from RandomAgent
raw_action = action
# Execute the action

# Execute the agent in the environment
obs_next, reward, done, info = self.env.step(raw_action)

# Create and record a Transition
# Take out first element because we only have one environment wrapped for TrajectoryRunner
# Take out first elements because of VecEnv with single environment (no batch dim in Transition)
# Note that action can be Tensor type (can be used for backprop)
transition = Transition(s=obs[0],
a=action[0],
r=reward[0],
s_next=obs_next[0],
done=done[0])
# Record state value if required

# Handle state value if available
state_value = out_agent.pop('state_value', None)
if state_value is not None:
transition.add_info('V_s', state_value[0])
# Record additional information from output_agent
# Note that 'action' and 'state_value' already poped out
for key, val in output_agent.items():
transition.add_info(key, val[0])

# Record additional information from out_agent to transitions
# Note that 'action' and 'state_value' already poped out
[transition.add_info(key, val[0]) for key, val in out_agent.items()]

# Add transition to Trajectory
trajectory.add_transition(transition)

Expand All @@ -110,16 +116,15 @@ def __call__(self, N, T):
# Terminate if episode finishes
if done[0]:
break

# Call agent again to compute state value for final obsevation in collected trajectory
# If state value available, calculate state value for final observation
if state_value is not None:
V_s_next = self.agent.choose_action(obs)['state_value']
# We do not set zero even if it is terminal state
# Because it should be handled in Trajectory e.g. compute TD errors
# Return original Tensor in general can help backprop to work properly, e.g. learning value function
# Add to the final transition as 'V_s_next'
# Add to final transition in the trajectory
# Do not set zero for terminal state, useful for backprop to learn value function
# It will be handled in Trajectory or Segment method when calculating things like returns/TD errors
trajectory.transitions[-1].add_info('V_s_next', V_s_next[0])

# Append trajectory to data
D.append(trajectory)

Expand Down

0 comments on commit 6884940

Please sign in to comment.