Skip to content

Commit

Permalink
Add tests for deterministic policy eval
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejwolczyk authored and mergify-bot committed Jul 3, 2020
1 parent f148dc6 commit 964caa4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/fixtures/policies/dummy_policy.py
Expand Up @@ -17,6 +17,7 @@ def __init__(
self,
env_spec,
):
# pylint: disable=super-init-not-called
self._env_spec = env_spec
self._param = []
self._param_values = np.random.uniform(-1, 1, 1000)
Expand All @@ -32,7 +33,7 @@ def get_action(self, observation):
dict: Distribution parameters.
"""
return self.action_space.sample(), dict(dummy='dummy')
return self.action_space.sample(), dict(dummy='dummy', mean=0.)

def get_actions(self, observations):
"""Get multiple actions from this policy for the input observations.
Expand Down
7 changes: 7 additions & 0 deletions tests/garage/sampler/test_utils.py
Expand Up @@ -32,6 +32,13 @@ def test_does_flatten(self):
assert path['observations'][0].shape == (16, )
assert path['actions'][0].shape == (2, 2)

def test_deterministic_action(self):
path = utils.rollout(self.env,
self.policy,
max_path_length=5,
deterministic=True)
assert (path['actions'] == 0.).all()


class TestTruncatePaths:

Expand Down

0 comments on commit 964caa4

Please sign in to comment.