Skip to content

Commit

Permalink
Used notify/wake for wait and addressed all comments
Browse files Browse the repository at this point in the history
Reference: pytorch#7434
  • Loading branch information
teng-li committed May 16, 2018
1 parent 50e1362 commit b9294ee
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 145 deletions.
2 changes: 1 addition & 1 deletion torch/lib/c10d/CMakeLists.txt
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.2 FATAL_ERROR)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH})

add_library(store Utils.cpp Store.cpp FileStore.cpp TcpStore.cpp)
add_library(store Utils.cpp Store.cpp FileStore.cpp TCPStore.cpp)
target_compile_options(store PUBLIC "-std=c++11")

enable_testing()
Expand Down
168 changes: 98 additions & 70 deletions torch/lib/c10d/TcpStore.cpp → torch/lib/c10d/TCPStore.cpp
@@ -1,25 +1,31 @@
#include "TcpStore.hpp"
#include "TCPStore.hpp"

#include <poll.h>
#include <unistd.h>
#include <system_error>
#include <algorithm>

namespace c10d {

namespace {

enum class QueryType : std::uint8_t {
enum class QueryType : uint8_t {
SET,
GET,
ADD,
CHECK
CHECK,
WAIT
};

enum class CheckResponseType : std::uint8_t {
enum class CheckResponseType : uint8_t {
READY,
NOT_READY
};

enum class WaitResponseType : uint8_t {
STOP_WAITING
};

} // anonymous namespace

// TCPStoreDaemon class methods
Expand All @@ -31,13 +37,10 @@ TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) :
}

TCPStoreDaemon::~TCPStoreDaemon() {

// Stop the run
stop();

// Join the thread
join();

// Close unclosed sockets
for (auto socket : sockets_) {
if (socket != -1) {
Expand All @@ -57,9 +60,7 @@ void TCPStoreDaemon::join() {
}

void TCPStoreDaemon::run() {

// Create the control pipe
controlPipeFd_ = std::vector<int>{-1, -1};
if (pipe(controlPipeFd_.data()) == -1) {
throw std::runtime_error("Failed to create the control pipe to start the "
"TCPStoreDaemon run");
Expand All @@ -79,10 +80,8 @@ void TCPStoreDaemon::run() {

SYSCHECK(::poll(fds.data(), fds.size(), -1));

/**
* TCPStore's listening socket has an event and it should now be able to
* accept new connections.
*/
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(ECONNABORTED, std::system_category(),
Expand All @@ -93,9 +92,7 @@ void TCPStoreDaemon::run() {
sockets_.push_back(sockFd);
fds.push_back({ .fd = sockFd, .events = POLLIN });
}
/**
* The pipe receives an event which tells us to shutdown the daemon
*/
// The pipe receives an event which tells us to shutdown the daemon
if (fds[1].revents != 0) {
// Will be POLLUP when the pipe is closed
if (fds[1].revents ^ POLLHUP) {
Expand All @@ -106,11 +103,9 @@ void TCPStoreDaemon::run() {
finished = true;
break;
}
/**
* Skipping the fds[0] and fds[1],
* fds[0] is master's listening socket
* fds[1] is control pipe's reading fd
*/
// Skipping the fds[0] and fds[1],
// fds[0] is master's listening socket
// fds[1] is control pipe's reading fd
for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) {
if (fds[fdIdx].revents == 0) {
continue;
Expand All @@ -126,15 +121,13 @@ void TCPStoreDaemon::run() {
try {
query(fds[fdIdx].fd);
} catch (...) {
/**
* There was an error when processing query. Probably an exception
* occurred in recv/send what would indicate that socket on the other
* side has been closed. If the closing was due to normal exit, then the
* store should continue executing. Otherwise, if it was different
* exception, other connections will get an exception once they try to
* use the store. We will go ahead and close this connection whenever
* we hit an exception here.
*/
// There was an error when processing query. Probably an exception
// occurred in recv/send what would indicate that socket on the other
// side has been closed. If the closing was due to normal exit, then
// the store should continue executing. Otherwise, if it was different
// exception, other connections will get an exception once they try to
// use the store. We will go ahead and close this connection whenever
// we hit an exception here.
::close(fds[fdIdx].fd);
fds.erase(fds.begin() + fdIdx);
sockets_.erase(sockets_.begin() + fdIdx - 2);
Expand All @@ -146,22 +139,20 @@ void TCPStoreDaemon::run() {
}

void TCPStoreDaemon::stop() {
if (controlPipeFd_.size() == 2 && controlPipeFd_[1] != -1) {
if (controlPipeFd_[1] != -1) {
// close the write end of the pipe
::close(controlPipeFd_[1]);
controlPipeFd_[1] = -1;
}
}

/**
* query communicates with the worker. The format
* of the query is as follows:
* type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
* or, in the case of wait
* type of query | number of args | size of arg1 | arg1 | ...
*/
void TCPStoreDaemon::query(int socket) {

// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreDaemon::query(int socket) {
QueryType qt;
tcputil::recvBytes<QueryType>(socket, &qt, 1);

Expand All @@ -177,14 +168,32 @@ void TCPStoreDaemon::query(int socket) {
} else if (qt == QueryType::CHECK) {
checkHandler(socket);

} else if (qt == QueryType::WAIT) {
waitHandler(socket);

} else {
throw std::runtime_error("Unexpected query type");
}
}

void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) {
auto socketsToWait = waitingSockets_.find(key);
if (socketsToWait != waitingSockets_.end()) {
for (int socket : socketsToWait->second) {
if (--keysAwaited_[socket] == 0) {
tcputil::sendValue<WaitResponseType>(socket,
WaitResponseType::STOP_WAITING);
}
}
waitingSockets_.erase(socketsToWait);
}
}

void TCPStoreDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
// On "set", wake up all clients that have been waiting
wakeupWaitingClients(key);
}

void TCPStoreDaemon::addHandler(int socket) {
Expand All @@ -200,40 +209,55 @@ void TCPStoreDaemon::addHandler(int socket) {
tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
// Now send the new value
tcputil::sendValue<int64_t>(socket, addVal);
// On "add", wake up all clients that have been waiting
wakeupWaitingClients(key);
}

void TCPStoreDaemon::getHandler(int socket) {
void TCPStoreDaemon::getHandler(int socket) const {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
}

void TCPStoreDaemon::checkHandler(int socket) {
void TCPStoreDaemon::checkHandler(int socket) const {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
// Now we have received all the keys
if (checkAndUpdate(keys)) {
if (checkKeys(keys)) {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
} else {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
}
}

bool TCPStoreDaemon::checkAndUpdate(std::vector<std::string>& keys) const {
bool ret = true;
for (auto it = keys.begin(); it != keys.end();) {
if (tcpStore_.count(*it) == 0) {
ret = false;
it++;
} else {
it = keys.erase(it);
void TCPStoreDaemon::waitHandler(int socket) {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
if (checkKeys(keys)) {
tcputil::sendValue<WaitResponseType>(socket,
WaitResponseType::STOP_WAITING);
} else {
for (auto& key : keys) {
waitingSockets_[key].push_back(socket);
}
keysAwaited_[socket] = keys.size();
}
return ret;
}

bool TCPStoreDaemon::
checkKeys(const std::vector<std::string>& keys) const {
return std::all_of(keys.begin(), keys.end(),
[this](const std::string& s) {
return tcpStore_.count(s) > 0;
});
}

// TCPStore class methods
Expand All @@ -243,7 +267,6 @@ TCPStore::TCPStore(const std::string& masterAddr,
: isServer_(isServer)
, tcpStoreAddr_(masterAddr)
, tcpStorePort_(masterPort)

{
if (isServer_) {
// Opening up the listening socket
Expand All @@ -260,10 +283,8 @@ TCPStore::TCPStore(const std::string& masterAddr,
TCPStore::~TCPStore() {
::close(storeSocket_);
if (isServer_) {
/**
* Store daemon should end because of closed connection.
* daemon destructor should join the thread
*/
// Store daemon should end because of closed connection.
// daemon destructor should join the thread
tcpStoreDaemon_.reset(nullptr);
::close(masterListenSocket_);
}
Expand All @@ -290,7 +311,6 @@ int64_t TCPStore::add(const std::string& key, int64_t value) {
}

bool TCPStore::check(const std::vector<std::string>& keys) {

tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
Expand All @@ -307,19 +327,27 @@ bool TCPStore::check(const std::vector<std::string>& keys) {
}
}

void TCPStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {

const auto start = std::chrono::steady_clock::now();
while (!check(keys)) {
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout != kNoTimeout && elapsed > timeout) {
throw std::runtime_error("Wait timeout");
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
void TCPStore::wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
// Set the socket timeout if there is a wait timeout
if (timeout != kNoTimeout) {
struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
.tv_usec = (timeout.count() % 1000) * 1000};
SYSCHECK(::setsockopt(storeSocket_,
SOL_SOCKET,
SO_RCVTIMEO,
reinterpret_cast<char*>(&timeoutTV),
sizeof(timeoutTV)));
}
tcputil::sendValue<QueryType>(storeSocket_, QueryType::WAIT);
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
for (size_t i = 0; i < nkeys; i++) {
tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
}
auto waitResponse = tcputil::recvValue<WaitResponseType>(storeSocket_);
if (waitResponse != WaitResponseType::STOP_WAITING) {
throw std::runtime_error("Stop_waiting response is expected");
}
}

Expand Down
14 changes: 10 additions & 4 deletions torch/lib/c10d/TcpStore.hpp → torch/lib/c10d/TCPStore.hpp
Expand Up @@ -27,17 +27,23 @@ class TCPStoreDaemon {

void setHandler(int socket);
void addHandler(int socket);
void getHandler(int socket);
void checkHandler(int socket);
void getHandler(int socket) const;
void checkHandler(int socket) const;
void waitHandler(int socket);

bool checkAndUpdate(std::vector<std::string>& keys) const;
bool checkKeys(const std::vector<std::string>& keys) const;
void wakeupWaitingClients(const std::string &key);

std::thread daemonThread_;
std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
// From key -> the list of sockets waiting on it
std::unordered_map<std::string, std::vector<int>> waitingSockets_;
// From socket -> number of keys awaited
std::unordered_map<int, size_t> keysAwaited_;

std::vector<int> sockets_;
int storeListenSocket_;
std::vector<int> controlPipeFd_;
std::vector<int> controlPipeFd_ {-1, -1};
};

class TCPStore : public Store {
Expand Down

0 comments on commit b9294ee

Please sign in to comment.