Skip to content

Commit

Permalink
Support numpy.uint8 in some classes (#455)
Browse files Browse the repository at this point in the history
- Support numpy.uint8 in some classes
- Grayscale should only work on np.uint8.
- garage.spaces.Box should have dtype np.float32 by default.
- Use dummyEnv for unit test.
  • Loading branch information
ahtsan committed Jan 17, 2019
1 parent afbeee8 commit 582c3c5
Show file tree
Hide file tree
Showing 19 changed files with 319 additions and 141 deletions.
16 changes: 13 additions & 3 deletions garage/envs/wrappers/grayscale.py
@@ -1,8 +1,11 @@
"""Grayscale wrapper for gym.Env."""
import warnings

import gym
from gym.spaces import Box
import numpy as np
from skimage import color
from skimage import img_as_ubyte


class Grayscale(gym.Wrapper):
Expand Down Expand Up @@ -40,11 +43,13 @@ def __init__(self, env):

_low = env.observation_space.low.flatten()[0]
_high = env.observation_space.high.flatten()[0]
assert _low == 0
assert _high == 255
self._observation_space = Box(
_low,
_high,
shape=env.observation_space.shape[:-1],
dtype=np.float32)
dtype=np.uint8)

@property
def observation_space(self):
Expand All @@ -56,8 +61,13 @@ def observation_space(self, observation_space):
self._observation_space = observation_space

def _observation(self, obs):
obs = color.rgb2gray(np.asarray(obs, dtype=np.uint8))
return obs
with warnings.catch_warnings():
"""
Suppressing warning for possible precision loss
when converting from float64 to uint8
"""
warnings.simplefilter("ignore")
return img_as_ubyte(color.rgb2gray((obs)))

def reset(self):
"""gym.Env reset function."""
Expand Down
18 changes: 16 additions & 2 deletions garage/envs/wrappers/resize.py
@@ -1,7 +1,10 @@
"""Resize wrapper for gym.Env."""
import warnings

import gym
from gym.spaces import Box
import numpy as np
from skimage import img_as_ubyte
from skimage.transform import resize


Expand Down Expand Up @@ -40,8 +43,9 @@ def __init__(self, env, width, height):

_low = env.observation_space.low.flatten()[0]
_high = env.observation_space.high.flatten()[0]
self._dtype = env.observation_space.dtype
self._observation_space = Box(
_low, _high, shape=[width, height], dtype=np.float32)
_low, _high, shape=[width, height], dtype=self._dtype)

self._width = width
self._height = height
Expand All @@ -56,7 +60,17 @@ def observation_space(self, observation_space):
self._observation_space = observation_space

def _observation(self, obs):
return resize(obs, (self._width, self._height))
with warnings.catch_warnings():
"""
Suppressing warnings for
1. possible precision loss when converting from float64 to uint8
2. anti-aliasing will be enabled by default in skimage 0.15
"""
warnings.simplefilter("ignore")
obs = resize(obs, (self._width, self._height)) # now it's float
if self._dtype == np.uint8:
obs = img_as_ubyte(obs)
return obs

def reset(self):
"""gym.Env reset function."""
Expand Down
5 changes: 4 additions & 1 deletion garage/envs/wrappers/stack_frames.py
Expand Up @@ -40,7 +40,10 @@ def __init__(self, env, n_frames):
_low = env.observation_space.low.flatten()[0]
_high = env.observation_space.high.flatten()[0]
self._observation_space = Box(
_low, _high, shape=new_obs_space_shape, dtype=np.float32)
_low,
_high,
shape=new_obs_space_shape,
dtype=env.observation_space.dtype)

@property
def observation_space(self):
Expand Down
10 changes: 8 additions & 2 deletions garage/replay_buffer/base.py
Expand Up @@ -18,7 +18,11 @@
class ReplayBuffer(metaclass=abc.ABCMeta):
"""Abstract class for Replay Buffer."""

