diff --git a/buildPIPandDeploy.sh b/buildPIPandDeploy.sh new file mode 100644 index 0000000000..5a673d8a4c --- /dev/null +++ b/buildPIPandDeploy.sh @@ -0,0 +1,7 @@ +# remove later +bazel clean +pip uninstall tensorboard -y +rm -rf /tmp/tensorboard/ +bazel run //tensorboard/pip_package:build_pip_package +pip install /tmp/tensorboard/dist/tensorboard-1.13.0a0-py3-none-any.whl + diff --git a/tensorboard/compat/BUILD b/tensorboard/compat/BUILD index 9a12c44bdb..0f6d5a865b 100644 --- a/tensorboard/compat/BUILD +++ b/tensorboard/compat/BUILD @@ -44,6 +44,7 @@ py_library( deps = [ ":compat", "//tensorboard/compat/tensorflow_stub", + "//tensorboard/compat/tensorboard", ], ) diff --git a/tensorboard/compat/proto/BUILD b/tensorboard/compat/proto/BUILD index 647829f982..612b4699db 100644 --- a/tensorboard/compat/proto/BUILD +++ b/tensorboard/compat/proto/BUILD @@ -40,6 +40,9 @@ tb_proto_library( "types.proto", "verifier_config.proto", "versions.proto", + "plugin_text.proto", # relative dir is not accepted. copt that. + "plugin_pr_curve.proto", # relative dir is not accepted. copt that. + "layout.proto", # relative dir is not accepted. copt that. ], visibility = ["//visibility:public"], ) diff --git a/tensorboard/compat/proto/layout.proto b/tensorboard/compat/proto/layout.proto new file mode 100644 index 0000000000..0a3342644f --- /dev/null +++ b/tensorboard/compat/proto/layout.proto @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorboard; + + +/** + * Encapsulates information on a single chart. Many charts appear in a category. + */ +message Chart { + // The title shown atop this chart. Optional. Defaults to 'untitled'. + string title = 1; + + // The content of the chart. This depends on the type of the chart. + oneof content { + MultilineChartContent multiline = 2; + MarginChartContent margin = 3; + } +} + +/** + * Encapsulates information on a single line chart. This line chart may have + * lines associated with several tags. + */ +message MultilineChartContent { + // A list of regular expressions for tags that should appear in this chart. + // Tags are matched from beginning to end. Each regex captures a set of tags. + repeated string tag = 1; +} + +/** + * Encapsulates information on a single margin chart. A margin chart uses fill + * area to visualize lower and upper bounds that surround a value. + */ +message MarginChartContent { + /** + * Encapsulates a tag of data for the chart. + */ + message Series { + // The exact tag string associated with the scalar summaries making up the + // main value between the bounds. + string value = 1; + + // The exact tag string associated with the scalar summaries making up the + // lower bound. + string lower = 2; + + // The exact tag string associated with the scalar summaries making up the + // upper bound. + string upper = 3; + } + + // A list of data series to include within this margin chart. + repeated Series series = 1; +} + +/** + * A category contains a group of charts. Each category maps to a collapsible + * within the dashboard. + */ +message Category { + // This string appears atop each grouping of charts within the dashboard. + string title = 1; + + // Encapsulates data on charts to be shown in the category. + repeated Chart chart = 2; + + // Whether this category should be initially closed. False by default. + bool closed = 3; +} + +/** + * A layout encapsulates how charts are laid out within the custom scalars + * dashboard. + */ +message Layout { + // Version `0` is the only supported version. + int32 version = 1; + + // The categories here are rendered from top to bottom. + repeated Category category = 2; +} diff --git a/tensorboard/compat/proto/plugin_pr_curve.proto b/tensorboard/compat/proto/plugin_pr_curve.proto new file mode 100644 index 0000000000..33e0f91641 --- /dev/null +++ b/tensorboard/compat/proto/plugin_pr_curve.proto @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorboard; + +message PrCurvePluginData { + // Version `0` is the only supported version. + int32 version = 1; + + uint32 num_thresholds = 2; +} diff --git a/tensorboard/compat/proto/plugin_text.proto b/tensorboard/compat/proto/plugin_text.proto new file mode 100644 index 0000000000..54cb053304 --- /dev/null +++ b/tensorboard/compat/proto/plugin_text.proto @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorboard; + +// Text summaries created by the `tensorboard.plugins.text.summary` +// module will include `SummaryMetadata` whose `plugin_data` field has +// as `content` a binary string that is the encoding of an +// `TextPluginData` proto. +message TextPluginData { + // Version `0` is the only supported version. + int32 version = 1; +} diff --git a/tensorboard/compat/tensorboard/BUILD b/tensorboard/compat/tensorboard/BUILD new file mode 100644 index 0000000000..51c2a63b42 --- /dev/null +++ b/tensorboard/compat/tensorboard/BUILD @@ -0,0 +1,22 @@ +# Description: +# TensorBoard, a dashboard for investigating TensorFlow + +package(default_visibility = ["//tensorboard:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "tensorboard", + srcs = glob([ + "*.py", + ]), + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorboard:expect_numpy_installed", + "//tensorboard/compat/proto:protos_all_py_pb2", + "@org_pythonhosted_six", + ], +) \ No newline at end of file diff --git a/tensorboard/compat/tensorboard/__init__.py b/tensorboard/compat/tensorboard/__init__.py new file mode 100644 index 0000000000..1014fa8d1b --- /dev/null +++ b/tensorboard/compat/tensorboard/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compatibility interfaces for TensorBoard.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +USING_TF = True + +# Don't attempt to use TF at all if this import exists due to build rules. +try: + from tensorboard.compat import notf + USING_TF = False +except ImportError: + pass + +if USING_TF: + try: + import tensorflow as tf + except ImportError: + USING_TF = False + +if not USING_TF: + from tensorboard.compat import tensorflow_stub as tf diff --git a/tensorboard/compat/tensorboard/crc32c.py b/tensorboard/compat/tensorboard/crc32c.py new file mode 100644 index 0000000000..461334cd15 --- /dev/null +++ b/tensorboard/compat/tensorboard/crc32c.py @@ -0,0 +1,138 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import array + + +CRC_TABLE = ( + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, + 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, + 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, + 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, + 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, + 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, + 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, + 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, + 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, + 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, + 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, + 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, + 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, + 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, + 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, + 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, + 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, + 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, + 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, + 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, + 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, + 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, + 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, + 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, + 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, + 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, + 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, + 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, + 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, + 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, + 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, + 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, + 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, + 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, + 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, + 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, + 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, + 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, + 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, + 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, + 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, + 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, + 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, + 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, + 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, + 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, + 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, + 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, + 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, + 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, + 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, + 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, + 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, + 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, +) + + +CRC_INIT = 0 + +_MASK = 0xFFFFFFFF + + +def crc_update(crc, data): + """Update CRC-32C checksum with data. + + Args: + crc: 32-bit checksum to update as long. + data: byte array, string or iterable over bytes. + + Returns: + 32-bit updated CRC-32C as long. + """ + + if type(data) != array.array or data.itemsize != 1: + buf = array.array("B", data) + else: + buf = data + + crc ^= _MASK + for b in buf: + table_index = (crc ^ b) & 0xff + crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK + return crc ^ _MASK + + +def crc_finalize(crc): + """Finalize CRC-32C checksum. + + This function should be called as last step of crc calculation. + + Args: + crc: 32-bit checksum as long. + + Returns: + finalized 32-bit checksum as long + """ + return crc & _MASK + + +def crc32c(data): + """Compute CRC-32C checksum of the data. + + Args: + data: byte array, string or iterable over bytes. + + Returns: + 32-bit CRC-32C checksum of data as long. + """ + return crc_finalize(crc_update(CRC_INIT, data)) diff --git a/tensorboard/compat/tensorboard/event_file_writer.py b/tensorboard/compat/tensorboard/event_file_writer.py new file mode 100644 index 0000000000..59097a9abb --- /dev/null +++ b/tensorboard/compat/tensorboard/event_file_writer.py @@ -0,0 +1,231 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes events to disk in a logdir.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import socket +import threading +import time + +import six + +from ..proto import event_pb2 +from .record_writer import RecordWriter, directory_check + + +class EventsWriter(object): + '''Writes `Event` protocol buffers to an event file.''' + + def __init__(self, file_prefix, filename_suffix=''): + ''' + Events files have a name of the form + '/some/file/path/events.out.tfevents.[timestamp].[hostname]' + ''' + self._file_name = file_prefix + ".out.tfevents." + str(time.time())[:10] + "." +\ + socket.gethostname() + filename_suffix + + self._num_outstanding_events = 0 + + self._py_recordio_writer = RecordWriter(self._file_name) + + # Initialize an event instance. + self._event = event_pb2.Event() + + self._event.wall_time = time.time() + + self._lock = threading.Lock() + + self.write_event(self._event) + + def write_event(self, event): + '''Append "event" to the file.''' + + # Check if event is of type event_pb2.Event proto. + if not isinstance(event, event_pb2.Event): + raise TypeError("Expected an event_pb2.Event proto, " + " but got %s" % type(event)) + return self._write_serialized_event(event.SerializeToString()) + + def _write_serialized_event(self, event_str): + with self._lock: + self._num_outstanding_events += 1 + self._py_recordio_writer.write(event_str) + + def flush(self): + '''Flushes the event file to disk.''' + with self._lock: + self._num_outstanding_events = 0 + self._py_recordio_writer.flush() + return True + + def close(self): + '''Call self.flush().''' + return_value = self.flush() + with self._lock: + self._py_recordio_writer.close() + return return_value + + +class EventFileWriter(object): + """Writes `Event` protocol buffers to an event file. + The `EventFileWriter` class creates an event file in the specified directory, + and asynchronously writes Event protocol buffers to the file. The Event file + is encoded using the tfrecord format, which is similar to RecordIO. + @@__init__ + @@add_event + @@flush + @@close + """ + + def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''): + """Creates a `EventFileWriter` and an event file to write to. + On construction the summary writer creates a new event file in `logdir`. + This event file will contain `Event` protocol buffers, which are written to + disk via the add_event method. + The other arguments to the constructor control the asynchronous writes to + the event file: + * `flush_secs`: How often, in seconds, to flush the added summaries + and events to disk. + * `max_queue`: Maximum number of summaries or events pending to be + written to disk before one of the 'add' calls block. + Args: + logdir: A string. Directory where event file will be written. + max_queue: Integer. Size of the queue for pending events and summaries. + flush_secs: Number. How often, in seconds, to flush the + pending events and summaries to disk. + """ + self._logdir = logdir + directory_check(self._logdir) + self._event_queue = six.moves.queue.Queue(max_queue) + self._ev_writer = EventsWriter(os.path.join( + self._logdir, "events"), filename_suffix) + self._flush_secs = flush_secs + self._closed = False + self._worker = _EventLoggerThread(self._event_queue, self._ev_writer, + flush_secs) + + self._worker.start() + + def get_logdir(self): + """Returns the directory where event file will be written.""" + return self._logdir + + def reopen(self): + """Reopens the EventFileWriter. + Can be called after `close()` to add more events in the same directory. + The events will go into a new events file and a new write/flush worker + is created. Does nothing if the EventFileWriter was not closed. + """ + if self._closed: + self._closed = False + self._worker = _EventLoggerThread( + self._event_queue, self._ev_writer, self._flush_secs + ) + self._worker.start() + + def add_event(self, event): + """Adds an event to the event file. + Args: + event: An `Event` protocol buffer. + """ + if not self._closed: + self._event_queue.put(event) + + def flush(self): + """Flushes the event file to disk. + Call this method to make sure that all pending events have been written to + disk. + """ + if not self._closed: + self._event_queue.join() + self._ev_writer.flush() + + def close(self): + """Performs a final flush of the event file to disk, stops the + write/flush worker and closes the file. Call this method when you do not + need the summary writer anymore. + """ + if not self._closed: + self.flush() + self._worker.stop() + self._ev_writer.close() + self._closed = True + + +class _EventLoggerThread(threading.Thread): + """Thread that logs events.""" + + def __init__(self, queue, ev_writer, flush_secs): + """Creates an _EventLoggerThread. + Args: + queue: A Queue from which to dequeue events. + ev_writer: An event writer. Used to log brain events for + the visualizer. + flush_secs: How often, in seconds, to flush the + pending file to disk. + """ + threading.Thread.__init__(self) + self.daemon = True + self._queue = queue + self._ev_writer = ev_writer + self._flush_secs = flush_secs + # The first event will be flushed immediately. + self._next_event_flush_time = 0 + self._has_pending_events = False + self._shutdown_signal = object() + + def stop(self): + self._queue.put(self._shutdown_signal) + self.join() + + def run(self): + # Here wait on the queue until an event appears, or till the next + # time to flush the writer, whichever is earlier. If we have an + # event, write it. If not, an empty queue exception will be raised + # and we can proceed to flush the writer. + while True: + now = time.time() + queue_wait_duration = self._next_event_flush_time - now + event = None + try: + if queue_wait_duration > 0: + event = self._queue.get(True, queue_wait_duration) + else: + event = self._queue.get(False) + + if event == self._shutdown_signal: + return + self._ev_writer.write_event(event) + self._has_pending_events = True + except six.moves.queue.Empty: + pass + finally: + if event: + self._queue.task_done() + + now = time.time() + if now > self._next_event_flush_time: + if self._has_pending_events: + # Small optimization - if there are no pending events, + # there's no need to flush, since each flush can be + # expensive (e.g. uploading a new file to a server). + self._ev_writer.flush() + self._has_pending_events = False + # Do it again in flush_secs. + self._next_event_flush_time = now + self._flush_secs diff --git a/tensorboard/compat/tensorboard/record_writer.py b/tensorboard/compat/tensorboard/record_writer.py new file mode 100644 index 0000000000..cca51f8317 --- /dev/null +++ b/tensorboard/compat/tensorboard/record_writer.py @@ -0,0 +1,158 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +To write tf_record into file. Here we use it for tensorboard's event writting. +The code was borrowed from https://github.com/TeamHG-Memex/tensorboard_logger +""" + +import copy +import io +import os.path +import re +import struct +try: + import boto3 + S3_ENABLED = True +except ImportError: + S3_ENABLED = False + +from .crc32c import crc32c + + +_VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]') +_VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+') + +# Registry of writer factories by prefix backends. +# +# Currently supports "s3://" URLs for S3 based on boto and falls +# back to local filesystem. +REGISTERED_FACTORIES = {} + + +def register_writer_factory(prefix, factory): + if ':' in prefix: + raise ValueError('prefix cannot contain a :') + REGISTERED_FACTORIES[prefix] = factory + + +def directory_check(path): + '''Initialize the directory for log files.''' + try: + prefix = path.split(':')[0] + factory = REGISTERED_FACTORIES[prefix] + return factory.directory_check(path) + except KeyError: + if not os.path.exists(path): + os.makedirs(path) + + +def open_file(path): + '''Open a writer for outputting event files.''' + try: + prefix = path.split(':')[0] + factory = REGISTERED_FACTORIES[prefix] + return factory.open(path) + except KeyError: + return open(path, 'wb') + + +class S3RecordWriter(object): + """Writes tensorboard protocol buffer files to S3.""" + + def __init__(self, path): + if not S3_ENABLED: + raise ImportError("boto3 must be installed for S3 support.") + self.path = path + self.buffer = io.BytesIO() + + def __del__(self): + self.close() + + def bucket_and_path(self): + path = self.path + if path.startswith("s3://"): + path = path[len("s3://"):] + bp = path.split("/") + bucket = bp[0] + path = path[1 + len(bucket):] + return bucket, path + + def write(self, val): + self.buffer.write(val) + + def flush(self): + s3 = boto3.client('s3') + bucket, path = self.bucket_and_path() + upload_buffer = copy.copy(self.buffer) + upload_buffer.seek(0) + s3.upload_fileobj(upload_buffer, bucket, path) + + def close(self): + self.flush() + + +class S3RecordWriterFactory(object): + """Factory for event protocol buffer files to S3.""" + + def open(self, path): + return S3RecordWriter(path) + + def directory_check(self, path): + # S3 doesn't need directories created before files are added + # so we can just skip this check + pass + + +register_writer_factory("s3", S3RecordWriterFactory()) + + +class RecordWriter(object): + def __init__(self, path): + self._name_to_tf_name = {} + self._tf_names = set() + self.path = path + self._writer = None + self._writer = open_file(path) + + def write(self, event_str): + w = self._writer.write + header = struct.pack('Q', len(event_str)) + w(header) + w(struct.pack('I', masked_crc32c(header))) + w(event_str) + w(struct.pack('I', masked_crc32c(event_str))) + + def flush(self): + self._writer.flush() + + def close(self): + self._writer.close() + + +def masked_crc32c(data): + x = u32(crc32c(data)) + return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8) + + +def u32(x): + return x & 0xffffffff + + +def make_valid_tf_name(name): + if not _VALID_OP_NAME_START.match(name): + # Must make it valid somehow, but don't want to remove stuff + name = '.' + name + return '_'.join(_VALID_OP_NAME_PART.findall(name))