Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 20 additions & 36 deletions tensorpipe/channel/basic/channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,11 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {
void init();

void send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback);

void recv(
TDescriptor descriptor,
void* ptr,
size_t length,
TRecvCallback callback);
void recv(TDescriptor descriptor, CpuBuffer buffer, TRecvCallback callback);

// Tell the channel what its identifier is.
void setId(std::string id);
Expand All @@ -53,16 +48,14 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {
OnDemandLoop loop_;

void sendFromLoop_(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback);

// Receive memory region from peer.
void recvFromLoop_(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback);

void setIdFromLoop_(std::string id);
Expand Down Expand Up @@ -126,32 +119,27 @@ Channel::Impl::Impl(
id_(std::move(id)) {}

void Channel::send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
impl_->send(ptr, length, std::move(descriptorCallback), std::move(callback));
impl_->send(buffer, std::move(descriptorCallback), std::move(callback));
}

void Channel::Impl::send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
loop_.deferToLoop([this,
ptr,
length,
buffer,
descriptorCallback{std::move(descriptorCallback)},
callback{std::move(callback)}]() mutable {
sendFromLoop_(
ptr, length, std::move(descriptorCallback), std::move(callback));
sendFromLoop_(buffer, std::move(descriptorCallback), std::move(callback));
});
}

