Skip to content

Commit

Permalink
Remove waiting for the remote address of a TensorHandle from within…
Browse files Browse the repository at this point in the history
… the scope of acquiring a shared lock in `RemoteMgr`

Waiting for the remote address itself involves acquiring a lock and is independent of acquiring the `RemoteMgr`'s shared lock. Making the first acquisition nested inside the second should not be done. This may lead to deadlocks.

PiperOrigin-RevId: 642385183
  • Loading branch information
anshumang authored and tensorflower-gardener committed Jun 11, 2024
1 parent a415715 commit c2e7e9f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 15 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/distributed_runtime/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ cc_library(
"//tensorflow/core/nccl:collective_communicator",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@local_xla//xla/tsl/distributed_runtime/preemption:preemption_notifier",
] + tf_grpc_cc_dependencies(),
Expand Down Expand Up @@ -192,6 +194,7 @@ tf_cc_test(
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/platform:error_payloads",
"@com_google_absl//absl/time",
],
)

Expand Down
28 changes: 18 additions & 10 deletions tensorflow/core/distributed_runtime/eager/remote_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,15 @@ Status RemoteMgr::GetMirroredResourceShape(
return absl::OkStatus();
}

Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
const bool wait_until_ready,
int64_t* op_id, int32* output_num) {
TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
op_id, output_num));
Status RemoteMgr::ValidateRemoteTensorHandle(
const tensorflow::TensorHandle* handle, int64_t op_id, int32 output_num) {
tensorflow::TensorHandle* h;
TF_RETURN_IF_ERROR(
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
GetTensorHandleImpl(RemoteTensorHandleInternal(op_id, output_num), &h));
if (handle != h) {
return WithErrorSourcePayload(errors::Internal(
"Found two different tensor handles with the same op_id:", *op_id,
" and output_num:", *output_num));
"Found two different tensor handles with the same op_id:", op_id,
" and output_num:", output_num));
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -177,9 +174,20 @@ Status RemoteMgr::SerializeRemoteTensorHandle(
LOG(ERROR)
<< "Failed to get remote address for tensor handle with given device "
<< device->name() << " error " << status.message();
tf_shared_lock l(remote_tensor_handle_mu_);
DCHECK(in->Type() == TensorHandle::REMOTE);
// `device` passed as an argument to this function may or may not be the
// same as the device associated with the handle, `in->device()`. It could
// be used to obtain the `op_id` and `output_num` for the given handle by
// using its mirrors. But if this handle is not present in the other
// device's mirrors, then we could have to use `in->device()` anyway. By
// adding this check, the only other reason `RemoteAddress` can fail with
// `in->device()` is if the handle is poisoned.
TF_RETURN_IF_ERROR(
GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
in->RemoteAddress(in->device(), wait_until_ready, &op_id, &output_num));
{
tf_shared_lock l(remote_tensor_handle_mu_);
TF_RETURN_IF_ERROR(ValidateRemoteTensorHandle(in, op_id, output_num));
}
}
out->Clear();
out->set_op_id(op_id);
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/core/distributed_runtime/eager/remote_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ class RemoteMgr {
uint64 next_op_id_ TF_GUARDED_BY(next_id_mutex_) = 1;

private:
// Returns the op_id and output_num if the given local TensorHandle exists in
// remote_tensor_handle_map_.
Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
const bool wait_until_ready, int64_t* op_id,
int32* output_num)
// Checks if the given local `TensorHandle` has a different entry in
// `remote_tensor_handle_map_` for its `op_id` and `output_num`. It should
// not.
Status ValidateRemoteTensorHandle(const tensorflow::TensorHandle* handle,
int64_t op_id, int32 output_num)
TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);

Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle,
Expand Down
54 changes: 54 additions & 0 deletions tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/time/clock.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
Expand Down Expand Up @@ -51,6 +52,9 @@ class RemoteMgrTest : public ::testing::Test {
devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
remote_device_ = devices.back().get();
devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:1"));
another_remote_device_ = devices.back().get();
auto device_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices));
auto rendezvous = tsl::core::RefCountPtr<tensorflow::Rendezvous>(
new tensorflow::IntraProcessRendezvous(device_mgr.get()));
Expand All @@ -65,6 +69,7 @@ class RemoteMgrTest : public ::testing::Test {

Device* local_device_;
Device* remote_device_;
Device* another_remote_device_;
EagerContext* ctx_;
};

Expand All @@ -91,6 +96,55 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
handle->Unref();
}

TEST_F(RemoteMgrTest, SerializeByWaitingDeadlockAvoided) {
RemoteMgr remote_mgr(false, ctx_);

const uint64 op_id = 1;
const int output_num = 1;
// Later `SerializeRemoteTensorHandle` is called on `another_remote_device_`
// instead of the device used to create the handle (that is, `remote_device_`)
// to trigger a second call to `RemoteAddress` inside
// `SerializeRemoteTensorHandle`.
TensorHandle* handle = TensorHandle::CreateLazyRemoteHandle(
op_id, output_num, DT_FLOAT, remote_device_, /*is_ready=*/false, ctx_);

std::unique_ptr<Thread> thread_worker_1;
thread_worker_1.reset(tsl::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "thread_worker2",
[&remote_mgr, &handle, this]() {
// Grab tensor handle's lock for reading and then block because tensor
// handle is not ready. But do not grab remote mgr's lock for reading
// (which was not the case before).
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, /*wait_until_ready=*/true, &remote_handle,
another_remote_device_, another_remote_device_->name()));
}));

std::unique_ptr<Thread> thread_worker_2;
thread_worker_2.reset(tsl::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "thread_worker3",
[&remote_mgr, &handle, this]() {
// This sleep of 5s ensures that `AddOperationOutput` cannot get the
// remote mgr's lock before `SerializeRemoteTensorHandle` have had a
// chance to get to blocked state.
absl::SleepFor(absl::Seconds(5));
// Grab remote mgr's lock for writing (which would get stuck before) and
// release it.
remote_mgr.AddOperationOutput(handle, op_id, output_num);
// Set the tensor handle to ready (which would not happen before because
// `AddOperationOutput` is stuck) so that the other thread is now
// unblocked.
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
ctx_->GetContextViewId()));
}));

thread_worker_1.reset();
thread_worker_2.reset();

handle->Unref();
}

TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
RemoteMgr remote_mgr(false, ctx_);

Expand Down

0 comments on commit c2e7e9f

Please sign in to comment.