Skip to content

Commit

Permalink
Add TensorBoard Support
Browse files Browse the repository at this point in the history
Adds TensorBoard support for basic key-value pairs. Anything logged via
`logger.record_tabular()` is also available via TensorBoard.
  • Loading branch information
cjcchen authored and ryanjulian committed Mar 31, 2018
1 parent b3a2899 commit 77ef3e0
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 42 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ dependencies:
- pylru==1.0.9
- hyperopt
- polling
- tensorboard
69 changes: 62 additions & 7 deletions rllab/misc/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import pickle
import base64
import tensorflow as tf

_prefixes = []
_prefix_str = ''
Expand All @@ -31,13 +32,17 @@
_tabular_fds = {}
_tabular_header_written = set()

_tensorboard_writer = None
_snapshot_dir = None
_snapshot_mode = 'all'
_snapshot_gap = 1

_log_tabular_only = False
_header_printed = False

_tensorboard_default_step = 0
_tensorboard_step_key = None


def _add_output(file_name, arr, fds, mode='a'):
if file_name not in arr:
Expand Down Expand Up @@ -77,6 +82,20 @@ def remove_tabular_output(file_name):
_remove_output(file_name, _tabular_outputs, _tabular_fds)


def set_tensorboard_dir(dir_name):
global _tensorboard_writer
if not dir_name:
if _tensorboard_writer:
_tensorboard_writer.close()
_tensorboard_writer = None
else:
mkdir_p(os.path.dirname(dir_name))
_tensorboard_writer = tf.summary.FileWriter(dir_name)
_tensorboard_default_step = 0
assert _tensorboard_writer is not None
print("tensorboard data will be logged into:", dir_name)


def set_snapshot_dir(dir_name):
global _snapshot_dir
_snapshot_dir = dir_name
Expand All @@ -94,18 +113,26 @@ def set_snapshot_mode(mode):
global _snapshot_mode
_snapshot_mode = mode


def get_snapshot_gap():
return _snapshot_gap


def set_snapshot_gap(gap):
global _snapshot_gap
_snapshot_gap = gap


def set_log_tabular_only(log_tabular_only):
global _log_tabular_only
_log_tabular_only = log_tabular_only


def set_tensorboard_step_key(key):
global _tensorboard_step_key
_tensorboard_step_key = key


def get_log_tabular_only():
return _log_tabular_only

Expand Down Expand Up @@ -186,6 +213,23 @@ def refresh(self):
table_printer = TerminalTablePrinter()


def dump_tensorboard(*args, **kwargs):
if len(_tabular) > 0 and _tensorboard_writer:
tabular_dict = dict(_tabular)
if _tensorboard_step_key and _tensorboard_step_key in tabular_dict:
step = tabular_dict[_tensorboard_step_key]
else:
global _tensorboard_default_step
step = _tensorboard_default_step
_tensorboard_default_step += 1

summary = tf.Summary()
for k, v in tabular_dict.items():
summary.value.add(tag=k, simple_value=float(v))
_tensorboard_writer.add_summary(summary, int(step))
_tensorboard_writer.flush()


def dump_tabular(*args, **kwargs):
wh = kwargs.pop("write_header", None)
if len(_tabular) > 0:
Expand All @@ -195,11 +239,18 @@ def dump_tabular(*args, **kwargs):
for line in tabulate(_tabular).split('\n'):
log(line, *args, **kwargs)
tabular_dict = dict(_tabular)

# write to the tensorboard folder
# This assumes that the keys in each iteration won't change!
dump_tensorboard(args, kwargs)

