Skip to content

Commit

Permalink
Compatibility with dm_env and libEGL
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Aug 13, 2019
1 parent a1857b9 commit a728c3d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
10 changes: 5 additions & 5 deletions digideep/environment/dmc2gym/spec2space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from gym import spaces
# spaces.Dict | spaces.Tuple | spaces.Box | spaces.Discrete | spaces.MultiDiscrete | spaces.MultiBinary
from dm_control.rl import specs
from dm_env import specs
import numpy as np
import collections
# import warnings
Expand All @@ -19,22 +19,22 @@ def spec2space_single(spec):
:obj:`gym.spaces`: The ``gym`` equivalent ``spaces``.
"""

if (type(spec) is specs.BoundedArraySpec) and (spec.dtype == np.int):
if (type(spec) is specs.BoundedArray) and (spec.dtype == np.int):
# Discrete
# warnings.warn("The DMC environment uses a discrete action space!")
if spec.minimum == 0:
return spaces.Discrete(spec.maximum)
else:
raise ValueError("The environment's minimum values must be zero in the Discrete case!")
# Box
elif type(spec) is specs.BoundedArraySpec:
elif type(spec) is specs.BoundedArray:
_min = np.broadcast_to(spec.minimum, shape=spec.shape)
_max = np.broadcast_to(spec.maximum, shape=spec.shape)
# if clip_inf:
# _min = np.clip(_min, -sys.float_info.max, sys.float_info.max)
# _max = np.clip(_max, -sys.float_info.max, sys.float_info.max)
return spaces.Box(_min, _max, dtype=np.float32)
elif type(spec) is specs.ArraySpec:
elif type(spec) is specs.Array:
return spaces.Box(-np.inf, np.inf, shape=spec.shape, dtype=np.float32)
else:
raise ValueError('Unknown spec in spec2space_single!')
Expand All @@ -45,7 +45,7 @@ def spec2space(spec):
Caution:
Currently it supports ``spaces.Discrete``, ``spaces.Box``, and ``spaces.Dict`` as outputs.
"""
if isinstance(spec, specs.ArraySpec) or isinstance(spec, specs.BoundedArraySpec):
if isinstance(spec, specs.Array) or isinstance(spec, specs.BoundedArray):
return spec2space_single(spec)
elif isinstance(spec, collections.OrderedDict):
space = collections.OrderedDict()
Expand Down
17 changes: 13 additions & 4 deletions digideep/environment/dmc2gym/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
import six

from .spec2space import spec2space
from .viewer import Viewer

from dm_control._render import BACKEND
from dm_control._render import constants
if BACKEND==constants.GLFW:
from .viewer import Viewer
else:
Viewer=None

# TODO: flatten_observation does not work with nested dictionaries.
from dm_control.rl.control import flatten_observation
# from dm_control.rl.control import FLAT_OBSERVATION_KEY
from dm_control.rl.control import PhysicsError
from dm_control.rl import specs
from dm_env import specs


def _spec_from_observation(observation):
Expand All @@ -31,7 +37,7 @@ def _spec_from_observation(observation):
elif isinstance(value, dict):
raise NotImplementedError("'dict' types in observations are not supported as they may not preserve order. Use OrderedDict instead.")
else:
result[key] = specs.ArraySpec(value.shape, value.dtype, name=key)
result[key] = specs.Array(value.shape, value.dtype, name=key)
return result


Expand Down Expand Up @@ -217,7 +223,10 @@ def _get_viewer(self, mode):
if mode == "rgb_array":
self.viewer = self.dmcenv.physics.render
elif mode == "human":
self.viewer = Viewer(dmcenv=self.dmcenv, width=640, height=480)
if Viewer:
self.viewer = Viewer(dmcenv=self.dmcenv, width=640, height=480)
else:
self.viewer = None
self._viewers[mode] = self.viewer
return self.viewer

Expand Down
7 changes: 4 additions & 3 deletions digideep/pipeline/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(self, root_path):
self.state['path_base_sessions'] = os.path.join(os.path.split(checkpoint_path)[0], "evaluations")
else:
# OK, we are loading from a checkpoint, just create session from scratch.
self.state['path_root_session'] = self.args["session_path"]
self.state['path_base_sessions'] = os.path.join(self.state['path_root_session'], 'digideep_sessions')
# self.state['path_root_session'] = self.args["session_path"]
# self.state['path_base_sessions'] = os.path.join(self.state['path_root_session'], 'digideep_sessions')
self.state['path_base_sessions'] = self.args["session_path"]


# 1. Creating 'path_base_sessions', i.e. '/tmp/digideep_sessions':
Expand Down Expand Up @@ -334,7 +335,7 @@ def parse_arguments(self):
parser.add_argument('--play', action="store_true", help="Will play the stored policy.")
parser.add_argument('--dry-run', action="store_true", help="If used no footprints will be stored on disc whatsoever.")
## Session
parser.add_argument('--session-path', metavar=('<path>'), default='/tmp', type=str, help="The path to store the sessions. Default is in /tmp")
parser.add_argument('--session-path', metavar=('<path>'), default='/tmp/digideep_sessions', type=str, help="The path to store the sessions. Default is in /tmp")
parser.add_argument('--session-name', metavar=('<name>'), default='', type=str, help="A default name for the session. Random name if not provided.")
parser.add_argument('--save-modules', metavar=('<path>'), default=[], nargs='+', type=str, help="The modules to be stored in the session.")
parser.add_argument('--log-level', metavar=('<n>'), default=1, type=int, help="The logging level: 0 (debug and above), 1 (info and above), 2 (warn and above), 3 (error and above), 4 (fatal and above)")
Expand Down

0 comments on commit a728c3d

Please sign in to comment.