Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verbs fix: Removed Dependency on Duplicate Recv Flag #12705

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
254 changes: 154 additions & 100 deletions tensorflow/contrib/verbs/rdma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ void RdmaAdapter::Process_CQ() {
RdmaBuffer* ab = rc->tx_ack_buffer_;
ab->SendNextItem();
// find buffer
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
RdmaTensorBuffer* tb =
reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
tb->SetBufferStatus(remote, idle);
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
// remote host requests to create a tensor buffer;
// send ack to release remote tx message buffer
Expand Down Expand Up @@ -198,7 +199,8 @@ void RdmaAdapter::Process_CQ() {
RdmaBuffer* ab = rc->tx_ack_buffer_;
ab->SendNextItem();
// find buffer
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
RdmaTensorBuffer* tb =
reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
CHECK(rm.buffer_size_ == tb->size_)
<< "rm.buffer_size = " << rm.buffer_size_
<< "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
Expand All @@ -208,7 +210,7 @@ void RdmaAdapter::Process_CQ() {
tb->SetRemoteMR(rmr, true);
tb->SetBufferStatus(local, idle);
tb->SetBufferStatus(remote, idle);
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
} else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
// tensor RDMA write completed
worker_env_->compute_pool->Schedule([rm, rc]() {
Expand Down Expand Up @@ -624,6 +626,12 @@ RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
: RdmaBuffer(channel, name) {}

RdmaTensorBuffer::~RdmaTensorBuffer() {
for (Itable it = retable.begin(); it != retable.end(); ++it) {
delete (it->second);
}
}

// Send the next ack from the buffer's job queue.
void RdmaAckBuffer::SendNextItem() {
uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
Expand Down Expand Up @@ -655,6 +663,99 @@ void RdmaMessageBuffer::SendNextItem() {
}
}

Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
const string& key_with_step_id, const string& key, int64 step_id,
const Rendezvous::ParsedKey& parsed) {
Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, parsed](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
<< " error message: " << status.error_message();
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
size_t tensor_bytes = 0;
// Figures out which device the tensor is hosted on.
Device* src_dev = nullptr;
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
parsed.src_device, &src_dev);
CHECK(s.ok()) << "src device not found";
// Does the device have the right incarnation number we expect?
CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
<< "RecvTensor expects a different device incarnation: "
<< parsed.src_incarnation << " vs. "
<< src_dev->attributes().incarnation()
<< ". Your worker job was probably restarted. Check your "
<< "worker job for the reason why it was restarted.";
Device* dst_dev = nullptr;
// destination is on CPU.
s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
&dst_dev);
CHECK(s.ok()) << "dst device not found";
AllocatorAttributes dst_alloc_attr;
dst_alloc_attr.set_on_host(true);

bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
Tensor copy;
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
CHECK(send_args.device_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();

if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
host_alloc_attrs.set_gpu_compatible(true);
host_alloc_attrs.set_on_host(true);
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
copy = Tensor(alloc, in.dtype(), in.shape());
tensor_bytes = in.TotalBytes();
buffer_size += tensor_bytes;
GPUUtil::CopyGPUTensorToCPU(
src_dev, send_args.device_context, &in, &copy,
[this, copy, tensor_bytes, buffer_size, key, in, step_id,
key_with_step_id, is_dead, send_args, recv_args](const Status& s) {
CHECK(s.ok()) << "copy tensor from gpu sync";
StringPiece copy_buf;
copy_buf = copy.tensor_data();
PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, &copy,
NULL, &copy_buf, send_args, recv_args);
});
} else {
// "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
[this, proto, buffer_size, key, in, step_id, key_with_step_id,
is_dead, send_args, recv_args](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;
PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, NULL,
&proto, NULL, send_args, recv_args);
});
}
} else {
// tensor is in CPU memory.
StringPiece copy_buf;
if (can_memcpy) {
copy_buf = in.tensor_data();
tensor_bytes = in.TotalBytes();
} else {
in.AsProtoTensorContent(&proto);
tensor_bytes = proto.ByteSize();
}
buffer_size += tensor_bytes;
PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, &copy, &proto,
&copy_buf, send_args, recv_args);
}
};
return cb;
}

// Send the next tensor from the buffer's job queue.
void RdmaTensorBuffer::SendNextItem() {
// get the key
Expand All @@ -666,6 +767,7 @@ void RdmaTensorBuffer::SendNextItem() {
queue_.pop();
}
}

// send the tensor if a key is acquired.
if (key_with_step_id != "") {
VLOG(2) << "try to send tensor: " << key_with_step_id;
Expand All @@ -675,107 +777,54 @@ void RdmaTensorBuffer::SendNextItem() {
CHECK(key.compare(name_) == 0);
Rendezvous::ParsedKey parsed;
Rendezvous::ParseKey(key, &parsed);
Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
parsed](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
<< " error message: " << status.error_message();
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
size_t tensor_bytes = 0;
// Figures out which device the tensor is hosted on.
Device* src_dev = nullptr;
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
parsed.src_device, &src_dev);
CHECK(s.ok()) << "src device not found";
// Does the device have the right incarnation number we expect?
CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
<< "RecvTensor expects a different device incarnation: "
<< parsed.src_incarnation << " vs. "
<< src_dev->attributes().incarnation()
<< ". Your worker job was probably restarted. Check your "
<< "worker job for the reason why it was restarted.";
Device* dst_dev = nullptr;
// destination is on CPU.
s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
&dst_dev);
CHECK(s.ok()) << "dst device not found";
AllocatorAttributes dst_alloc_attr;
dst_alloc_attr.set_on_host(true);

bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
Tensor copy;
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
CHECK(send_args.device_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();

if (can_memcpy) {
AllocatorAttributes host_alloc_attrs;
host_alloc_attrs.set_gpu_compatible(true);
host_alloc_attrs.set_on_host(true);
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
copy = Tensor(alloc, in.dtype(), in.shape());
tensor_bytes = in.TotalBytes();
buffer_size += tensor_bytes;
GPUUtil::CopyGPUTensorToCPU(
src_dev, send_args.device_context, &in, &copy,
[this, copy, tensor_bytes, buffer_size, key, in, step_id,
key_with_step_id, is_dead](const Status& s) {
CHECK(s.ok()) << "copy tensor from gpu sync";
StringPiece copy_buf;
copy_buf = copy.tensor_data();
PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, &copy,
NULL, &copy_buf);
});
} else {
// "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
// aync instead
GPUUtil::SetProtoFromGPU(
in, src_dev, send_args.device_context, &proto, is_dead,
[this, proto, buffer_size, key, in, step_id, key_with_step_id,
is_dead](const Status& s) mutable {
CHECK(s.ok()) << "copy proto from gpu sync";
auto tensor_bytes = proto.ByteSize();
buffer_size += tensor_bytes;
PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, NULL,
&proto, NULL);
});
}
} else {
// tensor is in CPU memory.
StringPiece copy_buf;
if (can_memcpy) {
copy_buf = in.tensor_data();
tensor_bytes = in.TotalBytes();
} else {
in.AsProtoTensorContent(&proto);
tensor_bytes = proto.ByteSize();
}
buffer_size += tensor_bytes;
PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
step_id, is_dead, key_with_step_id, &copy, &proto,
&copy_buf);
}
// maybe some margin for string tensor?
};

Rendezvous::DoneCallback cb =
getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id,
parsed, cb);
}
}

void RdmaTensorBuffer::ReSendNextItem() {
// get the key
string key_with_step_id = "";
{
mutex_lock lock{mu_};
if (!requeue.empty()) {
key_with_step_id = requeue.front();
requeue.pop();
}
}

// send the tensor if a key is acquired.
if (key_with_step_id != "") {
VLOG(2) << "try to send tensor: " << key_with_step_id;
string key;
int64 step_id;
VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
CHECK(key.compare(name_) == 0);
Rendezvous::ParsedKey parsed;
Rendezvous::ParseKey(key, &parsed);
Rendezvous::DoneCallback cb =
getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
ReItem* item;
{
mutex_lock lock{mu_};
Itable it = retable.find(key_with_step_id);
CHECK(it != retable.end()) << "Could not find dup-recv context";
item = it->second;
retable.erase(it);
}
cb(Status::OK(), item->send_args, item->recv_args, item->in, item->is_dead);
delete (item);
}
}

void RdmaTensorBuffer::PostCopyOperations(
bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key,
const Tensor& in, int64 step_id, bool is_dead,
const string& key_with_step_id, const Tensor* copy,
const TensorProto* proto, const StringPiece* copy_buf) {
const TensorProto* proto, const StringPiece* copy_buf,
const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args) {
// prepare message
RdmaMessage rm;
rm.name_size_ = key.size();
Expand All @@ -793,9 +842,12 @@ void RdmaTensorBuffer::PostCopyOperations(
VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size;
}
CreateCPUBuffer(buffer_size, false);
// Need to be received again, put into the re-recv queue and the table
requeue.push(key_with_step_id);
ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
mu_.unlock();
// put back the key since it is not sent;
EnqueueItem(key_with_step_id);
// no longer used: put back the key since it is not sent;
// ask the remote to create the same buffer
rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
Expand Down Expand Up @@ -841,9 +893,11 @@ void RdmaTensorBuffer::PostCopyOperations(
}
Write(imm_data, buffer_size);
} else {
// Need to be received again, put into the re-recv queue and the table
requeue.push(key_with_step_id);
ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
mu_.unlock();
// put back the key since it is not sent;
EnqueueItem(key_with_step_id);
}
}

Expand Down
50 changes: 47 additions & 3 deletions tensorflow/contrib/verbs/rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <vector>

#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
Expand Down Expand Up @@ -224,14 +225,57 @@ class RdmaMessageBuffer : public RdmaBuffer {
class RdmaTensorBuffer : public RdmaBuffer {
public:
explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
virtual ~RdmaTensorBuffer() override {}
virtual ~RdmaTensorBuffer() override;
void SendNextItem() override;
void PostCopyOperations(bool can_memcpy, size_t buffer_size,
size_t tensor_bytes, const string& key,
const Tensor& in, int64 step_id, bool is_dead,
const string& key_with_step_id, const Tensor* copy,
const TensorProto* proto,
const StringPiece* copy_buf);
const TensorProto* proto, const StringPiece* copy_buf,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args);

void ReSendNextItem();

private:
Rendezvous::DoneCallback getRecvTensorCallback(
const string& key_with_step_id, const string& key, int64 step_id,
const Rendezvous::ParsedKey& parsed);

struct ReItem {
Rendezvous::Args send_args;
Rendezvous::Args recv_args;
Tensor in;
bool is_dead;

ReItem(const Rendezvous::Args& send_args_,
const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_)
: send_args(send_args_),
recv_args(recv_args_),
in(in_),
is_dead(is_dead_) {
if (send_args.device_context) {
send_args.device_context->Ref();
}
if (recv_args.device_context) {
recv_args.device_context->Ref();
}
}

~ReItem() {
if (send_args.device_context) {
send_args.device_context->Unref();
}
if (recv_args.device_context) {
recv_args.device_context->Unref();
}
}
};
typedef std::map<string, ReItem*> Table;
typedef Table::iterator Itable;

std::queue<string> requeue GUARDED_BY(mu_);
Table retable GUARDED_BY(mu_);
};

struct RdmaMessage {
Expand Down