# Also write to the csv files
# This assumes that the keys in each iteration won't change!
for tabular_fd in list(_tabular_fds.values()):
writer = csv.DictWriter(tabular_fd, fieldnames=list(tabular_dict.keys()))
if wh or (wh is None and tabular_fd not in _tabular_header_written):
writer = csv.DictWriter(
tabular_fd, fieldnames=list(tabular_dict.keys()))
if wh or (wh is None
and tabular_fd not in _tabular_header_written):
writer.writeheader()
_tabular_header_written.add(tabular_fd)
writer.writerow(tabular_dict)
Expand Down Expand Up @@ -245,7 +296,8 @@ def log_parameters(log_file, args, classes):
log_params[name] = params
else:
log_params[name] = getattr(cls, "__kwargs", dict())
log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
log_params[name][
"_name"] = cls.__module__ + "." + cls.__class__.__name__
mkdir_p(os.path.dirname(log_file))
with open(log_file, "w") as f:
json.dump(log_params, f, indent=2, sort_keys=True)
Expand All @@ -258,13 +310,13 @@ def stub_to_json(stub_sth):
data = dict()
for k, v in stub_sth.kwargs.items():
data[k] = stub_to_json(v)
data["_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
data[
"_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
return data
elif isinstance(stub_sth, instrument.StubAttr):
return dict(
obj=stub_to_json(stub_sth.obj),
attr=stub_to_json(stub_sth.attr_name)
)
attr=stub_to_json(stub_sth.attr_name))
elif isinstance(stub_sth, instrument.StubMethodCall):
return dict(
obj=stub_to_json(stub_sth.obj),
Expand Down Expand Up @@ -294,7 +346,10 @@ def default(self, o):
if isinstance(o, type):
return {'$class': o.__module__ + "." + o.__name__}
elif isinstance(o, Enum):
return {'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name}
return {
'$enum':
o.__module__ + "." + o.__class__.__name__ + '.' + o.name
}
return json.JSONEncoder.default(self, o)


Expand Down
127 changes: 92 additions & 35 deletions scripts/run_experiment_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,95 @@ def run_experiment(argv):

default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
parser = argparse.ArgumentParser()
parser.add_argument('--n_parallel', type=int, default=1,
help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
parser.add_argument(
'--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
parser.add_argument('--log_dir', type=str, default=None,
help='Path to save the log and iteration snapshot.')
parser.add_argument('--snapshot_mode', type=str, default='all',
help='Mode to save the snapshot. Can be either "all" '
'(all iterations will be saved), "last" (only '
'the last iteration will be saved), "gap" (every'
'`snapshot_gap` iterations are saved), or "none" '
'(do not save snapshots)')
parser.add_argument('--snapshot_gap', type=int, default=1,
help='Gap between snapshot iterations.')
parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
help='Name of the tabular log file (in csv).')
parser.add_argument('--text_log_file', type=str, default='debug.log',
help='Name of the text log file (in pure text).')
parser.add_argument('--params_log_file', type=str, default='params.json',
help='Name of the parameter log file (in json).')
parser.add_argument('--variant_log_file', type=str, default='variant.json',
help='Name of the variant log file (in json).')
parser.add_argument('--resume_from', type=str, default=None,
help='Name of the pickle file to resume experiment from.')
parser.add_argument('--plot', type=ast.literal_eval, default=False,
help='Whether to plot the iteration results')
parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
help='Whether to only print the tabular log information (in a horizontal format)')
parser.add_argument('--seed', type=int,
help='Random seed for numpy')
parser.add_argument('--args_data', type=str,
help='Pickled data for stub objects')
parser.add_argument('--variant_data', type=str,
help='Pickled data for variant configuration')
parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)
parser.add_argument(
'--n_parallel',
type=int,
default=1,
help=
'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
)
parser.add_argument(
'--exp_name',
type=str,
default=default_exp_name,
help='Name of the experiment.')
parser.add_argument(
'--log_dir',
type=str,
default=None,
help='Path to save the log and iteration snapshot.')
parser.add_argument(
'--snapshot_mode',
type=str,
default='all',
help='Mode to save the snapshot. Can be either "all" '
'(all iterations will be saved), "last" (only '
'the last iteration will be saved), "gap" (every'
'`snapshot_gap` iterations are saved), or "none" '
'(do not save snapshots)')
parser.add_argument(
'--snapshot_gap',
type=int,
default=1,
help='Gap between snapshot iterations.')
parser.add_argument(
'--tabular_log_file',
type=str,
default='progress.csv',
help='Name of the tabular log file (in csv).')
parser.add_argument(
'--text_log_file',
type=str,
default='debug.log',
help='Name of the text log file (in pure text).')
parser.add_argument(
'--tensorboard_log_dir',
type=str,
default='progress',
help='Name of the folder for tensorboard_summary.')
parser.add_argument(
'--tensorboard_step_key',
type=str,
default=None,
help=
'Name of the step key in log data which shows the step in tensorboard_summary.'
)
parser.add_argument(
'--params_log_file',
type=str,
default='params.json',
help='Name of the parameter log file (in json).')
parser.add_argument(
'--variant_log_file',
type=str,
default='variant.json',
help='Name of the variant log file (in json).')
parser.add_argument(
'--resume_from',
type=str,
default=None,
help='Name of the pickle file to resume experiment from.')
parser.add_argument(
'--plot',
type=ast.literal_eval,
default=False,
help='Whether to plot the iteration results')
parser.add_argument(
'--log_tabular_only',
type=ast.literal_eval,
default=False,
help=
'Whether to only print the tabular log information (in a horizontal format)'
)
parser.add_argument('--seed', type=int, help='Random seed for numpy')
parser.add_argument(
'--args_data', type=str, help='Pickled data for stub objects')
parser.add_argument(
'--variant_data',
type=str,
help='Pickled data for variant configuration')
parser.add_argument(
'--use_cloudpickle', type=ast.literal_eval, default=False)

args = parser.parse_args(argv[1:])

Expand All @@ -87,6 +141,7 @@ def run_experiment(argv):
tabular_log_file = osp.join(log_dir, args.tabular_log_file)
text_log_file = osp.join(log_dir, args.text_log_file)
params_log_file = osp.join(log_dir, args.params_log_file)
tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)

if args.variant_data is not None:
variant_data = pickle.loads(base64.b64decode(args.variant_data))
Expand All @@ -100,12 +155,14 @@ def run_experiment(argv):

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
logger.set_tensorboard_dir(tensorboard_log_dir)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(args.snapshot_mode)
logger.set_snapshot_gap(args.snapshot_gap)
logger.set_log_tabular_only(args.log_tabular_only)
logger.set_tensorboard_step_key(args.tensorboard_step_key)
logger.push_prefix("[%s] " % args.exp_name)

if args.resume_from is not None:
Expand Down

0 comments on commit 77ef3e0

Please sign in to comment.