Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gloo/transport/tcp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ else()
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/listener.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/loop.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/socket.cc"
Expand All @@ -19,6 +20,8 @@ else()
"${CMAKE_CURRENT_SOURCE_DIR}/context.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device.h"
"${CMAKE_CURRENT_SOURCE_DIR}/error.h"
"${CMAKE_CURRENT_SOURCE_DIR}/helpers.h"
"${CMAKE_CURRENT_SOURCE_DIR}/listener.h"
"${CMAKE_CURRENT_SOURCE_DIR}/loop.h"
"${CMAKE_CURRENT_SOURCE_DIR}/pair.h"
"${CMAKE_CURRENT_SOURCE_DIR}/socket.h"
Expand Down
101 changes: 101 additions & 0 deletions gloo/transport/tcp/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "gloo/common/logging.h"
#include "gloo/common/error.h"
#include "gloo/transport/tcp/context.h"
#include "gloo/transport/tcp/helpers.h"
#include "gloo/transport/tcp/pair.h"

namespace gloo {
Expand Down Expand Up @@ -217,6 +218,7 @@ const std::string sockaddrToInterfaceName(const struct attr& attr) {
Device::Device(const struct attr& attr)
: attr_(attr),
loop_(std::make_shared<Loop>()),
listener_(std::make_shared<Listener>(loop_, attr)),
interfaceName_(sockaddrToInterfaceName(attr_)),
interfaceSpeedMbps_(getInterfaceSpeedByName(interfaceName_)),
pciBusID_(interfaceToBusID(interfaceName_)) {
Expand Down Expand Up @@ -257,6 +259,105 @@ void Device::unregisterDescriptor(int fd, Handler* h) {
loop_->unregisterDescriptor(fd, h);
}

Address Device::nextAddress() {
return listener_->nextAddress();
}

bool Device::isInitiator(
const Address& local,
const Address& remote) const {
int rv = 0;
// The remote side of a pair will be called with the same
// addresses, but in reverse. There should only be a single
// connection between the two, so we pick one side as the listener
// and the other side as the connector.
const auto& ss1 = local.getSockaddr();
const auto& ss2 = remote.getSockaddr();
GLOO_ENFORCE_EQ(ss1.ss_family, ss2.ss_family);
const int family = ss1.ss_family;
if (family == AF_INET) {
const struct sockaddr_in* sa = (struct sockaddr_in*)&ss1;
const struct sockaddr_in* sb = (struct sockaddr_in*)&ss2;
rv = memcmp(&sa->sin_addr, &sb->sin_addr, sizeof(struct in_addr));
if (rv == 0) {
rv = sa->sin_port - sb->sin_port;
}
} else if (family == AF_INET6) {
const struct sockaddr_in6* sa = (struct sockaddr_in6*)&ss1;
const struct sockaddr_in6* sb = (struct sockaddr_in6*)&ss2;
rv = memcmp(&sa->sin6_addr, &sb->sin6_addr, sizeof(struct in6_addr));
if (rv == 0) {
rv = sa->sin6_port - sb->sin6_port;
}
} else {
GLOO_ENFORCE(false, "Unknown address family: ", family);
}

// If both sides of the pair use the same address and port, they are
// sharing the same device instance. This happens in tests. Compare
// sequence number to allow pairs to connect.
if (rv == 0) {
rv = local.getSeq() - remote.getSeq();
}
GLOO_ENFORCE_NE(rv, 0, "Cannot connect to self");
return rv > 0;
}

void Device::connect(
const Address& local,
const Address& remote,
std::chrono::milliseconds timeout,
connect_callback_t fn) {
auto initiator = isInitiator(local, remote);

if (initiator) {
connectAsInitiator(remote, timeout, std::move(fn));
return;
}
connectAsListener(local, timeout, std::move(fn));
}

// Connecting as listener is passive.
//
// Register the connect callback to be executed when the other side of
// the pair has connected and identified itself as destined for this
// address. To do so, we register the callback for the sequence number
// associated with the address. If this connection already exists,
// deal with it here.
//
void Device::connectAsListener(
const Address& local,
std::chrono::milliseconds /* unused */,
connect_callback_t fn) {
// TODO(pietern): Use timeout.
listener_->waitForConnection(local.getSeq(), std::move(fn));
}

// Connecting as initiator is active.
//
// The connect callback is fired when the connection to the other side
// of the pair has been made, and the sequence number for this
// connection has been written. If an error occurs at any time, the
// callback is called with an associated error event.
//
void Device::connectAsInitiator(
const Address& remote,
std::chrono::milliseconds /* unused */,
connect_callback_t fn) {
const auto& sockaddr = remote.getSockaddr();

// Create new socket to connect to peer.
auto socket = Socket::createForFamily(sockaddr.ss_family);
socket->reuseAddr(true);
socket->noDelay(true);
socket->connect(sockaddr);

// Write sequence number for peer to new socket.
// TODO(pietern): Use timeout.
write<sequence_number_t>(
loop_, std::move(socket), remote.getSeq(), std::move(fn));
}

} // namespace tcp
} // namespace transport
} // namespace gloo
50 changes: 50 additions & 0 deletions gloo/transport/tcp/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

