Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions gloo/barrier_all_to_all.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,33 @@ namespace gloo {
class BarrierAllToAll : public Barrier {
public:
explicit BarrierAllToAll(const std::shared_ptr<Context>& context)
: Barrier(context) {
: Barrier(context) {}

void run() {
// Create send/recv buffers for every peer
auto slot = this->context_->nextSlot();

auto buffer = this->context_->createUnboundBuffer(nullptr, 0);
auto timeout = this->context_->getTimeout();

for (auto i = 0; i < this->contextSize_; i++) {
// Skip self
if (i == this->contextRank_) {
continue;
}

auto& pair = this->getPair(i);
auto sdata = std::unique_ptr<int>(new int);
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
sendBuffersData_.push_back(std::move(sdata));
sendBuffers_.push_back(std::move(sbuf));
auto rdata = std::unique_ptr<int>(new int);
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
recvBuffersData_.push_back(std::move(rdata));
recvBuffers_.push_back(std::move(rbuf));
buffer->send(i, slot);
buffer->recv(i, slot);
}
}

void run() {
// Notify peers
for (auto& buffer : sendBuffers_) {
buffer->send();
}
// Wait for notification from peers
for (auto& buffer : recvBuffers_) {
buffer->waitRecv();
for (auto i = 0; i < this->contextSize_; i++) {
// Skip self
if (i == this->contextRank_) {
continue;
}
buffer->waitSend(timeout);
buffer->waitRecv(timeout);
}
}

protected:
std::vector<std::unique_ptr<int>> sendBuffersData_;
std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
std::vector<std::unique_ptr<int>> recvBuffersData_;
std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
};

} // namespace gloo
65 changes: 21 additions & 44 deletions gloo/barrier_all_to_one.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,65 +17,42 @@ class BarrierAllToOne : public Barrier {
explicit BarrierAllToOne(
const std::shared_ptr<Context>& context,
int rootRank = 0)
: Barrier(context), rootRank_(rootRank) {
: Barrier(context), rootRank_(rootRank) {}

void run() {
auto slot = this->context_->nextSlot();
auto timeout = this->context_->getTimeout();

auto buffer = this->context_->createUnboundBuffer(nullptr, 0);

if (this->contextRank_ == rootRank_) {
// Create send/recv buffers for every peer
for (int i = 0; i < this->contextSize_; i++) {
// Skip self
if (i == this->contextRank_) {
continue;
}

auto& pair = this->getPair(i);
auto sdata = std::unique_ptr<int>(new int);
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
sendBuffersData_.push_back(std::move(sdata));
sendBuffers_.push_back(std::move(sbuf));
auto rdata = std::unique_ptr<int>(new int);
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
recvBuffersData_.push_back(std::move(rdata));
recvBuffers_.push_back(std::move(rbuf));
buffer->recv(i, slot);
buffer->waitRecv(timeout);
}
} else {
// Create send/recv buffers to/from the root
auto& pair = this->getPair(rootRank_);
auto sdata = std::unique_ptr<int>(new int);
auto sbuf = pair->createSendBuffer(slot, sdata.get(), sizeof(int));
sendBuffersData_.push_back(std::move(sdata));
sendBuffers_.push_back(std::move(sbuf));
auto rdata = std::unique_ptr<int>(new int);
auto rbuf = pair->createRecvBuffer(slot, rdata.get(), sizeof(int));
recvBuffersData_.push_back(std::move(rdata));
recvBuffers_.push_back(std::move(rbuf));
}
}

void run() {
if (this->contextRank_ == rootRank_) {
// Wait for message from all peers
for (auto& b : recvBuffers_) {
b->waitRecv();
}
// Notify all peers
for (auto& b : sendBuffers_) {
b->send();
for (int i = 0; i < this->contextSize_; i++) {
// Skip self
if (i == this->contextRank_) {
continue;
}
buffer->send(i, slot);
buffer->waitSend(timeout);
}

} else {
// Send message to root
sendBuffers_[0]->send();
// Wait for acknowledgement from root
recvBuffers_[0]->waitRecv();
buffer->send(rootRank_, slot);
buffer->waitSend(timeout);
buffer->recv(rootRank_, slot);
buffer->waitRecv(timeout);
}
}

protected:
const int rootRank_;

std::vector<std::unique_ptr<int>> sendBuffersData_;
std::vector<std::unique_ptr<transport::Buffer>> sendBuffers_;
std::vector<std::unique_ptr<int>> recvBuffersData_;
std::vector<std::unique_ptr<transport::Buffer>> recvBuffers_;
};

} // namespace gloo
63 changes: 14 additions & 49 deletions gloo/broadcast_one_to_all.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,6 @@ class BroadcastOneToAll : public Algorithm {
GLOO_ENFORCE_LT(rootRank_, contextSize_);
GLOO_ENFORCE_GE(rootPointerRank_, 0);
GLOO_ENFORCE_LT(rootPointerRank_, ptrs_.size());

// Setup pairs/buffers for sender/receivers
if (contextSize_ > 1) {
auto ptr = ptrs_[rootPointerRank_];
auto slot = context_->nextSlot();
if (contextRank_ == rootRank_) {
sender_.resize(contextSize_);
for (auto i = 0; i < contextSize_; i++) {
if (i == contextRank_) {
continue;
}

sender_[i] = make_unique<forSender>();
auto& pair = context_->getPair(i);
sender_[i]->clearToSendBuffer = pair->createRecvBuffer(
slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy));
sender_[i]->sendBuffer = pair->createSendBuffer(slot, ptr, bytes_);
}
} else {
receiver_ = make_unique<forReceiver>();
auto& rootPair = context_->getPair(rootRank_);
receiver_->clearToSendBuffer = rootPair->createSendBuffer(
slot, &receiver_->dummy, sizeof(receiver_->dummy));
receiver_->recvBuffer = rootPair->createRecvBuffer(slot, ptr, bytes_);
}
}
}

void run() {
Expand All @@ -70,14 +44,21 @@ class BroadcastOneToAll : public Algorithm {
return;
}

auto clearToSendBuffer = context_->createUnboundBuffer(nullptr, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N00b question why this is even needed? Is it because we need to wait for the receiver to be ready?

auto buffer =
context_->createUnboundBuffer(ptrs_[rootPointerRank_], bytes_);
auto slot = context_->nextSlot();
auto timeout = context_->getTimeout();

if (contextRank_ == rootRank_) {
// Fire off send operations after receiving clear to send
for (auto i = 0; i < contextSize_; i++) {
if (i == contextRank_) {
continue;
}
sender_[i]->clearToSendBuffer->waitRecv();
sender_[i]->sendBuffer->send();
clearToSendBuffer->recv(i, slot);
clearToSendBuffer->waitRecv(timeout);
buffer->send(i, slot);
}

// Broadcast locally while sends are happening
Expand All @@ -88,11 +69,13 @@ class BroadcastOneToAll : public Algorithm {
if (i == contextRank_) {
continue;
}
sender_[i]->sendBuffer->waitSend();
buffer->waitSend(timeout);
}
} else {
receiver_->clearToSendBuffer->send();
receiver_->recvBuffer->waitRecv();
clearToSendBuffer->send(rootRank_, slot);
clearToSendBuffer->waitSend(timeout);
buffer->recv(rootRank_, slot);
buffer->waitRecv(timeout);

// Broadcast locally after receiving from root
broadcastLocally();
Expand All @@ -116,24 +99,6 @@ class BroadcastOneToAll : public Algorithm {
const size_t bytes_;
const int rootRank_;
const int rootPointerRank_;

// For the sender (root)
struct forSender {
int dummy;
std::unique_ptr<transport::Buffer> clearToSendBuffer;
std::unique_ptr<transport::Buffer> sendBuffer;
};

std::vector<std::unique_ptr<forSender>> sender_;

// For all receivers
struct forReceiver {
int dummy;
std::unique_ptr<transport::Buffer> clearToSendBuffer;
std::unique_ptr<transport::Buffer> recvBuffer;
};

std::unique_ptr<forReceiver> receiver_;
};

} // namespace gloo
Loading