Skip to content

Commit

Permalink
Compare other infos when test env pickling (#2033)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeukfu committed Sep 11, 2020
1 parent bc8822c commit d8f0235
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/garage/envs/env_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,16 @@ def observation_space(self, observation_space):
"""
self._output_space = observation_space

def __eq__(self, other):
"""See :meth:`object.__eq__`.
Args:
other (EnvSpec): :class:`~EnvSpec` to compare with.
Returns:
bool: Whether these :class:`~EnvSpec` instances are equal.
"""
return (self.observation_space == other.observation_space
and self.action_space == other.action_space)
6 changes: 6 additions & 0 deletions tests/garage/tf/envs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ class TestGarageEnv:
def test_is_pickleable(self):
env = GarageEnv(env_name='CartPole-v1')
round_trip = pickle.loads(pickle.dumps(env))
assert round_trip.spec == env.spec
assert round_trip.env.spec.id == env.env.spec.id
assert (round_trip.env.spec.max_episode_steps ==
env.env.spec.max_episode_steps)

@pytest.mark.nightly
@pytest.mark.parametrize('spec', list(gym.envs.registry.all()))
Expand All @@ -29,6 +32,9 @@ def test_all_gym_envs_pickleable(self, spec):
if spec._env_name.startswith('Defender'):
pytest.skip(
'Defender-* envs bundled in atari-py 0.2.x don\'t load')
if spec.id == 'KellyCoinflipGeneralized-v0':
pytest.skip(
'KellyCoinflipGeneralized-v0\'s action space is random')
env = GarageEnv(env_name=spec.id)
step_env_with_gym_quirks(env,
spec,
Expand Down
6 changes: 6 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def step_env_with_gym_quirks(env, spec, n=10, render=True,
if serialize_env:
# Roundtrip serialization
round_trip = pickle.loads(pickle.dumps(env))
assert round_trip.spec == env.spec
assert round_trip.env.spec.id == env.env.spec.id
assert (round_trip.env.spec.max_episode_steps ==
env.env.spec.max_episode_steps)
env = round_trip

env.reset()
Expand All @@ -63,7 +66,10 @@ def step_env_with_gym_quirks(env, spec, n=10, render=True,
if serialize_env:
# Roundtrip serialization
round_trip = pickle.loads(pickle.dumps(env))
assert round_trip.spec == env.spec
assert round_trip.env.spec.id == env.env.spec.id
assert (round_trip.env.spec.max_episode_steps ==
env.env.spec.max_episode_steps)


def convolve(_input, filter_weights, filter_bias, strides, filters,
Expand Down

0 comments on commit d8f0235

Please sign in to comment.