diff --git a/gloo/transport/tcp/CMakeLists.txt b/gloo/transport/tcp/CMakeLists.txt index 9bf72b83f..9cb6535af 100644 --- a/gloo/transport/tcp/CMakeLists.txt +++ b/gloo/transport/tcp/CMakeLists.txt @@ -7,6 +7,7 @@ else() "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" "${CMAKE_CURRENT_SOURCE_DIR}/device.cc" "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/listener.cc" "${CMAKE_CURRENT_SOURCE_DIR}/loop.cc" "${CMAKE_CURRENT_SOURCE_DIR}/pair.cc" "${CMAKE_CURRENT_SOURCE_DIR}/socket.cc" @@ -19,6 +20,8 @@ else() "${CMAKE_CURRENT_SOURCE_DIR}/context.h" "${CMAKE_CURRENT_SOURCE_DIR}/device.h" "${CMAKE_CURRENT_SOURCE_DIR}/error.h" + "${CMAKE_CURRENT_SOURCE_DIR}/helpers.h" + "${CMAKE_CURRENT_SOURCE_DIR}/listener.h" "${CMAKE_CURRENT_SOURCE_DIR}/loop.h" "${CMAKE_CURRENT_SOURCE_DIR}/pair.h" "${CMAKE_CURRENT_SOURCE_DIR}/socket.h" diff --git a/gloo/transport/tcp/device.cc b/gloo/transport/tcp/device.cc index 481d67f7d..dd970b60f 100644 --- a/gloo/transport/tcp/device.cc +++ b/gloo/transport/tcp/device.cc @@ -19,6 +19,7 @@ #include "gloo/common/logging.h" #include "gloo/common/error.h" #include "gloo/transport/tcp/context.h" +#include "gloo/transport/tcp/helpers.h" #include "gloo/transport/tcp/pair.h" namespace gloo { @@ -217,6 +218,7 @@ const std::string sockaddrToInterfaceName(const struct attr& attr) { Device::Device(const struct attr& attr) : attr_(attr), loop_(std::make_shared()), + listener_(std::make_shared(loop_, attr)), interfaceName_(sockaddrToInterfaceName(attr_)), interfaceSpeedMbps_(getInterfaceSpeedByName(interfaceName_)), pciBusID_(interfaceToBusID(interfaceName_)) { @@ -257,6 +259,105 @@ void Device::unregisterDescriptor(int fd, Handler* h) { loop_->unregisterDescriptor(fd, h); } +Address Device::nextAddress() { + return listener_->nextAddress(); +} + +bool Device::isInitiator( + const Address& local, + const Address& remote) const { + int rv = 0; + // The remote side of a pair will be called with the same + // addresses, but in reverse. There should only be a single + // connection between the two, so we pick one side as the listener + // and the other side as the connector. + const auto& ss1 = local.getSockaddr(); + const auto& ss2 = remote.getSockaddr(); + GLOO_ENFORCE_EQ(ss1.ss_family, ss2.ss_family); + const int family = ss1.ss_family; + if (family == AF_INET) { + const struct sockaddr_in* sa = (struct sockaddr_in*)&ss1; + const struct sockaddr_in* sb = (struct sockaddr_in*)&ss2; + rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr)); + if (rv == 0) { + rv = sa->sin_port - sb->sin_port; + } + } else if (family == AF_INET6) { + const struct sockaddr_in6* sa = (struct sockaddr_in6*)&ss1; + const struct sockaddr_in6* sb = (struct sockaddr_in6*)&ss2; + rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr)); + if (rv == 0) { + rv = sa->sin6_port - sb->sin6_port; + } + } else { + GLOO_ENFORCE(false, "Unknown address family: ", family); + } + + // If both sides of the pair use the same address and port, they are + // sharing the same device instance. This happens in tests. Compare + // sequence number to allow pairs to connect. + if (rv == 0) { + rv = local.getSeq() - remote.getSeq(); + } + GLOO_ENFORCE_NE(rv, 0, "Cannot connect to self"); + return rv > 0; +} + +void Device::connect( + const Address& local, + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn) { + auto initiator = isInitiator(local, remote); + + if (initiator) { + connectAsInitiator(remote, timeout, std::move(fn)); + return; + } + connectAsListener(local, timeout, std::move(fn)); +} + +// Connecting as listener is passive. +// +// Register the connect callback to be executed when the other side of +// the pair has connected and identified itself as destined for this +// address. To do so, we register the callback for the sequence number +// associated with the address. If this connection already exists, +// deal with it here. +// +void Device::connectAsListener( + const Address& local, + std::chrono::milliseconds /* unused */, + connect_callback_t fn) { + // TODO(pietern): Use timeout. + listener_->waitForConnection(local.getSeq(), std::move(fn)); +} + +// Connecting as initiator is active. +// +// The connect callback is fired when the connection to the other side +// of the pair has been made, and the sequence number for this +// connection has been written. If an error occurs at any time, the +// callback is called with an associated error event. +// +void Device::connectAsInitiator( + const Address& remote, + std::chrono::milliseconds /* unused */, + connect_callback_t fn) { + const auto& sockaddr = remote.getSockaddr(); + + // Create new socket to connect to peer. + auto socket = Socket::createForFamily(sockaddr.ss_family); + socket->reuseAddr(true); + socket->noDelay(true); + socket->connect(sockaddr); + + // Write sequence number for peer to new socket. + // TODO(pietern): Use timeout. + write( + loop_, std::move(socket), remote.getSeq(), std::move(fn)); +} + } // namespace tcp } // namespace transport } // namespace gloo diff --git a/gloo/transport/tcp/device.h b/gloo/transport/tcp/device.h index ef0ffcaba..1a6f8e249 100644 --- a/gloo/transport/tcp/device.h +++ b/gloo/transport/tcp/device.h @@ -17,7 +17,10 @@ #include #include +#include +#include #include +#include namespace gloo { namespace transport { @@ -50,14 +53,61 @@ class Device : public ::gloo::transport::Device, void registerDescriptor(int fd, int events, Handler* h); void unregisterDescriptor(int fd, Handler* h); + // TCP is bidirectional so when we connect two ends of a pair, + // one side is the connection initiator and the other is the listener. + bool isInitiator( + const Address& local, + const Address& remote) const; + protected: const struct attr attr_; + // Return a new `Address` instance. + // + // This is called by the constructor of the `Pair` class. It gives + // the pair a uniquely identifying address even though the device + // uses a shared listening socket. + // + Address nextAddress(); + + // Connect a pair to a remote. + // + // This is performed by the device instance because we use a single + // listening socket for all inbound pair connections. + // + // Matching these connections with pairs is done with a handshake. + // The remote side of the connection writes a sequence number (see + // `Address::sequence_t`) to the stream that identifies the pair + // it wants to connect to. On the local side, this sequence number + // is read and used as key in a map with callbacks. If the callback + // is found, it is called. If the callback is not found, the + // connection is cached in a map, using the sequence number. + // + using connect_callback_t = + std::function socket, Error error)>; + + void connect( + const Address& local, + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn); + + void connectAsListener( + const Address& local, + std::chrono::milliseconds timeout, + connect_callback_t fn); + + void connectAsInitiator( + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn); + friend class Pair; friend class Buffer; private: std::shared_ptr loop_; + std::shared_ptr listener_; std::string interfaceName_; int interfaceSpeedMbps_; diff --git a/gloo/transport/tcp/helpers.h b/gloo/transport/tcp/helpers.h new file mode 100644 index 000000000..81c3a8df0 --- /dev/null +++ b/gloo/transport/tcp/helpers.h @@ -0,0 +1,171 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +// ReadValueOperation asynchronously reads a value of type T from the +// socket specified at construction. Upon completion or error, the +// callback is called. Its lifetime is coupled with completion of the +// operation, so the called doesn't need to hold on to the instance. +// It does so by storing a shared_ptr to itself (effectively a leak) +// until the event loop calls back. +template +class ReadValueOperation final + : public Handler, + public std::enable_shared_from_this> { + public: + using callback_t = + std::function, const Error& error, T&& t)>; + + ReadValueOperation( + std::shared_ptr loop, + std::shared_ptr socket, + callback_t fn) + : loop_(std::move(loop)), + socket_(std::move(socket)), + fn_(std::move(fn)) {} + + void run() { + // Cannot initialize leak until after the object has been + // constructed, because the std::make_shared initialization + // doesn't run after construction of the underlying object. + leak_ = this->shared_from_this(); + // Register with loop only after we've leaked the shared_ptr, + // because we unleak it when the event loop thread calls. + loop_->registerDescriptor(socket_->fd(), EPOLLIN | EPOLLONESHOT, this); + } + + void handleEvents(int events) override { + // Move leaked shared_ptr to the stack so that this object + // destroys itself once this function returns. + auto self = std::move(this->leak_); + + // Read T. + auto rv = socket_->read(&t_, sizeof(t_)); + if (rv == -1) { + fn_(socket_, SystemError("read", errno), std::move(t_)); + return; + } + + // Check for short read (assume we can read in a single call). + if (rv < sizeof(t_)) { + fn_(socket_, ShortReadError(rv, sizeof(t_)), std::move(t_)); + return; + } + + fn_(socket_, Error::kSuccess, std::move(t_)); + } + + private: + std::shared_ptr loop_; + std::shared_ptr socket_; + callback_t fn_; + std::shared_ptr> leak_; + + T t_; +}; + +template +void read( + std::shared_ptr loop, + std::shared_ptr socket, + typename ReadValueOperation::callback_t fn) { + auto x = std::make_shared>( + std::move(loop), std::move(socket), std::move(fn)); + x->run(); +} + +// WriteValueOperation asynchronously writes a value of type T to the +// socket specified at construction. Upon completion or error, the +// callback is called. Its lifetime is coupled with completion of the +// operation, so the called doesn't need to hold on to the instance. +// It does so by storing a shared_ptr to itself (effectively a leak) +// until the event loop calls back. +template +class WriteValueOperation final + : public Handler, + public std::enable_shared_from_this> { + public: + using callback_t = + std::function, const Error& error)>; + + WriteValueOperation( + std::shared_ptr loop, + std::shared_ptr socket, + T t, + callback_t fn) + : loop_(std::move(loop)), + socket_(std::move(socket)), + fn_(std::move(fn)), + t_(std::move(t)) {} + + void run() { + // Cannot initialize leak until after the object has been + // constructed, because the std::make_shared initialization + // doesn't run after construction of the underlying object. + leak_ = this->shared_from_this(); + // Register with loop only after we've leaked the shared_ptr, + // because we unleak it when the event loop thread calls. + loop_->registerDescriptor(socket_->fd(), EPOLLOUT | EPOLLONESHOT, this); + } + + void handleEvents(int events) override { + // Move leaked shared_ptr to the stack so that this object + // destroys itself once this function returns. + auto leak = std::move(this->leak_); + + // Write T. + auto rv = socket_->write(&t_, sizeof(t_)); + if (rv == -1) { + fn_(socket_, SystemError("write", errno)); + return; + } + + // Check for short write (assume we can write in a single call). + if (rv < sizeof(t_)) { + fn_(socket_, ShortWriteError(rv, sizeof(t_))); + return; + } + + fn_(socket_, Error::kSuccess); + } + + private: + std::shared_ptr loop_; + std::shared_ptr socket_; + callback_t fn_; + std::shared_ptr> leak_; + + T t_; +}; + +template +void write( + std::shared_ptr loop, + std::shared_ptr socket, + T t, + typename WriteValueOperation::callback_t fn) { + auto x = std::make_shared>( + std::move(loop), std::move(socket), std::move(t), std::move(fn)); + x->run(); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/listener.cc b/gloo/transport/tcp/listener.cc new file mode 100644 index 000000000..2caa7364a --- /dev/null +++ b/gloo/transport/tcp/listener.cc @@ -0,0 +1,121 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +Listener::Listener(std::shared_ptr loop, const attr& attr) + : loop_(std::move(loop)) { + listener_ = Socket::createForFamily(attr.ai_addr.ss_family); + listener_->reuseAddr(true); + listener_->bind(attr.ai_addr); + listener_->listen(kBacklog); + addr_ = listener_->sockName(); + + // Register with loop for readability events. + loop_->registerDescriptor(listener_->fd(), EPOLLIN, this); +} + +Listener::~Listener() { + if (listener_) { + loop_->unregisterDescriptor(listener_->fd(), this); + } +} + +void Listener::handleEvents(int /* unused */) { + std::lock_guard guard(mutex_); + + for (;;) { + auto sock = listener_->accept(); + if (!sock) { + // Let the loop try again on the next tick. + if (errno == EAGAIN) { + return; + } + // Actual error. + GLOO_ENFORCE(false, "accept: ", strerror(errno)); + } + + sock->reuseAddr(true); + sock->noDelay(true); + + // Read sequence number. + read( + loop_, + sock, + [this]( + std::shared_ptr socket, + const Error& error, + sequence_number_t&& seq) { + // If there was an error reading from the socket, the + // sequence number will be bogus, and we can't route it to + // the right callback. Ignore it. + if (error) { + return; + } + + haveConnection(std::move(socket), seq); + }); + } +} + +Address Listener::nextAddress() { + std::lock_guard guard(mutex_); + return Address(addr_.getSockaddr(), seq_++); +} + +void Listener::waitForConnection(sequence_number_t seq, connect_callback_t fn) { + std::unique_lock lock(mutex_); + + // If we don't yet have an fd for this sequence number, persist callback. + auto it = seqToSocket_.find(seq); + if (it == seqToSocket_.end()) { + seqToCallback_.emplace(seq, std::move(fn)); + return; + } + + // If we already have an fd for this sequence number, schedule callback. + auto socket = std::move(it->second); + seqToSocket_.erase(it); + loop_->defer([fn, socket]() { fn(socket, Error::kSuccess); }); +} + +void Listener::haveConnection( + std::shared_ptr socket, + sequence_number_t seq) { + std::unique_lock lock(mutex_); + + // If we don't yet have a callback for this sequence number, persist socket. + auto it = seqToCallback_.find(seq); + if (it == seqToCallback_.end()) { + seqToSocket_.emplace(seq, std::move(socket)); + return; + } + + // If we already have a callback for this sequence number, trigger it. + auto fn = std::move(it->second); + seqToCallback_.erase(it); + lock.unlock(); + fn(std::move(socket), Error::kSuccess); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/listener.h b/gloo/transport/tcp/listener.h new file mode 100644 index 000000000..e7aa3c711 --- /dev/null +++ b/gloo/transport/tcp/listener.h @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +// Listener deals with incoming connections. Incoming connections +// write a few bytes containing a sequence number. This sequence +// number is read off the socket and matched to a local sequence +// number. If there is a match, the socket is passed to the +// corresponding pair. If it can't be matched, it is stashed until a +// pair with the sequence number calls `waitForConnection`. +class Listener final : public Handler { + public: + using connect_callback_t = + std::function socket, Error error)>; + + static constexpr int kBacklog = 2048; + + Listener(std::shared_ptr loop, const attr& attr); + + ~Listener() override; + + void handleEvents(int events) override; + + Address nextAddress(); + + // Wait for connection with sequence number `seq`. The callback is + // always called from a different thread (the event loop thread), + // even if the connection is already available. + void waitForConnection(sequence_number_t seq, connect_callback_t fn); + + private: + std::mutex mutex_; + std::shared_ptr loop_; + std::shared_ptr listener_; + + // Address of this listener and the sequence number for the next + // connection. Sequence numbers are written by a peer right after + // establishing a new connection and used locally to match a new + // connection to a pair instance. + Address addr_; + sequence_number_t seq_{0}; + + // Called when we've read a sequence number from a new socket. + void haveConnection(std::shared_ptr socket, sequence_number_t seq); + + // Callbacks by sequence number (while waiting for a connection). + std::unordered_map seqToCallback_; + + // Sockets by sequence number (while waiting for a pair to call). + std::unordered_map> seqToSocket_; +}; + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/pair.cc b/gloo/transport/tcp/pair.cc index 91bb7c1ef..d6cd1baa6 100644 --- a/gloo/transport/tcp/pair.cc +++ b/gloo/transport/tcp/pair.cc @@ -8,8 +8,8 @@ #include "gloo/transport/tcp/pair.h" -#include #include +#include #include #include @@ -61,10 +61,8 @@ Pair::Pair( busyPoll_(false), fd_(FD_INVALID), sendBufferSize_(0), - is_client_(false), - ex_(nullptr) { - listen(); -} + self_(device_->nextAddress()), + ex_(nullptr) {} // Destructor performs a "soft" close. Pair::~Pair() { @@ -99,8 +97,73 @@ const Address& Pair::address() const { } void Pair::connect(const std::vector& bytes) { - auto peer = Address(bytes); - connect(peer); + const auto peer = Address(bytes); + + std::unique_lock lock(m_); + GLOO_ENFORCE_EQ(state_, INITIALIZING); + state_ = CONNECTING; + + // Both processes call the `Pair::connect` function with the address + // of the other. The device instance associated with both `Pair` + // instances is responsible for establishing the actual connection, + // seeing as it owns the listening socket. + // + // One side takes a passive role and the other side takes an active + // role in establishing the connection. The passive role means + // waiting for an incoming connection that identifies itself with a + // specific sequence number (encoded in the `Address`). The active + // role means creating a connection to a specific address, and + // writing out a specific sequence number. Once the process for + // either role succeeds, the connection callback for the pair gets + // called with the file descriptor for the underlying connection. + // + device_->connect( + self_, + peer, + timeout_, + std::bind( + &Pair::connectCallback, + this, + std::placeholders::_1, + std::placeholders::_2)); + + // Wait for connection to be made. + // + // NOTE(pietern): This can be split out to a separate function so + // that we first initiate all connections and then wait on all of + // them. This should make context initialization a bit faster. It + // requires a change to the base class though, so let's so it after + // this new transport has been merged. + // + waitUntilConnected(lock, true); +} + +void Pair::connectCallback(std::shared_ptr socket, Error error) { + std::lock_guard lock(m_); + if (error) { + signalException(GLOO_ERROR_MSG(error.what())); + return; + } + + // Finalize setup. + socket->block(false); + socket->noDelay(true); + socket->sendTimeout(timeout_); + socket->recvTimeout(timeout_); + + // Reset addresses. + self_ = socket->sockName(); + peer_ = socket->peerName(); + + // Take over ownership of the socket's file descriptor. The code in + // this class works directly with file descriptor directly. + fd_ = socket->release(); + + // Register with loop for socket readability. + device_->registerDescriptor(fd_, EPOLLIN, this); + + // We're done: update state and wake up waiting threads. + changeState(CONNECTED); } static void setSocketBlocking(int fd, bool enable) { @@ -152,133 +215,6 @@ void Pair::setSync(bool sync, bool busyPoll) { busyPoll_ = busyPoll; } -void Pair::listen() { - std::lock_guard lock(m_); - int rv; - - const auto& attr = device_->attr_; - auto fd = socket(attr.ai_family, attr.ai_socktype, attr.ai_protocol); - if (fd == -1) { - signalAndThrowException(GLOO_ERROR_MSG("socket: ", strerror(errno))); - } - - // Set SO_REUSEADDR to signal that reuse of the listening port is OK. - int on = 1; - rv = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - if (rv == -1) { - ::close(fd); - signalAndThrowException(GLOO_ERROR_MSG("setsockopt: ", strerror(errno))); - } - - rv = bind(fd, (const sockaddr*)&attr.ai_addr, attr.ai_addrlen); - if (rv == -1) { - ::close(fd); - signalAndThrowException(GLOO_ERROR_MSG("bind: ", strerror(errno))); - } - - // listen(2) on socket - fd_ = fd; - rv = ::listen(fd_, 1024); - if (rv == -1) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("listen: ", strerror(errno))); - } - - // Keep copy of address - self_ = Address::fromSockName(fd); - - // Register with device so we're called when peer connects - changeState(LISTENING); - device_->registerDescriptor(fd_, EPOLLIN, this); - - return; -} - -void Pair::connect(const Address& peer) { - std::unique_lock lock(m_); - int rv; - socklen_t addrlen; - throwIfException(); - - peer_ = peer; - - const auto& selfAddr = self_.getSockaddr(); - const auto& peerAddr = peer_.getSockaddr(); - - // Addresses have to have same family - if (selfAddr.ss_family != peerAddr.ss_family) { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("address family mismatch"); - } - - if (selfAddr.ss_family == AF_INET) { - struct sockaddr_in* sa = (struct sockaddr_in*)&selfAddr; - struct sockaddr_in* sb = (struct sockaddr_in*)&peerAddr; - addrlen = sizeof(struct sockaddr_in); - rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr)); - if (rv == 0) { - rv = sa->sin_port - sb->sin_port; - } - } else if (peerAddr.ss_family == AF_INET6) { - struct sockaddr_in6* sa = (struct sockaddr_in6*)&selfAddr; - struct sockaddr_in6* sb = (struct sockaddr_in6*)&peerAddr; - addrlen = sizeof(struct sockaddr_in6); - rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr)); - if (rv == 0) { - rv = sa->sin6_port - sb->sin6_port; - } - } else { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("unknown sa_family"); - } - - if (rv == 0) { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("cannot connect to self"); - } - - is_client_ = rv > 0; - - // self_ < peer_; we are listening side. - if (!is_client_) { - waitUntilConnected(lock, true); - return; - } - - // self_ > peer_; we are connecting side. - // First destroy listening socket. - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - - // Create new socket to connect to peer. - fd_ = socket(peerAddr.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (fd_ == -1) { - signalAndThrowException(GLOO_ERROR_MSG("socket: ", strerror(errno))); - } - - // Set SO_REUSEADDR to signal that reuse of the source port is OK. - int on = 1; - rv = setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - if (rv == -1) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("setsockopt: ", strerror(errno))); - } - - // Connect to peer - rv = ::connect(fd_, (struct sockaddr*)&peerAddr, addrlen); - if (rv == -1 && errno != EINPROGRESS) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("connect: ", strerror(errno))); - } - - // Register with device so we're called when connection completes. - changeState(CONNECTING); - device_->registerDescriptor(fd_, EPOLLIN | EPOLLOUT, this); - - // Wait for connection to complete - waitUntilConnected(lock, true); -} - ssize_t Pair::prepareWrite( Op& op, const NonOwningPtr& buf, @@ -723,16 +659,6 @@ void Pair::handleEvents(int events) { return; } - if (state_ == LISTENING) { - handleListening(); - return; - } - - if (state_ == CONNECTING) { - handleConnecting(); - return; - } - GLOO_ENFORCE(false, "Unexpected state: ", state_); } @@ -761,77 +687,6 @@ void Pair::handleReadWrite(int events) { } } -void Pair::handleListening() { - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - int rv; - - rv = accept(fd_, (struct sockaddr*)&addr, &addrlen); - - // Close the listening file descriptor whether we've successfully connected - // or run into an error and will throw an exception. - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - fd_ = FD_INVALID; - - if (rv == -1) { - signalException(GLOO_ERROR_MSG("accept: ", strerror(errno))); - return; - } - - // Connected, replace file descriptor - fd_ = rv; - - // Common connection-made code - handleConnected(); -} - -void Pair::handleConnecting() { - int optval; - socklen_t optlen = sizeof(optval); - int rv; - - // Verify that connecting was successful - rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen); - GLOO_ENFORCE_NE(rv, -1); - if (optval != 0) { - signalException( - GLOO_ERROR_MSG("connect ", peer_.str(), ": ", strerror(optval))); - return; - } - - // Common connection-made code - handleConnected(); -} - -void Pair::handleConnected() { - int rv; - - // Reset addresses - self_ = Address::fromSockName(fd_); - peer_ = Address::fromPeerName(fd_); - - // Make sure socket is non-blocking - setSocketBlocking(fd_, false); - - int flag = 1; - socklen_t optlen = sizeof(flag); - rv = setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen); - GLOO_ENFORCE_NE(rv, -1); - - // Set timeout - struct timeval tv = {}; - tv.tv_sec = timeout_.count() / 1000; - tv.tv_usec = (timeout_.count() % 1000) * 1000; - rv = setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - GLOO_ENFORCE_NE(rv, -1); - rv = setsockopt(fd_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); - GLOO_ENFORCE_NE(rv, -1); - - device_->registerDescriptor(fd_, EPOLLIN, this); - changeState(CONNECTED); -} - // getBuffer must only be called when holding lock. Buffer* Pair::getBuffer(int slot) { for (;;) { @@ -874,19 +729,9 @@ void Pair::changeState(state nextState) noexcept { if (nextState == CLOSED) { switch (state_) { case INITIALIZING: - // This state persists from construction up to the point where - // Pair::listen sets fd_ and calls listen(2). If this fails, - // it takes care of cleaning up the socket itself. + // Initial state upon construction. // There is no additional cleanup needed here. break; - case LISTENING: - // The pair may be in the LISTENING state when it is destructed. - if (fd_ != FD_INVALID) { - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - fd_ = FD_INVALID; - } - break; case CONNECTING: // The pair may be in the CONNECTING state when it is destructed. if (fd_ != FD_INVALID) { diff --git a/gloo/transport/tcp/pair.h b/gloo/transport/tcp/pair.h index a81270061..033301a0e 100644 --- a/gloo/transport/tcp/pair.h +++ b/gloo/transport/tcp/pair.h @@ -30,6 +30,8 @@ #include "gloo/transport/pair.h" #include "gloo/transport/tcp/address.h" #include "gloo/transport/tcp/device.h" +#include "gloo/transport/tcp/error.h" +#include "gloo/transport/tcp/socket.h" namespace gloo { namespace transport { @@ -83,7 +85,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { protected: enum state { INITIALIZING = 1, - LISTENING = 2, CONNECTING = 3, CONNECTED = 4, CLOSED = 5, @@ -170,7 +171,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { Address self_; Address peer_; - bool is_client_; std::mutex m_; std::condition_variable cv_; @@ -191,8 +191,7 @@ class Pair : public ::gloo::transport::Pair, public Handler { void sendNotifyRecvReady(uint64_t slot, size_t nbytes); void sendNotifySendReady(uint64_t slot, size_t nbytes); - void listen(); - void connect(const Address& peer); + void connectCallback(std::shared_ptr socket, Error error); Buffer* getBuffer(int slot); void registerBuffer(Buffer* buf); @@ -277,28 +276,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { // virtual void handleReadWrite(int events); - // Finishes connection setup if this side of the pair is on the - // listening side of connection initiation. This is called from - // `handleEvents` if the listening file descriptor is readable, i.e. - // if there is an incoming connection. - // - // The pair mutex is expected to be held when called. - // - void handleListening(); - - // Finishes connection setup if this side of the pair is on the - // connecting side of the connection initiation. This is called from - // `handleEvents` if the file descriptor associated with the - // connection is writable or in an error state, i.e. the connection - // has been established or failed to establish. - // - // The pair mutex is expected to be held when called. - // - void handleConnecting(); - - // Helper function called from `handleListening` or `handleConnecting`. - void handleConnected(); - // Advances this pair's state. See the `Pair::state` enum for // possible states. State can only move forward, i.e. from // initializing, to connected, to closed. diff --git a/gloo/transport/tcp/tls/pair.cc b/gloo/transport/tcp/tls/pair.cc index 176fa5e91..ce49cf735 100644 --- a/gloo/transport/tcp/tls/pair.cc +++ b/gloo/transport/tcp/tls/pair.cc @@ -194,7 +194,7 @@ bool Pair::read() { } void Pair::handleReadWrite(int events) { - if (!is_ssl_connected_ && !is_client_) { + if (!is_ssl_connected_ && !device_->isInitiator(self_, peer_)) { if (ssl_ == nullptr) { GLOO_ENFORCE(ssl_ctx_ != nullptr); ssl_ = _glootls::SSL_new(ssl_ctx_); @@ -239,7 +239,7 @@ void Pair::waitUntilConnected(std::unique_lock &lock, ::gloo::transport::tcp::Pair::waitUntilConnected(lock, useTimeout); if (!is_ssl_connected_) { - if (is_client_) { + if (device_->isInitiator(self_, peer_)) { GLOO_ENFORCE(ssl_ == nullptr); GLOO_ENFORCE(ssl_ctx_ != nullptr); ssl_ = _glootls::SSL_new(ssl_ctx_);