diff --git a/tensorboard/data/BUILD b/tensorboard/data/BUILD index 77cff74f58..bacffc3da4 100644 --- a/tensorboard/data/BUILD +++ b/tensorboard/data/BUILD @@ -95,6 +95,7 @@ py_library( ":ingester", "//tensorboard:expect_grpc_installed", "//tensorboard:expect_pkg_resources_installed", + "//tensorboard/data/proto:protos_all_py_pb2", "//tensorboard/util:tb_logging", ], ) diff --git a/tensorboard/data/server_ingester.py b/tensorboard/data/server_ingester.py index de74057e71..fb0d2fc59c 100644 --- a/tensorboard/data/server_ingester.py +++ b/tensorboard/data/server_ingester.py @@ -26,6 +26,7 @@ from tensorboard.data import grpc_provider from tensorboard.data import ingester +from tensorboard.data.proto import data_provider_pb2 from tensorboard.util import tb_logging @@ -47,7 +48,8 @@ def __init__(self, address, *, channel_creds_type): channel_creds_type: `grpc_util.ChannelCredsType`, as passed to `--grpc_creds_type`. """ - self._data_provider = _make_provider(address, channel_creds_type) + stub = _make_stub(address, channel_creds_type) + self._data_provider = grpc_provider.GrpcDataProvider(address, stub) @property def data_provider(self): @@ -170,13 +172,23 @@ def start(self): ) addr = "localhost:%d" % port - self._data_provider = _make_provider(addr, self._channel_creds_type) + stub = _make_stub(addr, self._channel_creds_type) logger.info( - "Established connection to data server at pid %d via %s", + "Opened channel to data server at pid %d via %s", popen.pid, addr, ) + req = data_provider_pb2.GetExperimentRequest() + try: + stub.GetExperiment(req, timeout=5) # should be near-instant + except grpc.RpcError as e: + msg = "Failed to communicate with data server at %s: %s" % (addr, e) + logging.warning("%s", msg) + raise DataServerStartupError(msg) from e + logger.info("Got valid response from data server") + self._data_provider = grpc_provider.GrpcDataProvider(addr, stub) + def _maybe_read_file(path): """Read a file, or return `None` on ENOENT specifically.""" @@ -189,12 +201,11 @@ def _maybe_read_file(path): raise -def _make_provider(addr, channel_creds_type): +def _make_stub(addr, channel_creds_type): (creds, options) = channel_creds_type.channel_config() options.append(("grpc.max_receive_message_length", 1024 * 1024 * 256)) channel = grpc.secure_channel(addr, creds, options=options) - stub = grpc_provider.make_stub(channel) - return grpc_provider.GrpcDataProvider(addr, stub) + return grpc_provider.make_stub(channel) class NoDataServerError(RuntimeError): diff --git a/tensorboard/version.py b/tensorboard/version.py index 94d570a915..7ecaad5974 100644 --- a/tensorboard/version.py +++ b/tensorboard/version.py @@ -15,4 +15,4 @@ """Contains the version string.""" -VERSION = "2.5.0a0" +VERSION = "2.5.0"