Skip to content

Commit

Permalink
Used pipe to terminate the store daemon and addressed all comments
Browse files Browse the repository at this point in the history
Please see: pytorch#7434
  • Loading branch information
teng-li committed May 16, 2018
1 parent 54cb0b5 commit 50e1362
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 131 deletions.
248 changes: 150 additions & 98 deletions torch/lib/c10d/TcpStore.cpp
@@ -1,54 +1,74 @@
#include "TcpStore.hpp"

#include <poll.h>
#include <system_error>
#include <unistd.h>

#include <system_error>

namespace c10d {


namespace {

enum class QueryType : std::uint8_t {
SET,
GET,
ADD,
CHECK,
STOP_WAITING,
KEEP_WAITING
CHECK
};

} // anonymous namespace

enum class CheckResponseType : std::uint8_t {
READY,
NOT_READY
};

// TcpStoreDaemon class methods
} // anonymous namespace

// TCPStoreDaemon class methods
// Simply start the daemon thread
TcpStoreDaemon::TcpStoreDaemon(int storeListenSocket) :
TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) :
storeListenSocket_(storeListenSocket)
{
daemonThread_ = std::thread(&TcpStoreDaemon::run, this);
daemonThread_ = std::thread(&TCPStoreDaemon::run, this);
}

TcpStoreDaemon::~TcpStoreDaemon() {
TCPStoreDaemon::~TCPStoreDaemon() {

// Stop the run
stop();

// Join the thread
join();

// Close unclosed sockets
for (auto socket : sockets_) {
if (socket != -1) {
::close(socket);
}
}
// Join the thread
join();
// Now close the rest control pipe
for (auto fd : controlPipeFd_) {
if (fd != -1) {
::close(fd);
}
}
}

void TcpStoreDaemon::join() {
void TCPStoreDaemon::join() {
daemonThread_.join();
}

void TcpStoreDaemon::run() {
void TCPStoreDaemon::run() {

// Create the control pipe
controlPipeFd_ = std::vector<int>{-1, -1};
if (pipe(controlPipeFd_.data()) == -1) {
throw std::runtime_error("Failed to create the control pipe to start the "
"TCPStoreDaemon run");
}

std::vector<struct pollfd> fds;
fds.push_back({ .fd = storeListenSocket_, .events = POLLIN });
// Push the read end of the pipe to signal the stopping of the daemon run
fds.push_back({ .fd = controlPipeFd_[0], .events = POLLHUP });

// receive the queries
bool finished = false;
Expand All @@ -59,50 +79,77 @@ void TcpStoreDaemon::run() {

SYSCHECK(::poll(fds.data(), fds.size(), -1));

/**
* TCPStore's listening socket has an event and it should now be able to
* accept new connections.
*/
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(ECONNABORTED, std::system_category());
throw std::system_error(ECONNABORTED, std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
sockets_.push_back(sockFd);
keysAwaited_.push_back(0);
fds.push_back({ .fd = sockFd, .events = POLLIN });
}
for (size_t rank = 0; rank < sockets_.size(); rank++) {
if (fds[rank + 1].revents == 0) {
/**
* The pipe receives an event which tells us to shutdown the daemon
*/
if (fds[1].revents != 0) {
// Will be POLLUP when the pipe is closed
if (fds[1].revents ^ POLLHUP) {
throw std::system_error(ECONNABORTED, std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}
finished = true;
break;
}
/**
* Skipping the fds[0] and fds[1],
* fds[0] is master's listening socket
* fds[1] is control pipe's reading fd
*/
for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) {
if (fds[fdIdx].revents == 0) {
continue;
}

if (fds[rank + 1].revents ^ POLLIN) {
throw std::system_error(ECONNABORTED, std::system_category());
if (fds[fdIdx].revents ^ POLLIN) {
throw std::system_error(ECONNABORTED, std::system_category(),
"Unexpected poll revent: " +
std::to_string(fds[fdIdx].revents) + " on socket: " +
std::to_string(fds[fdIdx].fd));
}
// Now query the socket that has the event
try {
query(rank);
} catch (std::exception& ex) {
query(fds[fdIdx].fd);
} catch (...) {
/**
* There was an error when processing query. Probably an exception
* occurred in recv/send what would indicate that socket on the other
* side has been closed. If the closing was due to normal exit, then the
* store should exit too. Otherwise, if it was different exception,
* other processes will get an exception once they try to use the store.
* store should continue executing. Otherwise, if it was different
* exception, other connections will get an exception once they try to
* use the store. We will go ahead and close this connection whenever
* we hit an exception here.
*/
finished = true;
break;
::close(fds[fdIdx].fd);
fds.erase(fds.begin() + fdIdx);
sockets_.erase(sockets_.begin() + fdIdx - 2);
--fdIdx;
continue;
}
}
}
}

void TcpStoreDaemon::wakeUpWaitingRanks(const std::string& key) {
auto toWake = waiting_.find(key);
if (toWake != waiting_.end()) {
for (int proc : toWake->second) {
if (--keysAwaited_[proc] == 0) {
tcputil::sendValue<QueryType>(sockets_[proc],
QueryType::STOP_WAITING);
}
}
waiting_.erase(toWake);
void TCPStoreDaemon::stop() {
if (controlPipeFd_.size() == 2 && controlPipeFd_[1] != -1) {
// close the write end of the pipe
::close(controlPipeFd_[1]);
controlPipeFd_[1] = -1;
}
}

Expand All @@ -113,64 +160,70 @@ void TcpStoreDaemon::wakeUpWaitingRanks(const std::string& key) {
* or, in the case of wait
* type of query | number of args | size of arg1 | arg1 | ...
*/
void TcpStoreDaemon::query(RankType rank) {
void TCPStoreDaemon::query(int socket) {

int socket = sockets_[rank];
QueryType qt;
tcputil::recvBytes<QueryType>(socket, &qt, 1);

if (qt == QueryType::SET) {
std::string key = tcputil::recvString(socket);
tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
// On "set", wake up all of the processes that wait
// for keys already in the store
wakeUpWaitingRanks(key);
setHandler(socket);

} else if (qt == QueryType::ADD) {
std::string key = tcputil::recvString(socket);
int64_t addVal = tcputil::recvValue<int64_t>(socket);

if (tcpStore_.find(key) != tcpStore_.end()) {
auto buf = reinterpret_cast<const char*>(tcpStore_[key].data());
auto len = tcpStore_[key].size();
addVal += std::stoll(std::string(buf, len));
}
auto addValStr = std::to_string(addVal);
tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
// Now send the new value
tcputil::sendValue<int64_t>(socket, addVal);
// On "add", wake up all of the processes that wait
// for keys already in the store
wakeUpWaitingRanks(key);
addHandler(socket);

} else if (qt == QueryType::GET) {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
getHandler(socket);

} else if (qt == QueryType::CHECK) {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
// Now we have received all the keys
if (checkAndUpdate(keys)) {
tcputil::sendValue<QueryType>(socket, QueryType::STOP_WAITING);
} else {
for (auto& key : keys) {
waiting_[key].push_back(rank);
}
keysAwaited_[rank] = keys.size();
tcputil::sendValue<QueryType>(socket, QueryType::KEEP_WAITING);
}
checkHandler(socket);

} else {
throw std::runtime_error("Unexpected query type");
}
}

void TCPStoreDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
}

void TCPStoreDaemon::addHandler(int socket) {
std::string key = tcputil::recvString(socket);
int64_t addVal = tcputil::recvValue<int64_t>(socket);

if (tcpStore_.find(key) != tcpStore_.end()) {
auto buf = reinterpret_cast<const char*>(tcpStore_[key].data());
auto len = tcpStore_[key].size();
addVal += std::stoll(std::string(buf, len));
}
auto addValStr = std::to_string(addVal);
tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
// Now send the new value
tcputil::sendValue<int64_t>(socket, addVal);
}

void TCPStoreDaemon::getHandler(int socket) {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
}

void TCPStoreDaemon::checkHandler(int socket) {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
// Now we have received all the keys
if (checkAndUpdate(keys)) {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
} else {
throw std::runtime_error("expected a query type");
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
}
}

bool TcpStoreDaemon::checkAndUpdate(std::vector<std::string>& keys) const {
bool TCPStoreDaemon::checkAndUpdate(std::vector<std::string>& keys) const {
bool ret = true;
for (auto it = keys.begin(); it != keys.end();) {
if (tcpStore_.count(*it) == 0) {
Expand All @@ -183,9 +236,8 @@ bool TcpStoreDaemon::checkAndUpdate(std::vector<std::string>& keys) const {
return ret;
}

// TcpStore class methods

TcpStore::TcpStore(const std::string& masterAddr,
// TCPStore class methods
TCPStore::TCPStore(const std::string& masterAddr,
PortType masterPort,
bool isServer)
: isServer_(isServer)
Expand All @@ -194,68 +246,68 @@ TcpStore::TcpStore(const std::string& masterAddr,

{
if (isServer_) {
// Openning up the listening socket
// Opening up the listening socket
std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort);
// Now start the daemon
tcpStoreDaemon_ = std::unique_ptr<TcpStoreDaemon>(
new TcpStoreDaemon(masterListenSocket_)
tcpStoreDaemon_ = std::unique_ptr<TCPStoreDaemon>(
new TCPStoreDaemon(masterListenSocket_)
);
}
// Connect to the daemon
storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_);
}

TcpStore::~TcpStore() {
TCPStore::~TCPStore() {
::close(storeSocket_);
if (isServer_) {
::close(masterListenSocket_);
/**
* Store daemon should end because of closed connection.
* daemon destructor should join the thread
*/
tcpStoreDaemon_.reset(nullptr);
::close(masterListenSocket_);
}
}

void TcpStore::set(const std::string& key, const std::vector<uint8_t>& data) {
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendVector<uint8_t>(storeSocket_, data);
}

std::vector<uint8_t> TcpStore::get(const std::string& key) {
std::vector<uint8_t> TCPStore::get(const std::string& key) {
wait({key});
tcputil::sendValue<QueryType>(storeSocket_, QueryType::GET);
tcputil::sendString(storeSocket_, key);
return tcputil::recvVector<uint8_t>(storeSocket_);
}

int64_t TcpStore::add(const std::string& key, int64_t value) {
int64_t TCPStore::add(const std::string& key, int64_t value) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendValue<int64_t>(storeSocket_, value);
return tcputil::recvValue<int64_t>(storeSocket_);
}

bool TcpStore::check(const std::vector<std::string>& keys) {
bool TCPStore::check(const std::vector<std::string>& keys) {

tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
for (size_t i = 0; i < nkeys; i++) {
tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
}
auto checkResponse = tcputil::recvValue<QueryType>(storeSocket_);
if (checkResponse == QueryType::STOP_WAITING) {
auto checkResponse = tcputil::recvValue<CheckResponseType>(storeSocket_);
if (checkResponse == CheckResponseType::READY) {
return true;
} else if (checkResponse == QueryType::KEEP_WAITING) {
} else if (checkResponse == CheckResponseType::NOT_READY) {
return false;
} else {
throw std::runtime_error("stop_waiting or keep_waiting response expected");
throw std::runtime_error("ready or not_ready response expected");
}
}

void TcpStore::wait(
void TCPStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {

Expand Down

0 comments on commit 50e1362

Please sign in to comment.