From 0643db5ac215460258a1b00a4fa3ae3a0fe31568 Mon Sep 17 00:00:00 2001 From: tsan Date: Fri, 1 Mar 2019 23:41:58 -0800 Subject: [PATCH] Add test --- .../replay_buffer/test_replay_buffer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/garage/replay_buffer/test_replay_buffer.py diff --git a/tests/garage/replay_buffer/test_replay_buffer.py b/tests/garage/replay_buffer/test_replay_buffer.py new file mode 100644 index 0000000000..614d8f0175 --- /dev/null +++ b/tests/garage/replay_buffer/test_replay_buffer.py @@ -0,0 +1,20 @@ +import unittest + +from garage.replay_buffer import SimpleReplayBuffer +from tests.fixtures.envs.dummy import DummyDiscreteEnv + + +class TestReplayBuffer(unittest.TestCase): + def test_replay_buffer_dtype(self): + env = DummyDiscreteEnv() + obs = env.reset() + replay_buffer = SimpleReplayBuffer( + env_spec=env, size_in_transitions=100, time_horizon=1) + replay_buffer.add_transition( + observation=[obs], action=[env.action_space.sample()]) + sample = replay_buffer.sample(1) + sample_obs = sample['observation'] + sample_action = sample['action'] + + assert sample_obs.dtype == env.observation_space.dtype + assert sample_action.dtype == env.action_space.dtype