Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

projector: read run names from data provider #4494

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorboard/plugins/projector/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ py_library(
deps = [
":metadata",
":protos_all_py_pb2",
"//tensorboard:context",
"//tensorboard:expect_numpy_installed",
"//tensorboard/backend:http_util",
"//tensorboard/backend/event_processing:plugin_asset_util",
"//tensorboard/compat:tensorflow",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:tb_logging",
Expand Down Expand Up @@ -62,6 +64,7 @@ py_test(
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins:base_plugin",
Expand All @@ -82,6 +85,7 @@ py_test(
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat:no_tensorflow",
"//tensorboard/compat/proto:protos_all_py_pb2",
Expand Down
95 changes: 50 additions & 45 deletions tensorboard/plugins/projector/projector_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from google.protobuf import json_format
from google.protobuf import text_format

from tensorboard import context
from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.backend.http_util import Respond
from tensorboard.compat import tf
from tensorboard.plugins import base_plugin
Expand Down Expand Up @@ -239,12 +241,11 @@ def __init__(self, context):
Args:
context: A base_plugin.TBContext instance.
"""
self.multiplexer = context.multiplexer
self.data_provider = context.data_provider
self.logdir = context.logdir
self.readers = {}
self.run_paths = None
self._run_paths = None
self._configs = {}
self.old_num_run_paths = None
self.config_fpaths = None
self.tensor_cache = LRUCache(_TENSOR_CACHE_CAPACITY)

Expand All @@ -257,9 +258,6 @@ def __init__(self, context):
# active. If such a thread exists, do not start a duplicate thread.
self._thread_for_determining_is_active = None

if self.multiplexer:
self.run_paths = self.multiplexer.RunPaths()

def get_plugin_apps(self):
asset_prefix = "tf_projector_plugin"
return {
Expand All @@ -286,12 +284,13 @@ def get_plugin_apps(self):
def is_active(self):
"""Determines whether this plugin is active.

This plugin is only active if any run has an embedding.
This plugin is only active if any run has an embedding, and only
when running against a local log directory.

Returns:
Whether any run has embedding data to show in the projector.
"""
if not self.multiplexer:
if not self.data_provider or not self.logdir:
return False

if self._is_active:
Expand Down Expand Up @@ -329,37 +328,41 @@ def _determine_is_active(self):
offer an immediate response to whether it is active and
determine whether it should be active in a separate thread.
"""
if self.configs:
self._update_configs()
if self._configs:
self._is_active = True
self._thread_for_determining_is_active = None

@property
def configs(self):
"""Returns a map of run paths to `ProjectorConfig` protos."""
run_path_pairs = list(self.run_paths.items())
def _update_configs(self):
"""Updates `self._configs` and `self._run_paths`."""
if self.data_provider and self.logdir:
# Create a background context; we may not be in a request.
ctx = context.RequestContext()
run_paths = {
run.run_name: os.path.join(self.logdir, run.run_name)
for run in self.data_provider.list_runs(ctx, experiment_id="")
}
else:
run_paths = {}
run_paths_changed = run_paths != self._run_paths
self._run_paths = run_paths

run_path_pairs = list(self._run_paths.items())
self._append_plugin_asset_directories(run_path_pairs)
# Also accept the root logdir as a model checkpoint directory,
# so that the projector still works when there are no runs.
# (Case on `run` rather than `path` to avoid issues with
# absolute/relative paths on any filesystems.)
if not any(run == "." for (run, path) in run_path_pairs):
if "." not in self._run_paths:
run_path_pairs.append((".", self.logdir))
if self._run_paths_changed() or _latest_checkpoints_changed(
if run_paths_changed or _latest_checkpoints_changed(
self._configs, run_path_pairs
):
self.readers = {}
self._configs, self.config_fpaths = self._read_latest_config_files(
run_path_pairs
)
self._augment_configs_with_checkpoint_info()
return self._configs

def _run_paths_changed(self):
num_run_paths = len(list(self.run_paths.keys()))
if num_run_paths != self.old_num_run_paths:
self.old_num_run_paths = num_run_paths
return True
return False

def _augment_configs_with_checkpoint_info(self):
for run, config in self._configs.items():
Expand Down Expand Up @@ -518,18 +521,20 @@ def _get_embedding(self, tensor_name, config):
return None

def _append_plugin_asset_directories(self, run_path_pairs):
for run, assets in self.multiplexer.PluginAssets(
metadata.PLUGIN_ASSETS_NAME
).items():
extra = []
plugin_assets_name = metadata.PLUGIN_ASSETS_NAME
for (run, logdir) in run_path_pairs:
assets = plugin_asset_util.ListAssets(logdir, plugin_assets_name)
if metadata.PROJECTOR_FILENAME not in assets:
continue
assets_dir = os.path.join(
self.run_paths[run],
self._run_paths[run],
metadata.PLUGINS_DIR,
metadata.PLUGIN_ASSETS_NAME,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe replace this with plugin_assets_name, since we've extracted it in L525?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, indeed. Thanks.

)
assets_path_pair = (run, os.path.abspath(assets_dir))
run_path_pairs.append(assets_path_pair)
extra.append(assets_path_pair)
run_path_pairs.extend(extra)

@wrappers.Request.application
def _serve_file(self, file_path, request):
Expand All @@ -542,7 +547,8 @@ def _serve_file(self, file_path, request):
@wrappers.Request.application
def _serve_runs(self, request):
"""Returns a list of runs that have embeddings."""
return Respond(request, list(self.configs.keys()), "application/json")
self._update_configs()
return Respond(request, list(self._configs.keys()), "application/json")

@wrappers.Request.application
def _serve_config(self, request):
Expand All @@ -551,12 +557,12 @@ def _serve_config(self, request):
return Respond(
request, 'query parameter "run" is required', "text/plain", 400
)
if run not in self.configs:
self._update_configs()
config = self._configs.get(run)
if config is None:
return Respond(
request, 'Unknown run: "%s"' % run, "text/plain", 400
)

config = self.configs[run]
return Respond(
request, json_format.MessageToJson(config), "application/json"
)
Expand Down Expand Up @@ -584,12 +590,12 @@ def _serve_metadata(self, request):
400,
)

if run not in self.configs:
self._update_configs()
config = self._configs.get(run)
if config is None:
return Respond(
request, 'Unknown run: "%s"' % run, "text/plain", 400
)

config = self.configs[run]
fpath = self._get_metadata_file_for_tensor(name, config)
if not fpath:
return Respond(
Expand Down Expand Up @@ -644,13 +650,12 @@ def _serve_tensor(self, request):
400,
)

if run not in self.configs:
self._update_configs()
config = self._configs.get(run)
if config is None:
return Respond(
request, 'Unknown run: "%s"' % run, "text/plain", 400
)

config = self.configs[run]

tensor = self.tensor_cache.get((run, name))
if tensor is None:
# See if there is a tensor file in the config.
Expand Down Expand Up @@ -711,12 +716,12 @@ def _serve_bookmarks(self, request):
request, 'query parameter "name" is required', "text/plain", 400
)

if run not in self.configs:
self._update_configs()
config = self._configs.get(run)
if config is None:
return Respond(
request, 'Unknown run: "%s"' % run, "text/plain", 400
)

config = self.configs[run]
fpath = self._get_bookmarks_file_for_tensor(name, config)
if not fpath:
return Respond(
Expand Down Expand Up @@ -754,14 +759,14 @@ def _serve_sprite_image(self, request):
request, 'query parameter "name" is required', "text/plain", 400
)

if run not in self.configs:
self._update_configs()
config = self._configs.get(run)
if config is None:
return Respond(
request, 'Unknown run: "%s"' % run, "text/plain", 400
)

config = self.configs[run]
embedding_info = self._get_embedding(name, config)

if not embedding_info or not embedding_info.sprite.image_path:
return Respond(
request,
Expand Down
7 changes: 4 additions & 3 deletions tensorboard/plugins/projector/projector_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.protobuf import text_format

from tensorboard.backend import application
from tensorboard.backend.event_processing import data_provider
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
Expand Down Expand Up @@ -274,10 +275,10 @@ def testPluginIsNotActive(self):
self.assertEqual(2, mock.call_count)

def _SetupWSGIApp(self):
logdir = self.log_dir
multiplexer = event_multiplexer.EventMultiplexer()
context = base_plugin.TBContext(
logdir=self.log_dir, multiplexer=multiplexer
)
provider = data_provider.MultiplexerDataProvider(multiplexer, logdir)
context = base_plugin.TBContext(logdir=logdir, data_provider=provider)
self.plugin = projector_plugin.ProjectorPlugin(context)
wsgi_app = application.TensorBoardWSGI([self.plugin])
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
Expand Down