diff --git a/gloo/common/store.h b/gloo/common/store.h index 010d81587..cff42572c 100644 --- a/gloo/common/store.h +++ b/gloo/common/store.h @@ -25,6 +25,21 @@ class IStore { virtual void wait( const std::vector& keys, const std::chrono::milliseconds& timeout) = 0; + + // Extended 2.0 API support + virtual bool has_v2_support() = 0; + + virtual std::vector> multi_get( + const std::vector& keys) = 0; + + virtual void multi_set( + const std::vector& keys, + const std::vector>& values) = 0; + + virtual void append( + const std::string& key, + const std::vector& value) = 0; + virtual int64_t add(const std::string& key, int64_t value) = 0; }; } // namespace gloo diff --git a/gloo/common/utils.cc b/gloo/common/utils.cc index 7b3172d0b..d543d2a0f 100644 --- a/gloo/common/utils.cc +++ b/gloo/common/utils.cc @@ -36,4 +36,10 @@ bool useRankAsSeqNumber() { (std::string(res) == "True" || std::string(res) == "1"); } +bool isStoreExtendedApiEnabled() { + const auto& res = std::getenv("GLOO_ENABLE_STORE_V2_API"); + return res != nullptr && + (std::string(res) == "True" || std::string(res) == "1"); +} + } // namespace gloo diff --git a/gloo/common/utils.h b/gloo/common/utils.h index 343c3fab9..185ebaf19 100644 --- a/gloo/common/utils.h +++ b/gloo/common/utils.h @@ -16,4 +16,6 @@ std::string getHostname(); bool useRankAsSeqNumber(); +bool isStoreExtendedApiEnabled(); + } // namespace gloo diff --git a/gloo/transport/tcp/context.cc b/gloo/transport/tcp/context.cc index 20cf97e29..5140fd2fe 100644 --- a/gloo/transport/tcp/context.cc +++ b/gloo/transport/tcp/context.cc @@ -8,10 +8,12 @@ #include "gloo/transport/tcp/context.h" +#include +#include #include #include +#include -#include "gloo/common/error.h" #include "gloo/common/logging.h" #include "gloo/common/utils.h" #include "gloo/transport/tcp/device.h" @@ -22,6 +24,8 @@ namespace gloo { namespace transport { namespace tcp { +constexpr int kDefaultBatchSize = 128; + Context::Context(std::shared_ptr device, int rank, int size) : ::gloo::transport::Context(rank, size), device_(std::move(device)) {} @@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) { // which does not have the rank info hosted at a higher `Pair` level). // So better safe than sorry for now we try to minimize the changeset needed. const auto& currentRankPair = getPair(rank); - auto deviceAddress = Address( + const auto& deviceAddress = Address( static_cast(currentRankPair.get())->address().getSockaddr()); Rank currentRankInfo( localHostName, deviceAddress.bytes(), std::move(pairIdentifiers)); store.set(std::to_string(rank), currentRankInfo.bytes()); + std::vector> remoteRankInfos; + int key = 0; + if (isStoreExtendedApiEnabled() && store.has_v2_support()) { + auto sizeRemaining = size; + while (sizeRemaining > 0) { + const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining); + std::vector keys(batchKeys); + std::generate_n( + keys.begin(), batchKeys, [&] { return std::to_string(key++); }); + const auto& batchRemoteInfos = store.multi_get(keys); + remoteRankInfos.insert( + remoteRankInfos.end(), + batchRemoteInfos.begin(), + batchRemoteInfos.end()); + sizeRemaining -= batchKeys; + } + } else { + std::generate_n(std::back_inserter(remoteRankInfos), size, [&] { + const auto& keyStr = std::to_string(key++); + store.wait({keyStr.c_str()}, getTimeout()); + return store.get(keyStr); + }); + } + // Connect every pair for (int i = 0; i < size; i++) { if (i == rank) { @@ -95,16 +123,9 @@ void Context::createAndConnectAllPairs(IStore& store) { continue; } - // Wait for address of other side of this pair to become available - std::ostringstream key; - key << i; - store.wait({key.str()}, getTimeout()); + Rank remoteRankInfo(remoteRankInfos[i]); - // Connect to other side of this pair - std::vector rankInfoBytes = store.get(key.str()); - Rank remoteRankInfo(rankInfoBytes); - const auto& remoteHostname = remoteRankInfo.hostname; - if (!localRankSet && remoteHostname == localHostName) { + if (!localRankSet && remoteRankInfo.hostname == localHostName) { ++localRank; } @@ -112,7 +133,8 @@ void Context::createAndConnectAllPairs(IStore& store) { auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr(); auto remoteAddr = Address( remoteDeviceAddr, - useRankAsSeqNum ? (ssize_t)rank : remoteRankInfo.pairIdentifiers[rank]); + useRankAsSeqNum ? (sequence_number_t)rank + : remoteRankInfo.pairIdentifiers[rank]); pair->connect(remoteAddr.bytes()); } diff --git a/gloo/transport/tcp/listener.cc b/gloo/transport/tcp/listener.cc index 038f376ea..5e514c0e6 100644 --- a/gloo/transport/tcp/listener.cc +++ b/gloo/transport/tcp/listener.cc @@ -15,6 +15,7 @@ #include #include #include +# namespace gloo { namespace transport {