-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TimeStep/Trajectory, StepRunner/EpisodeRunner
- Loading branch information
1 parent
2c280e2
commit 9211b8e
Showing
5 changed files
with
309 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from enum import IntEnum | ||
from dataclasses import dataclass | ||
import numpy as np | ||
|
||
|
||
class StepType(IntEnum): | ||
FIRST = 0 | ||
MID = 1 | ||
LAST = 2 | ||
|
||
|
||
@dataclass | ||
class TimeStep: | ||
step_type: StepType | ||
observation: object | ||
reward: float | ||
done: bool | ||
info: dict | ||
|
||
def __getitem__(self, key): | ||
return self.info[key] | ||
|
||
def first(self): | ||
if self.step_type == StepType.FIRST: | ||
assert all([x is None for x in [self.reward, self.done, self.info]]) | ||
return self.step_type == StepType.FIRST | ||
|
||
def mid(self): | ||
if self.step_type == StepType.MID: | ||
assert not self.first() and not self.last() | ||
return self.step_type == StepType.MID | ||
|
||
def last(self): | ||
if self.step_type == StepType.LAST: | ||
assert self.done is not None and self.done | ||
return self.step_type == StepType.LAST | ||
|
||
def time_limit(self): | ||
return self.last() and self.info.get('TimeLimit.truncated', False) | ||
|
||
def terminal(self): | ||
return self.last() and not self.time_limit() | ||
|
||
def __repr__(self): | ||
return f'{self.__class__.__name__}({self.step_type.name})' | ||
|
||
|
||
class Trajectory(object): | ||
def __init__(self): | ||
self.timesteps = [] | ||
self._actions = [] | ||
|
||
def __len__(self): | ||
return len(self.timesteps) | ||
|
||
@property | ||
def T(self): | ||
return max(0, len(self) - 1) | ||
|
||
def __getitem__(self, index): | ||
return self.timesteps[index] | ||
|
||
def __iter__(self): | ||
self.i = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.i < len(self): | ||
timestep = self.timesteps[self.i] | ||
self.i += 1 | ||
return timestep | ||
else: | ||
raise StopIteration | ||
|
||
@property | ||
def finished(self): | ||
return len(self) > 0 and self.timesteps[-1].last() | ||
|
||
@property | ||
def reach_time_limit(self): | ||
return len(self) > 0 and self.timesteps[-1].time_limit() | ||
|
||
@property | ||
def reach_terminal(self): | ||
return len(self) > 0 and self.timesteps[-1].terminal() | ||
|
||
def add(self, timestep, action): | ||
assert not self.finished | ||
if len(self) == 0: | ||
assert timestep.first() | ||
assert action is None | ||
else: | ||
assert action is not None | ||
self._actions.append(action) | ||
self.timesteps.append(timestep) | ||
|
||
@property | ||
def observations(self): | ||
return [timestep.observation for timestep in self.timesteps] | ||
|
||
@property | ||
def actions(self): | ||
return self._actions | ||
|
||
@property | ||
def rewards(self): | ||
return [timestep.reward for timestep in self.timesteps[1:]] | ||
|
||
@property | ||
def dones(self): | ||
return [timestep.done for timestep in self.timesteps[1:]] | ||
|
||
@property | ||
def infos(self): | ||
return [timestep.info for timestep in self.timesteps[1:]] | ||
|
||
def get_infos(self, key): | ||
return [timestep.info[key] for timestep in self.timesteps[1:] if key in timestep.info] | ||
|
||
def __repr__(self): | ||
return f'Trajectory(T: {self.T}, Finished: {self.finished}, Reach time limit: {self.reach_time_limit}, Reach terminal: {self.reach_terminal})' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,76 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
from lagom.metric import Trajectory | ||
from lagom.envs import VecEnv | ||
from lagom.envs.wrappers import VecStepInfo | ||
from lagom.data import StepType | ||
from lagom.data import TimeStep | ||
from lagom.data import Trajectory | ||
from lagom.envs.timestep_env import TimeStepEnv | ||
|
||
|
||
class BaseRunner(ABC): | ||
r"""Base class for all runners. | ||
A runner is a data collection interface between the agent and the environment. | ||
For each calling of the runner, the agent will take actions and receive observation | ||
in and from an environment for a certain number of trajectories/segments and a certain | ||
number of time steps. | ||
.. note:: | ||
By default, the agent handles batched data returned from :class:`VecEnv` type of environment. | ||
""" | ||
@abstractmethod | ||
def __call__(self, agent, env, T, **kwargs): | ||
r"""Run the agent in the environment for a number of time steps and collect all necessary interaction data. | ||
def __call__(self, agent, env, **kwargs): | ||
r"""Defines data collection via interactions between the agent and the environment. | ||
Args: | ||
agent (BaseAgent): agent | ||
env (VecEnv): VecEnv type of environment | ||
T (int): number of time steps | ||
env (Env): environment | ||
**kwargs: keyword arguments for more specifications. | ||
""" | ||
pass | ||
|
||
|
||
class EpisodeRunner(BaseRunner): | ||
def __call__(self, agent, env, N, **kwargs): | ||
assert isinstance(env, TimeStepEnv) | ||
D = [] | ||
for _ in range(N): | ||
traj = Trajectory() | ||
timestep = env.reset() | ||
traj.add(timestep, None) | ||
while not timestep.last(): | ||
out_agent = agent.choose_action(timestep, **kwargs) | ||
action = out_agent.pop('raw_action') | ||
timestep = env.step(action) | ||
timestep.info = {**timestep.info, **out_agent} | ||
traj.add(timestep, action) | ||
D.append(traj) | ||
return D | ||
|
||
|
||
class StepRunner(BaseRunner): | ||
def __init__(self, reset_on_call=True): | ||
self.reset_on_call = reset_on_call | ||
self.observation = None | ||
|
||
def __call__(self, agent, env, T, **kwargs): | ||
assert isinstance(env, VecEnv) and isinstance(env, VecStepInfo) and len(env) == 1 | ||
|
||
D = [Trajectory()] | ||
if self.reset_on_call: | ||
observation, _ = env.reset() | ||
def __call__(self, agent, env, T, **kwargs): | ||
assert isinstance(env, TimeStepEnv) | ||
D = [] | ||
traj = Trajectory() | ||
if self.reset_on_call or self.observation is None: | ||
timestep = env.reset() | ||
else: | ||
if self.observation is None: | ||
self.observation, _ = env.reset() | ||
observation = self.observation | ||
D[-1].add_observation(observation) | ||
timestep = TimeStep(StepType.FIRST, observation=self.observation, reward=None, done=None, info=None) | ||
traj.add(timestep, None) | ||
for t in range(T): | ||
out_agent = agent.choose_action(observation, **kwargs) | ||
out_agent = agent.choose_action(timestep, **kwargs) | ||
action = out_agent.pop('raw_action') | ||
next_observation, [reward], [step_info] = env.step(action) | ||
step_info.info = {**step_info.info, **out_agent} | ||
if step_info.last: | ||
D[-1].add_observation([step_info['last_observation']]) # add a batch dim | ||
else: | ||
D[-1].add_observation(next_observation) | ||
D[-1].add_action(action) | ||
D[-1].add_reward(reward) | ||
D[-1].add_step_info(step_info) | ||
if step_info.last: | ||
assert D[-1].completed | ||
D.append(Trajectory()) | ||
D[-1].add_observation(next_observation) # initial observation | ||
observation = next_observation | ||
if len(D[-1]) == 0: | ||
D = D[:-1] | ||
self.observation = observation | ||
timestep = env.step(action) | ||
timestep.info = {**timestep.info, **out_agent} | ||
traj.add(timestep, action) | ||
if timestep.last(): | ||
D.append(traj) | ||
traj = Trajectory() | ||
timestep = env.reset() | ||
traj.add(timestep, None) | ||
if traj.T > 0: | ||
D.append(traj) | ||
if not self.reset_on_call: | ||
self.observation = timestep.observation | ||
return D |
Oops, something went wrong.