Skip to content

Commit

Permalink
manager: add I/O for TensorboardInfo files (#1806)
Browse files Browse the repository at this point in the history
Summary:
This commit implements functions `write_info_file`, `remove_info_file`,
and `get_all` on the `tensorboard.manager` module. See docs for details.

Supersedes part of #1795.

Test Plan:
Integration tests included; run `bazel test //tensorboard:manager_test`.

wchargin-branch: tensorboardinfo-io
  • Loading branch information
wchargin committed Feb 7, 2019
1 parent ba54805 commit 20435a1
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ py_library(
visibility = ["//tensorboard:internal"],
deps = [
":version",
"//tensorboard/util:tb_logging",
"@org_pythonhosted_six",
],
)
Expand All @@ -89,6 +90,7 @@ py_test(
deps = [
":manager",
":version",
"//tensorboard/util:tb_logging",
"//tensorboard:expect_tensorflow_installed",
"@org_pythonhosted_six",
],
Expand Down
105 changes: 105 additions & 0 deletions tensorboard/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import base64
import collections
import datetime
import errno
import json
import os
import tempfile

import six

from tensorboard import version
from tensorboard.util import tb_logging


# Type descriptors for `TensorboardInfo` fields.
Expand Down Expand Up @@ -199,3 +203,104 @@ def cache_key(working_directory, arguments, configure_kwargs):
# `raw` is of type `bytes`, even though it only contains ASCII
# characters; we want it to be `str` in both Python 2 and 3.
return str(raw.decode("ascii"))


def _get_info_dir():
"""Get path to directory in which to store info files.
The directory returned by this function is "owned" by this module. If
the contents of the directory are modified other than via the public
functions of this module, subsequent behavior is undefined.
The directory will be created if it does not exist.
"""
path = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
return path


def _get_info_file_path():
"""Get path to info file for the current process.
As with `_get_info_dir`, the info directory will be created if it does
not exist.
"""
return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid())


def write_info_file(tensorboard_info):
"""Write TensorboardInfo to the current process's info file.
This should be called by `main` once the server is ready. When the
server shuts down, `remove_info_file` should be called.
Args:
tensorboard_info: A valid `TensorboardInfo` object.
Raises:
ValueError: If any field on `info` is not of the correct type.
"""
payload = "%s\n" % _info_to_string(tensorboard_info)
with open(_get_info_file_path(), "w") as outfile:
outfile.write(payload)


def remove_info_file():
"""Remove the current process's TensorboardInfo file, if it exists.
If the file does not exist, no action is taken and no error is raised.
"""
try:
os.unlink(_get_info_file_path())
except OSError as e:
if e.errno == errno.ENOENT:
# The user may have wiped their temporary directory or something.
# Not a problem: we're already in the state that we want to be in.
pass
else:
raise


def get_all():
"""Return TensorboardInfo values for running TensorBoard processes.
This function may not provide a perfect snapshot of the set of running
processes. Its result set may be incomplete if the user has cleaned
their /tmp/ directory while TensorBoard processes are running. It may
contain extraneous entries if TensorBoard processes exited uncleanly
(e.g., with SIGKILL or SIGQUIT).
Returns:
A fresh list of `TensorboardInfo` objects.
"""
info_dir = _get_info_dir()
results = []
for filename in os.listdir(info_dir):
filepath = os.path.join(info_dir, filename)
try:
with open(filepath) as infile:
contents = infile.read()
except IOError as e:
if e.errno == errno.EACCES:
# May have been written by this module in a process whose
# `umask` includes some bits of 0o444.
continue
else:
raise
try:
info = _info_from_string(contents)
except ValueError:
tb_logging.get_logger().warning(
"invalid info file: %r",
filepath,
exc_info=True,
)
else:
results.append(info)
return results
110 changes: 108 additions & 2 deletions tensorboard/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,24 @@
from __future__ import print_function

import datetime
import errno
import json
import os
import re
import tempfile

import six
import tensorflow as tf

try:
# python version >= 3.3
from unittest import mock # pylint: disable=g-import-not-at-top
except ImportError:
import mock # pylint: disable=g-import-not-at-top,unused-import

from tensorboard import manager
from tensorboard import version
from tensorboard.util import tb_logging


def _make_info(i=0):
Expand Down Expand Up @@ -54,8 +64,8 @@ class TensorboardInfoTest(tf.test.TestCase):
"""Unit tests for TensorboardInfo typechecking and serialization."""

def test_roundtrip_serialization(self):
# This will also be tested indirectly as part of `manager`
# integration tests.
# This is also tested indirectly as part of `manager` integration
# tests, in `test_get_all`.
info = _make_info()
also_info = manager._info_from_string(manager._info_to_string(info))
self.assertEqual(also_info, info)
Expand Down Expand Up @@ -235,5 +245,101 @@ def test_arguments_list_vs_tuple_irrelevant(self):
self.assertEqual(with_list, with_tuple)


class TensorboardInfoIoTest(tf.test.TestCase):
"""Tests for `write_info_file`, `remove_info_file`, and `get_all`."""

def setUp(self):
super(TensorboardInfoIoTest, self).setUp()
patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()})
patcher.start()
self.addCleanup(patcher.stop)
tempfile.tempdir = None # force `gettempdir` to reinitialize from env
self.info_dir = manager._get_info_dir() # ensure that directory exists

