From b0b3bdf3a8e1e7be4a9e1a6f2feb69e50d69343a Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Sun, 14 May 2023 16:52:28 -0700 Subject: [PATCH] Use a single listening socket per device Summary: For listening port, the mesh connection currently leads to O(mn) port usage where m is the num of ranks per host and n is the total num of ranks. Even when `SO_REUSEADDR` is set, this only allows those used by sockets in `ESTABLISHED` or `TIME-WAIT` state to be reused. Hence in large training jobs, or even testing env where a lot of processes are packed on the same machine, we would soon run out of ephemeral ports (e.g. a local 200-process would need 40k ephemeral ports just for listening which is obviously very inefficient and most likely outside the range of allowed ephemeral ports in linux systems, which is typically around 32K). We fix this by using a single listening socket per device instance instead of using one per pair. Connections to all pair instances are multiplexed on a single listening socket by adding a sequence number to the address struct. For ranks packed on the same host with the same interface address, we use a seq number to differentiate between those so each would have a unique `Address` object assoc. During actual connection, each pair would have one side as `Initiator` and the other as `Listener`. We assign the roles purely based on arbitrary address comparison logic. The exact result doesn't matter since TCP is bidirectional, so long as they are consistent for a pair. The initiator will connect to the listed address and write a few bytes containing the sequence number. The listener waits for a connection to the shared listening socket where it can read that same sequence number. Once the listener side establishes the connection, that `Pair` would get promoted via the deferred callback to handle the actual connection post rendezvous. Credit to original author: Pieter Noordhuis pietern This diff cleans up a few things and resolves conflicts. Differential Revision: D45437709 fbshipit-source-id: 193fecb7d58e1d3a3acce82614f62d56865c2251 --- gloo/transport/tcp/CMakeLists.txt | 3 + gloo/transport/tcp/device.cc | 101 ++++++++++ gloo/transport/tcp/device.h | 50 +++++ gloo/transport/tcp/helpers.h | 171 +++++++++++++++++ gloo/transport/tcp/listener.cc | 121 ++++++++++++ gloo/transport/tcp/listener.h | 74 ++++++++ gloo/transport/tcp/pair.cc | 297 +++++++----------------------- gloo/transport/tcp/pair.h | 29 +-- gloo/transport/tcp/tls/pair.cc | 4 +- 9 files changed, 596 insertions(+), 254 deletions(-) create mode 100644 gloo/transport/tcp/helpers.h create mode 100644 gloo/transport/tcp/listener.cc create mode 100644 gloo/transport/tcp/listener.h diff --git a/gloo/transport/tcp/CMakeLists.txt b/gloo/transport/tcp/CMakeLists.txt index 9bf72b83f..9cb6535af 100644 --- a/gloo/transport/tcp/CMakeLists.txt +++ b/gloo/transport/tcp/CMakeLists.txt @@ -7,6 +7,7 @@ else() "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" "${CMAKE_CURRENT_SOURCE_DIR}/device.cc" "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/listener.cc" "${CMAKE_CURRENT_SOURCE_DIR}/loop.cc" "${CMAKE_CURRENT_SOURCE_DIR}/pair.cc" "${CMAKE_CURRENT_SOURCE_DIR}/socket.cc" @@ -19,6 +20,8 @@ else() "${CMAKE_CURRENT_SOURCE_DIR}/context.h" "${CMAKE_CURRENT_SOURCE_DIR}/device.h" "${CMAKE_CURRENT_SOURCE_DIR}/error.h" + "${CMAKE_CURRENT_SOURCE_DIR}/helpers.h" + "${CMAKE_CURRENT_SOURCE_DIR}/listener.h" "${CMAKE_CURRENT_SOURCE_DIR}/loop.h" "${CMAKE_CURRENT_SOURCE_DIR}/pair.h" "${CMAKE_CURRENT_SOURCE_DIR}/socket.h" diff --git a/gloo/transport/tcp/device.cc b/gloo/transport/tcp/device.cc index 481d67f7d..dd970b60f 100644 --- a/gloo/transport/tcp/device.cc +++ b/gloo/transport/tcp/device.cc @@ -19,6 +19,7 @@ #include "gloo/common/logging.h" #include "gloo/common/error.h" #include "gloo/transport/tcp/context.h" +#include "gloo/transport/tcp/helpers.h" #include "gloo/transport/tcp/pair.h" namespace gloo { @@ -217,6 +218,7 @@ const std::string sockaddrToInterfaceName(const struct attr& attr) { Device::Device(const struct attr& attr) : attr_(attr), loop_(std::make_shared()), + listener_(std::make_shared(loop_, attr)), interfaceName_(sockaddrToInterfaceName(attr_)), interfaceSpeedMbps_(getInterfaceSpeedByName(interfaceName_)), pciBusID_(interfaceToBusID(interfaceName_)) { @@ -257,6 +259,105 @@ void Device::unregisterDescriptor(int fd, Handler* h) { loop_->unregisterDescriptor(fd, h); } +Address Device::nextAddress() { + return listener_->nextAddress(); +} + +bool Device::isInitiator( + const Address& local, + const Address& remote) const { + int rv = 0; + // The remote side of a pair will be called with the same + // addresses, but in reverse. There should only be a single + // connection between the two, so we pick one side as the listener + // and the other side as the connector. + const auto& ss1 = local.getSockaddr(); + const auto& ss2 = remote.getSockaddr(); + GLOO_ENFORCE_EQ(ss1.ss_family, ss2.ss_family); + const int family = ss1.ss_family; + if (family == AF_INET) { + const struct sockaddr_in* sa = (struct sockaddr_in*)&ss1; + const struct sockaddr_in* sb = (struct sockaddr_in*)&ss2; + rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr)); + if (rv == 0) { + rv = sa->sin_port - sb->sin_port; + } + } else if (family == AF_INET6) { + const struct sockaddr_in6* sa = (struct sockaddr_in6*)&ss1; + const struct sockaddr_in6* sb = (struct sockaddr_in6*)&ss2; + rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr)); + if (rv == 0) { + rv = sa->sin6_port - sb->sin6_port; + } + } else { + GLOO_ENFORCE(false, "Unknown address family: ", family); + } + + // If both sides of the pair use the same address and port, they are + // sharing the same device instance. This happens in tests. Compare + // sequence number to allow pairs to connect. + if (rv == 0) { + rv = local.getSeq() - remote.getSeq(); + } + GLOO_ENFORCE_NE(rv, 0, "Cannot connect to self"); + return rv > 0; +} + +void Device::connect( + const Address& local, + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn) { + auto initiator = isInitiator(local, remote); + + if (initiator) { + connectAsInitiator(remote, timeout, std::move(fn)); + return; + } + connectAsListener(local, timeout, std::move(fn)); +} + +// Connecting as listener is passive. +// +// Register the connect callback to be executed when the other side of +// the pair has connected and identified itself as destined for this +// address. To do so, we register the callback for the sequence number +// associated with the address. If this connection already exists, +// deal with it here. +// +void Device::connectAsListener( + const Address& local, + std::chrono::milliseconds /* unused */, + connect_callback_t fn) { + // TODO(pietern): Use timeout. + listener_->waitForConnection(local.getSeq(), std::move(fn)); +} + +// Connecting as initiator is active. +// +// The connect callback is fired when the connection to the other side +// of the pair has been made, and the sequence number for this +// connection has been written. If an error occurs at any time, the +// callback is called with an associated error event. +// +void Device::connectAsInitiator( + const Address& remote, + std::chrono::milliseconds /* unused */, + connect_callback_t fn) { + const auto& sockaddr = remote.getSockaddr(); + + // Create new socket to connect to peer. + auto socket = Socket::createForFamily(sockaddr.ss_family); + socket->reuseAddr(true); + socket->noDelay(true); + socket->connect(sockaddr); + + // Write sequence number for peer to new socket. + // TODO(pietern): Use timeout. + write( + loop_, std::move(socket), remote.getSeq(), std::move(fn)); +} + } // namespace tcp } // namespace transport } // namespace gloo diff --git a/gloo/transport/tcp/device.h b/gloo/transport/tcp/device.h index ef0ffcaba..1a6f8e249 100644 --- a/gloo/transport/tcp/device.h +++ b/gloo/transport/tcp/device.h @@ -17,7 +17,10 @@ #include #include +#include +#include #include +#include namespace gloo { namespace transport { @@ -50,14 +53,61 @@ class Device : public ::gloo::transport::Device, void registerDescriptor(int fd, int events, Handler* h); void unregisterDescriptor(int fd, Handler* h); + // TCP is bidirectional so when we connect two ends of a pair, + // one side is the connection initiator and the other is the listener. + bool isInitiator( + const Address& local, + const Address& remote) const; + protected: const struct attr attr_; + // Return a new `Address` instance. + // + // This is called by the constructor of the `Pair` class. It gives + // the pair a uniquely identifying address even though the device + // uses a shared listening socket. + // + Address nextAddress(); + + // Connect a pair to a remote. + // + // This is performed by the device instance because we use a single + // listening socket for all inbound pair connections. + // + // Matching these connections with pairs is done with a handshake. + // The remote side of the connection writes a sequence number (see + // `Address::sequence_t`) to the stream that identifies the pair + // it wants to connect to. On the local side, this sequence number + // is read and used as key in a map with callbacks. If the callback + // is found, it is called. If the callback is not found, the + // connection is cached in a map, using the sequence number. + // + using connect_callback_t = + std::function socket, Error error)>; + + void connect( + const Address& local, + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn); + + void connectAsListener( + const Address& local, + std::chrono::milliseconds timeout, + connect_callback_t fn); + + void connectAsInitiator( + const Address& remote, + std::chrono::milliseconds timeout, + connect_callback_t fn); + friend class Pair; friend class Buffer; private: std::shared_ptr loop_; + std::shared_ptr listener_; std::string interfaceName_; int interfaceSpeedMbps_; diff --git a/gloo/transport/tcp/helpers.h b/gloo/transport/tcp/helpers.h new file mode 100644 index 000000000..81c3a8df0 --- /dev/null +++ b/gloo/transport/tcp/helpers.h @@ -0,0 +1,171 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +// ReadValueOperation asynchronously reads a value of type T from the +// socket specified at construction. Upon completion or error, the +// callback is called. Its lifetime is coupled with completion of the +// operation, so the called doesn't need to hold on to the instance. +// It does so by storing a shared_ptr to itself (effectively a leak) +// until the event loop calls back. +template +class ReadValueOperation final + : public Handler, + public std::enable_shared_from_this> { + public: + using callback_t = + std::function, const Error& error, T&& t)>; + + ReadValueOperation( + std::shared_ptr loop, + std::shared_ptr socket, + callback_t fn) + : loop_(std::move(loop)), + socket_(std::move(socket)), + fn_(std::move(fn)) {} + + void run() { + // Cannot initialize leak until after the object has been + // constructed, because the std::make_shared initialization + // doesn't run after construction of the underlying object. + leak_ = this->shared_from_this(); + // Register with loop only after we've leaked the shared_ptr, + // because we unleak it when the event loop thread calls. + loop_->registerDescriptor(socket_->fd(), EPOLLIN | EPOLLONESHOT, this); + } + + void handleEvents(int events) override { + // Move leaked shared_ptr to the stack so that this object + // destroys itself once this function returns. + auto self = std::move(this->leak_); + + // Read T. + auto rv = socket_->read(&t_, sizeof(t_)); + if (rv == -1) { + fn_(socket_, SystemError("read", errno), std::move(t_)); + return; + } + + // Check for short read (assume we can read in a single call). + if (rv < sizeof(t_)) { + fn_(socket_, ShortReadError(rv, sizeof(t_)), std::move(t_)); + return; + } + + fn_(socket_, Error::kSuccess, std::move(t_)); + } + + private: + std::shared_ptr loop_; + std::shared_ptr socket_; + callback_t fn_; + std::shared_ptr> leak_; + + T t_; +}; + +template +void read( + std::shared_ptr loop, + std::shared_ptr socket, + typename ReadValueOperation::callback_t fn) { + auto x = std::make_shared>( + std::move(loop), std::move(socket), std::move(fn)); + x->run(); +} + +// WriteValueOperation asynchronously writes a value of type T to the +// socket specified at construction. Upon completion or error, the +// callback is called. Its lifetime is coupled with completion of the +// operation, so the called doesn't need to hold on to the instance. +// It does so by storing a shared_ptr to itself (effectively a leak) +// until the event loop calls back. +template +class WriteValueOperation final + : public Handler, + public std::enable_shared_from_this> { + public: + using callback_t = + std::function, const Error& error)>; + + WriteValueOperation( + std::shared_ptr loop, + std::shared_ptr socket, + T t, + callback_t fn) + : loop_(std::move(loop)), + socket_(std::move(socket)), + fn_(std::move(fn)), + t_(std::move(t)) {} + + void run() { + // Cannot initialize leak until after the object has been + // constructed, because the std::make_shared initialization + // doesn't run after construction of the underlying object. + leak_ = this->shared_from_this(); + // Register with loop only after we've leaked the shared_ptr, + // because we unleak it when the event loop thread calls. + loop_->registerDescriptor(socket_->fd(), EPOLLOUT | EPOLLONESHOT, this); + } + + void handleEvents(int events) override { + // Move leaked shared_ptr to the stack so that this object + // destroys itself once this function returns. + auto leak = std::move(this->leak_); + + // Write T. + auto rv = socket_->write(&t_, sizeof(t_)); + if (rv == -1) { + fn_(socket_, SystemError("write", errno)); + return; + } + + // Check for short write (assume we can write in a single call). + if (rv < sizeof(t_)) { + fn_(socket_, ShortWriteError(rv, sizeof(t_))); + return; + } + + fn_(socket_, Error::kSuccess); + } + + private: + std::shared_ptr loop_; + std::shared_ptr socket_; + callback_t fn_; + std::shared_ptr> leak_; + + T t_; +}; + +template +void write( + std::shared_ptr loop, + std::shared_ptr socket, + T t, + typename WriteValueOperation::callback_t fn) { + auto x = std::make_shared>( + std::move(loop), std::move(socket), std::move(t), std::move(fn)); + x->run(); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/listener.cc b/gloo/transport/tcp/listener.cc new file mode 100644 index 000000000..2caa7364a --- /dev/null +++ b/gloo/transport/tcp/listener.cc @@ -0,0 +1,121 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +Listener::Listener(std::shared_ptr loop, const attr& attr) + : loop_(std::move(loop)) { + listener_ = Socket::createForFamily(attr.ai_addr.ss_family); + listener_->reuseAddr(true); + listener_->bind(attr.ai_addr); + listener_->listen(kBacklog); + addr_ = listener_->sockName(); + + // Register with loop for readability events. + loop_->registerDescriptor(listener_->fd(), EPOLLIN, this); +} + +Listener::~Listener() { + if (listener_) { + loop_->unregisterDescriptor(listener_->fd(), this); + } +} + +void Listener::handleEvents(int /* unused */) { + std::lock_guard guard(mutex_); + + for (;;) { + auto sock = listener_->accept(); + if (!sock) { + // Let the loop try again on the next tick. + if (errno == EAGAIN) { + return; + } + // Actual error. + GLOO_ENFORCE(false, "accept: ", strerror(errno)); + } + + sock->reuseAddr(true); + sock->noDelay(true); + + // Read sequence number. + read( + loop_, + sock, + [this]( + std::shared_ptr socket, + const Error& error, + sequence_number_t&& seq) { + // If there was an error reading from the socket, the + // sequence number will be bogus, and we can't route it to + // the right callback. Ignore it. + if (error) { + return; + } + + haveConnection(std::move(socket), seq); + }); + } +} + +Address Listener::nextAddress() { + std::lock_guard guard(mutex_); + return Address(addr_.getSockaddr(), seq_++); +} + +void Listener::waitForConnection(sequence_number_t seq, connect_callback_t fn) { + std::unique_lock lock(mutex_); + + // If we don't yet have an fd for this sequence number, persist callback. + auto it = seqToSocket_.find(seq); + if (it == seqToSocket_.end()) { + seqToCallback_.emplace(seq, std::move(fn)); + return; + } + + // If we already have an fd for this sequence number, schedule callback. + auto socket = std::move(it->second); + seqToSocket_.erase(it); + loop_->defer([fn, socket]() { fn(socket, Error::kSuccess); }); +} + +void Listener::haveConnection( + std::shared_ptr socket, + sequence_number_t seq) { + std::unique_lock lock(mutex_); + + // If we don't yet have a callback for this sequence number, persist socket. + auto it = seqToCallback_.find(seq); + if (it == seqToCallback_.end()) { + seqToSocket_.emplace(seq, std::move(socket)); + return; + } + + // If we already have a callback for this sequence number, trigger it. + auto fn = std::move(it->second); + seqToCallback_.erase(it); + lock.unlock(); + fn(std::move(socket), Error::kSuccess); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/listener.h b/gloo/transport/tcp/listener.h new file mode 100644 index 000000000..e7aa3c711 --- /dev/null +++ b/gloo/transport/tcp/listener.h @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +// Listener deals with incoming connections. Incoming connections +// write a few bytes containing a sequence number. This sequence +// number is read off the socket and matched to a local sequence +// number. If there is a match, the socket is passed to the +// corresponding pair. If it can't be matched, it is stashed until a +// pair with the sequence number calls `waitForConnection`. +class Listener final : public Handler { + public: + using connect_callback_t = + std::function socket, Error error)>; + + static constexpr int kBacklog = 2048; + + Listener(std::shared_ptr loop, const attr& attr); + + ~Listener() override; + + void handleEvents(int events) override; + + Address nextAddress(); + + // Wait for connection with sequence number `seq`. The callback is + // always called from a different thread (the event loop thread), + // even if the connection is already available. + void waitForConnection(sequence_number_t seq, connect_callback_t fn); + + private: + std::mutex mutex_; + std::shared_ptr loop_; + std::shared_ptr listener_; + + // Address of this listener and the sequence number for the next + // connection. Sequence numbers are written by a peer right after + // establishing a new connection and used locally to match a new + // connection to a pair instance. + Address addr_; + sequence_number_t seq_{0}; + + // Called when we've read a sequence number from a new socket. + void haveConnection(std::shared_ptr socket, sequence_number_t seq); + + // Callbacks by sequence number (while waiting for a connection). + std::unordered_map seqToCallback_; + + // Sockets by sequence number (while waiting for a pair to call). + std::unordered_map> seqToSocket_; +}; + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/pair.cc b/gloo/transport/tcp/pair.cc index 91bb7c1ef..d6cd1baa6 100644 --- a/gloo/transport/tcp/pair.cc +++ b/gloo/transport/tcp/pair.cc @@ -8,8 +8,8 @@ #include "gloo/transport/tcp/pair.h" -#include #include +#include #include #include @@ -61,10 +61,8 @@ Pair::Pair( busyPoll_(false), fd_(FD_INVALID), sendBufferSize_(0), - is_client_(false), - ex_(nullptr) { - listen(); -} + self_(device_->nextAddress()), + ex_(nullptr) {} // Destructor performs a "soft" close. Pair::~Pair() { @@ -99,8 +97,73 @@ const Address& Pair::address() const { } void Pair::connect(const std::vector& bytes) { - auto peer = Address(bytes); - connect(peer); + const auto peer = Address(bytes); + + std::unique_lock lock(m_); + GLOO_ENFORCE_EQ(state_, INITIALIZING); + state_ = CONNECTING; + + // Both processes call the `Pair::connect` function with the address + // of the other. The device instance associated with both `Pair` + // instances is responsible for establishing the actual connection, + // seeing as it owns the listening socket. + // + // One side takes a passive role and the other side takes an active + // role in establishing the connection. The passive role means + // waiting for an incoming connection that identifies itself with a + // specific sequence number (encoded in the `Address`). The active + // role means creating a connection to a specific address, and + // writing out a specific sequence number. Once the process for + // either role succeeds, the connection callback for the pair gets + // called with the file descriptor for the underlying connection. + // + device_->connect( + self_, + peer, + timeout_, + std::bind( + &Pair::connectCallback, + this, + std::placeholders::_1, + std::placeholders::_2)); + + // Wait for connection to be made. + // + // NOTE(pietern): This can be split out to a separate function so + // that we first initiate all connections and then wait on all of + // them. This should make context initialization a bit faster. It + // requires a change to the base class though, so let's so it after + // this new transport has been merged. + // + waitUntilConnected(lock, true); +} + +void Pair::connectCallback(std::shared_ptr socket, Error error) { + std::lock_guard lock(m_); + if (error) { + signalException(GLOO_ERROR_MSG(error.what())); + return; + } + + // Finalize setup. + socket->block(false); + socket->noDelay(true); + socket->sendTimeout(timeout_); + socket->recvTimeout(timeout_); + + // Reset addresses. + self_ = socket->sockName(); + peer_ = socket->peerName(); + + // Take over ownership of the socket's file descriptor. The code in + // this class works directly with file descriptor directly. + fd_ = socket->release(); + + // Register with loop for socket readability. + device_->registerDescriptor(fd_, EPOLLIN, this); + + // We're done: update state and wake up waiting threads. + changeState(CONNECTED); } static void setSocketBlocking(int fd, bool enable) { @@ -152,133 +215,6 @@ void Pair::setSync(bool sync, bool busyPoll) { busyPoll_ = busyPoll; } -void Pair::listen() { - std::lock_guard lock(m_); - int rv; - - const auto& attr = device_->attr_; - auto fd = socket(attr.ai_family, attr.ai_socktype, attr.ai_protocol); - if (fd == -1) { - signalAndThrowException(GLOO_ERROR_MSG("socket: ", strerror(errno))); - } - - // Set SO_REUSEADDR to signal that reuse of the listening port is OK. - int on = 1; - rv = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - if (rv == -1) { - ::close(fd); - signalAndThrowException(GLOO_ERROR_MSG("setsockopt: ", strerror(errno))); - } - - rv = bind(fd, (const sockaddr*)&attr.ai_addr, attr.ai_addrlen); - if (rv == -1) { - ::close(fd); - signalAndThrowException(GLOO_ERROR_MSG("bind: ", strerror(errno))); - } - - // listen(2) on socket - fd_ = fd; - rv = ::listen(fd_, 1024); - if (rv == -1) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("listen: ", strerror(errno))); - } - - // Keep copy of address - self_ = Address::fromSockName(fd); - - // Register with device so we're called when peer connects - changeState(LISTENING); - device_->registerDescriptor(fd_, EPOLLIN, this); - - return; -} - -void Pair::connect(const Address& peer) { - std::unique_lock lock(m_); - int rv; - socklen_t addrlen; - throwIfException(); - - peer_ = peer; - - const auto& selfAddr = self_.getSockaddr(); - const auto& peerAddr = peer_.getSockaddr(); - - // Addresses have to have same family - if (selfAddr.ss_family != peerAddr.ss_family) { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("address family mismatch"); - } - - if (selfAddr.ss_family == AF_INET) { - struct sockaddr_in* sa = (struct sockaddr_in*)&selfAddr; - struct sockaddr_in* sb = (struct sockaddr_in*)&peerAddr; - addrlen = sizeof(struct sockaddr_in); - rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr)); - if (rv == 0) { - rv = sa->sin_port - sb->sin_port; - } - } else if (peerAddr.ss_family == AF_INET6) { - struct sockaddr_in6* sa = (struct sockaddr_in6*)&selfAddr; - struct sockaddr_in6* sb = (struct sockaddr_in6*)&peerAddr; - addrlen = sizeof(struct sockaddr_in6); - rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr)); - if (rv == 0) { - rv = sa->sin6_port - sb->sin6_port; - } - } else { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("unknown sa_family"); - } - - if (rv == 0) { - GLOO_THROW_INVALID_OPERATION_EXCEPTION("cannot connect to self"); - } - - is_client_ = rv > 0; - - // self_ < peer_; we are listening side. - if (!is_client_) { - waitUntilConnected(lock, true); - return; - } - - // self_ > peer_; we are connecting side. - // First destroy listening socket. - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - - // Create new socket to connect to peer. - fd_ = socket(peerAddr.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (fd_ == -1) { - signalAndThrowException(GLOO_ERROR_MSG("socket: ", strerror(errno))); - } - - // Set SO_REUSEADDR to signal that reuse of the source port is OK. - int on = 1; - rv = setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - if (rv == -1) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("setsockopt: ", strerror(errno))); - } - - // Connect to peer - rv = ::connect(fd_, (struct sockaddr*)&peerAddr, addrlen); - if (rv == -1 && errno != EINPROGRESS) { - ::close(fd_); - fd_ = FD_INVALID; - signalAndThrowException(GLOO_ERROR_MSG("connect: ", strerror(errno))); - } - - // Register with device so we're called when connection completes. - changeState(CONNECTING); - device_->registerDescriptor(fd_, EPOLLIN | EPOLLOUT, this); - - // Wait for connection to complete - waitUntilConnected(lock, true); -} - ssize_t Pair::prepareWrite( Op& op, const NonOwningPtr& buf, @@ -723,16 +659,6 @@ void Pair::handleEvents(int events) { return; } - if (state_ == LISTENING) { - handleListening(); - return; - } - - if (state_ == CONNECTING) { - handleConnecting(); - return; - } - GLOO_ENFORCE(false, "Unexpected state: ", state_); } @@ -761,77 +687,6 @@ void Pair::handleReadWrite(int events) { } } -void Pair::handleListening() { - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - int rv; - - rv = accept(fd_, (struct sockaddr*)&addr, &addrlen); - - // Close the listening file descriptor whether we've successfully connected - // or run into an error and will throw an exception. - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - fd_ = FD_INVALID; - - if (rv == -1) { - signalException(GLOO_ERROR_MSG("accept: ", strerror(errno))); - return; - } - - // Connected, replace file descriptor - fd_ = rv; - - // Common connection-made code - handleConnected(); -} - -void Pair::handleConnecting() { - int optval; - socklen_t optlen = sizeof(optval); - int rv; - - // Verify that connecting was successful - rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen); - GLOO_ENFORCE_NE(rv, -1); - if (optval != 0) { - signalException( - GLOO_ERROR_MSG("connect ", peer_.str(), ": ", strerror(optval))); - return; - } - - // Common connection-made code - handleConnected(); -} - -void Pair::handleConnected() { - int rv; - - // Reset addresses - self_ = Address::fromSockName(fd_); - peer_ = Address::fromPeerName(fd_); - - // Make sure socket is non-blocking - setSocketBlocking(fd_, false); - - int flag = 1; - socklen_t optlen = sizeof(flag); - rv = setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen); - GLOO_ENFORCE_NE(rv, -1); - - // Set timeout - struct timeval tv = {}; - tv.tv_sec = timeout_.count() / 1000; - tv.tv_usec = (timeout_.count() % 1000) * 1000; - rv = setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - GLOO_ENFORCE_NE(rv, -1); - rv = setsockopt(fd_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); - GLOO_ENFORCE_NE(rv, -1); - - device_->registerDescriptor(fd_, EPOLLIN, this); - changeState(CONNECTED); -} - // getBuffer must only be called when holding lock. Buffer* Pair::getBuffer(int slot) { for (;;) { @@ -874,19 +729,9 @@ void Pair::changeState(state nextState) noexcept { if (nextState == CLOSED) { switch (state_) { case INITIALIZING: - // This state persists from construction up to the point where - // Pair::listen sets fd_ and calls listen(2). If this fails, - // it takes care of cleaning up the socket itself. + // Initial state upon construction. // There is no additional cleanup needed here. break; - case LISTENING: - // The pair may be in the LISTENING state when it is destructed. - if (fd_ != FD_INVALID) { - device_->unregisterDescriptor(fd_, this); - ::close(fd_); - fd_ = FD_INVALID; - } - break; case CONNECTING: // The pair may be in the CONNECTING state when it is destructed. if (fd_ != FD_INVALID) { diff --git a/gloo/transport/tcp/pair.h b/gloo/transport/tcp/pair.h index a81270061..033301a0e 100644 --- a/gloo/transport/tcp/pair.h +++ b/gloo/transport/tcp/pair.h @@ -30,6 +30,8 @@ #include "gloo/transport/pair.h" #include "gloo/transport/tcp/address.h" #include "gloo/transport/tcp/device.h" +#include "gloo/transport/tcp/error.h" +#include "gloo/transport/tcp/socket.h" namespace gloo { namespace transport { @@ -83,7 +85,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { protected: enum state { INITIALIZING = 1, - LISTENING = 2, CONNECTING = 3, CONNECTED = 4, CLOSED = 5, @@ -170,7 +171,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { Address self_; Address peer_; - bool is_client_; std::mutex m_; std::condition_variable cv_; @@ -191,8 +191,7 @@ class Pair : public ::gloo::transport::Pair, public Handler { void sendNotifyRecvReady(uint64_t slot, size_t nbytes); void sendNotifySendReady(uint64_t slot, size_t nbytes); - void listen(); - void connect(const Address& peer); + void connectCallback(std::shared_ptr socket, Error error); Buffer* getBuffer(int slot); void registerBuffer(Buffer* buf); @@ -277,28 +276,6 @@ class Pair : public ::gloo::transport::Pair, public Handler { // virtual void handleReadWrite(int events); - // Finishes connection setup if this side of the pair is on the - // listening side of connection initiation. This is called from - // `handleEvents` if the listening file descriptor is readable, i.e. - // if there is an incoming connection. - // - // The pair mutex is expected to be held when called. - // - void handleListening(); - - // Finishes connection setup if this side of the pair is on the - // connecting side of the connection initiation. This is called from - // `handleEvents` if the file descriptor associated with the - // connection is writable or in an error state, i.e. the connection - // has been established or failed to establish. - // - // The pair mutex is expected to be held when called. - // - void handleConnecting(); - - // Helper function called from `handleListening` or `handleConnecting`. - void handleConnected(); - // Advances this pair's state. See the `Pair::state` enum for // possible states. State can only move forward, i.e. from // initializing, to connected, to closed. diff --git a/gloo/transport/tcp/tls/pair.cc b/gloo/transport/tcp/tls/pair.cc index 176fa5e91..ce49cf735 100644 --- a/gloo/transport/tcp/tls/pair.cc +++ b/gloo/transport/tcp/tls/pair.cc @@ -194,7 +194,7 @@ bool Pair::read() { } void Pair::handleReadWrite(int events) { - if (!is_ssl_connected_ && !is_client_) { + if (!is_ssl_connected_ && !device_->isInitiator(self_, peer_)) { if (ssl_ == nullptr) { GLOO_ENFORCE(ssl_ctx_ != nullptr); ssl_ = _glootls::SSL_new(ssl_ctx_); @@ -239,7 +239,7 @@ void Pair::waitUntilConnected(std::unique_lock &lock, ::gloo::transport::tcp::Pair::waitUntilConnected(lock, useTimeout); if (!is_ssl_connected_) { - if (is_client_) { + if (device_->isInitiator(self_, peer_)) { GLOO_ENFORCE(ssl_ == nullptr); GLOO_ENFORCE(ssl_ctx_ != nullptr); ssl_ = _glootls::SSL_new(ssl_ctx_);