Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherrypick: add an enter_master_device flag in tf.config.experimental_connect_to_cluster API. #32061

Merged
merged 3 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions tensorflow/python/eager/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from __future__ import print_function

import os
from absl import logging

from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.platform import remote_utils
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
Expand All @@ -36,30 +38,25 @@
@tf_export("config.experimental_connect_to_host")
def connect_to_remote_host(remote_host=None, job_name="worker"):
"""Connects to a single machine to enable remote execution on it.

Will make devices on the remote host available to use. Note that calling this
more than once will work, but will invalidate any tensor handles on the old
remote devices.

Using the default job_name of worker, you can schedule ops to run remotely as
follows:
```python
# Enable eager execution, and connect to the remote host.
tf.compat.v1.enable_eager_execution()
tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")

with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
# The following tensors should be resident on the remote device, and the op
# will also execute remotely.
x1 = array_ops.ones([2, 2])
x2 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x2)
```

Args:
remote_host: a single or a list the remote server addr in host-port format.
job_name: The job name under which the new server will be accessible.

Raises:
ValueError: if remote_host is None.
"""
Expand All @@ -77,23 +74,26 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
def connect_to_cluster(cluster_spec_or_resolver,
job_name="localhost",
task_index=0,
protocol=None):
protocol=None,
make_master_device_default=True):
"""Connects to the given cluster.

Will make devices on the cluster available to use. Note that calling this more
than once will work, but will invalidate any tensor handles on the old remote
devices.

If the given local job name is not present in the cluster specification, it
will be automatically added, using an unused port on the localhost.

Args:
cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
the cluster.
job_name: The name of the local job.
task_index: The local task index.
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
use the default from `python/platform/remote_utils.py`.
make_master_device_default: If True and a cluster resolver is passed, will
automatically enter the master task device scope, which indicates the
master becomes the default device to run ops. It won't do anything if
a cluster spec is passed. Will throw an error if the caller is currently
already in some device scope.
"""
protocol = protocol or remote_utils.get_default_communication_protocol()
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
Expand Down Expand Up @@ -124,6 +124,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
context.set_server_def(server_def)

if make_master_device_default and isinstance(
cluster_spec_or_resolver,
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
master = cluster_spec_or_resolver.master()
master_job_name = None
master_task_id = None
for job_name in cluster_spec.jobs:
for task_id in cluster_spec.task_indices(job_name):
task_address = cluster_spec.task_address(job_name, task_id)
if master in task_address or task_address in master:
master_job_name = job_name
master_task_id = task_id
break

if not master_job_name:
raise ValueError(
"`make_master_device_default` is set to True but cannot find "
"master %s in the cluster" % master)

master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
master_task_id)
if not _device_stack_is_empty():
raise ValueError("`connect_to_cluster` should not be called inside "
"an existing device scope")
logging.info("Entering into master device scope: %s", master_device)
# TODO(b/138389076): Think of the entering device scope behavior in the
# failure recovery case when dealing with preemptions.
ops.device(master_device).__enter__()


def _strip_prefix(s, prefix):
return s[len(prefix):] if s.startswith(prefix) else s


def _device_stack_is_empty():
if context.executing_eagerly():
return not bool(context.context().device_name)
# pylint: disable=protected-access
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
# pylint: enable=protected-access
return not bool(device_stack)
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
}
member_method {
name: "experimental_connect_to_host"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
}
member_method {
name: "experimental_connect_to_host"
Expand Down