From 54cb0b5d4314ec282f3d813348a1069599f832b6 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Mon, 14 May 2018 15:50:19 -0700 Subject: [PATCH 1/5] C10D: Added TCPStore to support C10D store interface --- torch/lib/c10d/CMakeLists.txt | 2 +- torch/lib/c10d/TcpStore.cpp | 274 +++++++++++++++++++++++ torch/lib/c10d/TcpStore.hpp | 75 +++++++ torch/lib/c10d/Utils.cpp | 282 ++++++++++++++++++++++++ torch/lib/c10d/Utils.hpp | 192 ++++++++++++++++ torch/lib/c10d/test/CMakeLists.txt | 2 + torch/lib/c10d/test/FileStoreTest.cpp | 75 ++----- torch/lib/c10d/test/StoreTestCommon.hpp | 54 +++++ torch/lib/c10d/test/TcpStoreTest.cpp | 63 ++++++ 9 files changed, 961 insertions(+), 58 deletions(-) create mode 100644 torch/lib/c10d/TcpStore.cpp create mode 100644 torch/lib/c10d/TcpStore.hpp create mode 100644 torch/lib/c10d/Utils.cpp create mode 100644 torch/lib/c10d/Utils.hpp create mode 100644 torch/lib/c10d/test/StoreTestCommon.hpp create mode 100644 torch/lib/c10d/test/TcpStoreTest.cpp diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index e8f762ff4b44c..b7d577b625679 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.2 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH}) -add_library(store Store.cpp FileStore.cpp) +add_library(store Utils.cpp Store.cpp FileStore.cpp TcpStore.cpp) target_compile_options(store PUBLIC "-std=c++11") enable_testing() diff --git a/torch/lib/c10d/TcpStore.cpp b/torch/lib/c10d/TcpStore.cpp new file mode 100644 index 0000000000000..91c0ffb1241a9 --- /dev/null +++ b/torch/lib/c10d/TcpStore.cpp @@ -0,0 +1,274 @@ +#include "TcpStore.hpp" + +#include +#include +#include + + +namespace c10d { + + +namespace { + +enum class QueryType : std::uint8_t { + SET, + GET, + ADD, + CHECK, + STOP_WAITING, + KEEP_WAITING +}; + +} // anonymous namespace + + +// TcpStoreDaemon class methods + +// Simply start the daemon thread +TcpStoreDaemon::TcpStoreDaemon(int storeListenSocket) : + storeListenSocket_(storeListenSocket) +{ + daemonThread_ = std::thread(&TcpStoreDaemon::run, this); +} + +TcpStoreDaemon::~TcpStoreDaemon() { + for (auto socket : sockets_) { + if (socket != -1) { + ::close(socket); + } + } + // Join the thread + join(); +} + +void TcpStoreDaemon::join() { + daemonThread_.join(); +} + +void TcpStoreDaemon::run() { + + std::vector fds; + fds.push_back({ .fd = storeListenSocket_, .events = POLLIN }); + + // receive the queries + bool finished = false; + while (!finished) { + for (size_t i = 0; i < sockets_.size(); i++) { + fds[i].revents = 0; + } + + SYSCHECK(::poll(fds.data(), fds.size(), -1)); + + if (fds[0].revents != 0) { + if (fds[0].revents ^ POLLIN) { + throw std::system_error(ECONNABORTED, std::system_category()); + } + int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); + sockets_.push_back(sockFd); + keysAwaited_.push_back(0); + fds.push_back({ .fd = sockFd, .events = POLLIN }); + } + for (size_t rank = 0; rank < sockets_.size(); rank++) { + if (fds[rank + 1].revents == 0) { + continue; + } + + if (fds[rank + 1].revents ^ POLLIN) { + throw std::system_error(ECONNABORTED, std::system_category()); + } + try { + query(rank); + } catch (std::exception& ex) { + /** + * There was an error when processing query. Probably an exception + * occurred in recv/send what would indicate that socket on the other + * side has been closed. If the closing was due to normal exit, then the + * store should exit too. Otherwise, if it was different exception, + * other processes will get an exception once they try to use the store. + */ + finished = true; + break; + } + } + } +} + +void TcpStoreDaemon::wakeUpWaitingRanks(const std::string& key) { + auto toWake = waiting_.find(key); + if (toWake != waiting_.end()) { + for (int proc : toWake->second) { + if (--keysAwaited_[proc] == 0) { + tcputil::sendValue(sockets_[proc], + QueryType::STOP_WAITING); + } + } + waiting_.erase(toWake); + } +} + +/** + * query communicates with the worker. The format + * of the query is as follows: + * type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... + * or, in the case of wait + * type of query | number of args | size of arg1 | arg1 | ... + */ +void TcpStoreDaemon::query(RankType rank) { + + int socket = sockets_[rank]; + QueryType qt; + tcputil::recvBytes(socket, &qt, 1); + + if (qt == QueryType::SET) { + std::string key = tcputil::recvString(socket); + tcpStore_[key] = tcputil::recvVector(socket); + // On "set", wake up all of the processes that wait + // for keys already in the store + wakeUpWaitingRanks(key); + + } else if (qt == QueryType::ADD) { + std::string key = tcputil::recvString(socket); + int64_t addVal = tcputil::recvValue(socket); + + if (tcpStore_.find(key) != tcpStore_.end()) { + auto buf = reinterpret_cast(tcpStore_[key].data()); + auto len = tcpStore_[key].size(); + addVal += std::stoll(std::string(buf, len)); + } + auto addValStr = std::to_string(addVal); + tcpStore_[key] = std::vector(addValStr.begin(), addValStr.end()); + // Now send the new value + tcputil::sendValue(socket, addVal); + // On "add", wake up all of the processes that wait + // for keys already in the store + wakeUpWaitingRanks(key); + + } else if (qt == QueryType::GET) { + std::string key = tcputil::recvString(socket); + auto data = tcpStore_.at(key); + tcputil::sendVector(socket, data); + + } else if (qt == QueryType::CHECK) { + SizeType nargs; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + for (size_t i = 0; i < nargs; i++) { + keys[i] = tcputil::recvString(socket); + } + // Now we have received all the keys + if (checkAndUpdate(keys)) { + tcputil::sendValue(socket, QueryType::STOP_WAITING); + } else { + for (auto& key : keys) { + waiting_[key].push_back(rank); + } + keysAwaited_[rank] = keys.size(); + tcputil::sendValue(socket, QueryType::KEEP_WAITING); + } + } else { + throw std::runtime_error("expected a query type"); + } +} + +bool TcpStoreDaemon::checkAndUpdate(std::vector& keys) const { + bool ret = true; + for (auto it = keys.begin(); it != keys.end();) { + if (tcpStore_.count(*it) == 0) { + ret = false; + it++; + } else { + it = keys.erase(it); + } + } + return ret; +} + +// TcpStore class methods + +TcpStore::TcpStore(const std::string& masterAddr, + PortType masterPort, + bool isServer) + : isServer_(isServer) + , tcpStoreAddr_(masterAddr) + , tcpStorePort_(masterPort) + +{ + if (isServer_) { + // Openning up the listening socket + std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort); + // Now start the daemon + tcpStoreDaemon_ = std::unique_ptr( + new TcpStoreDaemon(masterListenSocket_) + ); + } + // Connect to the daemon + storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_); +} + +TcpStore::~TcpStore() { + ::close(storeSocket_); + if (isServer_) { + ::close(masterListenSocket_); + /** + * Store daemon should end because of closed connection. + * daemon destructor should join the thread + */ + tcpStoreDaemon_.reset(nullptr); + } +} + +void TcpStore::set(const std::string& key, const std::vector& data) { + tcputil::sendValue(storeSocket_, QueryType::SET); + tcputil::sendString(storeSocket_, key, true); + tcputil::sendVector(storeSocket_, data); +} + +std::vector TcpStore::get(const std::string& key) { + wait({key}); + tcputil::sendValue(storeSocket_, QueryType::GET); + tcputil::sendString(storeSocket_, key); + return tcputil::recvVector(storeSocket_); +} + +int64_t TcpStore::add(const std::string& key, int64_t value) { + tcputil::sendValue(storeSocket_, QueryType::ADD); + tcputil::sendString(storeSocket_, key, true); + tcputil::sendValue(storeSocket_, value); + return tcputil::recvValue(storeSocket_); +} + +bool TcpStore::check(const std::vector& keys) { + + tcputil::sendValue(storeSocket_, QueryType::CHECK); + SizeType nkeys = keys.size(); + tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); + for (size_t i = 0; i < nkeys; i++) { + tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); + } + auto checkResponse = tcputil::recvValue(storeSocket_); + if (checkResponse == QueryType::STOP_WAITING) { + return true; + } else if (checkResponse == QueryType::KEEP_WAITING) { + return false; + } else { + throw std::runtime_error("stop_waiting or keep_waiting response expected"); + } +} + +void TcpStore::wait( + const std::vector& keys, + const std::chrono::milliseconds& timeout) { + + const auto start = std::chrono::steady_clock::now(); + while (!check(keys)) { + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout != kNoTimeout && elapsed > timeout) { + throw std::runtime_error("Wait timeout"); + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } +} + +} // namespace c10d diff --git a/torch/lib/c10d/TcpStore.hpp b/torch/lib/c10d/TcpStore.hpp new file mode 100644 index 0000000000000..9e9d6e34b9a4f --- /dev/null +++ b/torch/lib/c10d/TcpStore.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include "Store.hpp" +#include "Utils.hpp" + +#include +#include +#include + + +namespace c10d { + +class TcpStoreDaemon { + + public: + + explicit TcpStoreDaemon(int storeListenSocket); + ~TcpStoreDaemon(); + + void join(); + + protected: + + void run(); + void query(RankType rank); + bool checkAndUpdate(std::vector& keys) const; + void wakeUpWaitingRanks(const std::string& key); + + std::thread daemonThread_; + std::unordered_map> tcpStore_; + std::unordered_map> waiting_; + std::vector keysAwaited_; + std::vector sockets_; + + int storeListenSocket_; +}; + +class TcpStore : public Store { + + public: + + explicit TcpStore(const std::string& masterAddr, + PortType masterPort, + bool isServer = false); + + virtual ~TcpStore(); + + void set( + const std::string& key, + const std::vector& value) override; + + std::vector get(const std::string& key) override; + + int64_t add(const std::string& key, int64_t value) override; + + bool check(const std::vector& keys) override; + + void wait( + const std::vector& keys, + const std::chrono::milliseconds& timeout = kDefaultTimeout) override; + + protected: + + bool isServer_; + int storeSocket_ = -1; + int masterListenSocket_ = -1; + + std::string tcpStoreAddr_; + PortType tcpStorePort_; + + // Only needs to be launched on master rank + std::unique_ptr tcpStoreDaemon_ = nullptr; +}; + +} // namespace c10d diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp new file mode 100644 index 0000000000000..56c1bc0c3bd16 --- /dev/null +++ b/torch/lib/c10d/Utils.cpp @@ -0,0 +1,282 @@ +#include "Utils.hpp" + +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace c10d { +namespace tcputil { + + +namespace { + +constexpr int LISTEN_QUEUE_SIZE = 64; + +void setSocketNoDelay(int socket) { + int flag = 1; + socklen_t optlen = sizeof(flag); + SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen)); +} + +PortType getSocketPort(int fd) { + PortType listenPort; + struct ::sockaddr_storage addrStorage; + socklen_t addrLen = sizeof(addrStorage); + SYSCHECK(getsockname(fd, + reinterpret_cast(&addrStorage), &addrLen)); + + if (addrStorage.ss_family == AF_INET) { + struct ::sockaddr_in *addr = + reinterpret_cast(&addrStorage); + listenPort = ntohs(addr->sin_port); + + } else if (addrStorage.ss_family == AF_INET6) { // AF_INET6 + struct ::sockaddr_in6 *addr = + reinterpret_cast(&addrStorage); + listenPort = ntohs(addr->sin6_port); + + } else { + throw std::runtime_error("unsupported protocol"); + } + return listenPort; +} + +} + +std::string sockaddrToString(struct ::sockaddr *addr) { + char address[INET6_ADDRSTRLEN + 1]; + if (addr->sa_family == AF_INET) { + struct ::sockaddr_in *s = reinterpret_cast(addr); + SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN)) + address[INET_ADDRSTRLEN] = '\0'; + } else if (addr->sa_family == AF_INET6) { + struct ::sockaddr_in6 *s = reinterpret_cast(addr); + SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN)) + address[INET6_ADDRSTRLEN] = '\0'; + } else { + throw std::runtime_error("unsupported protocol"); + } + return address; +} + +// listen, connect and accept +std::pair listen(PortType port) { + struct ::addrinfo hints, *res = NULL; + + std::memset(&hints, 0x00, sizeof(hints)); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP + + // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked + // by editing `/etc/gai.conf`. so there is no need to manual sorting + // or protocol preference. + int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res); + if (err != 0 || !res) { + throw std::invalid_argument("cannot find host to listen on: " + + std::string(gai_strerror(err))); + } + + std::shared_ptr addresses(res, [](struct ::addrinfo* p) { + ::freeaddrinfo(p); + }); + + struct ::addrinfo *nextAddr = addresses.get(); + int socket; + while (true) { + try { + SYSCHECK(socket = ::socket(nextAddr->ai_family, + nextAddr->ai_socktype, + nextAddr->ai_protocol)) + + int optval = 1; + SYSCHECK(::setsockopt(socket, + SOL_SOCKET, + SO_REUSEADDR, + &optval, + sizeof(int))) + + SYSCHECK(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) + SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE)) + break; + + } catch (const std::system_error& e) { + ::close(socket); + nextAddr = nextAddr->ai_next; + + /** + * we have tried all addresses but could not start + * listening on any of them + */ + if (!nextAddr) { + throw; + } + } + } + + // get listen port and address + return {socket, getSocketPort(socket)}; +} + +int connect(const std::string& address, + PortType port, + bool wait, + int timeout) { + + struct ::addrinfo hints, *res = NULL; + + std::memset(&hints, 0x00, sizeof(hints)); + hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric + hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP + + // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked + // by editing `/etc/gai.conf`. so there is no need to manual sorting + // or protcol preference. + int err = ::getaddrinfo(address.data(), + std::to_string(port).data(), + &hints, + &res); + if (err != 0 || !res) { + throw std::invalid_argument("host not found: " + + std::string(gai_strerror(err))); + } + + std::shared_ptr addresses(res, [](struct ::addrinfo* p) { + ::freeaddrinfo(p); + }); + + struct ::addrinfo *nextAddr = addresses.get(); + int socket; + // we'll loop over the addresses only if at least of them gave us ECONNREFUSED + // Maybe the host was up, but the server wasn't running. + bool anyRefused = false; + while (true) { + try { + SYSCHECK(socket = ::socket(nextAddr->ai_family, + nextAddr->ai_socktype, + nextAddr->ai_protocol)) + + ResourceGuard socketGuard([socket]() { ::close(socket); }); + + // We need to connect in non-blocking mode, so we can use a timeout + SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK)); + + int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); + + if (ret != 0 && errno != EINPROGRESS) { + throw std::system_error(errno, std::system_category()); + } + + struct ::pollfd pfd; + pfd.fd = socket; + pfd.events = POLLOUT; + + int numReady = ::poll(&pfd, 1, timeout); + if (numReady < 0) { + throw std::system_error(errno, std::system_category()); + } else if (numReady == 0) { + errno = 0; + throw std::runtime_error("connect() timed out"); + } + + socklen_t errLen = sizeof(errno); + errno = 0; + ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen); + /** + * `errno` is set when: + * 1. `getsockopt` has failed + * 2. there is awaiting error in the socket + * (the error is saved to the `errno` variable) + */ + if (errno != 0) { + throw std::system_error(errno, std::system_category()); + } + + // Disable non-blocking mode + int flags; + SYSCHECK(flags = ::fcntl(socket, F_GETFL)); + SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); + socketGuard.release(); + break; + + } catch (std::exception& e) { + if (errno == ECONNREFUSED) { + anyRefused = true; + } + + // We need to move to the next address because this was not available + // to connect or to create a socket. + nextAddr = nextAddr->ai_next; + + // We have tried all addresses but could not connect to any of them. + if (!nextAddr) { + if (!wait || !anyRefused) { + throw; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + anyRefused = false; + nextAddr = addresses.get(); + } + } + } + + setSocketNoDelay(socket); + + return socket; +} + +std::tuple accept(int listenSocket, int timeout) { + + // poll on listen socket, it allows to make timeout + std::unique_ptr events(new struct ::pollfd[1]); + events[0] = {.fd = listenSocket, .events = POLLIN}; + + while (true) { + int res = ::poll(events.get(), 1, timeout); + if (res == 0) { + throw std::runtime_error("waiting for processes to " + "connect has timed out"); + } else if (res == -1) { + if (errno == EINTR) { + continue; + } + throw std::system_error(errno, std::system_category()); + } else { + if (!(events[0].revents & POLLIN)) + throw std::system_error(ECONNABORTED, std::system_category()); + break; + } + } + + int socket; + SYSCHECK(socket = ::accept(listenSocket, NULL, NULL)) + + // Get address of the connecting process + struct ::sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + SYSCHECK(::getpeername(socket, + reinterpret_cast(&addr), + &addrLen)) + + setSocketNoDelay(socket); + + return std::make_tuple( + socket, + sockaddrToString(reinterpret_cast(&addr))); +} + +} // namespace tcputil +} // namespace c10d diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp new file mode 100644 index 0000000000000..77905b783a4a3 --- /dev/null +++ b/torch/lib/c10d/Utils.hpp @@ -0,0 +1,192 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + + +using RankType = uint32_t; +using PortType = uint16_t; +using SizeType = uint64_t; + +#define SYSCHECK(expr) { \ + errno = 0; auto ___output = (expr); (void)___output; \ + if (errno != 0) throw std::system_error(errno, std::system_category()); \ +} + +inline PortType convertToPort(int64_t port) { + if ((port < 0) || (port >= std::numeric_limits::max())) + throw std::domain_error("invalid port (value out of range)"); + + return static_cast(port); +} + +inline RankType convertToRank(int64_t rank, int64_t min = 0) { + if ((rank < min) || (rank >= std::numeric_limits::max())) + throw std::domain_error("invalid rank (value out of range)"); + + return static_cast(rank); +} + +// TCP util namespace +namespace tcputil { + + +// Send and receive +template +void sendBytes(int socket, + const T* buffer, + size_t length, + bool moreData = false) { + + size_t bytesToSend = sizeof(T) * length; + if (bytesToSend == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t* currentBytes = const_cast(bytes); + + int flags = 0; + +#ifdef MSG_MORE + if (moreData) { // there is more data to send + flags |= MSG_MORE; + } +#endif + + while (bytesToSend > 0) { + ssize_t bytesSent; + SYSCHECK(bytesSent = ::send(socket, currentBytes, bytesToSend, flags)) + if (bytesSent == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToSend -= bytesSent; + currentBytes += bytesSent; + } +} + +template +void recvBytes(int socket, T* buffer, std::size_t length) { + + size_t bytesToReceive = sizeof(T) * length; + if (bytesToReceive == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t *currentBytes = bytes; + + while (bytesToReceive > 0) { + ssize_t bytesReceived; + SYSCHECK(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0)) + if (bytesReceived == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToReceive -= bytesReceived; + currentBytes += bytesReceived; + } +} + +// send a vector's length and data +template +void sendVector(int socket, + const std::vector& vec, + bool moreData = false) { + + SizeType size = vec.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, vec.data(), size, moreData); +} + +// receive a vector as sent in sendVector +template +std::vector recvVector(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return value; +} + +// this is only for convenience when sending rvalues +template +void sendValue(int socket, const T& value, bool moreData = false) { + sendBytes(socket, &value, 1, moreData); +} + +template +T recvValue(int socket) { + T value; + recvBytes(socket, &value, 1); + return value; +} + +// send a string's length and data +inline void sendString(int socket, + const std::string& str, + bool moreData = false) { + + SizeType size = str.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, str.data(), size, moreData); +} + +// receive a string as sent in sendString +inline std::string recvString(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return std::string(value.data(), value.size()); +} + +// Other helpers +std::string sockaddrToString(struct sockaddr *addr); + +std::pair listen(PortType port); + +int connect(const std::string& address, + PortType port, + bool wait = true, + int timeout = -1); + +std::tuple accept(int listenSocket, int timeout = -1); + +// Helper resource guard class +class ResourceGuard { + + public: + ResourceGuard(std::function destructor) + : destructor_(std::move(destructor)) + , released_(false) {} + + ~ResourceGuard() { + if (!released_) { + destructor_(); + } + } + + void release() { + released_ = true; + } + + private: + std::function destructor_; + bool released_; +}; + +} // namespace tcputil +} // namespace c10d diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 65c56f51fbaac..400b606c48944 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -1,5 +1,6 @@ set(test_srcs FileStoreTest.cpp + TcpStoreTest.cpp ) set(test_libraries @@ -13,4 +14,5 @@ foreach(test_src ${test_srcs}) target_include_directories(${test_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) target_link_libraries(${test_name} ${test_libraries}) add_test(NAME ${test_name} COMMAND $) + endforeach() diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index 64f98009924af..7f2d520db8e87 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -1,49 +1,10 @@ -#include +#include "StoreTestCommon.hpp" +#include "FileStore.hpp" -#include +#include #include -#include -#include #include -#include "FileStore.hpp" - -using namespace c10d; - -class Semaphore { - public: - void post(int n = 1) { - std::unique_lock lock(m_); - n_ += n; - cv_.notify_all(); - } - - void wait(int n = 1) { - std::unique_lock lock(m_); - while (n_ < n) { - cv_.wait(lock); - } - n_ -= n; - } - - protected: - int n_ = 0; - std::mutex m_; - std::condition_variable cv_; -}; - -void set(Store& store, const std::string& key, const std::string& value) { - std::vector data(value.begin(), value.end()); - store.set(key, data); -} - -void check(Store& store, const std::string& key, const std::string& expected) { - auto tmp = store.get(key); - auto actual = std::string((const char*) tmp.data(), tmp.size()); - if (actual != expected) { - throw std::runtime_error("Expected " + expected + ", got " + actual); - } -} std::string tmppath() { const char* tmpdir = getenv("TMPDIR"); @@ -71,29 +32,29 @@ int main(int argc, char** argv) { // Basic set/get { - FileStore store(path); - set(store, "key0", "value0"); - set(store, "key1", "value1"); - set(store, "key2", "value2"); - check(store, "key0", "value0"); - check(store, "key1", "value1"); - check(store, "key2", "value2"); + c10d::FileStore store(path); + c10d::test::set(store, "key0", "value0"); + c10d::test::set(store, "key1", "value1"); + c10d::test::set(store, "key2", "value2"); + c10d::test::check(store, "key0", "value0"); + c10d::test::check(store, "key1", "value1"); + c10d::test::check(store, "key2", "value2"); } // Perform get on new instance { - FileStore store(path); - check(store, "key0", "value0"); + c10d::FileStore store(path); + c10d::test::check(store, "key0", "value0"); } // Hammer on FileStore#add std::vector threads; const auto numThreads = 4; const auto numIterations = 100; - Semaphore sem1, sem2; + c10d::test::Semaphore sem1, sem2; for (auto i = 0; i < numThreads; i++) { threads.push_back(std::move(std::thread([&] { - FileStore store(path); + c10d::FileStore store(path); sem1.post(); sem2.wait(); for (auto j = 0; j < numIterations; j++) { @@ -109,12 +70,12 @@ int main(int argc, char** argv) { // Check that the counter has the expected value { - FileStore store(path); - std::stringstream ss; - ss << (numThreads * numIterations); - check(store, "counter", ss.str()); + c10d::FileStore store(path); + std::string expected = std::to_string(numThreads * numIterations); + c10d::test::check(store, "counter", expected); } unlink(path.c_str()); + std::cout << "Test succeeded" << std::endl; return 0; } diff --git a/torch/lib/c10d/test/StoreTestCommon.hpp b/torch/lib/c10d/test/StoreTestCommon.hpp new file mode 100644 index 0000000000000..cf2d7a90a5a04 --- /dev/null +++ b/torch/lib/c10d/test/StoreTestCommon.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "Store.hpp" + +#include +#include +#include +#include + +namespace c10d { +namespace test { + +class Semaphore { + public: + void post(int n = 1) { + std::unique_lock lock(m_); + n_ += n; + cv_.notify_all(); + } + + void wait(int n = 1) { + std::unique_lock lock(m_); + while (n_ < n) { + cv_.wait(lock); + } + n_ -= n; + } + + protected: + int n_ = 0; + std::mutex m_; + std::condition_variable cv_; +}; + + +inline void set(Store& store, + const std::string& key, + const std::string& value) { + std::vector data(value.begin(), value.end()); + store.set(key, data); +} + +inline void check(Store& store, + const std::string& key, + const std::string& expected) { + auto tmp = store.get(key); + auto actual = std::string((const char*) tmp.data(), tmp.size()); + if (actual != expected) { + throw std::runtime_error("Expected " + expected + ", got " + actual); + } +} + +} // namespace test +} // namespace c10d diff --git a/torch/lib/c10d/test/TcpStoreTest.cpp b/torch/lib/c10d/test/TcpStoreTest.cpp new file mode 100644 index 0000000000000..9d76e21a523c5 --- /dev/null +++ b/torch/lib/c10d/test/TcpStoreTest.cpp @@ -0,0 +1,63 @@ +#include "StoreTestCommon.hpp" +#include "TcpStore.hpp" + +#include +#include +#include + + +int main(int argc, char** argv) { + + // Master store + c10d::TcpStore masterStore("127.0.0.1", 29500, true); + + // Basic set/get on the master store + c10d::test::set(masterStore, "key0", "value0"); + c10d::test::set(masterStore, "key1", "value1"); + c10d::test::set(masterStore, "key2", "value2"); + c10d::test::check(masterStore, "key0", "value0"); + c10d::test::check(masterStore, "key1", "value1"); + c10d::test::check(masterStore, "key2", "value2"); + + // Hammer on TcpStore + std::vector threads; + const auto numThreads = 16; + const auto numIterations = 1000; + c10d::test::Semaphore sem1, sem2; + + // Each thread will have a slave store to send/recv data + std::vector> slaveStores; + for (auto i = 0; i < numThreads; i++) { + slaveStores.push_back(std::unique_ptr(new + c10d::TcpStore("127.0.0.1", 29500, false))); + } + + for (auto i = 0; i < numThreads; i++) { + threads.push_back(std::move(std::thread([&sem1, &sem2, &slaveStores, i] { + sem1.post(); + sem2.wait(); + for (auto j = 0; j < numIterations; j++) { + slaveStores[i]->add("counter", 1); + } + // Let each thread set and get key on its slave store + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + std::to_string(i); + for (auto j = 0; j < numIterations; j++) { + c10d::test::set(*slaveStores[i], key, val); + c10d::test::check(*slaveStores[i], key, val); + } + }))); + } + sem1.wait(numThreads); + sem2.post(numThreads); + for (auto& thread : threads) { + thread.join(); + } + + // Check that the counter has the expected value + std::string expected = std::to_string(numThreads * numIterations); + c10d::test::check(masterStore, "counter", expected); + + std::cout << "Test succeeded" << std::endl; + return EXIT_SUCCESS; +} From 50e1362121ebf7c31269eb9b3d5f4e1ec40dd481 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Tue, 15 May 2018 18:38:12 -0700 Subject: [PATCH 2/5] Used pipe to terminate the store daemon and addressed all comments Please see: https://github.com/pytorch/pytorch/issues/7434 --- torch/lib/c10d/TcpStore.cpp | 248 +++++++++++------- torch/lib/c10d/TcpStore.hpp | 33 ++- torch/lib/c10d/Utils.cpp | 11 +- torch/lib/c10d/Utils.hpp | 10 +- torch/lib/c10d/test/CMakeLists.txt | 3 +- torch/lib/c10d/test/FileStoreTest.cpp | 1 - .../{TcpStoreTest.cpp => TCPStoreTest.cpp} | 40 ++- 7 files changed, 215 insertions(+), 131 deletions(-) rename torch/lib/c10d/test/{TcpStoreTest.cpp => TCPStoreTest.cpp} (60%) diff --git a/torch/lib/c10d/TcpStore.cpp b/torch/lib/c10d/TcpStore.cpp index 91c0ffb1241a9..ff8045f0d833a 100644 --- a/torch/lib/c10d/TcpStore.cpp +++ b/torch/lib/c10d/TcpStore.cpp @@ -1,54 +1,74 @@ #include "TcpStore.hpp" #include -#include #include - +#include namespace c10d { - namespace { enum class QueryType : std::uint8_t { SET, GET, ADD, - CHECK, - STOP_WAITING, - KEEP_WAITING + CHECK }; -} // anonymous namespace - +enum class CheckResponseType : std::uint8_t { + READY, + NOT_READY +}; -// TcpStoreDaemon class methods +} // anonymous namespace +// TCPStoreDaemon class methods // Simply start the daemon thread -TcpStoreDaemon::TcpStoreDaemon(int storeListenSocket) : +TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) : storeListenSocket_(storeListenSocket) { - daemonThread_ = std::thread(&TcpStoreDaemon::run, this); + daemonThread_ = std::thread(&TCPStoreDaemon::run, this); } -TcpStoreDaemon::~TcpStoreDaemon() { +TCPStoreDaemon::~TCPStoreDaemon() { + + // Stop the run + stop(); + + // Join the thread + join(); + + // Close unclosed sockets for (auto socket : sockets_) { if (socket != -1) { ::close(socket); } } - // Join the thread - join(); + // Now close the rest control pipe + for (auto fd : controlPipeFd_) { + if (fd != -1) { + ::close(fd); + } + } } -void TcpStoreDaemon::join() { +void TCPStoreDaemon::join() { daemonThread_.join(); } -void TcpStoreDaemon::run() { +void TCPStoreDaemon::run() { + + // Create the control pipe + controlPipeFd_ = std::vector{-1, -1}; + if (pipe(controlPipeFd_.data()) == -1) { + throw std::runtime_error("Failed to create the control pipe to start the " + "TCPStoreDaemon run"); + } std::vector fds; fds.push_back({ .fd = storeListenSocket_, .events = POLLIN }); + // Push the read end of the pipe to signal the stopping of the daemon run + fds.push_back({ .fd = controlPipeFd_[0], .events = POLLHUP }); // receive the queries bool finished = false; @@ -59,50 +79,77 @@ void TcpStoreDaemon::run() { SYSCHECK(::poll(fds.data(), fds.size(), -1)); + /** + * TCPStore's listening socket has an event and it should now be able to + * accept new connections. + */ if (fds[0].revents != 0) { if (fds[0].revents ^ POLLIN) { - throw std::system_error(ECONNABORTED, std::system_category()); + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); } int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); sockets_.push_back(sockFd); - keysAwaited_.push_back(0); fds.push_back({ .fd = sockFd, .events = POLLIN }); } - for (size_t rank = 0; rank < sockets_.size(); rank++) { - if (fds[rank + 1].revents == 0) { + /** + * The pipe receives an event which tells us to shutdown the daemon + */ + if (fds[1].revents != 0) { + // Will be POLLUP when the pipe is closed + if (fds[1].revents ^ POLLHUP) { + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[1].revents)); + } + finished = true; + break; + } + /** + * Skipping the fds[0] and fds[1], + * fds[0] is master's listening socket + * fds[1] is control pipe's reading fd + */ + for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) { + if (fds[fdIdx].revents == 0) { continue; } - if (fds[rank + 1].revents ^ POLLIN) { - throw std::system_error(ECONNABORTED, std::system_category()); + if (fds[fdIdx].revents ^ POLLIN) { + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent: " + + std::to_string(fds[fdIdx].revents) + " on socket: " + + std::to_string(fds[fdIdx].fd)); } + // Now query the socket that has the event try { - query(rank); - } catch (std::exception& ex) { + query(fds[fdIdx].fd); + } catch (...) { /** * There was an error when processing query. Probably an exception * occurred in recv/send what would indicate that socket on the other * side has been closed. If the closing was due to normal exit, then the - * store should exit too. Otherwise, if it was different exception, - * other processes will get an exception once they try to use the store. + * store should continue executing. Otherwise, if it was different + * exception, other connections will get an exception once they try to + * use the store. We will go ahead and close this connection whenever + * we hit an exception here. */ - finished = true; - break; + ::close(fds[fdIdx].fd); + fds.erase(fds.begin() + fdIdx); + sockets_.erase(sockets_.begin() + fdIdx - 2); + --fdIdx; + continue; } } } } -void TcpStoreDaemon::wakeUpWaitingRanks(const std::string& key) { - auto toWake = waiting_.find(key); - if (toWake != waiting_.end()) { - for (int proc : toWake->second) { - if (--keysAwaited_[proc] == 0) { - tcputil::sendValue(sockets_[proc], - QueryType::STOP_WAITING); - } - } - waiting_.erase(toWake); +void TCPStoreDaemon::stop() { + if (controlPipeFd_.size() == 2 && controlPipeFd_[1] != -1) { + // close the write end of the pipe + ::close(controlPipeFd_[1]); + controlPipeFd_[1] = -1; } } @@ -113,64 +160,70 @@ void TcpStoreDaemon::wakeUpWaitingRanks(const std::string& key) { * or, in the case of wait * type of query | number of args | size of arg1 | arg1 | ... */ -void TcpStoreDaemon::query(RankType rank) { +void TCPStoreDaemon::query(int socket) { - int socket = sockets_[rank]; QueryType qt; tcputil::recvBytes(socket, &qt, 1); if (qt == QueryType::SET) { - std::string key = tcputil::recvString(socket); - tcpStore_[key] = tcputil::recvVector(socket); - // On "set", wake up all of the processes that wait - // for keys already in the store - wakeUpWaitingRanks(key); + setHandler(socket); } else if (qt == QueryType::ADD) { - std::string key = tcputil::recvString(socket); - int64_t addVal = tcputil::recvValue(socket); - - if (tcpStore_.find(key) != tcpStore_.end()) { - auto buf = reinterpret_cast(tcpStore_[key].data()); - auto len = tcpStore_[key].size(); - addVal += std::stoll(std::string(buf, len)); - } - auto addValStr = std::to_string(addVal); - tcpStore_[key] = std::vector(addValStr.begin(), addValStr.end()); - // Now send the new value - tcputil::sendValue(socket, addVal); - // On "add", wake up all of the processes that wait - // for keys already in the store - wakeUpWaitingRanks(key); + addHandler(socket); } else if (qt == QueryType::GET) { - std::string key = tcputil::recvString(socket); - auto data = tcpStore_.at(key); - tcputil::sendVector(socket, data); + getHandler(socket); } else if (qt == QueryType::CHECK) { - SizeType nargs; - tcputil::recvBytes(socket, &nargs, 1); - std::vector keys(nargs); - for (size_t i = 0; i < nargs; i++) { - keys[i] = tcputil::recvString(socket); - } - // Now we have received all the keys - if (checkAndUpdate(keys)) { - tcputil::sendValue(socket, QueryType::STOP_WAITING); - } else { - for (auto& key : keys) { - waiting_[key].push_back(rank); - } - keysAwaited_[rank] = keys.size(); - tcputil::sendValue(socket, QueryType::KEEP_WAITING); - } + checkHandler(socket); + + } else { + throw std::runtime_error("Unexpected query type"); + } +} + +void TCPStoreDaemon::setHandler(int socket) { + std::string key = tcputil::recvString(socket); + tcpStore_[key] = tcputil::recvVector(socket); +} + +void TCPStoreDaemon::addHandler(int socket) { + std::string key = tcputil::recvString(socket); + int64_t addVal = tcputil::recvValue(socket); + + if (tcpStore_.find(key) != tcpStore_.end()) { + auto buf = reinterpret_cast(tcpStore_[key].data()); + auto len = tcpStore_[key].size(); + addVal += std::stoll(std::string(buf, len)); + } + auto addValStr = std::to_string(addVal); + tcpStore_[key] = std::vector(addValStr.begin(), addValStr.end()); + // Now send the new value + tcputil::sendValue(socket, addVal); +} + +void TCPStoreDaemon::getHandler(int socket) { + std::string key = tcputil::recvString(socket); + auto data = tcpStore_.at(key); + tcputil::sendVector(socket, data); +} + +void TCPStoreDaemon::checkHandler(int socket) { + SizeType nargs; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + for (size_t i = 0; i < nargs; i++) { + keys[i] = tcputil::recvString(socket); + } + // Now we have received all the keys + if (checkAndUpdate(keys)) { + tcputil::sendValue(socket, CheckResponseType::READY); } else { - throw std::runtime_error("expected a query type"); + tcputil::sendValue(socket, CheckResponseType::NOT_READY); } } -bool TcpStoreDaemon::checkAndUpdate(std::vector& keys) const { +bool TCPStoreDaemon::checkAndUpdate(std::vector& keys) const { bool ret = true; for (auto it = keys.begin(); it != keys.end();) { if (tcpStore_.count(*it) == 0) { @@ -183,9 +236,8 @@ bool TcpStoreDaemon::checkAndUpdate(std::vector& keys) const { return ret; } -// TcpStore class methods - -TcpStore::TcpStore(const std::string& masterAddr, +// TCPStore class methods +TCPStore::TCPStore(const std::string& masterAddr, PortType masterPort, bool isServer) : isServer_(isServer) @@ -194,50 +246,50 @@ TcpStore::TcpStore(const std::string& masterAddr, { if (isServer_) { - // Openning up the listening socket + // Opening up the listening socket std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort); // Now start the daemon - tcpStoreDaemon_ = std::unique_ptr( - new TcpStoreDaemon(masterListenSocket_) + tcpStoreDaemon_ = std::unique_ptr( + new TCPStoreDaemon(masterListenSocket_) ); } // Connect to the daemon storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_); } -TcpStore::~TcpStore() { +TCPStore::~TCPStore() { ::close(storeSocket_); if (isServer_) { - ::close(masterListenSocket_); /** * Store daemon should end because of closed connection. * daemon destructor should join the thread */ tcpStoreDaemon_.reset(nullptr); + ::close(masterListenSocket_); } } -void TcpStore::set(const std::string& key, const std::vector& data) { +void TCPStore::set(const std::string& key, const std::vector& data) { tcputil::sendValue(storeSocket_, QueryType::SET); tcputil::sendString(storeSocket_, key, true); tcputil::sendVector(storeSocket_, data); } -std::vector TcpStore::get(const std::string& key) { +std::vector TCPStore::get(const std::string& key) { wait({key}); tcputil::sendValue(storeSocket_, QueryType::GET); tcputil::sendString(storeSocket_, key); return tcputil::recvVector(storeSocket_); } -int64_t TcpStore::add(const std::string& key, int64_t value) { +int64_t TCPStore::add(const std::string& key, int64_t value) { tcputil::sendValue(storeSocket_, QueryType::ADD); tcputil::sendString(storeSocket_, key, true); tcputil::sendValue(storeSocket_, value); return tcputil::recvValue(storeSocket_); } -bool TcpStore::check(const std::vector& keys) { +bool TCPStore::check(const std::vector& keys) { tcputil::sendValue(storeSocket_, QueryType::CHECK); SizeType nkeys = keys.size(); @@ -245,17 +297,17 @@ bool TcpStore::check(const std::vector& keys) { for (size_t i = 0; i < nkeys; i++) { tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); } - auto checkResponse = tcputil::recvValue(storeSocket_); - if (checkResponse == QueryType::STOP_WAITING) { + auto checkResponse = tcputil::recvValue(storeSocket_); + if (checkResponse == CheckResponseType::READY) { return true; - } else if (checkResponse == QueryType::KEEP_WAITING) { + } else if (checkResponse == CheckResponseType::NOT_READY) { return false; } else { - throw std::runtime_error("stop_waiting or keep_waiting response expected"); + throw std::runtime_error("ready or not_ready response expected"); } } -void TcpStore::wait( +void TCPStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { diff --git a/torch/lib/c10d/TcpStore.hpp b/torch/lib/c10d/TcpStore.hpp index 9e9d6e34b9a4f..ac4bcf6c1ffbe 100644 --- a/torch/lib/c10d/TcpStore.hpp +++ b/torch/lib/c10d/TcpStore.hpp @@ -7,43 +7,48 @@ #include #include - namespace c10d { -class TcpStoreDaemon { +class TCPStoreDaemon { public: - explicit TcpStoreDaemon(int storeListenSocket); - ~TcpStoreDaemon(); + explicit TCPStoreDaemon(int storeListenSocket); + ~TCPStoreDaemon(); void join(); protected: void run(); - void query(RankType rank); + void stop(); + + void query(int socket); + + void setHandler(int socket); + void addHandler(int socket); + void getHandler(int socket); + void checkHandler(int socket); + bool checkAndUpdate(std::vector& keys) const; - void wakeUpWaitingRanks(const std::string& key); std::thread daemonThread_; std::unordered_map> tcpStore_; - std::unordered_map> waiting_; - std::vector keysAwaited_; - std::vector sockets_; + std::vector sockets_; int storeListenSocket_; + std::vector controlPipeFd_; }; -class TcpStore : public Store { +class TCPStore : public Store { public: - explicit TcpStore(const std::string& masterAddr, + explicit TCPStore(const std::string& masterAddr, PortType masterPort, bool isServer = false); - virtual ~TcpStore(); + virtual ~TCPStore(); void set( const std::string& key, @@ -68,8 +73,8 @@ class TcpStore : public Store { std::string tcpStoreAddr_; PortType tcpStorePort_; - // Only needs to be launched on master rank - std::unique_ptr tcpStoreDaemon_ = nullptr; + // Only needs to be launched as the server + std::unique_ptr tcpStoreDaemon_ = nullptr; }; } // namespace c10d diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index 56c1bc0c3bd16..6a999887620ed 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -19,7 +19,6 @@ namespace c10d { namespace tcputil { - namespace { constexpr int LISTEN_QUEUE_SIZE = 64; @@ -133,7 +132,7 @@ std::pair listen(PortType port) { int connect(const std::string& address, PortType port, bool wait, - int timeout) { + const std::chrono::milliseconds& timeout) { struct ::addrinfo hints, *res = NULL; @@ -184,7 +183,7 @@ int connect(const std::string& address, pfd.fd = socket; pfd.events = POLLOUT; - int numReady = ::poll(&pfd, 1, timeout); + int numReady = ::poll(&pfd, 1, timeout.count()); if (numReady < 0) { throw std::system_error(errno, std::system_category()); } else if (numReady == 0) { @@ -238,14 +237,16 @@ int connect(const std::string& address, return socket; } -std::tuple accept(int listenSocket, int timeout) { +std::tuple accept( + int listenSocket, + const std::chrono::milliseconds& timeout) { // poll on listen socket, it allows to make timeout std::unique_ptr events(new struct ::pollfd[1]); events[0] = {.fd = listenSocket, .events = POLLIN}; while (true) { - int res = ::poll(events.get(), 1, timeout); + int res = ::poll(events.get(), 1, timeout.count()); if (res == 0) { throw std::runtime_error("waiting for processes to " "connect has timed out"); diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index 77905b783a4a3..c2639b17c45eb 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -11,10 +11,10 @@ #include #include #include +#include namespace c10d { - using RankType = uint32_t; using PortType = uint16_t; using SizeType = uint64_t; @@ -41,6 +41,8 @@ inline RankType convertToRank(int64_t rank, int64_t min = 0) { // TCP util namespace namespace tcputil { +constexpr std::chrono::milliseconds kNoTimeout = + std::chrono::milliseconds(-1); // Send and receive template @@ -161,9 +163,11 @@ std::pair listen(PortType port); int connect(const std::string& address, PortType port, bool wait = true, - int timeout = -1); + const std::chrono::milliseconds& timeout = kNoTimeout); -std::tuple accept(int listenSocket, int timeout = -1); +std::tuple +accept(int listenSocket, + const std::chrono::milliseconds& timeout = kNoTimeout); // Helper resource guard class class ResourceGuard { diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 400b606c48944..b19a00db888b2 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -1,6 +1,6 @@ set(test_srcs FileStoreTest.cpp - TcpStoreTest.cpp + TCPStoreTest.cpp ) set(test_libraries @@ -14,5 +14,4 @@ foreach(test_src ${test_srcs}) target_include_directories(${test_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) target_link_libraries(${test_name} ${test_libraries}) add_test(NAME ${test_name} COMMAND $) - endforeach() diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index 7f2d520db8e87..c75b9c0e86e9c 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -5,7 +5,6 @@ #include #include - std::string tmppath() { const char* tmpdir = getenv("TMPDIR"); if (tmpdir == nullptr) { diff --git a/torch/lib/c10d/test/TcpStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp similarity index 60% rename from torch/lib/c10d/test/TcpStoreTest.cpp rename to torch/lib/c10d/test/TCPStoreTest.cpp index 9d76e21a523c5..8ebcc3e3593f7 100644 --- a/torch/lib/c10d/test/TcpStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -9,7 +9,7 @@ int main(int argc, char** argv) { // Master store - c10d::TcpStore masterStore("127.0.0.1", 29500, true); + c10d::TCPStore masterStore("127.0.0.1", 29500, true); // Basic set/get on the master store c10d::test::set(masterStore, "key0", "value0"); @@ -19,45 +19,69 @@ int main(int argc, char** argv) { c10d::test::check(masterStore, "key1", "value1"); c10d::test::check(masterStore, "key2", "value2"); - // Hammer on TcpStore + // Hammer on TCPStore std::vector threads; const auto numThreads = 16; const auto numIterations = 1000; c10d::test::Semaphore sem1, sem2; // Each thread will have a slave store to send/recv data - std::vector> slaveStores; + std::vector> slaveStores; for (auto i = 0; i < numThreads; i++) { - slaveStores.push_back(std::unique_ptr(new - c10d::TcpStore("127.0.0.1", 29500, false))); + slaveStores.push_back(std::unique_ptr(new + c10d::TCPStore("127.0.0.1", 29500, false))); } for (auto i = 0; i < numThreads; i++) { threads.push_back(std::move(std::thread([&sem1, &sem2, &slaveStores, i] { - sem1.post(); - sem2.wait(); for (auto j = 0; j < numIterations; j++) { slaveStores[i]->add("counter", 1); } // Let each thread set and get key on its slave store std::string key = "thread_" + std::to_string(i); - std::string val = "thread_val_" + std::to_string(i); for (auto j = 0; j < numIterations; j++) { + std::string val = "thread_val_" + std::to_string(j); c10d::test::set(*slaveStores[i], key, val); c10d::test::check(*slaveStores[i], key, val); } + + sem1.post(); + sem2.wait(); + + // Now check other threads' written data + for (auto j = 0; j < numThreads; j++) { + if (j == i) { + continue; + } + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + + std::to_string(numIterations - 1); + c10d::test::check(*slaveStores[i], key, val); + } }))); } + sem1.wait(numThreads); sem2.post(numThreads); + for (auto& thread : threads) { thread.join(); } + // Clear the store to test that slave disconnect won't shutdown the store + slaveStores.clear(); + // Check that the counter has the expected value std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(masterStore, "counter", expected); + // Check that each threads' written data from the main thread + for (auto i = 0; i < numThreads; i++) { + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + std::to_string(numIterations - 1); + c10d::test::check(masterStore, key, val); + } + std::cout << "Test succeeded" << std::endl; return EXIT_SUCCESS; } From b9294ee1b6ca9c5095be7ce19586f7e812687868 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Wed, 16 May 2018 15:11:11 -0700 Subject: [PATCH 3/5] Used notify/wake for wait and addressed all comments Reference: https://github.com/pytorch/pytorch/issues/7434 --- torch/lib/c10d/CMakeLists.txt | 2 +- torch/lib/c10d/{TcpStore.cpp => TCPStore.cpp} | 168 ++++++++++-------- torch/lib/c10d/{TcpStore.hpp => TCPStore.hpp} | 14 +- torch/lib/c10d/Utils.cpp | 25 +-- torch/lib/c10d/Utils.hpp | 51 +++--- torch/lib/c10d/test/TCPStoreTest.cpp | 57 +++--- 6 files changed, 172 insertions(+), 145 deletions(-) rename torch/lib/c10d/{TcpStore.cpp => TCPStore.cpp} (64%) rename torch/lib/c10d/{TcpStore.hpp => TCPStore.hpp} (75%) diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index b7d577b625679..763131091becb 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.2 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH}) -add_library(store Utils.cpp Store.cpp FileStore.cpp TcpStore.cpp) +add_library(store Utils.cpp Store.cpp FileStore.cpp TCPStore.cpp) target_compile_options(store PUBLIC "-std=c++11") enable_testing() diff --git a/torch/lib/c10d/TcpStore.cpp b/torch/lib/c10d/TCPStore.cpp similarity index 64% rename from torch/lib/c10d/TcpStore.cpp rename to torch/lib/c10d/TCPStore.cpp index ff8045f0d833a..20525c0fe14b8 100644 --- a/torch/lib/c10d/TcpStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -1,25 +1,31 @@ -#include "TcpStore.hpp" +#include "TCPStore.hpp" #include #include #include +#include namespace c10d { namespace { -enum class QueryType : std::uint8_t { +enum class QueryType : uint8_t { SET, GET, ADD, - CHECK + CHECK, + WAIT }; -enum class CheckResponseType : std::uint8_t { +enum class CheckResponseType : uint8_t { READY, NOT_READY }; +enum class WaitResponseType : uint8_t { + STOP_WAITING +}; + } // anonymous namespace // TCPStoreDaemon class methods @@ -31,13 +37,10 @@ TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) : } TCPStoreDaemon::~TCPStoreDaemon() { - // Stop the run stop(); - // Join the thread join(); - // Close unclosed sockets for (auto socket : sockets_) { if (socket != -1) { @@ -57,9 +60,7 @@ void TCPStoreDaemon::join() { } void TCPStoreDaemon::run() { - // Create the control pipe - controlPipeFd_ = std::vector{-1, -1}; if (pipe(controlPipeFd_.data()) == -1) { throw std::runtime_error("Failed to create the control pipe to start the " "TCPStoreDaemon run"); @@ -79,10 +80,8 @@ void TCPStoreDaemon::run() { SYSCHECK(::poll(fds.data(), fds.size(), -1)); - /** - * TCPStore's listening socket has an event and it should now be able to - * accept new connections. - */ + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. if (fds[0].revents != 0) { if (fds[0].revents ^ POLLIN) { throw std::system_error(ECONNABORTED, std::system_category(), @@ -93,9 +92,7 @@ void TCPStoreDaemon::run() { sockets_.push_back(sockFd); fds.push_back({ .fd = sockFd, .events = POLLIN }); } - /** - * The pipe receives an event which tells us to shutdown the daemon - */ + // The pipe receives an event which tells us to shutdown the daemon if (fds[1].revents != 0) { // Will be POLLUP when the pipe is closed if (fds[1].revents ^ POLLHUP) { @@ -106,11 +103,9 @@ void TCPStoreDaemon::run() { finished = true; break; } - /** - * Skipping the fds[0] and fds[1], - * fds[0] is master's listening socket - * fds[1] is control pipe's reading fd - */ + // Skipping the fds[0] and fds[1], + // fds[0] is master's listening socket + // fds[1] is control pipe's reading fd for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) { if (fds[fdIdx].revents == 0) { continue; @@ -126,15 +121,13 @@ void TCPStoreDaemon::run() { try { query(fds[fdIdx].fd); } catch (...) { - /** - * There was an error when processing query. Probably an exception - * occurred in recv/send what would indicate that socket on the other - * side has been closed. If the closing was due to normal exit, then the - * store should continue executing. Otherwise, if it was different - * exception, other connections will get an exception once they try to - * use the store. We will go ahead and close this connection whenever - * we hit an exception here. - */ + // There was an error when processing query. Probably an exception + // occurred in recv/send what would indicate that socket on the other + // side has been closed. If the closing was due to normal exit, then + // the store should continue executing. Otherwise, if it was different + // exception, other connections will get an exception once they try to + // use the store. We will go ahead and close this connection whenever + // we hit an exception here. ::close(fds[fdIdx].fd); fds.erase(fds.begin() + fdIdx); sockets_.erase(sockets_.begin() + fdIdx - 2); @@ -146,22 +139,20 @@ void TCPStoreDaemon::run() { } void TCPStoreDaemon::stop() { - if (controlPipeFd_.size() == 2 && controlPipeFd_[1] != -1) { + if (controlPipeFd_[1] != -1) { // close the write end of the pipe ::close(controlPipeFd_[1]); controlPipeFd_[1] = -1; } } -/** - * query communicates with the worker. The format - * of the query is as follows: - * type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... - * or, in the case of wait - * type of query | number of args | size of arg1 | arg1 | ... - */ -void TCPStoreDaemon::query(int socket) { +// query communicates with the worker. The format +// of the query is as follows: +// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... +// or, in the case of wait +// type of query | number of args | size of arg1 | arg1 | ... +void TCPStoreDaemon::query(int socket) { QueryType qt; tcputil::recvBytes(socket, &qt, 1); @@ -177,14 +168,32 @@ void TCPStoreDaemon::query(int socket) { } else if (qt == QueryType::CHECK) { checkHandler(socket); + } else if (qt == QueryType::WAIT) { + waitHandler(socket); + } else { throw std::runtime_error("Unexpected query type"); } } +void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) { + auto socketsToWait = waitingSockets_.find(key); + if (socketsToWait != waitingSockets_.end()) { + for (int socket : socketsToWait->second) { + if (--keysAwaited_[socket] == 0) { + tcputil::sendValue(socket, + WaitResponseType::STOP_WAITING); + } + } + waitingSockets_.erase(socketsToWait); + } +} + void TCPStoreDaemon::setHandler(int socket) { std::string key = tcputil::recvString(socket); tcpStore_[key] = tcputil::recvVector(socket); + // On "set", wake up all clients that have been waiting + wakeupWaitingClients(key); } void TCPStoreDaemon::addHandler(int socket) { @@ -200,15 +209,17 @@ void TCPStoreDaemon::addHandler(int socket) { tcpStore_[key] = std::vector(addValStr.begin(), addValStr.end()); // Now send the new value tcputil::sendValue(socket, addVal); + // On "add", wake up all clients that have been waiting + wakeupWaitingClients(key); } -void TCPStoreDaemon::getHandler(int socket) { +void TCPStoreDaemon::getHandler(int socket) const { std::string key = tcputil::recvString(socket); auto data = tcpStore_.at(key); tcputil::sendVector(socket, data); } -void TCPStoreDaemon::checkHandler(int socket) { +void TCPStoreDaemon::checkHandler(int socket) const { SizeType nargs; tcputil::recvBytes(socket, &nargs, 1); std::vector keys(nargs); @@ -216,24 +227,37 @@ void TCPStoreDaemon::checkHandler(int socket) { keys[i] = tcputil::recvString(socket); } // Now we have received all the keys - if (checkAndUpdate(keys)) { + if (checkKeys(keys)) { tcputil::sendValue(socket, CheckResponseType::READY); } else { tcputil::sendValue(socket, CheckResponseType::NOT_READY); } } -bool TCPStoreDaemon::checkAndUpdate(std::vector& keys) const { - bool ret = true; - for (auto it = keys.begin(); it != keys.end();) { - if (tcpStore_.count(*it) == 0) { - ret = false; - it++; - } else { - it = keys.erase(it); +void TCPStoreDaemon::waitHandler(int socket) { + SizeType nargs; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + for (size_t i = 0; i < nargs; i++) { + keys[i] = tcputil::recvString(socket); + } + if (checkKeys(keys)) { + tcputil::sendValue(socket, + WaitResponseType::STOP_WAITING); + } else { + for (auto& key : keys) { + waitingSockets_[key].push_back(socket); } + keysAwaited_[socket] = keys.size(); } - return ret; +} + +bool TCPStoreDaemon:: +checkKeys(const std::vector& keys) const { + return std::all_of(keys.begin(), keys.end(), + [this](const std::string& s) { + return tcpStore_.count(s) > 0; + }); } // TCPStore class methods @@ -243,7 +267,6 @@ TCPStore::TCPStore(const std::string& masterAddr, : isServer_(isServer) , tcpStoreAddr_(masterAddr) , tcpStorePort_(masterPort) - { if (isServer_) { // Opening up the listening socket @@ -260,10 +283,8 @@ TCPStore::TCPStore(const std::string& masterAddr, TCPStore::~TCPStore() { ::close(storeSocket_); if (isServer_) { - /** - * Store daemon should end because of closed connection. - * daemon destructor should join the thread - */ + // Store daemon should end because of closed connection. + // daemon destructor should join the thread tcpStoreDaemon_.reset(nullptr); ::close(masterListenSocket_); } @@ -290,7 +311,6 @@ int64_t TCPStore::add(const std::string& key, int64_t value) { } bool TCPStore::check(const std::vector& keys) { - tcputil::sendValue(storeSocket_, QueryType::CHECK); SizeType nkeys = keys.size(); tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); @@ -307,19 +327,27 @@ bool TCPStore::check(const std::vector& keys) { } } -void TCPStore::wait( - const std::vector& keys, - const std::chrono::milliseconds& timeout) { - - const auto start = std::chrono::steady_clock::now(); - while (!check(keys)) { - const auto elapsed = std::chrono::duration_cast( - std::chrono::steady_clock::now() - start); - if (timeout != kNoTimeout && elapsed > timeout) { - throw std::runtime_error("Wait timeout"); - } - /* sleep override */ - std::this_thread::sleep_for(std::chrono::milliseconds(10)); +void TCPStore::wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) { + // Set the socket timeout if there is a wait timeout + if (timeout != kNoTimeout) { + struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000, + .tv_usec = (timeout.count() % 1000) * 1000}; + SYSCHECK(::setsockopt(storeSocket_, + SOL_SOCKET, + SO_RCVTIMEO, + reinterpret_cast(&timeoutTV), + sizeof(timeoutTV))); + } + tcputil::sendValue(storeSocket_, QueryType::WAIT); + SizeType nkeys = keys.size(); + tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); + for (size_t i = 0; i < nkeys; i++) { + tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); + } + auto waitResponse = tcputil::recvValue(storeSocket_); + if (waitResponse != WaitResponseType::STOP_WAITING) { + throw std::runtime_error("Stop_waiting response is expected"); } } diff --git a/torch/lib/c10d/TcpStore.hpp b/torch/lib/c10d/TCPStore.hpp similarity index 75% rename from torch/lib/c10d/TcpStore.hpp rename to torch/lib/c10d/TCPStore.hpp index ac4bcf6c1ffbe..00e0fe88c7826 100644 --- a/torch/lib/c10d/TcpStore.hpp +++ b/torch/lib/c10d/TCPStore.hpp @@ -27,17 +27,23 @@ class TCPStoreDaemon { void setHandler(int socket); void addHandler(int socket); - void getHandler(int socket); - void checkHandler(int socket); + void getHandler(int socket) const; + void checkHandler(int socket) const; + void waitHandler(int socket); - bool checkAndUpdate(std::vector& keys) const; + bool checkKeys(const std::vector& keys) const; + void wakeupWaitingClients(const std::string &key); std::thread daemonThread_; std::unordered_map> tcpStore_; + // From key -> the list of sockets waiting on it + std::unordered_map> waitingSockets_; + // From socket -> number of keys awaited + std::unordered_map keysAwaited_; std::vector sockets_; int storeListenSocket_; - std::vector controlPipeFd_; + std::vector controlPipeFd_ {-1, -1}; }; class TCPStore : public Store { diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index 6a999887620ed..33c1d8b8b783e 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -52,7 +52,7 @@ PortType getSocketPort(int fd) { return listenPort; } -} +} // namespace std::string sockaddrToString(struct ::sockaddr *addr) { char address[INET6_ADDRSTRLEN + 1]; @@ -73,14 +73,13 @@ std::string sockaddrToString(struct ::sockaddr *addr) { // listen, connect and accept std::pair listen(PortType port) { struct ::addrinfo hints, *res = NULL; - std::memset(&hints, 0x00, sizeof(hints)); hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; // TCP // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked - // by editing `/etc/gai.conf`. so there is no need to manual sorting + // by editing `/etc/gai.conf`. so there is no need to manual sorting // or protocol preference. int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res); if (err != 0 || !res) { @@ -115,10 +114,8 @@ std::pair listen(PortType port) { ::close(socket); nextAddr = nextAddr->ai_next; - /** - * we have tried all addresses but could not start - * listening on any of them - */ + // we have tried all addresses but could not start + // listening on any of them if (!nextAddr) { throw; } @@ -133,9 +130,7 @@ int connect(const std::string& address, PortType port, bool wait, const std::chrono::milliseconds& timeout) { - struct ::addrinfo hints, *res = NULL; - std::memset(&hints, 0x00, sizeof(hints)); hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 @@ -194,12 +189,11 @@ int connect(const std::string& address, socklen_t errLen = sizeof(errno); errno = 0; ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen); - /** - * `errno` is set when: - * 1. `getsockopt` has failed - * 2. there is awaiting error in the socket - * (the error is saved to the `errno` variable) - */ + + // `errno` is set when: + // 1. `getsockopt` has failed + // 2. there is awaiting error in the socket + // (the error is saved to the `errno` variable) if (errno != 0) { throw std::system_error(errno, std::system_category()); } @@ -240,7 +234,6 @@ int connect(const std::string& address, std::tuple accept( int listenSocket, const std::chrono::milliseconds& timeout) { - // poll on listen socket, it allows to make timeout std::unique_ptr events(new struct ::pollfd[1]); events[0] = {.fd = listenSocket, .events = POLLIN}; diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index c2639b17c45eb..53ea35fa24314 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -38,7 +38,27 @@ inline RankType convertToRank(int64_t rank, int64_t min = 0) { return static_cast(rank); } -// TCP util namespace +// Helper resource guard class +class ResourceGuard { + public: + ResourceGuard(std::function destructor) + : destructor_(std::move(destructor)) + , released_(false) {} + + ~ResourceGuard() { + if (!released_) { + destructor_(); + } + } + + void release() { + released_ = true; + } + private: + std::function destructor_; + bool released_; +}; + namespace tcputil { constexpr std::chrono::milliseconds kNoTimeout = @@ -50,7 +70,6 @@ void sendBytes(int socket, const T* buffer, size_t length, bool moreData = false) { - size_t bytesToSend = sizeof(T) * length; if (bytesToSend == 0) { return; @@ -80,8 +99,7 @@ void sendBytes(int socket, } template -void recvBytes(int socket, T* buffer, std::size_t length) { - +void recvBytes(int socket, T* buffer, size_t length) { size_t bytesToReceive = sizeof(T) * length; if (bytesToReceive == 0) { return; @@ -107,7 +125,6 @@ template void sendVector(int socket, const std::vector& vec, bool moreData = false) { - SizeType size = vec.size(); sendBytes(socket, &size, 1, true); sendBytes(socket, vec.data(), size, moreData); @@ -140,7 +157,6 @@ T recvValue(int socket) { inline void sendString(int socket, const std::string& str, bool moreData = false) { - SizeType size = str.size(); sendBytes(socket, &size, 1, true); sendBytes(socket, str.data(), size, moreData); @@ -169,28 +185,5 @@ std::tuple accept(int listenSocket, const std::chrono::milliseconds& timeout = kNoTimeout); -// Helper resource guard class -class ResourceGuard { - - public: - ResourceGuard(std::function destructor) - : destructor_(std::move(destructor)) - , released_(false) {} - - ~ResourceGuard() { - if (!released_) { - destructor_(); - } - } - - void release() { - released_ = true; - } - - private: - std::function destructor_; - bool released_; -}; - } // namespace tcputil } // namespace c10d diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 8ebcc3e3593f7..26d815fcac053 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -1,5 +1,5 @@ #include "StoreTestCommon.hpp" -#include "TcpStore.hpp" +#include "TCPStore.hpp" #include #include @@ -8,16 +8,16 @@ int main(int argc, char** argv) { - // Master store - c10d::TCPStore masterStore("127.0.0.1", 29500, true); + // server store + c10d::TCPStore serverStore("127.0.0.1", 29500, true); - // Basic set/get on the master store - c10d::test::set(masterStore, "key0", "value0"); - c10d::test::set(masterStore, "key1", "value1"); - c10d::test::set(masterStore, "key2", "value2"); - c10d::test::check(masterStore, "key0", "value0"); - c10d::test::check(masterStore, "key1", "value1"); - c10d::test::check(masterStore, "key2", "value2"); + // Basic set/get on the server store + c10d::test::set(serverStore, "key0", "value0"); + c10d::test::set(serverStore, "key1", "value1"); + c10d::test::set(serverStore, "key2", "value2"); + c10d::test::check(serverStore, "key0", "value0"); + c10d::test::check(serverStore, "key1", "value1"); + c10d::test::check(serverStore, "key2", "value2"); // Hammer on TCPStore std::vector threads; @@ -25,29 +25,36 @@ int main(int argc, char** argv) { const auto numIterations = 1000; c10d::test::Semaphore sem1, sem2; - // Each thread will have a slave store to send/recv data - std::vector> slaveStores; + // Each thread will have a client store to send/recv data + std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - slaveStores.push_back(std::unique_ptr(new + clientStores.push_back(std::unique_ptr(new c10d::TCPStore("127.0.0.1", 29500, false))); } + std::string expectedCounterRes = std::to_string(numThreads * numIterations); + for (auto i = 0; i < numThreads; i++) { - threads.push_back(std::move(std::thread([&sem1, &sem2, &slaveStores, i] { + threads.push_back(std::move(std::thread([&sem1, + &sem2, + &clientStores, + i, + &expectedCounterRes] { for (auto j = 0; j < numIterations; j++) { - slaveStores[i]->add("counter", 1); + clientStores[i]->add("counter", 1); } - // Let each thread set and get key on its slave store + // Let each thread set and get key on its client store std::string key = "thread_" + std::to_string(i); for (auto j = 0; j < numIterations; j++) { std::string val = "thread_val_" + std::to_string(j); - c10d::test::set(*slaveStores[i], key, val); - c10d::test::check(*slaveStores[i], key, val); + c10d::test::set(*clientStores[i], key, val); + c10d::test::check(*clientStores[i], key, val); } sem1.post(); sem2.wait(); - + // Check the counter results + c10d::test::check(*clientStores[i], "counter", expectedCounterRes); // Now check other threads' written data for (auto j = 0; j < numThreads; j++) { if (j == i) { @@ -56,8 +63,9 @@ int main(int argc, char** argv) { std::string key = "thread_" + std::to_string(i); std::string val = "thread_val_" + std::to_string(numIterations - 1); - c10d::test::check(*slaveStores[i], key, val); + c10d::test::check(*clientStores[i], key, val); } + }))); } @@ -68,18 +76,17 @@ int main(int argc, char** argv) { thread.join(); } - // Clear the store to test that slave disconnect won't shutdown the store - slaveStores.clear(); + // Clear the store to test that client disconnect won't shutdown the store + clientStores.clear(); // Check that the counter has the expected value - std::string expected = std::to_string(numThreads * numIterations); - c10d::test::check(masterStore, "counter", expected); + c10d::test::check(serverStore, "counter", expectedCounterRes); // Check that each threads' written data from the main thread for (auto i = 0; i < numThreads; i++) { std::string key = "thread_" + std::to_string(i); std::string val = "thread_val_" + std::to_string(numIterations - 1); - c10d::test::check(masterStore, key, val); + c10d::test::check(serverStore, key, val); } std::cout << "Test succeeded" << std::endl; From 346b6c575c98256c6c7265ab001f494a0be3ec58 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Wed, 16 May 2018 16:57:45 -0700 Subject: [PATCH 4/5] Clean up nits --- torch/lib/c10d/TCPStore.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index 20525c0fe14b8..7557ca3dc0dbe 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -146,7 +146,6 @@ void TCPStoreDaemon::stop() { } } - // query communicates with the worker. The format // of the query is as follows: // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... @@ -252,8 +251,7 @@ void TCPStoreDaemon::waitHandler(int socket) { } } -bool TCPStoreDaemon:: -checkKeys(const std::vector& keys) const { +bool TCPStoreDaemon::checkKeys(const std::vector& keys) const { return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) { return tcpStore_.count(s) > 0; From ba48f051925b6a9d49113e5d8ca43ebd3f551993 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Thu, 17 May 2018 10:58:43 -0700 Subject: [PATCH 5/5] Clean up all socket states when the socket is closed Reference: https://github.com/pytorch/pytorch/issues/7434 --- torch/lib/c10d/TCPStore.cpp | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index 7557ca3dc0dbe..370cb832758ea 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -129,6 +129,29 @@ void TCPStoreDaemon::run() { // use the store. We will go ahead and close this connection whenever // we hit an exception here. ::close(fds[fdIdx].fd); + + // Remove all the tracking state of the close FD + for (auto it = waitingSockets_.begin(); it != waitingSockets_.end(); ) { + for (auto vecIt = it->second.begin(); vecIt != it->second.end(); ) { + if (*vecIt == fds[fdIdx].fd) { + vecIt = it->second.erase(vecIt); + } else { + ++vecIt; + } + } + if (it->second.size() == 0) { + it = waitingSockets_.erase(it); + } else { + ++it; + } + } + for (auto it = keysAwaited_.begin(); it != keysAwaited_.end(); ) { + if (it->first == fds[fdIdx].fd) { + it = keysAwaited_.erase(it); + } else { + ++it; + } + } fds.erase(fds.begin() + fdIdx); sockets_.erase(sockets_.begin() + fdIdx - 2); --fdIdx;