def __init__(self, env_spec, size_in_transitions, time_horizon):
def __init__(self,
env_spec,
size_in_transitions,
time_horizon,
dtype=np.float32):
"""
Initialize the data used in ReplayBuffer.
Expand All @@ -33,6 +37,7 @@ def __init__(self, env_spec, size_in_transitions, time_horizon):
self._initialized_buffer = False
self._buffer = {}
self._episode_buffer = {}
self._dtype = dtype

def store_episode(self):
"""Add an episode to the buffer."""
Expand Down Expand Up @@ -66,7 +71,8 @@ def _initialize_buffer(self, **kwargs):
for key, value in kwargs.items():
self._episode_buffer[key] = list()
self._buffer[key] = np.zeros(
[self._size, self._time_horizon, *np.array(value).shape[1:]])
[self._size, self._time_horizon, *np.array(value).shape[1:]],
dtype=self._dtype)
self._initialized_buffer = True

def _get_storage_idx(self, size_increment=1):
Expand Down
30 changes: 26 additions & 4 deletions garage/spaces/box.py
@@ -1,3 +1,5 @@
import warnings

import numpy as np

from garage.spaces import Space
Expand All @@ -9,13 +11,18 @@ class Box(Space):
I.e., each coordinate is bounded.
"""

def __init__(self, low, high, shape=None):
def __init__(self, low, high, shape=None, dtype=np.float32):
"""
Two kinds of valid input:
Box(-1.0, 1.0, (3,4)) # low and high are scalars, and shape is
provided
Box(np.array([-1.0,-2.0]), np.array([2.0,4.0])) # low and high are
arrays of the same shape
If dtype is not specified, we assume dtype to be np.float32,
but when low=0 and high=255, it is very likely to be np.uint8.
We autodetect this case and warn user. It is different from gym.Box,
where they warn user as long as dtype is not specified.
"""
if shape is None:
assert low.shape == high.shape
Expand All @@ -26,10 +33,25 @@ def __init__(self, low, high, shape=None):
self.low = low + np.zeros(shape)
self.high = high + np.zeros(shape)

if (self.low == 0).all() and (
self.high == 255).all() and dtype != np.uint8:
warnings.warn("Creating a garage.spaces.Box with low=0, high=255 "
"and dtype=np.float32.")

self.dtype = dtype

def sample(self):
return np.random.uniform(
low=self.low, high=self.high,
size=self.low.shape).astype(np.float32)
if self.dtype == np.uint8:
# since np.random.randint() does not accept array as input
low = np.take(self.low, 0)
high = np.take(self.high, 0)
return np.random.randint(
low=low, high=high + 1, size=self.low.shape).astype(
self.dtype, copy=False)
else:
return np.random.uniform(
low=self.low, high=self.high, size=self.low.shape).astype(
self.dtype, copy=False)

def contains(self, x):
return x.shape == self.shape and (x >= self.low).all() and (
Expand Down
2 changes: 1 addition & 1 deletion garage/tf/envs/base.py
Expand Up @@ -42,7 +42,7 @@ def _to_garage_space(self, space):
space (garage.tf.spaces)
"""
if isinstance(space, GymBox):
return Box(low=space.low, high=space.high)
return Box(low=space.low, high=space.high, dtype=space.dtype)
elif isinstance(space, GymDict):
return Dict(space.spaces)
elif isinstance(space, GymDiscrete):
Expand Down
13 changes: 2 additions & 11 deletions garage/tf/spaces/box.py
Expand Up @@ -18,19 +18,10 @@ def new_tensor_variable(self, name, extra_dims, flatten=True):
"""
if flatten:
return tf.placeholder(
tf.float32,
self.dtype,
shape=[None] * extra_dims + [self.flat_dim],
name=name)
return tf.placeholder(
tf.float32,
self.dtype,
shape=[None] * extra_dims + list(self.shape),
name=name)

@property
def dtype(self):
"""
Return the Tensor element's type.
:return: data type of the Tensor element
"""
return tf.float32
9 changes: 8 additions & 1 deletion tests/fixtures/envs/dummy/__init__.py
@@ -1,5 +1,12 @@
from tests.fixtures.envs.dummy.base import DummyEnv
from tests.fixtures.envs.dummy.dummy_box_env import DummyBoxEnv
from tests.fixtures.envs.dummy.dummy_dict_env import DummyDictEnv
from tests.fixtures.envs.dummy.dummy_discrete_2d_env import DummyDiscrete2DEnv
from tests.fixtures.envs.dummy.dummy_discrete_env import DummyDiscreteEnv
from tests.fixtures.envs.dummy.dummy_discrete_pixel_env import (
DummyDiscretePixelEnv)

__all__ = ["DummyBoxEnv", "DummyDictEnv", "DummyDiscreteEnv"]
__all__ = [
"DummyEnv", "DummyBoxEnv", "DummyDictEnv", "DummyDiscrete2DEnv",
"DummyDiscreteEnv", "DummyDiscretePixelEnv"
]
27 changes: 27 additions & 0 deletions tests/fixtures/envs/dummy/base.py
@@ -0,0 +1,27 @@
import gym


class DummyEnv(gym.Env):
"""Base dummy environment."""

def __init__(self, random):
self.random = random
self.state = None

@property
def observation_space(self):
"""Return an observation space."""
raise NotImplementedError

@property
def action_space(self):
"""Return an action space."""
raise NotImplementedError

def reset(self):
"""Reset the environment."""
raise NotImplementedError

def step(self, action):
"""Step the environment."""
raise NotImplementedError
12 changes: 8 additions & 4 deletions tests/fixtures/envs/dummy/dummy_box_env.py
@@ -1,15 +1,19 @@
import gym
import numpy as np

from tests.fixtures.envs.dummy import DummyEnv

class DummyBoxEnv(gym.Env):

class DummyBoxEnv(DummyEnv):
"""A dummy box environment."""

def __init__(self, random=True):
super().__init__(random)

@property
def observation_space(self):
"""Return an observation space."""
return gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32)
return gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32)

