Skip to content

Commit

Permalink
Saving videos when playing from checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Apr 18, 2019
1 parent bae6119 commit 3c9d761
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
8 changes: 8 additions & 0 deletions digideep/environment/dmc2gym/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,18 @@ def step(self, action):
def render(self, mode='human', **render_kwargs):
"""Render function which supports two modes: ``rgb_array`` and ``human``.
If ``mode`` is ``rgb_array``, it will return the image in pixels format.
Args:
render_kwargs: Check ``dm_control/mujoco/engine.py``.
Defaults: ``render(height=240, width=320, camera_id=-1, overlays=(), depth=False, segmentation=False, scene_option=None)``
"""

# render_kwargs = { 'height', 'width', 'camera_id', 'overlays', 'depth', 'scene_option'}
if mode == 'rgb_array':
if not "width" in render_kwargs:
render_kwargs["width"] = 640
if not "height" in render_kwargs:
render_kwargs["height"] = 480
pixels = self._get_viewer(mode)(**render_kwargs)
return pixels
elif mode == 'human':
Expand Down
8 changes: 8 additions & 0 deletions digideep/environment/make_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from digideep.environment.wrappers import VecNormalize
from digideep.environment.wrappers import VecSaveState

from gym.wrappers.monitor import Monitor as MonitorVideoRecorder

from digideep.utility.toolbox import get_module

from digideep.utility.logging import logger
Expand Down Expand Up @@ -126,6 +128,12 @@ def _f():
if not force_no_monitor and self.params["wrappers"]["add_monitor"]:
log_dir = os.path.join(self.session["path_monitor"], str(rank))
env = Monitor(env, log_dir, **self.params["wrappers_args"]["Monitor"])

if self.mode == "eval":
videos_dir = os.path.join(self.session["path_videos"], str(rank))
env = MonitorVideoRecorder(env, videos_dir, video_callable=lambda id:True)
# elif self.mode == "train":
# env = MonitorVideoRecorder(env, videos_dir)

if is_atari and len(env.observation_space.shape) == 3:
env = wrap_deepmind(env)
Expand Down
6 changes: 5 additions & 1 deletion digideep/pipeline/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import deepcopy

def generateTimestamp():
# Always uses UTC as timezone
now = datetime.datetime.now()
timestamp = '{:%Y%m%d%H%M%S}'.format(now)
return timestamp
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(self, root_path):
self.state['path_session'] = os.path.join(self.state['path_base_sessions'], 'session_' + generateTimestamp())
self.state['path_checkpoints'] = os.path.join(self.state['path_session'], 'checkpoints')
self.state['path_monitor'] = os.path.join(self.state['path_session'], 'monitor')
self.state['path_videos'] = os.path.join(self.state['path_session'], 'videos')
# Hyper-parameters basically is a snapshot of intial parameter engine's state.
self.state['file_cpanel'] = os.path.join(self.state['path_session'], 'cpanel.json')
self.state['file_params'] = os.path.join(self.state['path_session'], 'params.yaml')
Expand Down Expand Up @@ -130,7 +132,9 @@ def __init__(self, root_path):
except Exception as ex:
logger.fatal("While importing user-specified params:", ex)
exit()

if self.is_loading:
logger.warn("Loading from:", self.args["load_checkpoint"])

print(':: The session will be stored in ' + self.state['path_session'])

def initLogger(self):
Expand Down

0 comments on commit 3c9d761

Please sign in to comment.