This repository has been archived by the owner on Sep 24, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
67f181b
commit d2a6096
Showing
2 changed files
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
import re | ||
import six | ||
import time | ||
import wandb | ||
|
||
|
||
# We have atleast the default namestep and a global step to track | ||
# TODO: reset this structure on wandb.join | ||
STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}} | ||
# TODO(cling): Set these when tensorboard behavior is configured. | ||
# We support rate limited logging by setting this to number of seconds, | ||
# can be a floating point. | ||
RATE_LIMIT_SECONDS = None | ||
IGNORE_KINDS = [] | ||
tensor_util = wandb.util.get_module("tensorboard.util.tensor_util") | ||
|
||
|
||
pb = wandb.util.get_module( | ||
"tensorboard.compat.proto.summary_pb2" | ||
) or wandb.util.get_module("tensorflow.summary") | ||
if pb: | ||
Summary = pb.Summary | ||
else: | ||
Summary = None | ||
|
||
|
||
def make_ndarray(tensor): | ||
if tensor_util: | ||
res = tensor_util.make_ndarray(tensor) | ||
# Tensorboard can log generic objects and we don't want to save them | ||
if res.dtype == "object": | ||
return None | ||
else: | ||
return res | ||
else: | ||
wandb.termwarn( | ||
"Can't convert tensor summary, upgrade tensorboard with `pip" | ||
" install tensorboard --upgrade`" | ||
) | ||
return None | ||
|
||
|
||
def namespaced_tag(tag, namespace=""): | ||
if not namespace: | ||
return tag | ||
elif tag in namespace: | ||
# This happens with tensorboardX | ||
return namespace | ||
else: | ||
return namespace + "/" + tag | ||
|
||
|
||
def history_image_key(key, namespace=""): | ||
"""Converts invalid filesystem characters to _ for use in History keys. | ||
Unfortunately this means currently certain image keys will collide silently. We | ||
implement this mapping up here in the TensorFlow stuff rather than in the History | ||
stuff so that we don't have to store a mapping anywhere from the original keys to | ||
the safe ones. | ||
""" | ||
return namespaced_tag(re.sub(r"[/\\]", "_", key), namespace) | ||
|
||
|
||
def tf_summary_to_dict(tf_summary_str_or_pb, namespace=""): # noqa: C901 | ||
"""Convert a Tensorboard Summary to a dictionary | ||
Accepts either a tensorflow.summary.Summary | ||
or one encoded as a string. | ||
""" | ||
values = {} | ||
if hasattr(tf_summary_str_or_pb, "summary"): | ||
summary_pb = tf_summary_str_or_pb.summary | ||
values[namespaced_tag("global_step", namespace)] = tf_summary_str_or_pb.step | ||
values["_timestamp"] = tf_summary_str_or_pb.wall_time | ||
elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)): | ||
summary_pb = Summary() | ||
summary_pb.ParseFromString(tf_summary_str_or_pb) | ||
else: | ||
summary_pb = tf_summary_str_or_pb | ||
|
||
if not hasattr(summary_pb, "value") or len(summary_pb.value) == 0: | ||
# Ignore these, caller is responsible for handling None | ||
return None | ||
|
||
for value in summary_pb.value: | ||
kind = value.WhichOneof("value") | ||
if kind in IGNORE_KINDS: | ||
continue | ||
if kind == "simple_value": | ||
values[namespaced_tag(value.tag, namespace)] = value.simple_value | ||
elif kind == "tensor": | ||
values[namespaced_tag(value.tag, namespace)] = make_ndarray(value.tensor) | ||
elif kind == "image": | ||
from PIL import Image | ||
|
||
img_str = value.image.encoded_image_string | ||
# Supports gifs from TboardX | ||
if img_str.startswith(b"GIF"): | ||
image = wandb.Video(six.BytesIO(img_str), format="gif") | ||
else: | ||
image = wandb.Image(Image.open(six.BytesIO(img_str))) | ||
tag_idx = value.tag.rsplit("/", 1) | ||
if len(tag_idx) > 1 and tag_idx[1].isdigit(): | ||
tag, idx = tag_idx | ||
values.setdefault(history_image_key(tag, namespace), []).append(image) | ||
else: | ||
values[history_image_key(value.tag, namespace)] = [image] | ||
# Coming soon... | ||
# elif kind == "audio": | ||
# audio = wandb.Audio( | ||
# six.BytesIO(value.audio.encoded_audio_string), | ||
# sample_rate=value.audio.sample_rate, | ||
# content_type=value.audio.content_type, | ||
# ) | ||
elif kind == "histo": | ||
tag = namespaced_tag(value.tag, namespace) | ||
if len(value.histo.bucket_limit) >= 3: | ||
first = ( | ||
value.histo.bucket_limit[0] | ||
+ value.histo.bucket_limit[0] # noqa: W503 | ||
- value.histo.bucket_limit[1] # noqa: W503 | ||
) | ||
last = ( | ||
value.histo.bucket_limit[-2] | ||
+ value.histo.bucket_limit[-2] # noqa: W503 | ||
- value.histo.bucket_limit[-3] # noqa: W503 | ||
) | ||
np_histogram = ( | ||
list(value.histo.bucket), | ||
[first] + value.histo.bucket_limit[:-1] + [last], | ||
) | ||
try: | ||
# TODO: we should just re-bin if there are too many buckets | ||
values[tag] = wandb.Histogram(np_histogram=np_histogram) | ||
except ValueError: | ||
wandb.termwarn( | ||
'Not logging key "{}". ' | ||
"Histograms must have fewer than {} bins".format( | ||
tag, wandb.Histogram.MAX_LENGTH | ||
), | ||
repeat=False, | ||
) | ||
else: | ||
# TODO: is there a case where we can render this? | ||
wandb.termwarn( | ||
'Not logging key "{}". Found a histogram with only 2 bins.'.format( | ||
tag | ||
), | ||
repeat=False, | ||
) | ||
elif value.tag == "_hparams_/session_start_info": | ||
if wandb.util.get_module("tensorboard.plugins.hparams"): | ||
from tensorboard.plugins.hparams import plugin_data_pb2 | ||
|
||
plugin_data = plugin_data_pb2.HParamsPluginData() | ||
plugin_data.ParseFromString(value.metadata.plugin_data.content) | ||
for key, param in six.iteritems(plugin_data.session_start_info.hparams): | ||
if not wandb.run.config.get(key): | ||
wandb.run.config[key] = ( | ||
param.number_value or param.string_value or param.bool_value | ||
) | ||
else: | ||
wandb.termerror( | ||
"Received hparams tf.summary, but could not import " | ||
"the hparams plugin from tensorboard" | ||
) | ||
return values | ||
|
||
|
||
def reset_state(): | ||
"""Internal method for reseting state, called by wandb.join""" | ||
global STEPS | ||
STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}} | ||
|
||
|
||
def log(tf_summary_str_or_pb, history=None, step=0, namespace="", **kwargs): | ||
"""Logs a tfsummary to wandb | ||
Can accept a tf summary string or parsed event. Will use wandb.run.history unless a | ||
history object is passed. Can optionally namespace events. Results are commited | ||
when step increases for this namespace. | ||
NOTE: This assumes that events being passed in are in chronological order | ||
""" | ||
global STEPS | ||
global RATE_LIMIT | ||
history = history or wandb.run.history | ||
# To handle multiple global_steps, we keep track of them here instead | ||
# of the global log | ||
last_step = STEPS.get(namespace, {"step": 0}) | ||
|
||
# Commit our existing data if this namespace increased its step | ||
commit = False | ||
if last_step["step"] < step: | ||
commit = True | ||
|
||
log_dict = tf_summary_to_dict(tf_summary_str_or_pb, namespace) | ||
if log_dict is None: | ||
# not an event, just return | ||
return | ||
|
||
# Pass timestamp to history for loading historic data | ||
timestamp = log_dict.get("_timestamp", time.time()) | ||
# Store our initial timestamp | ||
if STEPS["global"]["last_log"] is None: | ||
STEPS["global"]["last_log"] = timestamp | ||
# Rollup events that share the same step across namespaces | ||
if commit and step == STEPS["global"]["step"]: | ||
commit = False | ||
# Always add the biggest global_step key for non-default namespaces | ||
if step > STEPS["global"]["step"]: | ||
STEPS["global"]["step"] = step | ||
if namespace != "": | ||
log_dict["global_step"] = STEPS["global"]["step"] | ||
|
||
# Keep internal step counter | ||
STEPS[namespace] = {"step": step} | ||
|
||
if commit: | ||
# Only commit our data if we're below the rate limit or don't have one | ||
if ( | ||
RATE_LIMIT_SECONDS is None | ||
or timestamp - STEPS["global"]["last_log"] >= RATE_LIMIT_SECONDS # noqa: W503 E501 | ||
): | ||
history.add({}, **kwargs) | ||
STEPS["global"]["last_log"] = timestamp | ||
history._row_update(log_dict) | ||
|
||
|
||
__all__ = ['log', 'reset_state', 'tf_summary_to_dict'] |