Skip to content

Commit

Permalink
Add save_metrics flag to disable loggings
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 21, 2020
1 parent 4bd8312 commit 2fec466
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 21 deletions.
14 changes: 10 additions & 4 deletions d3rlpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def load_model(self, fname):

def fit(self,
episodes,
save_metrics=True,
experiment_name=None,
with_timestamp=True,
logdir='d3rlpy_logs',
Expand All @@ -261,6 +262,9 @@ def fit(self,
Args:
episodes (list(d3rlpy.dataset.Episode)): list of episodes to train.
save_metrics (bool): flag to record metrics in files. If False,
the log directory is not created and the model parameters are
not saved during training.
experiment_name (str): experiment name for logging. If not passed,
the directory name will be `{class name}_{timestamp}`.
with_timestamp (bool): flag to add timestamp string to the last of
Expand Down Expand Up @@ -298,8 +302,9 @@ def fit(self,
self.create_impl(observation_shape, action_size)

# setup logger
logger = self._prepare_logger(experiment_name, with_timestamp, logdir,
verbose, tensorboard)
logger = self._prepare_logger(save_metrics, experiment_name,
with_timestamp, logdir, verbose,
tensorboard)

# save hyperparameters
self._save_params(logger)
Expand Down Expand Up @@ -392,12 +397,13 @@ def _generate_new_data(self, transitions):
def _get_loss_labels(self):
raise NotImplementedError

def _prepare_logger(self, experiment_name, with_timestamp, logdir, verbose,
tensorboard):
def _prepare_logger(self, save_metrics, experiment_name, with_timestamp,
logdir, verbose, tensorboard):
if experiment_name is None:
experiment_name = self.__class__.__name__

logger = D3RLPyLogger(experiment_name,
save_metrics=save_metrics,
root_dir=logdir,
verbose=verbose,
tensorboard=tensorboard,
Expand Down
41 changes: 24 additions & 17 deletions d3rlpy/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def default_json_encoder(obj):
class D3RLPyLogger:
def __init__(self,
experiment_name,
save_metrics=True,
root_dir='logs',
verbose=True,
tensorboard=True,
with_timestamp=True):
self.save_metrics = save_metrics
self.verbose = verbose

# add timestamp to prevent unintentional overwrites
Expand All @@ -34,16 +36,18 @@ def __init__(self,
else:
self.experiment_name = experiment_name

self.logdir = os.path.join(root_dir, self.experiment_name)

if not os.path.exists(self.logdir):
os.makedirs(self.logdir)
break
else:
if with_timestamp:
time.sleep(1.0)
if self.save_metrics:
self.logdir = os.path.join(root_dir, self.experiment_name)
if not os.path.exists(self.logdir):
os.makedirs(self.logdir)
break
else:
raise ValueError('%s already exists.' % self.logdir)
if with_timestamp:
time.sleep(1.0)
else:
raise ValueError('%s already exists.' % self.logdir)
else:
break

self.metrics_buffer = {}

Expand All @@ -59,9 +63,10 @@ def __init__(self,
def add_params(self, params):
assert self.params is None, 'add_params can be called only once.'

# save dictionary as json file
with open(os.path.join(self.logdir, 'params.json'), 'w') as f:
f.write(json.dumps(params, default=default_json_encoder))
if self.save_metrics:
# save dictionary as json file
with open(os.path.join(self.logdir, 'params.json'), 'w') as f:
f.write(json.dumps(params, default=default_json_encoder))

if self.verbose:
for key, val in params.items():
Expand All @@ -80,8 +85,9 @@ def commit(self, epoch, step):
for name, buffer in self.metrics_buffer.items():
metric = sum(buffer) / len(buffer)

with open(os.path.join(self.logdir, name + '.csv'), 'a') as f:
print('%d,%d,%f' % (epoch, step, metric), file=f)
if self.save_metrics:
with open(os.path.join(self.logdir, name + '.csv'), 'a') as f:
print('%d,%d,%f' % (epoch, step, metric), file=f)

if self.verbose:
print('epoch=%d step=%d %s=%f' % (epoch, step, name, metric))
Expand All @@ -99,6 +105,7 @@ def commit(self, epoch, step):
global_step=epoch)

def save_model(self, epoch, algo):
# save entire model
model_path = os.path.join(self.logdir, 'model_%d.pt' % epoch)
algo.save_model(model_path)
if self.save_metrics:
# save entire model
model_path = os.path.join(self.logdir, 'model_%d.pt' % epoch)
algo.save_model(model_path)
4 changes: 4 additions & 0 deletions d3rlpy/online/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def train(env,
update_start_step=0,
eval_env=None,
eval_epsilon=0.05,
save_metrics=True,
experiment_name=None,
with_timestamp=True,
logdir='d3rlpy_logs',
Expand All @@ -38,6 +39,8 @@ def train(env,
skipped.
eval_epsilon (float): :math:`\\epsilon`-greedy factor during
evaluation.
save_metrics (bool): flag to record metrics. If False, the log
directory is not created and the model parameters are not saved.
experiment_name (str): experiment name for logging. If not passed,
the directory name will be `{class name}_online_{timestamp}`.
with_timestamp (bool): flag to add timestamp string to the last of
Expand All @@ -55,6 +58,7 @@ def train(env,
if experiment_name is None:
experiment_name = algo.__class__.__name__ + '_online'
logger = D3RLPyLogger(experiment_name,
save_metrics=save_metrics,
root_dir=logdir,
verbose=verbose,
tensorboard=tensorboard,
Expand Down
10 changes: 10 additions & 0 deletions docs/references/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@ You can designate the directory.
# the directory will be `custom_logs/custom_YYYYMMDDHHmmss`
dqn.fit(dataset.episodes, logdir='custom_logs', experiment_name='custom')
If you want to disable all loggings, you can pass `save_metrics=False`.

.. code-block:: python
dqn.fit(dataset.episodes, save_metrics=False)
TensorBoard
-----------

The same information is also automatically saved for tensorboard under `runs`
directory.
You can interactively visualize training metrics easily.


.. code-block:: shell
$ pip install tensorboard
Expand Down

0 comments on commit 2fec466

Please sign in to comment.