Skip to content

Commit

Permalink
Fix a typo to allow evaluating algos deterministically (#1617) (#1714)
Browse files Browse the repository at this point in the history
* Fix deterministic policy evaluation
* Add tests for deterministic policy eval
* Fix formatting in dummy policy init

Co-authored-by: Maciej Wołczyk <raihid888@gmail.com>
  • Loading branch information
2 people authored and irisliucy committed Aug 18, 2020
1 parent 1c5e74a commit 0c03385
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 37 deletions.
77 changes: 40 additions & 37 deletions src/garage/sampler/utils.py
@@ -1,5 +1,7 @@
"""Utility functions related to sampling."""

import time

import numpy as np

from garage.misc import tensor_utils
Expand All @@ -8,24 +10,22 @@
def rollout(env,
agent,
*,
max_episode_length=np.inf,
max_path_length=np.inf,
animated=False,
speedup=1,
deterministic=False):
"""Sample a single episode of the agent in the environment.
"""Sample a single rollout of the agent in the environment.
Args:
agent (Policy): Agent used to select actions.
env (Environment): Environment to perform actions in.
max_episode_length (int): If the episode reaches this many timesteps,
it is truncated.
animated (bool): If true, render the environment after each step.
speedup (float): Factor by which to decrease the wait time between
agent(Policy): Agent used to select actions.
env(gym.Env): Environment to perform actions in.
max_path_length(int): If the rollout reaches this many timesteps, it is
terminated.
animated(bool): If true, render the environment after each step.
speedup(float): Factor by which to decrease the wait time between
rendered steps. Only relevant, if animated == true.
deterministic (bool): If true, use the mean action returned by the
stochastic policy instead of sampling from the returned action
distribution.
Returns:
dict[str, np.ndarray or dict]: Dictionary, with keys:
* observations(np.array): Flattened array of observations.
Expand All @@ -43,47 +43,53 @@ def rollout(env,
* env_infos(Dict[str, np.array]): Dictionary of stacked,
non-flattened `env_info` arrays.
* dones(np.array): Array of termination signals.
"""
del speedup
env_steps = []
agent_infos = []
observations = []
last_obs = env.reset()[0]
actions = []
rewards = []
agent_infos = []
env_infos = []
dones = []
o = env.reset()
agent.reset()
episode_length = 0
path_length = 0
if animated:
env.visualize()
while episode_length < (max_episode_length or np.inf):
a, agent_info = agent.get_action(last_obs)
if deterministic and 'mean' in agent_info:
env.render()
while path_length < (max_path_length or np.inf):
o = env.observation_space.flatten(o)
a, agent_info = agent.get_action(o)
if deterministic and 'mean' in agent_infos:
a = agent_info['mean']
es = env.step(a)
env_steps.append(es)
observations.append(last_obs)
next_o, r, d, env_info = env.step(a)
observations.append(o)
rewards.append(r)
actions.append(a)
agent_infos.append(agent_info)
episode_length += 1
if es.last:
env_infos.append(env_info)
dones.append(d)
path_length += 1
if d:
break
last_obs = es.observation
o = next_o
if animated:
env.render()
timestep = 0.05
time.sleep(timestep / speedup)

return dict(
observations=np.array(observations),
actions=np.array([es.action for es in env_steps]),
rewards=np.array([es.reward for es in env_steps]),
actions=np.array(actions),
rewards=np.array(rewards),
agent_infos=tensor_utils.stack_tensor_dict_list(agent_infos),
env_infos=tensor_utils.stack_tensor_dict_list(
[es.env_info for es in env_steps]),
dones=np.array([es.terminal for es in env_steps]),
env_infos=tensor_utils.stack_tensor_dict_list(env_infos),
dones=np.array(dones),
)


def truncate_paths(paths, max_samples):
"""Truncate the paths so that the total number of samples is max_samples.
This is done by removing extra paths at the end of
the list, and make the last path shorter if necessary
Args:
paths (list[dict[str, np.ndarray]]): Samples, items with keys:
* observations (np.ndarray): Enviroment observations
Expand All @@ -92,15 +98,12 @@ def truncate_paths(paths, max_samples):
* env_infos (dict): Environment state information
* agent_infos (dict): Agent state information
max_samples(int) : Maximum number of samples allowed.
Returns:
list[dict[str, np.ndarray]]: A list of paths, truncated so that the
number of samples adds up to max-samples
Raises:
ValueError: If key a other than 'observations', 'actions', 'rewards',
'env_infos' and 'agent_infos' is found.
"""
# chop samples collected by extra paths
# make a copy
Expand All @@ -127,4 +130,4 @@ def truncate_paths(paths, max_samples):
'Unexpected key {} found in path. Valid keys: {}'.format(
k, valid_keys))
paths.append(truncated_last_path)
return paths
return paths
7 changes: 7 additions & 0 deletions tests/garage/sampler/test_utils.py
Expand Up @@ -30,6 +30,13 @@ def test_deterministic_action(self):
deterministic=True)
assert (path['actions'] == 0.).all()

def test_deterministic_action(self):
path = utils.rollout(self.env,
self.policy,
max_path_length=5,
deterministic=True)
assert (path['actions'] == 0.).all()


class TestTruncatePaths:

Expand Down

0 comments on commit 0c03385

Please sign in to comment.