Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gloo/common/store.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ class IStore {
virtual void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) = 0;

// Extended 2.0 API support
virtual bool has_v2_support() = 0;

virtual std::vector<std::vector<char>> multi_get(
const std::vector<std::string>& keys) = 0;

virtual void multi_set(
const std::vector<std::string>& keys,
const std::vector<std::vector<char>>& values) = 0;

virtual void append(
const std::string& key,
const std::vector<char>& value) = 0;
virtual int64_t add(const std::string& key, int64_t value) = 0;
};

} // namespace gloo
6 changes: 6 additions & 0 deletions gloo/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ bool useRankAsSeqNumber() {
(std::string(res) == "True" || std::string(res) == "1");
}

bool isStoreExtendedApiEnabled() {
const auto& res = std::getenv("GLOO_ENABLE_STORE_V2_API");
return res != nullptr &&
(std::string(res) == "True" || std::string(res) == "1");
}

} // namespace gloo
2 changes: 2 additions & 0 deletions gloo/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ std::string getHostname();

bool useRankAsSeqNumber();

bool isStoreExtendedApiEnabled();

} // namespace gloo
46 changes: 34 additions & 12 deletions gloo/transport/tcp/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

#include "gloo/transport/tcp/context.h"

#include <algorithm>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <string>

#include "gloo/common/error.h"
#include "gloo/common/logging.h"
#include "gloo/common/utils.h"
#include "gloo/transport/tcp/device.h"
Expand All @@ -22,6 +24,8 @@ namespace gloo {
namespace transport {
namespace tcp {

constexpr int kDefaultBatchSize = 128;

Context::Context(std::shared_ptr<Device> device, int rank, int size)
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}

Expand Down Expand Up @@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) {
// which does not have the rank info hosted at a higher `Pair` level).
// So better safe than sorry for now we try to minimize the changeset needed.
const auto& currentRankPair = getPair(rank);
auto deviceAddress = Address(
const auto& deviceAddress = Address(
static_cast<const Pair*>(currentRankPair.get())->address().getSockaddr());
Rank currentRankInfo(
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
store.set(std::to_string(rank), currentRankInfo.bytes());

std::vector<std::vector<char>> remoteRankInfos;
int key = 0;
if (isStoreExtendedApiEnabled() && store.has_v2_support()) {
auto sizeRemaining = size;
while (sizeRemaining > 0) {
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
std::vector<std::string> keys(batchKeys);
std::generate_n(
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
const auto& batchRemoteInfos = store.multi_get(keys);
remoteRankInfos.insert(
remoteRankInfos.end(),
batchRemoteInfos.begin(),
batchRemoteInfos.end());
sizeRemaining -= batchKeys;
}
} else {
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
const auto& keyStr = std::to_string(key++);
store.wait({keyStr.c_str()}, getTimeout());
return store.get(keyStr);
});
}

// Connect every pair
for (int i = 0; i < size; i++) {
if (i == rank) {
Expand All @@ -95,24 +123,18 @@ void Context::createAndConnectAllPairs(IStore& store) {
continue;
}

// Wait for address of other side of this pair to become available
std::ostringstream key;
key << i;
store.wait({key.str()}, getTimeout());
Rank remoteRankInfo(remoteRankInfos[i]);

// Connect to other side of this pair
std::vector<char> rankInfoBytes = store.get(key.str());
Rank remoteRankInfo(rankInfoBytes);
const auto& remoteHostname = remoteRankInfo.hostname;
if (!localRankSet && remoteHostname == localHostName) {
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
++localRank;
}

const auto& pair = getPair(i);
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
auto remoteAddr = Address(
remoteDeviceAddr,
useRankAsSeqNum ? (ssize_t)rank : remoteRankInfo.pairIdentifiers[rank]);
useRankAsSeqNum ? (sequence_number_t)rank
: remoteRankInfo.pairIdentifiers[rank]);
pair->connect(remoteAddr.bytes());
}

Expand Down
1 change: 1 addition & 0 deletions gloo/transport/tcp/listener.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <gloo/common/logging.h>
#include <gloo/common/utils.h>
#include <gloo/transport/tcp/helpers.h>
#

namespace gloo {
namespace transport {
Expand Down
Loading