diff --git a/gloo/common/logging.h b/gloo/common/logging.h index 2d4cce014..7f04700c3 100644 --- a/gloo/common/logging.h +++ b/gloo/common/logging.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -156,4 +157,7 @@ BINARY_COMP_HELPER(LessEquals, <=) #define GLOO_ENFORCE_GT(x, y, ...) \ GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__) +#define GLOO_ERROR(...) \ + std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl + } // namespace gloo diff --git a/gloo/common/utils.cc b/gloo/common/utils.cc index d543d2a0f..ab075776c 100644 --- a/gloo/common/utils.cc +++ b/gloo/common/utils.cc @@ -42,4 +42,14 @@ bool isStoreExtendedApiEnabled() { (std::string(res) == "True" || std::string(res) == "1"); } +bool disableConnectionRetries() { + // use meyer singleton to only compute this exactly once. + static bool disable = []() { + const auto& res = std::getenv("GLOO_DISABLE_CONNECTION_RETRIES"); + return res != nullptr && + (std::string(res) == "True" || std::string(res) == "1"); + }(); + return disable; +} + } // namespace gloo diff --git a/gloo/common/utils.h b/gloo/common/utils.h index 185ebaf19..e7b477df9 100644 --- a/gloo/common/utils.h +++ b/gloo/common/utils.h @@ -18,4 +18,6 @@ bool useRankAsSeqNumber(); bool isStoreExtendedApiEnabled(); +bool disableConnectionRetries(); + } // namespace gloo diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index ea47fa238..bafe449fa 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -24,6 +24,7 @@ if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") "${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/multiproc_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/transport_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/tcp_test.cc" ) list(APPEND GLOO_TEST_LIBRARIES rt) endif() diff --git a/gloo/test/tcp_test.cc b/gloo/test/tcp_test.cc new file mode 100644 index 000000000..33e34f0b6 --- /dev/null +++ b/gloo/test/tcp_test.cc @@ -0,0 +1,36 @@ +#include + +#include +#include + +namespace gloo { +namespace transport { +namespace tcp { + +TEST(TcpTest, ConnectTimeout) { + auto loop = std::make_shared(); + + std::mutex m; + std::condition_variable cv; + bool done = false; + + // Use bad address + auto remote = Address("::1", 10); + auto timeout = std::chrono::milliseconds(100); + auto fn = [&](std::shared_ptr, const Error& e) { + std::lock_guard lock(m); + done = true; + cv.notify_all(); + + EXPECT_TRUE(e); + EXPECT_TRUE(dynamic_cast(&e)); + }; + connectLoop(loop, remote, timeout, std::move(fn)); + + std::unique_lock lock(m); + cv.wait(lock, [&] { return done; }); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/CMakeLists.txt b/gloo/transport/tcp/CMakeLists.txt index 9cb6535af..206fd3e09 100644 --- a/gloo/transport/tcp/CMakeLists.txt +++ b/gloo/transport/tcp/CMakeLists.txt @@ -7,6 +7,7 @@ else() "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" "${CMAKE_CURRENT_SOURCE_DIR}/device.cc" "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/helpers.cc" "${CMAKE_CURRENT_SOURCE_DIR}/listener.cc" "${CMAKE_CURRENT_SOURCE_DIR}/loop.cc" "${CMAKE_CURRENT_SOURCE_DIR}/pair.cc" diff --git a/gloo/transport/tcp/address.cc b/gloo/transport/tcp/address.cc index 0b2e976df..e37f7b493 100644 --- a/gloo/transport/tcp/address.cc +++ b/gloo/transport/tcp/address.cc @@ -28,6 +28,29 @@ Address::Address(const struct sockaddr* addr, size_t addrlen) { memcpy(&impl_.ss, addr, addrlen); } +Address::Address(const std::string& ip, uint16_t port, sequence_number_t seq) { + if (ip.empty()) { + throw std::invalid_argument("Invalid IP address"); + } + sockaddr_in* addr4 = reinterpret_cast(&impl_.ss); + sockaddr_in6* addr6 = reinterpret_cast(&impl_.ss); + // Check if the IP address is an IPv4 or IPv6 address + if (inet_pton(AF_INET, ip.c_str(), &addr4->sin_addr) == 1) { + // IPv4 address + addr4->sin_family = AF_INET; + addr4->sin_port = htons(port); + } else if (inet_pton(AF_INET6, ip.c_str(), &addr6->sin6_addr) == 1) { + // IPv6 address + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(port); + } else { + throw std::invalid_argument("Invalid IP address"); + } + + // Store sequence number + impl_.seq = seq; +} + Address& Address::operator=(Address&& other) { std::lock_guard lock(m_); impl_.ss = std::move(other.impl_.ss); diff --git a/gloo/transport/tcp/address.h b/gloo/transport/tcp/address.h index 19bae6fcc..bc6c98c10 100644 --- a/gloo/transport/tcp/address.h +++ b/gloo/transport/tcp/address.h @@ -8,10 +8,14 @@ #pragma once -#include -#include #include +#ifdef _WIN32 +#include "gloo/common/win.h" // @manual +#else +#include +#endif + #include "gloo/transport/address.h" namespace gloo { @@ -32,6 +36,11 @@ class Address : public ::gloo::transport::Address { explicit Address(const std::vector&); + explicit Address( + const std::string& ip, + uint16_t port, + sequence_number_t seq = -1); + Address(const Address& other); Address& operator=(Address&& other); diff --git a/gloo/transport/tcp/device.cc b/gloo/transport/tcp/device.cc index d7e725e74..f3f2b950f 100644 --- a/gloo/transport/tcp/device.cc +++ b/gloo/transport/tcp/device.cc @@ -17,6 +17,7 @@ #include "gloo/common/error.h" #include "gloo/common/linux.h" #include "gloo/common/logging.h" +#include "gloo/common/utils.h" #include "gloo/transport/tcp/context.h" #include "gloo/transport/tcp/helpers.h" #include "gloo/transport/tcp/pair.h" @@ -334,20 +335,39 @@ void Device::connectAsListener( // void Device::connectAsInitiator( const Address& remote, - std::chrono::milliseconds /* unused */, + std::chrono::milliseconds timeout, connect_callback_t fn) { - const auto& sockaddr = remote.getSockaddr(); - - // Create new socket to connect to peer. - auto socket = Socket::createForFamily(sockaddr.ss_family); - socket->reuseAddr(true); - socket->noDelay(true); - socket->connect(sockaddr); - - // Write sequence number for peer to new socket. - // TODO(pietern): Use timeout. - write( - loop_, std::move(socket), remote.getSeq(), std::move(fn)); + auto writeSeq = [loop = loop_, seq = remote.getSeq()]( + std::shared_ptr socket, connect_callback_t fn) { + // Write sequence number for peer to new socket. + write(loop, std::move(socket), seq, std::move(fn)); + }; + + if (disableConnectionRetries()) { + const auto& sockaddr = remote.getSockaddr(); + + // Create new socket to connect to peer. + auto socket = Socket::createForFamily(sockaddr.ss_family); + socket->reuseAddr(true); + socket->noDelay(true); + socket->connect(sockaddr); + + writeSeq(std::move(socket), std::move(fn)); + } else { + connectLoop( + loop_, + remote, + timeout, + [loop = loop_, fn = std::move(fn), writeSeq = std::move(writeSeq)]( + std::shared_ptr socket, const Error& error) { + if (error) { + fn(socket, error); + return; + } + + writeSeq(std::move(socket), std::move(fn)); + }); + } } } // namespace tcp diff --git a/gloo/transport/tcp/error.cc b/gloo/transport/tcp/error.cc index 43d105b23..c36ded046 100644 --- a/gloo/transport/tcp/error.cc +++ b/gloo/transport/tcp/error.cc @@ -23,24 +23,32 @@ std::string Error::what() const { std::string SystemError::what() const { std::ostringstream ss; - ss << syscall_ << ": " << strerror(error_); + ss << syscall_ << ": " << strerror(error_) << ", remote=" << remote_.str(); return ss.str(); } std::string ShortReadError::what() const { std::ostringstream ss; ss << "short read: got " << actual_ << " bytes while expecting to read " - << expected_ << " bytes"; + << expected_ << " bytes, remote=" << remote_.str(); return ss.str(); } std::string ShortWriteError::what() const { std::ostringstream ss; ss << "short write: wrote " << actual_ << " bytes while expecting to write " - << expected_ << " bytes"; + << expected_ << " bytes, remote=" << remote_.str(); return ss.str(); } +std::string TimeoutError::what() const { + return msg_; +} + +std::string LoopError::what() const { + return msg_; +} + } // namespace tcp } // namespace transport } // namespace gloo diff --git a/gloo/transport/tcp/error.h b/gloo/transport/tcp/error.h index 48c3f07a9..08b244569 100644 --- a/gloo/transport/tcp/error.h +++ b/gloo/transport/tcp/error.h @@ -8,6 +8,7 @@ #pragma once +#include #include namespace gloo { @@ -52,38 +53,70 @@ class Error { class SystemError : public Error { public: - explicit SystemError(const char* syscall, int error) - : Error(true), syscall_(syscall), error_(error) {} + explicit SystemError(const char* syscall, int error, Address remote) + : Error(true), + syscall_(syscall), + error_(error), + remote_(std::move(remote)) {} std::string what() const override; private: const char* syscall_; const int error_; + const Address remote_; }; class ShortReadError : public Error { public: - ShortReadError(ssize_t expected, ssize_t actual) - : Error(true), expected_(expected), actual_(actual) {} + ShortReadError(ssize_t expected, ssize_t actual, Address remote) + : Error(true), + expected_(expected), + actual_(actual), + remote_(std::move(remote)) {} std::string what() const override; private: const ssize_t expected_; const ssize_t actual_; + const Address remote_; }; class ShortWriteError : public Error { public: - ShortWriteError(ssize_t expected, ssize_t actual) - : Error(true), expected_(expected), actual_(actual) {} + ShortWriteError(ssize_t expected, ssize_t actual, Address remote) + : Error(true), + expected_(expected), + actual_(actual), + remote_(std::move(remote)) {} std::string what() const override; private: const ssize_t expected_; const ssize_t actual_; + const Address remote_; +}; + +class TimeoutError : public Error { + public: + explicit TimeoutError(std::string msg) : Error(true), msg_(std::move(msg)) {} + + std::string what() const override; + + private: + const std::string msg_; +}; + +class LoopError : public Error { + public: + explicit LoopError(std::string msg) : Error(true), msg_(std::move(msg)) {} + + std::string what() const override; + + private: + const std::string msg_; }; } // namespace tcp diff --git a/gloo/transport/tcp/helpers.cc b/gloo/transport/tcp/helpers.cc new file mode 100644 index 000000000..d30a813e5 --- /dev/null +++ b/gloo/transport/tcp/helpers.cc @@ -0,0 +1,19 @@ +#include + +namespace gloo { +namespace transport { +namespace tcp { + +void connectLoop( + std::shared_ptr loop, + const Address& remote, + std::chrono::milliseconds timeout, + typename ConnectOperation::callback_t fn) { + auto x = std::make_shared( + std::move(loop), remote, timeout, std::move(fn)); + x->run(); +} + +} // namespace tcp +} // namespace transport +} // namespace gloo diff --git a/gloo/transport/tcp/helpers.h b/gloo/transport/tcp/helpers.h index 81c3a8df0..af6da8d37 100644 --- a/gloo/transport/tcp/helpers.h +++ b/gloo/transport/tcp/helpers.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -59,13 +60,17 @@ class ReadValueOperation final // Read T. auto rv = socket_->read(&t_, sizeof(t_)); if (rv == -1) { - fn_(socket_, SystemError("read", errno), std::move(t_)); + fn_(socket_, + SystemError("read", errno, socket_->peerName()), + std::move(t_)); return; } // Check for short read (assume we can read in a single call). if (rv < sizeof(t_)) { - fn_(socket_, ShortReadError(rv, sizeof(t_)), std::move(t_)); + fn_(socket_, + ShortReadError(rv, sizeof(t_), socket_->peerName()), + std::move(t_)); return; } @@ -133,13 +138,13 @@ class WriteValueOperation final // Write T. auto rv = socket_->write(&t_, sizeof(t_)); if (rv == -1) { - fn_(socket_, SystemError("write", errno)); + fn_(socket_, SystemError("write", errno, socket_->peerName())); return; } // Check for short write (assume we can write in a single call). if (rv < sizeof(t_)) { - fn_(socket_, ShortWriteError(rv, sizeof(t_))); + fn_(socket_, ShortWriteError(rv, sizeof(t_), socket_->peerName())); return; } @@ -166,6 +171,107 @@ void write( x->run(); } +class ConnectOperation final + : public Handler, + public std::enable_shared_from_this { + public: + using callback_t = + std::function, const Error& error)>; + ConnectOperation( + std::shared_ptr loop, + const Address& remote, + std::chrono::milliseconds timeout, + callback_t fn) + : remote_(remote), + deadline_(std::chrono::steady_clock::now() + timeout), + loop_(std::move(loop)), + fn_(std::move(fn)) {} + + void run() { + // Cannot initialize leak until after the object has been + // constructed, because the std::make_shared initialization + // doesn't run after construction of the underlying object. + leak_ = this->shared_from_this(); + + const auto& sockaddr = remote_.getSockaddr(); + + // Create new socket to connect to peer. + socket_ = Socket::createForFamily(sockaddr.ss_family); + socket_->reuseAddr(true); + socket_->noDelay(true); + socket_->connect(sockaddr); + + // Register with loop only after we've leaked the shared_ptr, + // because we unleak it when the event loop thread calls. + // Register for EPOLLOUT, because we want to be notified when + // the connect completes. EPOLLERR is also necessary because + // connect() can fail. + if (auto loop = loop_.lock()) { + loop->registerDescriptor( + socket_->fd(), EPOLLOUT | EPOLLERR | EPOLLONESHOT, this); + } else { + fn_(socket_, LoopError("loop is gone")); + } + } + + void handleEvents(int events) override { + // Move leaked shared_ptr to the stack so that this object + // destroys itself once this function returns. + auto leak = std::move(this->leak_); + + int result; + socklen_t result_len = sizeof(result); + if (getsockopt(socket_->fd(), SOL_SOCKET, SO_ERROR, &result, &result_len) < + 0) { + fn_(socket_, SystemError("getsockopt", errno, remote_)); + return; + } + if (result != 0) { + SystemError e("SO_ERROR", result, remote_); + bool willRetry = std::chrono::steady_clock::now() < deadline_ && + retry_++ < maxRetries_; + GLOO_ERROR( + "failed to connect, willRetry=", + willRetry, + ", retry=", + retry_, + ", remote=", + remote_.str(), + ", error=", + e.what()); + // check deadline + if (willRetry) { + run(); + } else { + fn_(socket_, TimeoutError("timed out connecting: " + e.what())); + } + return; + } + + fn_(socket_, Error::kSuccess); + } + + private: + const Address remote_; + const std::chrono::time_point deadline_; + const int maxRetries_{3}; + + int retry_{0}; + + // We use a weak_ptr to the loop to avoid a reference cycle when an error + // occurs. + std::weak_ptr loop_; + std::shared_ptr socket_; + callback_t fn_; + std::shared_ptr leak_; +}; + +void connectLoop( + std::shared_ptr loop, + const Address& remote, + std::chrono::milliseconds timeout, + typename ConnectOperation::callback_t fn); + } // namespace tcp } // namespace transport } // namespace gloo