Skip to content

Commit

Permalink
Add TimeStep/Trajectory, StepRunner/EpisodeRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 7, 2019
1 parent 2c280e2 commit 9211b8e
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 92 deletions.
14 changes: 14 additions & 0 deletions docs/source/lagom.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ Agent

.. autoclass:: RandomAgent
:members:

Data
----------------------------
.. autoclass:: StepType
:members:

.. autoclass:: TimeStep
:members:

.. autoclass:: Trajectory
:members:

Logger
----------------------------
Expand All @@ -28,6 +39,9 @@ Runner
.. autoclass:: EpisodeRunner
:members:

.. autoclass:: StepRunner
:members:

Evolution Strategies
----------------------------
.. autoclass:: BaseES
Expand Down
5 changes: 5 additions & 0 deletions lagom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from .agent import BaseAgent
from .agent import RandomAgent

from .data import StepType
from .data import TimeStep
from .data import Trajectory

from .engine import BaseEngine

from .es import BaseES
Expand All @@ -13,3 +17,4 @@

from .runner import BaseRunner
from .runner import EpisodeRunner
from .runner import StepRunner
121 changes: 121 additions & 0 deletions lagom/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from enum import IntEnum
from dataclasses import dataclass
import numpy as np


class StepType(IntEnum):
FIRST = 0
MID = 1
LAST = 2


@dataclass
class TimeStep:
step_type: StepType
observation: object
reward: float
done: bool
info: dict

def __getitem__(self, key):
return self.info[key]

def first(self):
if self.step_type == StepType.FIRST:
assert all([x is None for x in [self.reward, self.done, self.info]])
return self.step_type == StepType.FIRST

def mid(self):
if self.step_type == StepType.MID:
assert not self.first() and not self.last()
return self.step_type == StepType.MID

def last(self):
if self.step_type == StepType.LAST:
assert self.done is not None and self.done
return self.step_type == StepType.LAST

def time_limit(self):
return self.last() and self.info.get('TimeLimit.truncated', False)

def terminal(self):
return self.last() and not self.time_limit()

def __repr__(self):
return f'{self.__class__.__name__}({self.step_type.name})'


class Trajectory(object):
def __init__(self):
self.timesteps = []
self._actions = []

def __len__(self):
return len(self.timesteps)

@property
def T(self):
return max(0, len(self) - 1)

def __getitem__(self, index):
return self.timesteps[index]

def __iter__(self):
self.i = 0
return self

def __next__(self):
if self.i < len(self):
timestep = self.timesteps[self.i]
self.i += 1
return timestep
else:
raise StopIteration

@property
def finished(self):
return len(self) > 0 and self.timesteps[-1].last()

@property
def reach_time_limit(self):
return len(self) > 0 and self.timesteps[-1].time_limit()

@property
def reach_terminal(self):
return len(self) > 0 and self.timesteps[-1].terminal()

def add(self, timestep, action):
assert not self.finished
if len(self) == 0:
assert timestep.first()
assert action is None
else:
assert action is not None
self._actions.append(action)
self.timesteps.append(timestep)

@property
def observations(self):
return [timestep.observation for timestep in self.timesteps]

@property
def actions(self):
return self._actions

@property
def rewards(self):
return [timestep.reward for timestep in self.timesteps[1:]]

@property
def dones(self):
return [timestep.done for timestep in self.timesteps[1:]]

@property
def infos(self):
return [timestep.info for timestep in self.timesteps[1:]]

def get_infos(self, key):
return [timestep.info[key] for timestep in self.timesteps[1:] if key in timestep.info]

def __repr__(self):
return f'Trajectory(T: {self.T}, Finished: {self.finished}, Reach time limit: {self.reach_time_limit}, Reach terminal: {self.reach_terminal})'
89 changes: 47 additions & 42 deletions lagom/runner.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,76 @@
from abc import ABC
from abc import abstractmethod

from lagom.metric import Trajectory
from lagom.envs import VecEnv
from lagom.envs.wrappers import VecStepInfo
from lagom.data import StepType
from lagom.data import TimeStep
from lagom.data import Trajectory
from lagom.envs.timestep_env import TimeStepEnv


class BaseRunner(ABC):
r"""Base class for all runners.
A runner is a data collection interface between the agent and the environment.
For each calling of the runner, the agent will take actions and receive observation
in and from an environment for a certain number of trajectories/segments and a certain
number of time steps.
.. note::
By default, the agent handles batched data returned from :class:`VecEnv` type of environment.
"""
@abstractmethod
def __call__(self, agent, env, T, **kwargs):
r"""Run the agent in the environment for a number of time steps and collect all necessary interaction data.
def __call__(self, agent, env, **kwargs):
r"""Defines data collection via interactions between the agent and the environment.
Args:
agent (BaseAgent): agent
env (VecEnv): VecEnv type of environment
T (int): number of time steps
env (Env): environment
**kwargs: keyword arguments for more specifications.
"""
pass


class EpisodeRunner(BaseRunner):
def __call__(self, agent, env, N, **kwargs):
assert isinstance(env, TimeStepEnv)
D = []
for _ in range(N):
traj = Trajectory()
timestep = env.reset()
traj.add(timestep, None)
while not timestep.last():
out_agent = agent.choose_action(timestep, **kwargs)
action = out_agent.pop('raw_action')
timestep = env.step(action)
timestep.info = {**timestep.info, **out_agent}
traj.add(timestep, action)
D.append(traj)
return D


class StepRunner(BaseRunner):
def __init__(self, reset_on_call=True):
self.reset_on_call = reset_on_call
self.observation = None

def __call__(self, agent, env, T, **kwargs):
assert isinstance(env, VecEnv) and isinstance(env, VecStepInfo) and len(env) == 1

D = [Trajectory()]
if self.reset_on_call:
observation, _ = env.reset()
def __call__(self, agent, env, T, **kwargs):
assert isinstance(env, TimeStepEnv)
D = []
traj = Trajectory()
if self.reset_on_call or self.observation is None:
timestep = env.reset()
else:
if self.observation is None:
self.observation, _ = env.reset()
observation = self.observation
D[-1].add_observation(observation)
timestep = TimeStep(StepType.FIRST, observation=self.observation, reward=None, done=None, info=None)
traj.add(timestep, None)
for t in range(T):
out_agent = agent.choose_action(observation, **kwargs)
out_agent = agent.choose_action(timestep, **kwargs)
action = out_agent.pop('raw_action')
next_observation, [reward], [step_info] = env.step(action)
step_info.info = {**step_info.info, **out_agent}
if step_info.last:
D[-1].add_observation([step_info['last_observation']]) # add a batch dim
else:
D[-1].add_observation(next_observation)
D[-1].add_action(action)
D[-1].add_reward(reward)
D[-1].add_step_info(step_info)
if step_info.last:
assert D[-1].completed
D.append(Trajectory())
D[-1].add_observation(next_observation) # initial observation
observation = next_observation
if len(D[-1]) == 0:
D = D[:-1]
self.observation = observation
timestep = env.step(action)
timestep.info = {**timestep.info, **out_agent}
traj.add(timestep, action)
if timestep.last():
D.append(traj)
traj = Trajectory()
timestep = env.reset()
traj.add(timestep, None)
if traj.T > 0:
D.append(traj)
if not self.reset_on_call:
self.observation = timestep.observation
return D

0 comments on commit 9211b8e

Please sign in to comment.