From fbad869a43afc14b85f33361fb9b91c528f7a3a5 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Thu, 17 May 2018 13:38:06 -0700 Subject: [PATCH] C10D: Added TCPStore to support C10D store interface (#7560) Reference: https://github.com/pytorch/pytorch/issues/7434 * C10D: Added TCPStore to support C10D store interface * Used pipe to terminate the store daemon and addressed all comments * Used notify/wake for wait and addressed all comments * Clean up nits * Clean up all socket states when the socket is closed --- torch/lib/c10d/CMakeLists.txt | 2 +- torch/lib/c10d/TCPStore.cpp | 375 ++++++++++++++++++++++++ torch/lib/c10d/TCPStore.hpp | 86 ++++++ torch/lib/c10d/Utils.cpp | 276 +++++++++++++++++ torch/lib/c10d/Utils.hpp | 189 ++++++++++++ torch/lib/c10d/test/CMakeLists.txt | 1 + torch/lib/c10d/test/FileStoreTest.cpp | 76 ++--- torch/lib/c10d/test/StoreTestCommon.hpp | 54 ++++ torch/lib/c10d/test/TCPStoreTest.cpp | 94 ++++++ 9 files changed, 1094 insertions(+), 59 deletions(-) create mode 100644 torch/lib/c10d/TCPStore.cpp create mode 100644 torch/lib/c10d/TCPStore.hpp create mode 100644 torch/lib/c10d/Utils.cpp create mode 100644 torch/lib/c10d/Utils.hpp create mode 100644 torch/lib/c10d/test/StoreTestCommon.hpp create mode 100644 torch/lib/c10d/test/TCPStoreTest.cpp diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index e8f762ff4b44..763131091bec 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.2 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH}) -add_library(store Store.cpp FileStore.cpp) +add_library(store Utils.cpp Store.cpp FileStore.cpp TCPStore.cpp) target_compile_options(store PUBLIC "-std=c++11") enable_testing() diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp new file mode 100644 index 000000000000..370cb832758e --- /dev/null +++ b/torch/lib/c10d/TCPStore.cpp @@ -0,0 +1,375 @@ +#include "TCPStore.hpp" + +#include +#include +#include +#include + +namespace c10d { + +namespace { + +enum class QueryType : uint8_t { + SET, + GET, + ADD, + CHECK, + WAIT +}; + +enum class CheckResponseType : uint8_t { + READY, + NOT_READY +}; + +enum class WaitResponseType : uint8_t { + STOP_WAITING +}; + +} // anonymous namespace + +// TCPStoreDaemon class methods +// Simply start the daemon thread +TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) : + storeListenSocket_(storeListenSocket) +{ + daemonThread_ = std::thread(&TCPStoreDaemon::run, this); +} + +TCPStoreDaemon::~TCPStoreDaemon() { + // Stop the run + stop(); + // Join the thread + join(); + // Close unclosed sockets + for (auto socket : sockets_) { + if (socket != -1) { + ::close(socket); + } + } + // Now close the rest control pipe + for (auto fd : controlPipeFd_) { + if (fd != -1) { + ::close(fd); + } + } +} + +void TCPStoreDaemon::join() { + daemonThread_.join(); +} + +void TCPStoreDaemon::run() { + // Create the control pipe + if (pipe(controlPipeFd_.data()) == -1) { + throw std::runtime_error("Failed to create the control pipe to start the " + "TCPStoreDaemon run"); + } + + std::vector fds; + fds.push_back({ .fd = storeListenSocket_, .events = POLLIN }); + // Push the read end of the pipe to signal the stopping of the daemon run + fds.push_back({ .fd = controlPipeFd_[0], .events = POLLHUP }); + + // receive the queries + bool finished = false; + while (!finished) { + for (size_t i = 0; i < sockets_.size(); i++) { + fds[i].revents = 0; + } + + SYSCHECK(::poll(fds.data(), fds.size(), -1)); + + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (fds[0].revents ^ POLLIN) { + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); + sockets_.push_back(sockFd); + fds.push_back({ .fd = sockFd, .events = POLLIN }); + } + // The pipe receives an event which tells us to shutdown the daemon + if (fds[1].revents != 0) { + // Will be POLLUP when the pipe is closed + if (fds[1].revents ^ POLLHUP) { + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[1].revents)); + } + finished = true; + break; + } + // Skipping the fds[0] and fds[1], + // fds[0] is master's listening socket + // fds[1] is control pipe's reading fd + for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) { + if (fds[fdIdx].revents == 0) { + continue; + } + + if (fds[fdIdx].revents ^ POLLIN) { + throw std::system_error(ECONNABORTED, std::system_category(), + "Unexpected poll revent: " + + std::to_string(fds[fdIdx].revents) + " on socket: " + + std::to_string(fds[fdIdx].fd)); + } + // Now query the socket that has the event + try { + query(fds[fdIdx].fd); + } catch (...) { + // There was an error when processing query. Probably an exception + // occurred in recv/send what would indicate that socket on the other + // side has been closed. If the closing was due to normal exit, then + // the store should continue executing. Otherwise, if it was different + // exception, other connections will get an exception once they try to + // use the store. We will go ahead and close this connection whenever + // we hit an exception here. + ::close(fds[fdIdx].fd); + + // Remove all the tracking state of the close FD + for (auto it = waitingSockets_.begin(); it != waitingSockets_.end(); ) { + for (auto vecIt = it->second.begin(); vecIt != it->second.end(); ) { + if (*vecIt == fds[fdIdx].fd) { + vecIt = it->second.erase(vecIt); + } else { + ++vecIt; + } + } + if (it->second.size() == 0) { + it = waitingSockets_.erase(it); + } else { + ++it; + } + } + for (auto it = keysAwaited_.begin(); it != keysAwaited_.end(); ) { + if (it->first == fds[fdIdx].fd) { + it = keysAwaited_.erase(it); + } else { + ++it; + } + } + fds.erase(fds.begin() + fdIdx); + sockets_.erase(sockets_.begin() + fdIdx - 2); + --fdIdx; + continue; + } + } + } +} + +void TCPStoreDaemon::stop() { + if (controlPipeFd_[1] != -1) { + // close the write end of the pipe + ::close(controlPipeFd_[1]); + controlPipeFd_[1] = -1; + } +} + +// query communicates with the worker. The format +// of the query is as follows: +// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... +// or, in the case of wait +// type of query | number of args | size of arg1 | arg1 | ... +void TCPStoreDaemon::query(int socket) { + QueryType qt; + tcputil::recvBytes(socket, &qt, 1); + + if (qt == QueryType::SET) { + setHandler(socket); + + } else if (qt == QueryType::ADD) { + addHandler(socket); + + } else if (qt == QueryType::GET) { + getHandler(socket); + + } else if (qt == QueryType::CHECK) { + checkHandler(socket); + + } else if (qt == QueryType::WAIT) { + waitHandler(socket); + + } else { + throw std::runtime_error("Unexpected query type"); + } +} + +void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) { + auto socketsToWait = waitingSockets_.find(key); + if (socketsToWait != waitingSockets_.end()) { + for (int socket : socketsToWait->second) { + if (--keysAwaited_[socket] == 0) { + tcputil::sendValue(socket, + WaitResponseType::STOP_WAITING); + } + } + waitingSockets_.erase(socketsToWait); + } +} + +void TCPStoreDaemon::setHandler(int socket) { + std::string key = tcputil::recvString(socket); + tcpStore_[key] = tcputil::recvVector(socket); + // On "set", wake up all clients that have been waiting + wakeupWaitingClients(key); +} + +void TCPStoreDaemon::addHandler(int socket) { + std::string key = tcputil::recvString(socket); + int64_t addVal = tcputil::recvValue(socket); + + if (tcpStore_.find(key) != tcpStore_.end()) { + auto buf = reinterpret_cast(tcpStore_[key].data()); + auto len = tcpStore_[key].size(); + addVal += std::stoll(std::string(buf, len)); + } + auto addValStr = std::to_string(addVal); + tcpStore_[key] = std::vector(addValStr.begin(), addValStr.end()); + // Now send the new value + tcputil::sendValue(socket, addVal); + // On "add", wake up all clients that have been waiting + wakeupWaitingClients(key); +} + +void TCPStoreDaemon::getHandler(int socket) const { + std::string key = tcputil::recvString(socket); + auto data = tcpStore_.at(key); + tcputil::sendVector(socket, data); +} + +void TCPStoreDaemon::checkHandler(int socket) const { + SizeType nargs; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + for (size_t i = 0; i < nargs; i++) { + keys[i] = tcputil::recvString(socket); + } + // Now we have received all the keys + if (checkKeys(keys)) { + tcputil::sendValue(socket, CheckResponseType::READY); + } else { + tcputil::sendValue(socket, CheckResponseType::NOT_READY); + } +} + +void TCPStoreDaemon::waitHandler(int socket) { + SizeType nargs; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + for (size_t i = 0; i < nargs; i++) { + keys[i] = tcputil::recvString(socket); + } + if (checkKeys(keys)) { + tcputil::sendValue(socket, + WaitResponseType::STOP_WAITING); + } else { + for (auto& key : keys) { + waitingSockets_[key].push_back(socket); + } + keysAwaited_[socket] = keys.size(); + } +} + +bool TCPStoreDaemon::checkKeys(const std::vector& keys) const { + return std::all_of(keys.begin(), keys.end(), + [this](const std::string& s) { + return tcpStore_.count(s) > 0; + }); +} + +// TCPStore class methods +TCPStore::TCPStore(const std::string& masterAddr, + PortType masterPort, + bool isServer) + : isServer_(isServer) + , tcpStoreAddr_(masterAddr) + , tcpStorePort_(masterPort) +{ + if (isServer_) { + // Opening up the listening socket + std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort); + // Now start the daemon + tcpStoreDaemon_ = std::unique_ptr( + new TCPStoreDaemon(masterListenSocket_) + ); + } + // Connect to the daemon + storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_); +} + +TCPStore::~TCPStore() { + ::close(storeSocket_); + if (isServer_) { + // Store daemon should end because of closed connection. + // daemon destructor should join the thread + tcpStoreDaemon_.reset(nullptr); + ::close(masterListenSocket_); + } +} + +void TCPStore::set(const std::string& key, const std::vector& data) { + tcputil::sendValue(storeSocket_, QueryType::SET); + tcputil::sendString(storeSocket_, key, true); + tcputil::sendVector(storeSocket_, data); +} + +std::vector TCPStore::get(const std::string& key) { + wait({key}); + tcputil::sendValue(storeSocket_, QueryType::GET); + tcputil::sendString(storeSocket_, key); + return tcputil::recvVector(storeSocket_); +} + +int64_t TCPStore::add(const std::string& key, int64_t value) { + tcputil::sendValue(storeSocket_, QueryType::ADD); + tcputil::sendString(storeSocket_, key, true); + tcputil::sendValue(storeSocket_, value); + return tcputil::recvValue(storeSocket_); +} + +bool TCPStore::check(const std::vector& keys) { + tcputil::sendValue(storeSocket_, QueryType::CHECK); + SizeType nkeys = keys.size(); + tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); + for (size_t i = 0; i < nkeys; i++) { + tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); + } + auto checkResponse = tcputil::recvValue(storeSocket_); + if (checkResponse == CheckResponseType::READY) { + return true; + } else if (checkResponse == CheckResponseType::NOT_READY) { + return false; + } else { + throw std::runtime_error("ready or not_ready response expected"); + } +} + +void TCPStore::wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) { + // Set the socket timeout if there is a wait timeout + if (timeout != kNoTimeout) { + struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000, + .tv_usec = (timeout.count() % 1000) * 1000}; + SYSCHECK(::setsockopt(storeSocket_, + SOL_SOCKET, + SO_RCVTIMEO, + reinterpret_cast(&timeoutTV), + sizeof(timeoutTV))); + } + tcputil::sendValue(storeSocket_, QueryType::WAIT); + SizeType nkeys = keys.size(); + tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); + for (size_t i = 0; i < nkeys; i++) { + tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); + } + auto waitResponse = tcputil::recvValue(storeSocket_); + if (waitResponse != WaitResponseType::STOP_WAITING) { + throw std::runtime_error("Stop_waiting response is expected"); + } +} + +} // namespace c10d diff --git a/torch/lib/c10d/TCPStore.hpp b/torch/lib/c10d/TCPStore.hpp new file mode 100644 index 000000000000..00e0fe88c782 --- /dev/null +++ b/torch/lib/c10d/TCPStore.hpp @@ -0,0 +1,86 @@ +#pragma once + +#include "Store.hpp" +#include "Utils.hpp" + +#include +#include +#include + +namespace c10d { + +class TCPStoreDaemon { + + public: + + explicit TCPStoreDaemon(int storeListenSocket); + ~TCPStoreDaemon(); + + void join(); + + protected: + + void run(); + void stop(); + + void query(int socket); + + void setHandler(int socket); + void addHandler(int socket); + void getHandler(int socket) const; + void checkHandler(int socket) const; + void waitHandler(int socket); + + bool checkKeys(const std::vector& keys) const; + void wakeupWaitingClients(const std::string &key); + + std::thread daemonThread_; + std::unordered_map> tcpStore_; + // From key -> the list of sockets waiting on it + std::unordered_map> waitingSockets_; + // From socket -> number of keys awaited + std::unordered_map keysAwaited_; + + std::vector sockets_; + int storeListenSocket_; + std::vector controlPipeFd_ {-1, -1}; +}; + +class TCPStore : public Store { + + public: + + explicit TCPStore(const std::string& masterAddr, + PortType masterPort, + bool isServer = false); + + virtual ~TCPStore(); + + void set( + const std::string& key, + const std::vector& value) override; + + std::vector get(const std::string& key) override; + + int64_t add(const std::string& key, int64_t value) override; + + bool check(const std::vector& keys) override; + + void wait( + const std::vector& keys, + const std::chrono::milliseconds& timeout = kDefaultTimeout) override; + + protected: + + bool isServer_; + int storeSocket_ = -1; + int masterListenSocket_ = -1; + + std::string tcpStoreAddr_; + PortType tcpStorePort_; + + // Only needs to be launched as the server + std::unique_ptr tcpStoreDaemon_ = nullptr; +}; + +} // namespace c10d diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp new file mode 100644 index 000000000000..33c1d8b8b783 --- /dev/null +++ b/torch/lib/c10d/Utils.cpp @@ -0,0 +1,276 @@ +#include "Utils.hpp" + +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace c10d { +namespace tcputil { + +namespace { + +constexpr int LISTEN_QUEUE_SIZE = 64; + +void setSocketNoDelay(int socket) { + int flag = 1; + socklen_t optlen = sizeof(flag); + SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen)); +} + +PortType getSocketPort(int fd) { + PortType listenPort; + struct ::sockaddr_storage addrStorage; + socklen_t addrLen = sizeof(addrStorage); + SYSCHECK(getsockname(fd, + reinterpret_cast(&addrStorage), &addrLen)); + + if (addrStorage.ss_family == AF_INET) { + struct ::sockaddr_in *addr = + reinterpret_cast(&addrStorage); + listenPort = ntohs(addr->sin_port); + + } else if (addrStorage.ss_family == AF_INET6) { // AF_INET6 + struct ::sockaddr_in6 *addr = + reinterpret_cast(&addrStorage); + listenPort = ntohs(addr->sin6_port); + + } else { + throw std::runtime_error("unsupported protocol"); + } + return listenPort; +} + +} // namespace + +std::string sockaddrToString(struct ::sockaddr *addr) { + char address[INET6_ADDRSTRLEN + 1]; + if (addr->sa_family == AF_INET) { + struct ::sockaddr_in *s = reinterpret_cast(addr); + SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN)) + address[INET_ADDRSTRLEN] = '\0'; + } else if (addr->sa_family == AF_INET6) { + struct ::sockaddr_in6 *s = reinterpret_cast(addr); + SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN)) + address[INET6_ADDRSTRLEN] = '\0'; + } else { + throw std::runtime_error("unsupported protocol"); + } + return address; +} + +// listen, connect and accept +std::pair listen(PortType port) { + struct ::addrinfo hints, *res = NULL; + std::memset(&hints, 0x00, sizeof(hints)); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP + + // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked + // by editing `/etc/gai.conf`. so there is no need to manual sorting + // or protocol preference. + int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res); + if (err != 0 || !res) { + throw std::invalid_argument("cannot find host to listen on: " + + std::string(gai_strerror(err))); + } + + std::shared_ptr addresses(res, [](struct ::addrinfo* p) { + ::freeaddrinfo(p); + }); + + struct ::addrinfo *nextAddr = addresses.get(); + int socket; + while (true) { + try { + SYSCHECK(socket = ::socket(nextAddr->ai_family, + nextAddr->ai_socktype, + nextAddr->ai_protocol)) + + int optval = 1; + SYSCHECK(::setsockopt(socket, + SOL_SOCKET, + SO_REUSEADDR, + &optval, + sizeof(int))) + + SYSCHECK(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) + SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE)) + break; + + } catch (const std::system_error& e) { + ::close(socket); + nextAddr = nextAddr->ai_next; + + // we have tried all addresses but could not start + // listening on any of them + if (!nextAddr) { + throw; + } + } + } + + // get listen port and address + return {socket, getSocketPort(socket)}; +} + +int connect(const std::string& address, + PortType port, + bool wait, + const std::chrono::milliseconds& timeout) { + struct ::addrinfo hints, *res = NULL; + std::memset(&hints, 0x00, sizeof(hints)); + hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric + hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP + + // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked + // by editing `/etc/gai.conf`. so there is no need to manual sorting + // or protcol preference. + int err = ::getaddrinfo(address.data(), + std::to_string(port).data(), + &hints, + &res); + if (err != 0 || !res) { + throw std::invalid_argument("host not found: " + + std::string(gai_strerror(err))); + } + + std::shared_ptr addresses(res, [](struct ::addrinfo* p) { + ::freeaddrinfo(p); + }); + + struct ::addrinfo *nextAddr = addresses.get(); + int socket; + // we'll loop over the addresses only if at least of them gave us ECONNREFUSED + // Maybe the host was up, but the server wasn't running. + bool anyRefused = false; + while (true) { + try { + SYSCHECK(socket = ::socket(nextAddr->ai_family, + nextAddr->ai_socktype, + nextAddr->ai_protocol)) + + ResourceGuard socketGuard([socket]() { ::close(socket); }); + + // We need to connect in non-blocking mode, so we can use a timeout + SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK)); + + int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); + + if (ret != 0 && errno != EINPROGRESS) { + throw std::system_error(errno, std::system_category()); + } + + struct ::pollfd pfd; + pfd.fd = socket; + pfd.events = POLLOUT; + + int numReady = ::poll(&pfd, 1, timeout.count()); + if (numReady < 0) { + throw std::system_error(errno, std::system_category()); + } else if (numReady == 0) { + errno = 0; + throw std::runtime_error("connect() timed out"); + } + + socklen_t errLen = sizeof(errno); + errno = 0; + ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen); + + // `errno` is set when: + // 1. `getsockopt` has failed + // 2. there is awaiting error in the socket + // (the error is saved to the `errno` variable) + if (errno != 0) { + throw std::system_error(errno, std::system_category()); + } + + // Disable non-blocking mode + int flags; + SYSCHECK(flags = ::fcntl(socket, F_GETFL)); + SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); + socketGuard.release(); + break; + + } catch (std::exception& e) { + if (errno == ECONNREFUSED) { + anyRefused = true; + } + + // We need to move to the next address because this was not available + // to connect or to create a socket. + nextAddr = nextAddr->ai_next; + + // We have tried all addresses but could not connect to any of them. + if (!nextAddr) { + if (!wait || !anyRefused) { + throw; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + anyRefused = false; + nextAddr = addresses.get(); + } + } + } + + setSocketNoDelay(socket); + + return socket; +} + +std::tuple accept( + int listenSocket, + const std::chrono::milliseconds& timeout) { + // poll on listen socket, it allows to make timeout + std::unique_ptr events(new struct ::pollfd[1]); + events[0] = {.fd = listenSocket, .events = POLLIN}; + + while (true) { + int res = ::poll(events.get(), 1, timeout.count()); + if (res == 0) { + throw std::runtime_error("waiting for processes to " + "connect has timed out"); + } else if (res == -1) { + if (errno == EINTR) { + continue; + } + throw std::system_error(errno, std::system_category()); + } else { + if (!(events[0].revents & POLLIN)) + throw std::system_error(ECONNABORTED, std::system_category()); + break; + } + } + + int socket; + SYSCHECK(socket = ::accept(listenSocket, NULL, NULL)) + + // Get address of the connecting process + struct ::sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + SYSCHECK(::getpeername(socket, + reinterpret_cast(&addr), + &addrLen)) + + setSocketNoDelay(socket); + + return std::make_tuple( + socket, + sockaddrToString(reinterpret_cast(&addr))); +} + +} // namespace tcputil +} // namespace c10d diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp new file mode 100644 index 000000000000..53ea35fa2431 --- /dev/null +++ b/torch/lib/c10d/Utils.hpp @@ -0,0 +1,189 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +using RankType = uint32_t; +using PortType = uint16_t; +using SizeType = uint64_t; + +#define SYSCHECK(expr) { \ + errno = 0; auto ___output = (expr); (void)___output; \ + if (errno != 0) throw std::system_error(errno, std::system_category()); \ +} + +inline PortType convertToPort(int64_t port) { + if ((port < 0) || (port >= std::numeric_limits::max())) + throw std::domain_error("invalid port (value out of range)"); + + return static_cast(port); +} + +inline RankType convertToRank(int64_t rank, int64_t min = 0) { + if ((rank < min) || (rank >= std::numeric_limits::max())) + throw std::domain_error("invalid rank (value out of range)"); + + return static_cast(rank); +} + +// Helper resource guard class +class ResourceGuard { + public: + ResourceGuard(std::function destructor) + : destructor_(std::move(destructor)) + , released_(false) {} + + ~ResourceGuard() { + if (!released_) { + destructor_(); + } + } + + void release() { + released_ = true; + } + private: + std::function destructor_; + bool released_; +}; + +namespace tcputil { + +constexpr std::chrono::milliseconds kNoTimeout = + std::chrono::milliseconds(-1); + +// Send and receive +template +void sendBytes(int socket, + const T* buffer, + size_t length, + bool moreData = false) { + size_t bytesToSend = sizeof(T) * length; + if (bytesToSend == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t* currentBytes = const_cast(bytes); + + int flags = 0; + +#ifdef MSG_MORE + if (moreData) { // there is more data to send + flags |= MSG_MORE; + } +#endif + + while (bytesToSend > 0) { + ssize_t bytesSent; + SYSCHECK(bytesSent = ::send(socket, currentBytes, bytesToSend, flags)) + if (bytesSent == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToSend -= bytesSent; + currentBytes += bytesSent; + } +} + +template +void recvBytes(int socket, T* buffer, size_t length) { + size_t bytesToReceive = sizeof(T) * length; + if (bytesToReceive == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t *currentBytes = bytes; + + while (bytesToReceive > 0) { + ssize_t bytesReceived; + SYSCHECK(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0)) + if (bytesReceived == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToReceive -= bytesReceived; + currentBytes += bytesReceived; + } +} + +// send a vector's length and data +template +void sendVector(int socket, + const std::vector& vec, + bool moreData = false) { + SizeType size = vec.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, vec.data(), size, moreData); +} + +// receive a vector as sent in sendVector +template +std::vector recvVector(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return value; +} + +// this is only for convenience when sending rvalues +template +void sendValue(int socket, const T& value, bool moreData = false) { + sendBytes(socket, &value, 1, moreData); +} + +template +T recvValue(int socket) { + T value; + recvBytes(socket, &value, 1); + return value; +} + +// send a string's length and data +inline void sendString(int socket, + const std::string& str, + bool moreData = false) { + SizeType size = str.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, str.data(), size, moreData); +} + +// receive a string as sent in sendString +inline std::string recvString(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return std::string(value.data(), value.size()); +} + +// Other helpers +std::string sockaddrToString(struct sockaddr *addr); + +std::pair listen(PortType port); + +int connect(const std::string& address, + PortType port, + bool wait = true, + const std::chrono::milliseconds& timeout = kNoTimeout); + +std::tuple +accept(int listenSocket, + const std::chrono::milliseconds& timeout = kNoTimeout); + +} // namespace tcputil +} // namespace c10d diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 65c56f51fbaa..b19a00db888b 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -1,5 +1,6 @@ set(test_srcs FileStoreTest.cpp + TCPStoreTest.cpp ) set(test_libraries diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index 64f98009924a..c75b9c0e86e9 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -1,50 +1,10 @@ -#include +#include "StoreTestCommon.hpp" +#include "FileStore.hpp" -#include +#include #include -#include -#include #include -#include "FileStore.hpp" - -using namespace c10d; - -class Semaphore { - public: - void post(int n = 1) { - std::unique_lock lock(m_); - n_ += n; - cv_.notify_all(); - } - - void wait(int n = 1) { - std::unique_lock lock(m_); - while (n_ < n) { - cv_.wait(lock); - } - n_ -= n; - } - - protected: - int n_ = 0; - std::mutex m_; - std::condition_variable cv_; -}; - -void set(Store& store, const std::string& key, const std::string& value) { - std::vector data(value.begin(), value.end()); - store.set(key, data); -} - -void check(Store& store, const std::string& key, const std::string& expected) { - auto tmp = store.get(key); - auto actual = std::string((const char*) tmp.data(), tmp.size()); - if (actual != expected) { - throw std::runtime_error("Expected " + expected + ", got " + actual); - } -} - std::string tmppath() { const char* tmpdir = getenv("TMPDIR"); if (tmpdir == nullptr) { @@ -71,29 +31,29 @@ int main(int argc, char** argv) { // Basic set/get { - FileStore store(path); - set(store, "key0", "value0"); - set(store, "key1", "value1"); - set(store, "key2", "value2"); - check(store, "key0", "value0"); - check(store, "key1", "value1"); - check(store, "key2", "value2"); + c10d::FileStore store(path); + c10d::test::set(store, "key0", "value0"); + c10d::test::set(store, "key1", "value1"); + c10d::test::set(store, "key2", "value2"); + c10d::test::check(store, "key0", "value0"); + c10d::test::check(store, "key1", "value1"); + c10d::test::check(store, "key2", "value2"); } // Perform get on new instance { - FileStore store(path); - check(store, "key0", "value0"); + c10d::FileStore store(path); + c10d::test::check(store, "key0", "value0"); } // Hammer on FileStore#add std::vector threads; const auto numThreads = 4; const auto numIterations = 100; - Semaphore sem1, sem2; + c10d::test::Semaphore sem1, sem2; for (auto i = 0; i < numThreads; i++) { threads.push_back(std::move(std::thread([&] { - FileStore store(path); + c10d::FileStore store(path); sem1.post(); sem2.wait(); for (auto j = 0; j < numIterations; j++) { @@ -109,12 +69,12 @@ int main(int argc, char** argv) { // Check that the counter has the expected value { - FileStore store(path); - std::stringstream ss; - ss << (numThreads * numIterations); - check(store, "counter", ss.str()); + c10d::FileStore store(path); + std::string expected = std::to_string(numThreads * numIterations); + c10d::test::check(store, "counter", expected); } unlink(path.c_str()); + std::cout << "Test succeeded" << std::endl; return 0; } diff --git a/torch/lib/c10d/test/StoreTestCommon.hpp b/torch/lib/c10d/test/StoreTestCommon.hpp new file mode 100644 index 000000000000..cf2d7a90a5a0 --- /dev/null +++ b/torch/lib/c10d/test/StoreTestCommon.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "Store.hpp" + +#include +#include +#include +#include + +namespace c10d { +namespace test { + +class Semaphore { + public: + void post(int n = 1) { + std::unique_lock lock(m_); + n_ += n; + cv_.notify_all(); + } + + void wait(int n = 1) { + std::unique_lock lock(m_); + while (n_ < n) { + cv_.wait(lock); + } + n_ -= n; + } + + protected: + int n_ = 0; + std::mutex m_; + std::condition_variable cv_; +}; + + +inline void set(Store& store, + const std::string& key, + const std::string& value) { + std::vector data(value.begin(), value.end()); + store.set(key, data); +} + +inline void check(Store& store, + const std::string& key, + const std::string& expected) { + auto tmp = store.get(key); + auto actual = std::string((const char*) tmp.data(), tmp.size()); + if (actual != expected) { + throw std::runtime_error("Expected " + expected + ", got " + actual); + } +} + +} // namespace test +} // namespace c10d diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp new file mode 100644 index 000000000000..26d815fcac05 --- /dev/null +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -0,0 +1,94 @@ +#include "StoreTestCommon.hpp" +#include "TCPStore.hpp" + +#include +#include +#include + + +int main(int argc, char** argv) { + + // server store + c10d::TCPStore serverStore("127.0.0.1", 29500, true); + + // Basic set/get on the server store + c10d::test::set(serverStore, "key0", "value0"); + c10d::test::set(serverStore, "key1", "value1"); + c10d::test::set(serverStore, "key2", "value2"); + c10d::test::check(serverStore, "key0", "value0"); + c10d::test::check(serverStore, "key1", "value1"); + c10d::test::check(serverStore, "key2", "value2"); + + // Hammer on TCPStore + std::vector threads; + const auto numThreads = 16; + const auto numIterations = 1000; + c10d::test::Semaphore sem1, sem2; + + // Each thread will have a client store to send/recv data + std::vector> clientStores; + for (auto i = 0; i < numThreads; i++) { + clientStores.push_back(std::unique_ptr(new + c10d::TCPStore("127.0.0.1", 29500, false))); + } + + std::string expectedCounterRes = std::to_string(numThreads * numIterations); + + for (auto i = 0; i < numThreads; i++) { + threads.push_back(std::move(std::thread([&sem1, + &sem2, + &clientStores, + i, + &expectedCounterRes] { + for (auto j = 0; j < numIterations; j++) { + clientStores[i]->add("counter", 1); + } + // Let each thread set and get key on its client store + std::string key = "thread_" + std::to_string(i); + for (auto j = 0; j < numIterations; j++) { + std::string val = "thread_val_" + std::to_string(j); + c10d::test::set(*clientStores[i], key, val); + c10d::test::check(*clientStores[i], key, val); + } + + sem1.post(); + sem2.wait(); + // Check the counter results + c10d::test::check(*clientStores[i], "counter", expectedCounterRes); + // Now check other threads' written data + for (auto j = 0; j < numThreads; j++) { + if (j == i) { + continue; + } + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + + std::to_string(numIterations - 1); + c10d::test::check(*clientStores[i], key, val); + } + + }))); + } + + sem1.wait(numThreads); + sem2.post(numThreads); + + for (auto& thread : threads) { + thread.join(); + } + + // Clear the store to test that client disconnect won't shutdown the store + clientStores.clear(); + + // Check that the counter has the expected value + c10d::test::check(serverStore, "counter", expectedCounterRes); + + // Check that each threads' written data from the main thread + for (auto i = 0; i < numThreads; i++) { + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + std::to_string(numIterations - 1); + c10d::test::check(serverStore, key, val); + } + + std::cout << "Test succeeded" << std::endl; + return EXIT_SUCCESS; +}