#include <gloo/transport/device.h>
#include <gloo/transport/tcp/attr.h>
#include <gloo/transport/tcp/error.h>
#include <gloo/transport/tcp/listener.h>
#include <gloo/transport/tcp/loop.h>
#include <gloo/transport/tcp/socket.h>

namespace gloo {
namespace transport {
Expand Down Expand Up @@ -50,14 +53,61 @@ class Device : public ::gloo::transport::Device,
void registerDescriptor(int fd, int events, Handler* h);
void unregisterDescriptor(int fd, Handler* h);

// TCP is bidirectional so when we connect two ends of a pair,
// one side is the connection initiator and the other is the listener.
bool isInitiator(
const Address& local,
const Address& remote) const;

protected:
const struct attr attr_;

// Return a new `Address` instance.
//
// This is called by the constructor of the `Pair` class. It gives
// the pair a uniquely identifying address even though the device
// uses a shared listening socket.
//
Address nextAddress();

// Connect a pair to a remote.
//
// This is performed by the device instance because we use a single
// listening socket for all inbound pair connections.
//
// Matching these connections with pairs is done with a handshake.
// The remote side of the connection writes a sequence number (see
// `Address::sequence_t`) to the stream that identifies the pair
// it wants to connect to. On the local side, this sequence number
// is read and used as key in a map with callbacks. If the callback
// is found, it is called. If the callback is not found, the
// connection is cached in a map, using the sequence number.
//
using connect_callback_t =
std::function<void(std::shared_ptr<Socket> socket, Error error)>;

void connect(
const Address& local,
const Address& remote,
std::chrono::milliseconds timeout,
connect_callback_t fn);

void connectAsListener(
const Address& local,
std::chrono::milliseconds timeout,
connect_callback_t fn);

void connectAsInitiator(
const Address& remote,
std::chrono::milliseconds timeout,
connect_callback_t fn);

friend class Pair;
friend class Buffer;

private:
std::shared_ptr<Loop> loop_;
std::shared_ptr<Listener> listener_;

std::string interfaceName_;
int interfaceSpeedMbps_;
Expand Down
171 changes: 171 additions & 0 deletions gloo/transport/tcp/helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <functional>
#include <memory>

#include <gloo/transport/tcp/error.h>
#include <gloo/transport/tcp/loop.h>
#include <gloo/transport/tcp/socket.h>

namespace gloo {
namespace transport {
namespace tcp {

// ReadValueOperation asynchronously reads a value of type T from the
// socket specified at construction. Upon completion or error, the
// callback is called. Its lifetime is coupled with completion of the
// operation, so the called doesn't need to hold on to the instance.
// It does so by storing a shared_ptr to itself (effectively a leak)
// until the event loop calls back.
template <typename T>
class ReadValueOperation final
: public Handler,
public std::enable_shared_from_this<ReadValueOperation<T>> {
public:
using callback_t =
std::function<void(std::shared_ptr<Socket>, const Error& error, T&& t)>;

ReadValueOperation(
std::shared_ptr<Loop> loop,
std::shared_ptr<Socket> socket,
callback_t fn)
: loop_(std::move(loop)),
socket_(std::move(socket)),
fn_(std::move(fn)) {}

void run() {
// Cannot initialize leak until after the object has been
// constructed, because the std::make_shared initialization
// doesn't run after construction of the underlying object.
leak_ = this->shared_from_this();
// Register with loop only after we've leaked the shared_ptr,
// because we unleak it when the event loop thread calls.
loop_->registerDescriptor(socket_->fd(), EPOLLIN | EPOLLONESHOT, this);
}

void handleEvents(int events) override {
// Move leaked shared_ptr to the stack so that this object
// destroys itself once this function returns.
auto self = std::move(this->leak_);

// Read T.
auto rv = socket_->read(&t_, sizeof(t_));
if (rv == -1) {
fn_(socket_, SystemError("read", errno), std::move(t_));
return;
}

// Check for short read (assume we can read in a single call).
if (rv < sizeof(t_)) {
fn_(socket_, ShortReadError(rv, sizeof(t_)), std::move(t_));
return;
}

fn_(socket_, Error::kSuccess, std::move(t_));
}

private:
std::shared_ptr<Loop> loop_;
std::shared_ptr<Socket> socket_;
callback_t fn_;
std::shared_ptr<ReadValueOperation<T>> leak_;

T t_;
};

template <typename T>
void read(
std::shared_ptr<Loop> loop,
std::shared_ptr<Socket> socket,
typename ReadValueOperation<T>::callback_t fn) {
auto x = std::make_shared<ReadValueOperation<T>>(
std::move(loop), std::move(socket), std::move(fn));
x->run();
}

// WriteValueOperation asynchronously writes a value of type T to the
// socket specified at construction. Upon completion or error, the
// callback is called. Its lifetime is coupled with completion of the
// operation, so the called doesn't need to hold on to the instance.
// It does so by storing a shared_ptr to itself (effectively a leak)
// until the event loop calls back.
template <typename T>
class WriteValueOperation final
: public Handler,
public std::enable_shared_from_this<WriteValueOperation<T>> {
public:
using callback_t =
std::function<void(std::shared_ptr<Socket>, const Error& error)>;

WriteValueOperation(
std::shared_ptr<Loop> loop,
std::shared_ptr<Socket> socket,
T t,
callback_t fn)
: loop_(std::move(loop)),
socket_(std::move(socket)),
fn_(std::move(fn)),
t_(std::move(t)) {}

void run() {
// Cannot initialize leak until after the object has been
// constructed, because the std::make_shared initialization
// doesn't run after construction of the underlying object.
leak_ = this->shared_from_this();
// Register with loop only after we've leaked the shared_ptr,
// because we unleak it when the event loop thread calls.
loop_->registerDescriptor(socket_->fd(), EPOLLOUT | EPOLLONESHOT, this);
}

void handleEvents(int events) override {
// Move leaked shared_ptr to the stack so that this object
// destroys itself once this function returns.
auto leak = std::move(this->leak_);

// Write T.
auto rv = socket_->write(&t_, sizeof(t_));
if (rv == -1) {
fn_(socket_, SystemError("write", errno));
return;
}

// Check for short write (assume we can write in a single call).
if (rv < sizeof(t_)) {
fn_(socket_, ShortWriteError(rv, sizeof(t_)));
return;
}

fn_(socket_, Error::kSuccess);
}

private:
std::shared_ptr<Loop> loop_;
std::shared_ptr<Socket> socket_;
callback_t fn_;
std::shared_ptr<WriteValueOperation<T>> leak_;

T t_;
};

template <typename T>
void write(
std::shared_ptr<Loop> loop,
std::shared_ptr<Socket> socket,
T t,
typename WriteValueOperation<T>::callback_t fn) {
auto x = std::make_shared<WriteValueOperation<T>>(
std::move(loop), std::move(socket), std::move(t), std::move(fn));
x->run();
}

} // namespace tcp
} // namespace transport
} // namespace gloo
Loading