@property
def action_space(self):
Expand All @@ -23,4 +27,4 @@ def reset(self):

def step(self, action):
"""Step the environment."""
return np.zeros(1), 0, True, dict()
return self.observation_space.sample(), 0, True, dict()
7 changes: 6 additions & 1 deletion tests/fixtures/envs/dummy/dummy_dict_env.py
@@ -1,10 +1,15 @@
import gym
import numpy as np

from tests.fixtures.envs.dummy import DummyEnv

class DummyDictEnv(gym.Env):

class DummyDictEnv(DummyEnv):
"""A dummy dict environment."""

def __init__(self, random=True):
super().__init__(random)

@property
def observation_space(self):
"""Return an observation space."""
Expand Down
40 changes: 40 additions & 0 deletions tests/fixtures/envs/dummy/dummy_discrete_2d_env.py
@@ -0,0 +1,40 @@
import gym
import numpy as np

from tests.fixtures.envs.dummy import DummyEnv


class DummyDiscrete2DEnv(DummyEnv):
"""A dummy discrete environment."""

def __init__(self, random=True):
super().__init__(random)

@property
def observation_space(self):
"""Return an observation space."""
self.shape = (2, 2)
return gym.spaces.Box(
low=-1, high=1, shape=self.shape, dtype=np.float32)

@property
def action_space(self):
"""Return an action space."""
return gym.spaces.Discrete(2)

def reset(self):
"""Reset the environment."""
self.state = np.zeros(self.shape)
return self.state

def step(self, action):
"""Step the environment."""
if self.state is not None:
if self.random:
obs = self.observation_space.sample()
else:
obs = self.state + action / 10.
else:
raise RuntimeError(
"DummyEnv: reset() must be called before step()!")
return obs, 0, True, dict()
23 changes: 18 additions & 5 deletions tests/fixtures/envs/dummy/dummy_discrete_env.py
@@ -1,15 +1,19 @@
import gym
import numpy as np

from tests.fixtures.envs.dummy import DummyEnv

class DummyDiscreteEnv(gym.Env):

class DummyDiscreteEnv(DummyEnv):
"""A dummy discrete environment."""

def __init__(self, random=True):
super().__init__(random)

@property
def observation_space(self):
"""Return an observation space."""
return gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32)
return gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32)

@property
def action_space(self):
Expand All @@ -18,8 +22,17 @@ def action_space(self):

def reset(self):
"""Reset the environment."""
return np.zeros(1)
self.state = np.zeros(1)
return self.state

def step(self, action):
"""Step the environment."""
return np.zeros(1), 0, True, dict()
if self.state is not None:
if self.random:
obs = self.observation_space.sample()
else:
obs = self.state + action / 10.
else:
raise RuntimeError(
"DummyEnv: reset() must be called before step()!")
return obs, 0, True, dict()

0 comments on commit 582c3c5

Please sign in to comment.