Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gloo/common/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <climits>
#include <exception>
#include <functional>
#include <iostream>
#include <limits>
#include <vector>

Expand Down Expand Up @@ -156,4 +157,7 @@ BINARY_COMP_HELPER(LessEquals, <=)
#define GLOO_ENFORCE_GT(x, y, ...) \
GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__)

#define GLOO_ERROR(...) \
std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl

} // namespace gloo
10 changes: 10 additions & 0 deletions gloo/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,14 @@ bool isStoreExtendedApiEnabled() {
(std::string(res) == "True" || std::string(res) == "1");
}

bool disableConnectionRetries() {
// use meyer singleton to only compute this exactly once.
static bool disable = []() {
const auto& res = std::getenv("GLOO_DISABLE_CONNECTION_RETRIES");
return res != nullptr &&
(std::string(res) == "True" || std::string(res) == "1");
}();
return disable;
}

} // namespace gloo
2 changes: 2 additions & 0 deletions gloo/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ bool useRankAsSeqNumber();

bool isStoreExtendedApiEnabled();

bool disableConnectionRetries();

} // namespace gloo
1 change: 1 addition & 0 deletions gloo/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
"${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/multiproc_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/transport_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/tcp_test.cc"
)
list(APPEND GLOO_TEST_LIBRARIES rt)
endif()
Expand Down
36 changes: 36 additions & 0 deletions gloo/test/tcp_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <gtest/gtest.h>

#include <gloo/transport/tcp/helpers.h>
#include <gloo/transport/tcp/loop.h>

namespace gloo {
namespace transport {
namespace tcp {

TEST(TcpTest, ConnectTimeout) {
auto loop = std::make_shared<Loop>();

std::mutex m;
std::condition_variable cv;
bool done = false;

// Use bad address
auto remote = Address("::1", 10);
auto timeout = std::chrono::milliseconds(100);
auto fn = [&](std::shared_ptr<Socket>, const Error& e) {
std::lock_guard<std::mutex> lock(m);
done = true;
cv.notify_all();

EXPECT_TRUE(e);
EXPECT_TRUE(dynamic_cast<const TimeoutError*>(&e));
};
connectLoop(loop, remote, timeout, std::move(fn));

std::unique_lock<std::mutex> lock(m);
cv.wait(lock, [&] { return done; });
}

} // namespace tcp
} // namespace transport
} // namespace gloo
1 change: 1 addition & 0 deletions gloo/transport/tcp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ else()
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/helpers.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/listener.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/loop.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"
Expand Down
23 changes: 23 additions & 0 deletions gloo/transport/tcp/address.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@ Address::Address(const struct sockaddr* addr, size_t addrlen) {
memcpy(&impl_.ss, addr, addrlen);
}

Address::Address(const std::string& ip, uint16_t port, sequence_number_t seq) {
if (ip.empty()) {
throw std::invalid_argument("Invalid IP address");
}
sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&impl_.ss);
sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&impl_.ss);
// Check if the IP address is an IPv4 or IPv6 address
if (inet_pton(AF_INET, ip.c_str(), &addr4->sin_addr) == 1) {
// IPv4 address
addr4->sin_family = AF_INET;
addr4->sin_port = htons(port);
} else if (inet_pton(AF_INET6, ip.c_str(), &addr6->sin6_addr) == 1) {
// IPv6 address
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
} else {
throw std::invalid_argument("Invalid IP address");
}

// Store sequence number
impl_.seq = seq;
}

