Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --plugins option to uploader #3402

Merged
merged 4 commits into from
Mar 21, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 27 additions & 14 deletions tensorboard/uploader/proto/server_info.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,37 @@ package tensorboard.service;
message ServerInfoRequest {
// Client-side TensorBoard version, per `tensorboard.version.VERSION`.
string version = 1;
// Information about the plugins for which the client wishes to upload data.
//
// If specified then the list of plugins will be confirmed by the server and
// echoed in the PluginControl.allowed_plugins field. Otherwise the server
// will return the default set of plugins it supports.
//
// If one of the plugins is not supported by the server then it will respond
// with compatibility verdict VERDICT_ERROR.
PluginSpecification plugin_specification = 2;
}

message ServerInfoResponse {
// Primary bottom-line: is the server compatible with the client, and is
// there anything that the end user should be aware of?
// Primary bottom-line: is the server compatible with the client, can it
// serve its request, and is there anything that the end user should be
// aware of?
Compatibility compatibility = 1;
// Identifier for a gRPC server providing the `TensorBoardExporterService` and
// `TensorBoardWriterService` services (under the `tensorboard.service` proto
// package).
ApiServer api_server = 2;
// How to generate URLs to experiment pages.
ExperimentUrlFormat url_format = 3;
// For which plugins should we upload data? (Even if the uploader is
// structurally capable of uploading data from many plugins, we only actually
// upload data that can be currently displayed in TensorBoard.dev. Otherwise,
// users may be surprised to see that experiments that they uploaded a while
// ago and have since shared or published now have extra information that
// they didn't realize had been uploaded.)
// Information about the plugins for which data should be uploaded.
//
// The client may always choose to upload less data than is permitted by this
// field: e.g., if the end user specifies not to upload data for a given
// plugin, or the client does not yet support uploading some kind of data.
// If PluginSpecification.requested_plugins is specified then
// that list of plugins will be confirmed by the server and echoed in the
// the response. Otherwise the server will return the default set of
// plugins it supports.
//
// If this field is omitted, there are no upfront restrictions on what the
// client may send.
// The client should only upload data for the plugins in the response even
// if it is capable of uploading more data.
PluginControl plugin_control = 4;
}

Expand Down Expand Up @@ -74,8 +80,15 @@ message ExperimentUrlFormat {
string id_placeholder = 2;
}

message PluginSpecification {
// Plugins for which the client wishes to upload data. These are plugin names
// as stored in the the `SummaryMetadata.plugin_data.plugin_name` proto
// field.
repeated string upload_plugins = 2;
}

message PluginControl {
// Only send data from plugins with these names. These are plugin names as
// Plugins for which data should be uploaded. These are plugin names as
// stored in the the `SummaryMetadata.plugin_data.plugin_name` proto field.
repeated string allowed_plugins = 1;
}
26 changes: 22 additions & 4 deletions tensorboard/uploader/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.protobuf import message
import requests

from absl import logging
from tensorboard import version
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.uploader.proto import server_info_pb2
Expand All @@ -30,19 +31,31 @@
_REQUEST_TIMEOUT_SECONDS = 10


def _server_info_request():
def _server_info_request(upload_plugins):
"""Generates a ServerInfoRequest

Args:
upload_plugins: List of plugin names requested by the user and to be
verified by the server.

Returns:
A `server_info_pb2.ServerInfoRequest` message.
"""
request = server_info_pb2.ServerInfoRequest()
request.version = version.VERSION
request.plugin_specification.upload_plugins[:] = upload_plugins
return request


def fetch_server_info(origin):
def fetch_server_info(origin, upload_plugins):
"""Fetches server info from a remote server.

Args:
origin: The server with which to communicate. Should be a string
like "https://tensorboard.dev", including protocol, host, and (if
needed) port.
upload_plugins: List of plugins names requested by the user and to be
verified by the server.

Returns:
A `server_info_pb2.ServerInfoResponse` message.
Expand All @@ -52,7 +65,9 @@ def fetch_server_info(origin):
communicate with the remote server.
"""
endpoint = "%s/api/uploader" % origin
post_body = _server_info_request().SerializeToString()
server_info_request = _server_info_request(upload_plugins)
post_body = server_info_request.SerializeToString()
logging.info("Requested server info: <%r>", server_info_request)
try:
response = requests.post(
endpoint,
Expand All @@ -75,13 +90,15 @@ def fetch_server_info(origin):
)