def _list_info_dir(self):
return os.listdir(self.info_dir)

def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self):
os.rmdir(self.info_dir)
with open(self.info_dir, "w") as outfile:
pass
with self.assertRaises(OSError) as cm:
manager._get_info_dir()
self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception)

@mock.patch("os.getpid", lambda: 76540)
def test_write_remove_info_file(self):
info = _make_info()
self.assertEqual(self._list_info_dir(), [])
manager.write_info_file(info)
filename = "pid-76540.info"
expected_filepath = os.path.join(self.info_dir, filename)
self.assertEqual(self._list_info_dir(), [filename])
with open(expected_filepath) as infile:
self.assertEqual(manager._info_from_string(infile.read()), info)
manager.remove_info_file()
self.assertEqual(self._list_info_dir(), [])

def test_write_info_file_rejects_bad_types(self):
# The particulars of validation are tested more thoroughly in
# `TensorboardInfoTest` above.
info = _make_info()._replace(start_time=1549061116)
with six.assertRaisesRegex(
self,
ValueError,
"expected 'start_time' of type.*datetime.*, but found: 1549061116"):
manager.write_info_file(info)
self.assertEqual(self._list_info_dir(), [])

def test_write_info_file_rejects_wrong_version(self):
# The particulars of validation are tested more thoroughly in
# `TensorboardInfoTest` above.
info = _make_info()._replace(version="reversion")
with six.assertRaisesRegex(
self,
ValueError,
"expected 'version' to be '.*', but found: 'reversion'"):
manager.write_info_file(info)
self.assertEqual(self._list_info_dir(), [])

def test_remove_nonexistent(self):
# Should be a no-op, except to create the info directory if
# necessary. In particular, should not raise any exception.
manager.remove_info_file()

def test_get_all(self):
def add_info(i):
with mock.patch("os.getpid", lambda: 76540 + i):
manager.write_info_file(_make_info(i))
def remove_info(i):
with mock.patch("os.getpid", lambda: 76540 + i):
manager.remove_info_file()
self.assertItemsEqual(manager.get_all(), [])
add_info(1)
self.assertItemsEqual(manager.get_all(), [_make_info(1)])
add_info(2)
self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)])
remove_info(1)
self.assertItemsEqual(manager.get_all(), [_make_info(2)])
add_info(3)
self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)])
remove_info(3)
self.assertItemsEqual(manager.get_all(), [_make_info(2)])
remove_info(2)
self.assertItemsEqual(manager.get_all(), [])

def test_get_all_ignores_bad_files(self):
with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile:
outfile.write("good luck parsing this\n")
with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile:
outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n')
with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile:
outfile.write('if a tbinfo has st_mode==0, does it make a sound?\n')
os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000)
with mock.patch.object(tb_logging.get_logger(), "warning") as fn:
self.assertEqual(manager.get_all(), [])
self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 20435a1

Please sign in to comment.