Address& Address::operator=(Address&& other) {
std::lock_guard<std::mutex> lock(m_);
impl_.ss = std::move(other.impl_.ss);
Expand Down
13 changes: 11 additions & 2 deletions gloo/transport/tcp/address.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

#pragma once

#include <sys/socket.h>
#include <unistd.h>
#include <mutex>

#ifdef _WIN32
#include "gloo/common/win.h" // @manual
#else
#include <sys/socket.h>
#endif

#include "gloo/transport/address.h"

namespace gloo {
Expand All @@ -32,6 +36,11 @@ class Address : public ::gloo::transport::Address {

explicit Address(const std::vector<char>&);

explicit Address(
const std::string& ip,
uint16_t port,
sequence_number_t seq = -1);

Address(const Address& other);

Address& operator=(Address&& other);
Expand Down
46 changes: 33 additions & 13 deletions gloo/transport/tcp/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "gloo/common/error.h"
#include "gloo/common/linux.h"
#include "gloo/common/logging.h"
#include "gloo/common/utils.h"
#include "gloo/transport/tcp/context.h"
#include "gloo/transport/tcp/helpers.h"
#include "gloo/transport/tcp/pair.h"
Expand Down Expand Up @@ -334,20 +335,39 @@ void Device::connectAsListener(
//
void Device::connectAsInitiator(
const Address& remote,
std::chrono::milliseconds /* unused */,
std::chrono::milliseconds timeout,
connect_callback_t fn) {
const auto& sockaddr = remote.getSockaddr();

// Create new socket to connect to peer.
auto socket = Socket::createForFamily(sockaddr.ss_family);
socket->reuseAddr(true);
socket->noDelay(true);
socket->connect(sockaddr);

// Write sequence number for peer to new socket.
// TODO(pietern): Use timeout.
write<sequence_number_t>(
loop_, std::move(socket), remote.getSeq(), std::move(fn));
auto writeSeq = [loop = loop_, seq = remote.getSeq()](
std::shared_ptr<Socket> socket, connect_callback_t fn) {
// Write sequence number for peer to new socket.
write<sequence_number_t>(loop, std::move(socket), seq, std::move(fn));
};

if (disableConnectionRetries()) {
const auto& sockaddr = remote.getSockaddr();

// Create new socket to connect to peer.
auto socket = Socket::createForFamily(sockaddr.ss_family);
socket->reuseAddr(true);
socket->noDelay(true);
socket->connect(sockaddr);

writeSeq(std::move(socket), std::move(fn));
} else {
connectLoop(
loop_,
remote,
timeout,
[loop = loop_, fn = std::move(fn), writeSeq = std::move(writeSeq)](
std::shared_ptr<Socket> socket, const Error& error) {
if (error) {
fn(socket, error);
return;
}

writeSeq(std::move(socket), std::move(fn));
});
}
}

} // namespace tcp
Expand Down
14 changes: 11 additions & 3 deletions gloo/transport/tcp/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,32 @@ std::string Error::what() const {

std::string SystemError::what() const {
std::ostringstream ss;
ss << syscall_ << ": " << strerror(error_);
ss << syscall_ << ": " << strerror(error_) << ", remote=" << remote_.str();
return ss.str();
}

std::string ShortReadError::what() const {
std::ostringstream ss;
ss << "short read: got " << actual_ << " bytes while expecting to read "
<< expected_ << " bytes";
<< expected_ << " bytes, remote=" << remote_.str();
return ss.str();
}

std::string ShortWriteError::what() const {
std::ostringstream ss;
ss << "short write: wrote " << actual_ << " bytes while expecting to write "
<< expected_ << " bytes";
<< expected_ << " bytes, remote=" << remote_.str();
return ss.str();
}

std::string TimeoutError::what() const {
return msg_;
}

std::string LoopError::what() const {
return msg_;
}

} // namespace tcp
} // namespace transport
} // namespace gloo
45 changes: 39 additions & 6 deletions gloo/transport/tcp/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <gloo/transport/tcp/address.h>
#include <string>

namespace gloo {
Expand Down Expand Up @@ -52,38 +53,70 @@ class Error {

class SystemError : public Error {
public:
explicit SystemError(const char* syscall, int error)
: Error(true), syscall_(syscall), error_(error) {}
explicit SystemError(const char* syscall, int error, Address remote)
: Error(true),
syscall_(syscall),
error_(error),
remote_(std::move(remote)) {}

std::string what() const override;

private:
const char* syscall_;
const int error_;
const Address remote_;
};

class ShortReadError : public Error {
public:
ShortReadError(ssize_t expected, ssize_t actual)
: Error(true), expected_(expected), actual_(actual) {}
ShortReadError(ssize_t expected, ssize_t actual, Address remote)
: Error(true),
expected_(expected),
actual_(actual),
remote_(std::move(remote)) {}

std::string what() const override;

private:
const ssize_t expected_;
const ssize_t actual_;
const Address remote_;
};

class ShortWriteError : public Error {
public:
ShortWriteError(ssize_t expected, ssize_t actual)
: Error(true), expected_(expected), actual_(actual) {}
ShortWriteError(ssize_t expected, ssize_t actual, Address remote)
: Error(true),
expected_(expected),
actual_(actual),
remote_(std::move(remote)) {}

std::string what() const override;

private:
const ssize_t expected_;
const ssize_t actual_;
const Address remote_;
};

class TimeoutError : public Error {
public:
explicit TimeoutError(std::string msg) : Error(true), msg_(std::move(msg)) {}

std::string what() const override;

private:
const std::string msg_;
};

class LoopError : public Error {
public:
explicit LoopError(std::string msg) : Error(true), msg_(std::move(msg)) {}

std::string what() const override;

private:
const std::string msg_;
};

} // namespace tcp
Expand Down
19 changes: 19 additions & 0 deletions gloo/transport/tcp/helpers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <gloo/transport/tcp/helpers.h>

namespace gloo {
namespace transport {
namespace tcp {

void connectLoop(
std::shared_ptr<Loop> loop,
const Address& remote,
std::chrono::milliseconds timeout,
typename ConnectOperation::callback_t fn) {
auto x = std::make_shared<ConnectOperation>(
std::move(loop), remote, timeout, std::move(fn));
x->run();
}

} // namespace tcp
} // namespace transport
} // namespace gloo
Loading
Loading