def create_server_info(frontend_origin, api_endpoint):
def create_server_info(frontend_origin, api_endpoint, upload_plugins):
"""Manually creates server info given a frontend and backend.

Args:
frontend_origin: The origin of the TensorBoard.dev frontend, like
"https://tensorboard.dev" or "http://localhost:8000".
api_endpoint: As to `server_info_pb2.ApiServer.endpoint`.
upload_plugins: List of plugin names requested by the user and to be
verified by the server.

Returns:
A `server_info_pb2.ServerInfoResponse` message.
Expand All @@ -95,6 +112,7 @@ def create_server_info(frontend_origin, api_endpoint):
placeholder = "{%s}" % placeholder
url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder)
url_format.id_placeholder = placeholder
result.plugin_control.allowed_plugins[:] = upload_plugins
return result


Expand Down
46 changes: 39 additions & 7 deletions tensorboard/uploader/server_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,39 @@ def app(request):
body = request.get_data()
request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
self.assertEqual(request_pb.version, version.VERSION)
self.assertEqual(request_pb.plugin_specification.upload_plugins, [])
return wrappers.BaseResponse(expected_result.SerializeToString())

origin = self._start_server(app)
result = server_info.fetch_server_info(origin)
result = server_info.fetch_server_info(origin, [])
self.assertEqual(result, expected_result)

def test_fetches_with_plugins(self):
@wrappers.BaseRequest.application
def app(request):
body = request.get_data()
request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
self.assertEqual(request_pb.version, version.VERSION)
self.assertEqual(
request_pb.plugin_specification.upload_plugins,
["plugin1", "plugin2"],
)
return wrappers.BaseResponse(
server_info_pb2.ServerInfoResponse().SerializeToString()
)

origin = self._start_server(app)
result = server_info.fetch_server_info(origin, ["plugin1", "plugin2"])
self.assertIsNotNone(result)

def test_econnrefused(self):
(family, localhost) = _localhost()
s = socket.socket(family)
s.bind((localhost, 0))
self.addCleanup(s.close)
port = s.getsockname()[1]
with self.assertRaises(server_info.CommunicationError) as cm:
server_info.fetch_server_info("http://localhost:%d" % port)
server_info.fetch_server_info("http://localhost:%d" % port, [])
msg = str(cm.exception)
self.assertIn("Failed to connect to backend", msg)
if os.name != "nt":
Expand All @@ -97,7 +116,7 @@ def app(request):

origin = self._start_server(app)
with self.assertRaises(server_info.CommunicationError) as cm:
server_info.fetch_server_info(origin)
server_info.fetch_server_info(origin, [])
msg = str(cm.exception)
self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg)
self.assertIn("very sad", msg)
Expand All @@ -110,7 +129,7 @@ def app(request):

origin = self._start_server(app)
with self.assertRaises(server_info.CommunicationError) as cm:
server_info.fetch_server_info(origin)
server_info.fetch_server_info(origin, [])
msg = str(cm.exception)
self.assertIn("Corrupt response from backend", msg)
self.assertIn("an unlikely proto", msg)
Expand All @@ -123,18 +142,18 @@ def app(request):
return wrappers.BaseResponse(result.SerializeToString())

origin = self._start_server(app)
result = server_info.fetch_server_info(origin)
result = server_info.fetch_server_info(origin, [])
expected_user_agent = "tensorboard/%s" % version.VERSION
self.assertEqual(result.compatibility.details, expected_user_agent)


class CreateServerInfoTest(tb_test.TestCase):
"""Tests for `create_server_info`."""

def test(self):
def test_response(self):
frontend = "http://localhost:8080"
backend = "localhost:10000"
result = server_info.create_server_info(frontend, backend)
result = server_info.create_server_info(frontend, backend, [])

expected_compatibility = server_info_pb2.Compatibility()
expected_compatibility.verdict = server_info_pb2.VERDICT_OK
Expand All @@ -152,6 +171,19 @@ def test(self):
expected_url = "http://localhost:8080/experiment/123/"
self.assertEqual(actual_url, expected_url)

self.assertEqual(result.plugin_control.allowed_plugins, [])

def test_response_with_plugins(self):
frontend = "http://localhost:8080"
backend = "localhost:10000"
result = server_info.create_server_info(
frontend, backend, ["plugin1", "plugin2"]
)

self.assertEqual(
result.plugin_control.allowed_plugins, ["plugin1", "plugin2"]
)


class ExperimentUrlTest(tb_test.TestCase):
"""Tests for `experiment_url`."""
Expand Down
18 changes: 16 additions & 2 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def _define_flags(parser):
help="Experiment description. Markdown format. Max 600 characters.",
)

upload.add_argument(
bmd3k marked this conversation as resolved.
Show resolved Hide resolved
"--plugins",
type=str,
nargs="*",
default=[],
help="List of plugins for which data should be uploaded. If "
"unspecified then data will be uploaded for all plugins supported by "
"the server.",
)

update_metadata = subparsers.add_parser(
"update-metadata",
help="change the name, description, or other user "
Expand Down Expand Up @@ -733,9 +743,13 @@ def _get_intent(flags):

def _get_server_info(flags):
bmd3k marked this conversation as resolved.
Show resolved Hide resolved
origin = flags.origin or _DEFAULT_ORIGIN
plugins = flags.plugins if hasattr(flags, "plugins") else []
bmd3k marked this conversation as resolved.
Show resolved Hide resolved

if flags.api_endpoint and not flags.origin:
return server_info_lib.create_server_info(origin, flags.api_endpoint)
server_info = server_info_lib.fetch_server_info(origin)
return server_info_lib.create_server_info(
origin, flags.api_endpoint, plugins
)
server_info = server_info_lib.fetch_server_info(origin, plugins)
# Override with any API server explicitly specified on the command
# line, but only if the server accepted our initial handshake.
if flags.api_endpoint and server_info.api_server.endpoint:
Expand Down