From 964caa419210f71978830bbe553fc37070f73b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Wo=C5=82czyk?= Date: Thu, 2 Jul 2020 18:08:13 +0200 Subject: [PATCH] Add tests for deterministic policy eval --- tests/fixtures/policies/dummy_policy.py | 3 ++- tests/garage/sampler/test_utils.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/policies/dummy_policy.py b/tests/fixtures/policies/dummy_policy.py index 1e1c03e760..8c977edb14 100644 --- a/tests/fixtures/policies/dummy_policy.py +++ b/tests/fixtures/policies/dummy_policy.py @@ -17,6 +17,7 @@ def __init__( self, env_spec, ): + # pylint: disable=super-init-not-called self._env_spec = env_spec self._param = [] self._param_values = np.random.uniform(-1, 1, 1000) @@ -32,7 +33,7 @@ def get_action(self, observation): dict: Distribution parameters. """ - return self.action_space.sample(), dict(dummy='dummy') + return self.action_space.sample(), dict(dummy='dummy', mean=0.) def get_actions(self, observations): """Get multiple actions from this policy for the input observations. diff --git a/tests/garage/sampler/test_utils.py b/tests/garage/sampler/test_utils.py index e3bfe1ddf7..807d9c47e8 100644 --- a/tests/garage/sampler/test_utils.py +++ b/tests/garage/sampler/test_utils.py @@ -32,6 +32,13 @@ def test_does_flatten(self): assert path['observations'][0].shape == (16, ) assert path['actions'][0].shape == (2, 2) + def test_deterministic_action(self): + path = utils.rollout(self.env, + self.policy, + max_path_length=5, + deterministic=True) + assert (path['actions'] == 0.).all() + class TestTruncatePaths: