Skip to content

Commit

Permalink
Support device map for distributed autograd while using TensorPipe.
Browse files Browse the repository at this point in the history
Pull Request resolved: #44859

TensorPipe's `set_device_map` option was applied during the forward
pass. However, if we ran the backward pass for the graph we would not
automatically pick up the reverse device mapping.

As a result, users had to specify both forward and backward device mapping
which is very tedious to do.

In this PR, I've added this functionality such that TensorPipe automatically
picks up the reverse device mapping during the backward pass. This is done by
storing the appropriate device mapping in the "recv" autograd function for
distributed autograd.

#Closes: #44170
ghstack-source-id: 112351599

Differential Revision: [D23751975](https://our.internmc.facebook.com/intern/diff/D23751975/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D23751975/)!
  • Loading branch information
pritamdamania committed Sep 18, 2020
1 parent cce7680 commit 6f7280f
Show file tree
Hide file tree
Showing 19 changed files with 204 additions and 53 deletions.
8 changes: 8 additions & 0 deletions test/cpp/rpc/e2e_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class TestE2EBase : public ::testing::Test {
RpcAgent::setCurrentRpcAgent(rpcAgent);
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
// For Dict that is used for device map.
auto pos = qn.name().find("Dict");
if (pos != std::string::npos) {
return c10::StrongTypePtr(
nullptr,
c10::DictType::create(
c10::IntType::create(), c10::IntType::create()));
}
return c10::StrongTypePtr(
nullptr, c10::TensorType::create(at::Tensor()));
});
Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ using torch::autograd::variable_list;
RecvRpcBackward::RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
ContextPtr autogradContext,
rpc::worker_id_t fromWorkerId)
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: autogradMetadata_(autogradMetadata),
autogradContext_(std::move(autogradContext)),
fromWorkerId_(fromWorkerId) {}
fromWorkerId_(fromWorkerId),
deviceMap_(std::move(deviceMap)) {}

variable_list RecvRpcBackward::apply(variable_list&& grads) {
std::vector<Variable> outputGrads;
Expand Down Expand Up @@ -48,7 +50,10 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) {
// Send the gradients over to the appropriate node.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
auto futureMessage = rpcAgent->send(
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage());
rpcAgent->getWorkerInfo(fromWorkerId_),
std::move(gradCall).toMessage(),
rpc::kUnsetRpcTimeout,
deviceMap_);

// Record the future in the context.
sharedContext->addOutstandingRpc(futureMessage);
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/distributed/autograd/functions/recvrpc_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
explicit RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::shared_ptr<DistAutogradContext> autogradContext,
rpc::worker_id_t fromWorkerId);
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap);

torch::autograd::variable_list apply(
torch::autograd::variable_list&& grads) override;
Expand All @@ -38,6 +39,9 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
// The worker id from which the RPC was received. During the backward pass,
// we need to propagate the gradients to this workerId.
rpc::worker_id_t fromWorkerId_;

// Device mapping for tensors sent over RPC.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
};

} // namespace autograd
Expand Down
38 changes: 31 additions & 7 deletions torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ RpcWithAutograd::RpcWithAutograd(
worker_id_t fromWorkerId,
MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage)
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
wrappedMessage_(std::move(wrappedMessage)) {
wrappedMessage_(std::move(wrappedMessage)),
deviceMap_(std::move(deviceMap)) {
TORCH_INTERNAL_ASSERT(
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
Expand All @@ -36,13 +38,15 @@ RpcWithAutograd::RpcWithAutograd(
const AutogradMetadata& autogradMetadata,
std::unique_ptr<RpcCommandBase> wrappedRpc,
MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors)
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
wrappedRpc_(std::move(wrappedRpc)),
wrappedMessageType_(wrappedMessageType),
tensors_(std::move(tensors)) {
tensors_(std::move(tensors)),
deviceMap_(std::move(deviceMap)) {
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
TORCH_INTERNAL_ASSERT(
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
Expand All @@ -56,10 +60,17 @@ Message RpcWithAutograd::toMessageImpl() && {
auto payload = std::move(wrappedMessage_).movePayload();
TORCH_INTERNAL_ASSERT(!payload.empty());

// Convert deviceMap to c10::Dict for serialization.
c10::Dict<int64_t, int64_t> deviceMap;
for (const auto& mapEntry : deviceMap_) {
deviceMap.insert(mapEntry.first, mapEntry.second);
}

std::vector<at::IValue> ivalues{wrappedMessageType,
autogradMetadata_.autogradContextId,
autogradMetadata_.autogradMessageId,
fromWorkerId_};
fromWorkerId_,
deviceMap};

// Now pickle using JIT pickler.
std::vector<torch::Tensor> tensorTable;
Expand Down Expand Up @@ -92,12 +103,19 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
auto tupleElements = rpc::readWrappedPayload(payload, message);

// Gather all the fields.
TORCH_INTERNAL_ASSERT(tupleElements.size() == 4);
TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
MessageType wrappedMessageType =
static_cast<MessageType>(tupleElements[0].toInt());
AutogradMetadata autogradMetadata(
tupleElements[1].toInt(), tupleElements[2].toInt());
worker_id_t workerId = tupleElements[3].toInt();
auto c10DeviceMap = tupleElements[4].to<c10::Dict<int64_t, int64_t>>();

// Convert to regular map.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap;
for (const auto& mapEntry : c10DeviceMap) {
deviceMap.insert({mapEntry.key(), mapEntry.value()});
}

// Create new message type and build wrapped RPC.
Message wrappedMessage(
Expand All @@ -116,7 +134,8 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
autogradMetadata,
std::move(wrappedRpc),
wrappedMessageType,
wrappedMessage.tensors());
wrappedMessage.tensors(),
deviceMap);
}

std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
Expand Down Expand Up @@ -150,6 +169,11 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
return fromWorkerId_;
}

const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& RpcWithAutograd::
deviceMap() {
return deviceMap_;
}

} // namespace autograd
} // namespace distributed
} // namespace torch
12 changes: 10 additions & 2 deletions torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
rpc::worker_id_t fromWorkerId,
rpc::MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage);
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});

// Used when receiving an RPC over the wire.
RpcWithAutograd(
Expand All @@ -27,7 +28,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
const AutogradMetadata& autogradMetadata,
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors);
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});

rpc::Message toMessageImpl() && override;

Expand All @@ -52,6 +54,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
// Retrieve the worker id from which the RPC originated.
rpc::worker_id_t fromWorkerId() const;

// Retrieve the device map.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap();

private:
// WorkerId from which this RPC originated. This is necessary for knowing
// which worker we need to contact during the backward pass.
Expand Down Expand Up @@ -83,6 +88,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {

// Tensors part of the wrappedRpc that need to be considered for autograd.
std::vector<torch::Tensor> tensors_;

// Device mapping for tensors that are sent across an RPC to another node.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
};

} // namespace autograd
Expand Down
14 changes: 9 additions & 5 deletions torch/csrc/distributed/autograd/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void addSendRpcBackward(
ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId) {
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
auto autogradContext =
Expand All @@ -61,7 +62,7 @@ ContextPtr addRecvRpcBackward(
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
// Attach the tensors as inputs to the autograd function.
auto grad_fn = std::make_shared<RecvRpcBackward>(
autogradMetadata, autogradContext, fromWorkerId);
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
for (auto& tensor : tensors) {
if (tensor.requires_grad()) {
torch::autograd::set_history(tensor, grad_fn);
Expand Down Expand Up @@ -102,7 +103,8 @@ Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording) {
bool forceGradRecording,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
auto& autogradContainer = DistAutogradContainer::getInstance();

// If there is no valid context and no tensor requires grads, send original
Expand All @@ -125,7 +127,8 @@ Message getMessageWithAutograd(
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
msgType,
autogradMetadata,
std::move(wrappedRpcMsg));
std::move(wrappedRpcMsg),
deviceMap);

if (tensorsRequireGrad) {
// Record autograd information for 'send'.
Expand All @@ -148,7 +151,8 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording);
forceGradRecording,
agent.getDeviceMap(dst));

std::shared_ptr<FutureMessage> fut;
// If profiler is enabled, wrap this message with profiling metadata that will
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/distributed/autograd/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ TORCH_API void addSendRpcBackward(
TORCH_API ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId);
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap);

// This method is a wrapper utility used internally to wrap autograd info
// and attach autograd function for each type of rpc call if it has valid
Expand All @@ -42,7 +43,9 @@ TORCH_API rpc::Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
rpc::Message&& wrappedRpcMsg,
rpc::MessageType msgType,
bool forceGradRecording = false);
bool forceGradRecording = false,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{});

// Send message after autograd checking
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage>
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/rpc/process_group_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ void ProcessGroupAgent::shutdownImpl() {
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds) {
const float rpcTimeoutSeconds,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
// Throw if we previously encountered an exception in ::listenLoop.
{
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/distributed/rpc/process_group_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) override;

// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/distributed/rpc/request_callback_no_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,19 @@ void RequestCallbackNoPython::processRpc(
case MessageType::FORWARD_AUTOGRAD_REQ: {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

// Need to reverse the device map for the backward pass of distributed
// autograd.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> reverseDeviceMap;
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
}

// Attach 'recv' autograd function.
auto autogradContext = addRecvRpcBackward(
rpcWithAutograd.autogradMetadata(),
rpcWithAutograd.tensors(),
rpcWithAutograd.fromWorkerId());
rpcWithAutograd.fromWorkerId(),
reverseDeviceMap);
// For this recv thread on server side, before processRpc(),
// set current_context_id_ to be context_id passed from client.
// In this way, if there is nested rpc call in python rpc call, original
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/distributed/rpc/rpc_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ bool RpcAgent::isGILProfilingEnabled() {
return profilingEnabled_.load();
}

std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> RpcAgent::getDeviceMap(
const WorkerInfo& dest) {
// Default implementation has no device map.
return {};
}

std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
/* This would later include more info other than metrics for eg: may include
stack traces for the threads owned by the agent */
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/distributed/rpc/rpc_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ class TORCH_API RpcAgent {
virtual std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0;
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) = 0;

// Retries sending the message up to maxRetries times until an ACK is
// receieved. The duration between consecutive sends is increased over
Expand Down Expand Up @@ -256,6 +258,10 @@ class TORCH_API RpcAgent {
// Get the type resolver
std::shared_ptr<TypeResolver> getTypeResolver();

// Retrieves the device map for the provided destination worker.
virtual std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> getDeviceMap(
const WorkerInfo& dest);

protected:
const WorkerInfo workerInfo_;
const std::unique_ptr<RequestCallback> cb_;
Expand Down

0 comments on commit 6f7280f

Please sign in to comment.