Skip to content

Commit

Permalink
Skips wait for remote tensor handle to be ready based on an environme…
Browse files Browse the repository at this point in the history
…nt variable

This wait is needed when both a tensor handle (input to a function) and a component function can be remote to ensure that the tensor handle is ready before the function executes. In cases, where it can be guaranteed that the function will never be sent before the input to the function, this wait can be skipped.

PiperOrigin-RevId: 612221424
  • Loading branch information
anshumang authored and tensorflower-gardener committed Mar 3, 2024
1 parent 6999f26 commit c5f5900
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ tf_cuda_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/util:env_var",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/core/common_runtime/eager/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tsl/platform/refcount.h"
#include "tsl/util/env_var.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
Expand Down Expand Up @@ -105,6 +106,16 @@ auto* eager_context_created =

const int64_t EagerContext::kGlobalRendezvousId = -1;

bool SkipRemoteHandleWaitReady() {
static bool skip_remote_handle_wait_ready = []() {
bool result;
TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_REMOTE_HANDLE_SKIP_WAIT_FOR_READY",
false, &result));
return result;
}();
return skip_remote_handle_wait_ready;
}

// Find the rendezvous instance corresponding to the step id, or create a
// new instance if not existing.
tsl::core::RefCountPtr<IntraProcessRendezvous>
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/common_runtime/eager/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ namespace eager {
class RemoteMgr;
} // namespace eager

// Check the value of the environment variable,
// `TF_REMOTE_HANDLE_SKIP_WAIT_FOR_READY` from its cached copy in memory and if
// not cached, reads from the environment variable.
bool SkipRemoteHandleWaitReady();

class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
public:
static constexpr uint64 kInvalidContextId = 0;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/common_runtime/eager/execute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// guarantee that the input generation request is processed before the
// function execution request, so wait until the remote input is ready
// before sending it to the multi-device function device.
const bool wait_until_ready = op->is_function();
bool wait_until_ready =
SkipRemoteHandleWaitReady() ? false : op->is_function();
TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
input, wait_until_ready, input_handle, input_device,
*input_device_name, serialize_resource_dtype_and_shape));
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/common_runtime/eager/execute_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/execute_node.h"

#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tsl/util/env_var.h"

namespace tensorflow {

Expand Down Expand Up @@ -112,9 +114,9 @@ Status ExecuteNodeArgs::Init(
// worker before a remote execution request which produces an input of the
// component function. So we wait until the remote input is ready before
// serializing it.
const bool wait_util_ready = is_function;
bool wait_until_ready = SkipRemoteHandleWaitReady() ? false : is_function;
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
h, wait_util_ready, handle, device, device->name());
h, wait_until_ready, handle, device, device->name());
};
}
#endif // !IS_MOBILE_PLATFORM
Expand Down

0 comments on commit c5f5900

Please sign in to comment.