// Send memory region to peer.
void Channel::Impl::sendFromLoop_(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
TP_DCHECK(loop_.inLoop());
Expand Down Expand Up @@ -191,8 +179,8 @@ void Channel::Impl::sendFromLoop_(
TP_VLOG(6) << "Channel " << id_ << " is writing payload (#" << sequenceNumber
<< ")";
connection_->write(
ptr,
length,
buffer.ptr,
buffer.length,
eagerCallbackWrapper_(
[sequenceNumber, callback{std::move(callback)}](Impl& impl) {
TP_VLOG(6) << "Channel " << impl.id_ << " done writing payload (#"
Expand All @@ -206,30 +194,26 @@ void Channel::Impl::sendFromLoop_(
// Receive memory region from peer.
void Channel::recv(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
impl_->recv(std::move(descriptor), ptr, length, std::move(callback));
impl_->recv(std::move(descriptor), buffer, std::move(callback));
}

void Channel::Impl::recv(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
loop_.deferToLoop([this,
descriptor{std::move(descriptor)},
ptr,
length,
buffer,
callback{std::move(callback)}]() mutable {
recvFromLoop_(std::move(descriptor), ptr, length, std::move(callback));
recvFromLoop_(std::move(descriptor), buffer, std::move(callback));
});
}

void Channel::Impl::recvFromLoop_(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
TP_DCHECK(loop_.inLoop());

Expand Down Expand Up @@ -257,8 +241,8 @@ void Channel::Impl::recvFromLoop_(
TP_VLOG(6) << "Channel " << id_ << " is reading payload (#" << sequenceNumber
<< ")";
connection_->read(
ptr,
length,
buffer.ptr,
buffer.length,
eagerCallbackWrapper_(
[sequenceNumber, callback{std::move(callback)}](
Impl& impl, const void* /* unused */, size_t /* unused */) {
Expand Down
14 changes: 5 additions & 9 deletions tensorpipe/channel/basic/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
#include <memory>

#include <tensorpipe/channel/basic/context.h>
#include <tensorpipe/channel/channel.h>
#include <tensorpipe/channel/cpu_context.h>

namespace tensorpipe {
namespace channel {
namespace basic {

class Channel : public channel::Channel {
class Channel : public channel::CpuChannel {
// Use the passkey idiom to allow make_shared to call what should be a private
// constructor. See https://abseil.io/tips/134 for more information.
struct ConstructorToken {};
Expand All @@ -31,17 +31,13 @@ class Channel : public channel::Channel {

// Send memory region to peer.
void send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) override;

// Receive memory region from peer.
void recv(
TDescriptor descriptor,
void* ptr,
size_t length,
TRecvCallback callback) override;
void recv(TDescriptor descriptor, CpuBuffer buffer, TRecvCallback callback)
override;

// Tell the channel what its identifier is.
void setId(std::string id) override;
Expand Down
6 changes: 3 additions & 3 deletions tensorpipe/channel/basic/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Context::Impl : public Context::PrivateIface,

const std::string& domainDescriptor() const;

std::shared_ptr<channel::Channel> createChannel(
std::shared_ptr<channel::CpuChannel> createChannel(
std::shared_ptr<transport::Connection>,
Endpoint);

Expand Down Expand Up @@ -88,13 +88,13 @@ const std::string& Context::Impl::domainDescriptor() const {
return domainDescriptor_;
}

std::shared_ptr<channel::Channel> Context::createChannel(
std::shared_ptr<channel::CpuChannel> Context::createChannel(
std::shared_ptr<transport::Connection> connection,
Endpoint endpoint) {
return impl_->createChannel(std::move(connection), endpoint);
}

std::shared_ptr<channel::Channel> Context::Impl::createChannel(
std::shared_ptr<channel::CpuChannel> Context::Impl::createChannel(
std::shared_ptr<transport::Connection> connection,
Endpoint /* unused */) {
std::string channelId = id_ + ".c" + std::to_string(channelCounter_++);
Expand Down
6 changes: 3 additions & 3 deletions tensorpipe/channel/basic/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
#include <memory>
#include <string>

#include <tensorpipe/channel/context.h>
#include <tensorpipe/channel/cpu_context.h>
#include <tensorpipe/common/callback.h>

namespace tensorpipe {
namespace channel {
namespace basic {

class Context : public channel::Context {
class Context : public channel::CpuContext {
public:
Context();

const std::string& domainDescriptor() const override;

std::shared_ptr<Channel> createChannel(
std::shared_ptr<CpuChannel> createChannel(
std::shared_ptr<transport::Connection>,
Endpoint) override;

Expand Down
7 changes: 3 additions & 4 deletions tensorpipe/channel/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,19 @@ using TSendCallback = std::function<void(const Error&)>;
using TRecvCallback = std::function<void(const Error&)>;

// Abstract base class for channel classes.
template <typename TBuffer>
class Channel {
public:
// Send memory region to peer.
virtual void send(
const void* ptr,
size_t length,
TBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) = 0;

// Receive memory region from peer.
virtual void recv(
TDescriptor descriptor,
void* ptr,
size_t length,
TBuffer buffer,
TRecvCallback callback) = 0;

// Tell the channel what its identifier is.
Expand Down
54 changes: 19 additions & 35 deletions tensorpipe/channel/cma/channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,11 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {
void init();

void send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback);

void recv(
TDescriptor descriptor,
void* ptr,
size_t length,
TRecvCallback callback);
void recv(TDescriptor descriptor, CpuBuffer buffer, TRecvCallback callback);

// Tell the channel what its identifier is.
void setId(std::string id);
Expand All @@ -78,16 +73,14 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {

// Send memory region to peer.
void sendFromLoop_(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback);

// Receive memory region from peer.
void recvFromLoop_(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback);

void setIdFromLoop_(std::string id);
Expand Down Expand Up @@ -158,31 +151,26 @@ void Channel::Impl::initFromLoop_() {
}

void Channel::send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
impl_->send(ptr, length, std::move(descriptorCallback), std::move(callback));
impl_->send(buffer, std::move(descriptorCallback), std::move(callback));
}

void Channel::Impl::send(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
loop_.deferToLoop([this,
ptr,
length,
buffer,
descriptorCallback{std::move(descriptorCallback)},
callback{std::move(callback)}]() mutable {
sendFromLoop_(
ptr, length, std::move(descriptorCallback), std::move(callback));
sendFromLoop_(buffer, std::move(descriptorCallback), std::move(callback));
});
}

void Channel::Impl::sendFromLoop_(
const void* ptr,
size_t length,
CpuBuffer buffer,
TDescriptorCallback descriptorCallback,
TSendCallback callback) {
TP_DCHECK(loop_.inLoop());
Expand Down Expand Up @@ -236,38 +224,34 @@ void Channel::Impl::sendFromLoop_(
NopHolder<Descriptor> nopHolder;
Descriptor& nopDescriptor = nopHolder.getObject();
nopDescriptor.pid = getpid();
nopDescriptor.ptr = reinterpret_cast<uint64_t>(ptr);
nopDescriptor.ptr = reinterpret_cast<uint64_t>(buffer.ptr);

descriptorCallback(Error::kSuccess, saveDescriptor(nopHolder));
}

// Receive memory region from peer.
void Channel::recv(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
impl_->recv(std::move(descriptor), ptr, length, std::move(callback));
impl_->recv(std::move(descriptor), buffer, std::move(callback));
}

void Channel::Impl::recv(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
loop_.deferToLoop([this,
descriptor{std::move(descriptor)},
ptr,
length,
buffer,
callback{std::move(callback)}]() mutable {
recvFromLoop_(std::move(descriptor), ptr, length, std::move(callback));
recvFromLoop_(std::move(descriptor), buffer, std::move(callback));
});
}

void Channel::Impl::recvFromLoop_(
TDescriptor descriptor,
void* ptr,
size_t length,
CpuBuffer buffer,
TRecvCallback callback) {
TP_DCHECK(loop_.inLoop());

Expand Down Expand Up @@ -301,8 +285,8 @@ void Channel::Impl::recvFromLoop_(
context_->requestCopy(
remotePid,
remotePtr,
ptr,
length,
buffer.ptr,
buffer.length,
eagerCallbackWrapper_([sequenceNumber,
callback{std::move(callback)}](Impl& impl) {
TP_VLOG(6) << "Channel " << impl.id_ << " done copying payload (#"
Expand Down
Loading