diff --git a/garage/envs/wrappers/grayscale.py b/garage/envs/wrappers/grayscale.py index ab83864cf6..11a72e9159 100644 --- a/garage/envs/wrappers/grayscale.py +++ b/garage/envs/wrappers/grayscale.py @@ -1,8 +1,11 @@ """Grayscale wrapper for gym.Env.""" +import warnings + import gym from gym.spaces import Box import numpy as np from skimage import color +from skimage import img_as_ubyte class Grayscale(gym.Wrapper): @@ -40,11 +43,13 @@ def __init__(self, env): _low = env.observation_space.low.flatten()[0] _high = env.observation_space.high.flatten()[0] + assert _low == 0 + assert _high == 255 self._observation_space = Box( _low, _high, shape=env.observation_space.shape[:-1], - dtype=np.float32) + dtype=np.uint8) @property def observation_space(self): @@ -56,8 +61,13 @@ def observation_space(self, observation_space): self._observation_space = observation_space def _observation(self, obs): - obs = color.rgb2gray(np.asarray(obs, dtype=np.uint8)) - return obs + with warnings.catch_warnings(): + """ + Suppressing warning for possible precision loss + when converting from float64 to uint8 + """ + warnings.simplefilter("ignore") + return img_as_ubyte(color.rgb2gray((obs))) def reset(self): """gym.Env reset function.""" diff --git a/garage/envs/wrappers/resize.py b/garage/envs/wrappers/resize.py index aa94d90aab..ea15b06c75 100644 --- a/garage/envs/wrappers/resize.py +++ b/garage/envs/wrappers/resize.py @@ -1,7 +1,10 @@ """Resize wrapper for gym.Env.""" +import warnings + import gym from gym.spaces import Box import numpy as np +from skimage import img_as_ubyte from skimage.transform import resize @@ -40,8 +43,9 @@ def __init__(self, env, width, height): _low = env.observation_space.low.flatten()[0] _high = env.observation_space.high.flatten()[0] + self._dtype = env.observation_space.dtype self._observation_space = Box( - _low, _high, shape=[width, height], dtype=np.float32) + _low, _high, shape=[width, height], dtype=self._dtype) self._width = width self._height = height @@ -56,7 +60,17 @@ def observation_space(self, observation_space): self._observation_space = observation_space def _observation(self, obs): - return resize(obs, (self._width, self._height)) + with warnings.catch_warnings(): + """ + Suppressing warnings for + 1. possible precision loss when converting from float64 to uint8 + 2. anti-aliasing will be enabled by default in skimage 0.15 + """ + warnings.simplefilter("ignore") + obs = resize(obs, (self._width, self._height)) # now it's float + if self._dtype == np.uint8: + obs = img_as_ubyte(obs) + return obs def reset(self): """gym.Env reset function.""" diff --git a/garage/envs/wrappers/stack_frames.py b/garage/envs/wrappers/stack_frames.py index d6087f0e45..73f33b7840 100644 --- a/garage/envs/wrappers/stack_frames.py +++ b/garage/envs/wrappers/stack_frames.py @@ -40,7 +40,10 @@ def __init__(self, env, n_frames): _low = env.observation_space.low.flatten()[0] _high = env.observation_space.high.flatten()[0] self._observation_space = Box( - _low, _high, shape=new_obs_space_shape, dtype=np.float32) + _low, + _high, + shape=new_obs_space_shape, + dtype=env.observation_space.dtype) @property def observation_space(self): diff --git a/garage/replay_buffer/base.py b/garage/replay_buffer/base.py index d2c9772c6e..311692740d 100644 --- a/garage/replay_buffer/base.py +++ b/garage/replay_buffer/base.py @@ -18,7 +18,11 @@ class ReplayBuffer(metaclass=abc.ABCMeta): """Abstract class for Replay Buffer.""" - def __init__(self, env_spec, size_in_transitions, time_horizon): + def __init__(self, + env_spec, + size_in_transitions, + time_horizon, + dtype=np.float32): """ Initialize the data used in ReplayBuffer. @@ -33,6 +37,7 @@ def __init__(self, env_spec, size_in_transitions, time_horizon): self._initialized_buffer = False self._buffer = {} self._episode_buffer = {} + self._dtype = dtype def store_episode(self): """Add an episode to the buffer.""" @@ -66,7 +71,8 @@ def _initialize_buffer(self, **kwargs): for key, value in kwargs.items(): self._episode_buffer[key] = list() self._buffer[key] = np.zeros( - [self._size, self._time_horizon, *np.array(value).shape[1:]]) + [self._size, self._time_horizon, *np.array(value).shape[1:]], + dtype=self._dtype) self._initialized_buffer = True def _get_storage_idx(self, size_increment=1): diff --git a/garage/spaces/box.py b/garage/spaces/box.py index c1fbd813cd..7c6e98bc13 100644 --- a/garage/spaces/box.py +++ b/garage/spaces/box.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from garage.spaces import Space @@ -9,13 +11,18 @@ class Box(Space): I.e., each coordinate is bounded. """ - def __init__(self, low, high, shape=None): + def __init__(self, low, high, shape=None, dtype=np.float32): """ Two kinds of valid input: Box(-1.0, 1.0, (3,4)) # low and high are scalars, and shape is provided Box(np.array([-1.0,-2.0]), np.array([2.0,4.0])) # low and high are arrays of the same shape + + If dtype is not specified, we assume dtype to be np.float32, + but when low=0 and high=255, it is very likely to be np.uint8. + We autodetect this case and warn user. It is different from gym.Box, + where they warn user as long as dtype is not specified. """ if shape is None: assert low.shape == high.shape @@ -26,10 +33,25 @@ def __init__(self, low, high, shape=None): self.low = low + np.zeros(shape) self.high = high + np.zeros(shape) + if (self.low == 0).all() and ( + self.high == 255).all() and dtype != np.uint8: + warnings.warn("Creating a garage.spaces.Box with low=0, high=255 " + "and dtype=np.float32.") + + self.dtype = dtype + def sample(self): - return np.random.uniform( - low=self.low, high=self.high, - size=self.low.shape).astype(np.float32) + if self.dtype == np.uint8: + # since np.random.randint() does not accept array as input + low = np.take(self.low, 0) + high = np.take(self.high, 0) + return np.random.randint( + low=low, high=high + 1, size=self.low.shape).astype( + self.dtype, copy=False) + else: + return np.random.uniform( + low=self.low, high=self.high, size=self.low.shape).astype( + self.dtype, copy=False) def contains(self, x): return x.shape == self.shape and (x >= self.low).all() and ( diff --git a/garage/tf/envs/base.py b/garage/tf/envs/base.py index fadb7e4726..20ffc9f979 100644 --- a/garage/tf/envs/base.py +++ b/garage/tf/envs/base.py @@ -42,7 +42,7 @@ def _to_garage_space(self, space): space (garage.tf.spaces) """ if isinstance(space, GymBox): - return Box(low=space.low, high=space.high) + return Box(low=space.low, high=space.high, dtype=space.dtype) elif isinstance(space, GymDict): return Dict(space.spaces) elif isinstance(space, GymDiscrete): diff --git a/garage/tf/spaces/box.py b/garage/tf/spaces/box.py index 281ae2eb29..c33ae1d35c 100644 --- a/garage/tf/spaces/box.py +++ b/garage/tf/spaces/box.py @@ -18,19 +18,10 @@ def new_tensor_variable(self, name, extra_dims, flatten=True): """ if flatten: return tf.placeholder( - tf.float32, + self.dtype, shape=[None] * extra_dims + [self.flat_dim], name=name) return tf.placeholder( - tf.float32, + self.dtype, shape=[None] * extra_dims + list(self.shape), name=name) - - @property - def dtype(self): - """ - Return the Tensor element's type. - - :return: data type of the Tensor element - """ - return tf.float32 diff --git a/tests/fixtures/envs/dummy/__init__.py b/tests/fixtures/envs/dummy/__init__.py index 5bf1d9bc62..b87513b143 100644 --- a/tests/fixtures/envs/dummy/__init__.py +++ b/tests/fixtures/envs/dummy/__init__.py @@ -1,5 +1,12 @@ +from tests.fixtures.envs.dummy.base import DummyEnv from tests.fixtures.envs.dummy.dummy_box_env import DummyBoxEnv from tests.fixtures.envs.dummy.dummy_dict_env import DummyDictEnv +from tests.fixtures.envs.dummy.dummy_discrete_2d_env import DummyDiscrete2DEnv from tests.fixtures.envs.dummy.dummy_discrete_env import DummyDiscreteEnv +from tests.fixtures.envs.dummy.dummy_discrete_pixel_env import ( + DummyDiscretePixelEnv) -__all__ = ["DummyBoxEnv", "DummyDictEnv", "DummyDiscreteEnv"] +__all__ = [ + "DummyEnv", "DummyBoxEnv", "DummyDictEnv", "DummyDiscrete2DEnv", + "DummyDiscreteEnv", "DummyDiscretePixelEnv" +] diff --git a/tests/fixtures/envs/dummy/base.py b/tests/fixtures/envs/dummy/base.py new file mode 100644 index 0000000000..a70b472fbd --- /dev/null +++ b/tests/fixtures/envs/dummy/base.py @@ -0,0 +1,27 @@ +import gym + + +class DummyEnv(gym.Env): + """Base dummy environment.""" + + def __init__(self, random): + self.random = random + self.state = None + + @property + def observation_space(self): + """Return an observation space.""" + raise NotImplementedError + + @property + def action_space(self): + """Return an action space.""" + raise NotImplementedError + + def reset(self): + """Reset the environment.""" + raise NotImplementedError + + def step(self, action): + """Step the environment.""" + raise NotImplementedError diff --git a/tests/fixtures/envs/dummy/dummy_box_env.py b/tests/fixtures/envs/dummy/dummy_box_env.py index f335fd473f..53c6181a6c 100644 --- a/tests/fixtures/envs/dummy/dummy_box_env.py +++ b/tests/fixtures/envs/dummy/dummy_box_env.py @@ -1,15 +1,19 @@ import gym import numpy as np +from tests.fixtures.envs.dummy import DummyEnv -class DummyBoxEnv(gym.Env): + +class DummyBoxEnv(DummyEnv): """A dummy box environment.""" + def __init__(self, random=True): + super().__init__(random) + @property def observation_space(self): """Return an observation space.""" - return gym.spaces.Box( - low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + return gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32) @property def action_space(self): @@ -23,4 +27,4 @@ def reset(self): def step(self, action): """Step the environment.""" - return np.zeros(1), 0, True, dict() + return self.observation_space.sample(), 0, True, dict() diff --git a/tests/fixtures/envs/dummy/dummy_dict_env.py b/tests/fixtures/envs/dummy/dummy_dict_env.py index d7d8766df2..355d87a574 100644 --- a/tests/fixtures/envs/dummy/dummy_dict_env.py +++ b/tests/fixtures/envs/dummy/dummy_dict_env.py @@ -1,10 +1,15 @@ import gym import numpy as np +from tests.fixtures.envs.dummy import DummyEnv -class DummyDictEnv(gym.Env): + +class DummyDictEnv(DummyEnv): """A dummy dict environment.""" + def __init__(self, random=True): + super().__init__(random) + @property def observation_space(self): """Return an observation space.""" diff --git a/tests/fixtures/envs/dummy/dummy_discrete_2d_env.py b/tests/fixtures/envs/dummy/dummy_discrete_2d_env.py new file mode 100644 index 0000000000..9b174bfd45 --- /dev/null +++ b/tests/fixtures/envs/dummy/dummy_discrete_2d_env.py @@ -0,0 +1,40 @@ +import gym +import numpy as np + +from tests.fixtures.envs.dummy import DummyEnv + + +class DummyDiscrete2DEnv(DummyEnv): + """A dummy discrete environment.""" + + def __init__(self, random=True): + super().__init__(random) + + @property + def observation_space(self): + """Return an observation space.""" + self.shape = (2, 2) + return gym.spaces.Box( + low=-1, high=1, shape=self.shape, dtype=np.float32) + + @property + def action_space(self): + """Return an action space.""" + return gym.spaces.Discrete(2) + + def reset(self): + """Reset the environment.""" + self.state = np.zeros(self.shape) + return self.state + + def step(self, action): + """Step the environment.""" + if self.state is not None: + if self.random: + obs = self.observation_space.sample() + else: + obs = self.state + action / 10. + else: + raise RuntimeError( + "DummyEnv: reset() must be called before step()!") + return obs, 0, True, dict() diff --git a/tests/fixtures/envs/dummy/dummy_discrete_env.py b/tests/fixtures/envs/dummy/dummy_discrete_env.py index 7d6d7f631e..735f095f3b 100644 --- a/tests/fixtures/envs/dummy/dummy_discrete_env.py +++ b/tests/fixtures/envs/dummy/dummy_discrete_env.py @@ -1,15 +1,19 @@ import gym import numpy as np +from tests.fixtures.envs.dummy import DummyEnv -class DummyDiscreteEnv(gym.Env): + +class DummyDiscreteEnv(DummyEnv): """A dummy discrete environment.""" + def __init__(self, random=True): + super().__init__(random) + @property def observation_space(self): """Return an observation space.""" - return gym.spaces.Box( - low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + return gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32) @property def action_space(self): @@ -18,8 +22,17 @@ def action_space(self): def reset(self): """Reset the environment.""" - return np.zeros(1) + self.state = np.zeros(1) + return self.state def step(self, action): """Step the environment.""" - return np.zeros(1), 0, True, dict() + if self.state is not None: + if self.random: + obs = self.observation_space.sample() + else: + obs = self.state + action / 10. + else: + raise RuntimeError( + "DummyEnv: reset() must be called before step()!") + return obs, 0, True, dict() diff --git a/tests/fixtures/envs/dummy/dummy_discrete_pixel_env.py b/tests/fixtures/envs/dummy/dummy_discrete_pixel_env.py new file mode 100644 index 0000000000..9e23528be7 --- /dev/null +++ b/tests/fixtures/envs/dummy/dummy_discrete_pixel_env.py @@ -0,0 +1,49 @@ +import gym +import numpy as np + +from tests.fixtures.envs.dummy import DummyEnv + + +class DummyDiscretePixelEnv(DummyEnv): + """A dummy discrete environment.""" + + def __init__(self, random=True): + super().__init__(random) + + @property + def observation_space(self): + """Return an observation space.""" + self.shape = (10, 10, 3) + self._observation_space = gym.spaces.Box( + low=0, high=255, shape=self.shape, dtype=np.uint8) + return self._observation_space + + @property + def action_space(self): + """Return an action space.""" + return gym.spaces.Discrete(2) + + def reset(self): + """Reset the environment.""" + self.state = np.zeros(self.shape, dtype=np.uint8) + return self.state + + def step(self, action): + """ + Step the environment. + + Before gym fixed overflow issue for sample() in + np.uint8 environment, we will handle the sampling here. + We need high=256 since np.random.uniform sample from [low, high) + (includes low, but excludes high). + """ + if self.state is not None: + if self.random: + obs = np.random.uniform( + low=0, high=256, size=self.shape).astype(np.uint8) + else: + obs = self.state + action + else: + raise RuntimeError( + "DummyEnv: reset() must be called before step()!") + return obs, 0, True, dict() diff --git a/tests/garage/envs/test_grayscale_env.py b/tests/garage/envs/test_grayscale_env.py index 4cff4834be..670e5afa78 100644 --- a/tests/garage/envs/test_grayscale_env.py +++ b/tests/garage/envs/test_grayscale_env.py @@ -1,5 +1,4 @@ import unittest -from unittest import mock from gym.spaces import Box from gym.spaces import Discrete @@ -7,25 +6,15 @@ from garage.envs.wrappers import Grayscale from garage.misc.overrides import overrides +from garage.tf.envs import TfEnv +from tests.fixtures.envs.dummy import DummyDiscretePixelEnv class TestGrayscale(unittest.TestCase): @overrides def setUp(self): - self.shape = (50, 50, 3) - self.env = mock.Mock() - self.env.observation_space = Box( - low=0, high=255, shape=self.shape, dtype=np.uint8) - self.env.reset.return_value = np.zeros(self.shape) - self.env.step.side_effect = self._step - - self.env_g = Grayscale(self.env) - - self.obs = self.env.reset() - self.obs_g = self.env_g.reset() - - def _step(self, action): - return np.full(self.shape, 125), 0, False, dict() + self.env = TfEnv(DummyDiscretePixelEnv(random=False)) + self.env_g = TfEnv(Grayscale(DummyDiscretePixelEnv(random=False))) def test_gray_scale_invalid_environment_type(self): with self.assertRaises(ValueError): @@ -39,7 +28,8 @@ def test_gray_scale_invalid_environment_shape(self): Grayscale(self.env) def test_grayscale_observation_space(self): - assert self.env_g.observation_space.shape == self.shape[:-1] + assert self.env_g.observation_space.shape == ( + self.env.observation_space.shape[:-1]) def test_grayscale_reset(self): """ @@ -51,14 +41,18 @@ def test_grayscale_reset(self): Reference: http://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2grey """ - gray_scale_output = np.dot(self.obs[:, :, :3], - [0.2125, 0.7154, 0.0721]) / 255.0 - np.testing.assert_array_almost_equal(gray_scale_output, self.obs_g) + gray_scale_output = np.round( + np.dot(self.env.reset()[:, :, :3], + [0.2125, 0.7154, 0.0721])).astype(np.uint8) + np.testing.assert_array_almost_equal(gray_scale_output, + self.env_g.reset()) def test_grayscale_step(self): - obs, _, _, _ = self.env.step(0) - obs_g, _, _, _ = self.env_g.step(0) + self.env.reset() + self.env_g.reset() + obs, _, _, _ = self.env.step(1) + obs_g, _, _, _ = self.env_g.step(1) - gray_scale_output = np.dot(obs[:, :, :3], - [0.2125, 0.7154, 0.0721]) / 255.0 + gray_scale_output = np.round( + np.dot(obs[:, :, :3], [0.2125, 0.7154, 0.0721])).astype(np.uint8) np.testing.assert_array_almost_equal(gray_scale_output, obs_g) diff --git a/tests/garage/envs/test_repeat_action_env.py b/tests/garage/envs/test_repeat_action_env.py index 913bf7d3b9..d2c432f63c 100644 --- a/tests/garage/envs/test_repeat_action_env.py +++ b/tests/garage/envs/test_repeat_action_env.py @@ -1,41 +1,26 @@ import unittest -from unittest import mock -from gym.spaces import Box import numpy as np from garage.envs.wrappers import RepeatAction from garage.misc.overrides import overrides +from garage.tf.envs import TfEnv +from tests.fixtures.envs.dummy import DummyDiscreteEnv class TestRepeatAction(unittest.TestCase): @overrides def setUp(self): - self.shape = (16, ) - self.env = mock.Mock() - self.env.observation_space = Box( - low=0, high=255, shape=self.shape, dtype=np.float32) - self.env.reset.return_value = np.zeros(self.shape) - self.env.step.side_effect = self._step - - self.env_r = RepeatAction(self.env, n_frame_to_repeat=4) - - self.obs = self.env.reset() - self.obs_r = self.env_r.reset() - - def _step(self, action): - def generate(): - for i in range(0, 255): - yield np.full(self.shape, i) - - generator = generate() - - return next(generator), 0, False, dict() + self.env = TfEnv(DummyDiscreteEnv(random=False)) + self.env_r = TfEnv( + RepeatAction(DummyDiscreteEnv(random=False), n_frame_to_repeat=4)) def test_repeat_action_reset(self): - np.testing.assert_array_equal(self.obs, self.obs_r) + np.testing.assert_array_equal(self.env.reset(), self.env_r.reset()) def test_repeat_action_step(self): + self.env.reset() + self.env_r.reset() obs_repeat, _, _, _ = self.env_r.step(1) for i in range(4): obs, _, _, _ = self.env.step(1) diff --git a/tests/garage/envs/test_resize_env.py b/tests/garage/envs/test_resize_env.py index 54b720e5ed..a604b284e2 100644 --- a/tests/garage/envs/test_resize_env.py +++ b/tests/garage/envs/test_resize_env.py @@ -1,5 +1,4 @@ import unittest -from unittest import mock from gym.spaces import Box from gym.spaces import Discrete @@ -7,47 +6,37 @@ from garage.envs.wrappers import Resize from garage.misc.overrides import overrides +from garage.tf.envs import TfEnv +from tests.fixtures.envs.dummy import DummyDiscrete2DEnv class TestResize(unittest.TestCase): @overrides def setUp(self): - self.shape = (50, 50) - self.env = mock.Mock() - self.env.observation_space = Box( - low=0, high=255, shape=self.shape, dtype=np.uint8) - self.env.reset.return_value = np.zeros(self.shape) - self.env.step.side_effect = self._step - - self._width = 16 - self._height = 16 - self.env_r = Resize(self.env, width=self._width, height=self._height) - - self.obs = self.env.reset() - self.obs_r = self.env_r.reset() - - def _step(self, action): - return np.full(self.shape, 125), 0, False, dict() + self.width = 16 + self.height = 16 + self.env = TfEnv(DummyDiscrete2DEnv()) + self.env_r = TfEnv( + Resize(DummyDiscrete2DEnv(), width=self.width, height=self.height)) def test_resize_invalid_environment_type(self): with self.assertRaises(ValueError): self.env.observation_space = Discrete(64) - Resize(self.env, width=self._width, height=self._height) + Resize(self.env, width=self.width, height=self.height) def test_resize_invalid_environment_shape(self): with self.assertRaises(ValueError): self.env.observation_space = Box( low=0, high=255, shape=(4, ), dtype=np.uint8) - Resize(self.env, width=self._width, height=self._height) + Resize(self.env, width=self.width, height=self.height) def test_resize_output_observation_space(self): - assert self.env_r.observation_space.shape == (self._width, - self._height) + assert self.env_r.observation_space.shape == (self.width, self.height) def test_resize_output_reset(self): - assert self.obs_r.shape == (self._width, self._height) + assert self.env_r.reset().shape == (self.width, self.height) def test_resize_output_step(self): - obs_r, _, _, _ = self.env_r.step(0) - - assert obs_r.shape == (self._width, self._height) + self.env_r.reset() + obs_r, _, _, _ = self.env_r.step(1) + assert obs_r.shape == (self.width, self.height) diff --git a/tests/garage/envs/test_stack_frames_env.py b/tests/garage/envs/test_stack_frames_env.py index e2285a6761..a2f48ec6c4 100644 --- a/tests/garage/envs/test_stack_frames_env.py +++ b/tests/garage/envs/test_stack_frames_env.py @@ -1,5 +1,4 @@ import unittest -from unittest import mock from gym.spaces import Box from gym.spaces import Discrete @@ -7,34 +6,19 @@ from garage.envs.wrappers import StackFrames from garage.misc.overrides import overrides +from garage.tf.envs import TfEnv +from tests.fixtures.envs.dummy import DummyDiscrete2DEnv class TestStackFrames(unittest.TestCase): @overrides def setUp(self): - self.shape = (50, 50) - self.env = mock.Mock() - self.env.observation_space = Box( - low=0, high=255, shape=self.shape, dtype=np.uint8) - self.env.reset.return_value = np.zeros(self.shape) - self.env.step.side_effect = self._step - - self._n_frames = 4 - self.env_s = StackFrames(self.env, n_frames=self._n_frames) - - self.obs = self.env.reset() - self.obs_s = self.env_s.reset() - self.frame_width = self.env.observation_space.shape[0] - self.frame_height = self.env.observation_space.shape[1] - - def _step(self, action): - def generate(): - for i in range(0, 255): - yield np.full(self.shape, i) - - generator = generate() - - return next(generator), 0, False, dict() + self.n_frames = 4 + self.env = TfEnv(DummyDiscrete2DEnv(random=False)) + self.env_s = TfEnv( + StackFrames( + DummyDiscrete2DEnv(random=False), n_frames=self.n_frames)) + self.width, self.height = self.env.observation_space.shape def test_stack_frames_invalid_environment_type(self): with self.assertRaises(ValueError): @@ -48,24 +32,25 @@ def test_stack_frames_invalid_environment_shape(self): StackFrames(self.env, n_frames=4) def test_stack_frames_output_observation_space(self): - assert self.env_s.observation_space.shape == (self.frame_width, - self.frame_height, - self._n_frames) + assert self.env_s.observation_space.shape == (self.width, self.height, + self.n_frames) def test_stack_frames_for_reset(self): - frame_stack = self.obs - for i in range(self._n_frames - 1): - frame_stack = np.dstack((frame_stack, self.obs)) + frame_stack = self.env.reset() + for i in range(self.n_frames - 1): + frame_stack = np.dstack((frame_stack, self.env.reset())) - np.testing.assert_array_equal(self.obs_s, frame_stack) + np.testing.assert_array_equal(self.env_s.reset(), frame_stack) def test_stack_frames_for_step(self): - frame_stack = np.empty((self.frame_width, self.frame_height, - self._n_frames)) + self.env.reset() + self.env_s.reset() + + frame_stack = np.empty((self.width, self.height, self.n_frames)) for i in range(10): frame_stack = frame_stack[:, :, 1:] - obs, _, _, _ = self.env.step(0) + obs, _, _, _ = self.env.step(1) frame_stack = np.dstack((frame_stack, obs)) + obs_stack, _, _, _ = self.env_s.step(1) - obs_stack, _, _, _ = self.env_s.step(0) np.testing.assert_array_equal(obs_stack, frame_stack) diff --git a/tests/garage/spaces/test_box.py b/tests/garage/spaces/test_box.py index 8fe3fd8ba5..f357c82ca8 100644 --- a/tests/garage/spaces/test_box.py +++ b/tests/garage/spaces/test_box.py @@ -14,3 +14,37 @@ def test_pickleable(self): assert round_trip.shape == obj.shape assert np.array_equal(round_trip.bounds[0], obj.bounds[0]) assert np.array_equal(round_trip.bounds[1], obj.bounds[1]) + + def test_same_dtype(self): + type1 = np.float32 + box = Box(low=0, high=255, shape=(3, 4), dtype=type1) + assert box.dtype == type1 + + type2 = np.uint8 + box = Box(low=0, high=255, shape=(3, 4), dtype=type2) + assert box.dtype == type2 + + def test_invalid_env(self): + with self.assertRaises(AttributeError): + Box(low=0.0, high=1.0) + + with self.assertRaises(AssertionError): + Box(low=np.array([-1.0, -2.0]), + high=np.array([1.0, 2.0]), + shape=(2, 2)) + + def test_default_float32_env(self): + box = Box(low=0.0, high=1.0, shape=(3, 4)) + assert box.dtype == np.float32 + + box = Box(low=np.array([-1.0, -2.0]), high=np.array([1.0, 2.0])) + assert box.dtype == np.float32 + + def test_uint8_warning_env(self): + with self.assertWarns(UserWarning): + box = Box(low=0, high=255, shape=(3, 4)) + assert box.dtype == np.float32 + + with self.assertWarns(UserWarning): + box = Box(low=np.array([0, 0]), high=np.array([255, 255])) + assert box.dtype == np.float32