Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support device map for distributed autograd while using TensorPipe. #44859

Closed
8 changes: 8 additions & 0 deletions test/cpp/rpc/e2e_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class TestE2EBase : public ::testing::Test {
RpcAgent::setCurrentRpcAgent(rpcAgent);
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
// For Dict that is used for device map.
auto pos = qn.name().find("Dict");
if (pos != std::string::npos) {
return c10::StrongTypePtr(
nullptr,
c10::DictType::create(
c10::IntType::create(), c10::IntType::create()));
}
return c10::StrongTypePtr(
nullptr, c10::TensorType::create(at::Tensor()));
});
Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ using torch::autograd::variable_list;
RecvRpcBackward::RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
ContextPtr autogradContext,
rpc::worker_id_t fromWorkerId)
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: autogradMetadata_(autogradMetadata),
autogradContext_(std::move(autogradContext)),
fromWorkerId_(fromWorkerId) {}
fromWorkerId_(fromWorkerId),
deviceMap_(std::move(deviceMap)) {}

variable_list RecvRpcBackward::apply(variable_list&& grads) {
std::vector<Variable> outputGrads;
Expand Down Expand Up @@ -48,7 +50,10 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) {
// Send the gradients over to the appropriate node.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
auto futureMessage = rpcAgent->send(
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage());
rpcAgent->getWorkerInfo(fromWorkerId_),
std::move(gradCall).toMessage(),
rpc::kUnsetRpcTimeout,
deviceMap_);

// Record the future in the context.
sharedContext->addOutstandingRpc(futureMessage);
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/distributed/autograd/functions/recvrpc_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
explicit RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::shared_ptr<DistAutogradContext> autogradContext,
rpc::worker_id_t fromWorkerId);
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap);

torch::autograd::variable_list apply(
torch::autograd::variable_list&& grads) override;
Expand All @@ -38,6 +39,9 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
// The worker id from which the RPC was received. During the backward pass,
// we need to propagate the gradients to this workerId.
rpc::worker_id_t fromWorkerId_;

// Device mapping for tensors sent over RPC.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
};

} // namespace autograd
Expand Down
38 changes: 31 additions & 7 deletions torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ RpcWithAutograd::RpcWithAutograd(
worker_id_t fromWorkerId,
MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage)
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
wrappedMessage_(std::move(wrappedMessage)) {
wrappedMessage_(std::move(wrappedMessage)),
deviceMap_(std::move(deviceMap)) {
TORCH_INTERNAL_ASSERT(
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
Expand All @@ -36,13 +38,15 @@ RpcWithAutograd::RpcWithAutograd(
const AutogradMetadata& autogradMetadata,
std::unique_ptr<RpcCommandBase> wrappedRpc,
MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors)
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
wrappedRpc_(std::move(wrappedRpc)),
wrappedMessageType_(wrappedMessageType),
tensors_(std::move(tensors)) {
tensors_(std::move(tensors)),
deviceMap_(std::move(deviceMap)) {
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
TORCH_INTERNAL_ASSERT(
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
Expand All @@ -56,10 +60,17 @@ Message RpcWithAutograd::toMessageImpl() && {
auto payload = std::move(wrappedMessage_).movePayload();
TORCH_INTERNAL_ASSERT(!payload.empty());

// Convert deviceMap to c10::Dict for serialization.
c10::Dict<int64_t, int64_t> deviceMap;
for (const auto& mapEntry : deviceMap_) {
deviceMap.insert(mapEntry.first, mapEntry.second);
}

std::vector<at::IValue> ivalues{wrappedMessageType,
autogradMetadata_.autogradContextId,
autogradMetadata_.autogradMessageId,
fromWorkerId_};
fromWorkerId_,
deviceMap};

// Now pickle using JIT pickler.
std::vector<torch::Tensor> tensorTable;
Expand Down Expand Up @@ -92,12 +103,19 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
auto tupleElements = rpc::readWrappedPayload(payload, message);

// Gather all the fields.
TORCH_INTERNAL_ASSERT(tupleElements.size() == 4);
TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
MessageType wrappedMessageType =
static_cast<MessageType>(tupleElements[0].toInt());
AutogradMetadata autogradMetadata(
tupleElements[1].toInt(), tupleElements[2].toInt());
worker_id_t workerId = tupleElements[3].toInt();
auto c10DeviceMap = tupleElements[4].to<c10::Dict<int64_t, int64_t>>();

// Convert to regular map.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap;
for (const auto& mapEntry : c10DeviceMap) {
deviceMap.insert({mapEntry.key(), mapEntry.value()});
}

// Create new message type and build wrapped RPC.
Message wrappedMessage(
Expand All @@ -116,7 +134,8 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
autogradMetadata,
std::move(wrappedRpc),
wrappedMessageType,
wrappedMessage.tensors());
wrappedMessage.tensors(),
deviceMap);
}

std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
Expand Down Expand Up @@ -150,6 +169,11 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
return fromWorkerId_;
}

const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& RpcWithAutograd::
deviceMap() {
return deviceMap_;
}

} // namespace autograd
} // namespace distributed
} // namespace torch
12 changes: 10 additions & 2 deletions torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
rpc::worker_id_t fromWorkerId,
rpc::MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage);
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});

// Used when receiving an RPC over the wire.
RpcWithAutograd(
Expand All @@ -27,7 +28,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
const AutogradMetadata& autogradMetadata,
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors);
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});

rpc::Message toMessageImpl() && override;

Expand All @@ -52,6 +54,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
// Retrieve the worker id from which the RPC originated.
rpc::worker_id_t fromWorkerId() const;

// Retrieve the device map.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap();

private:
// WorkerId from which this RPC originated. This is necessary for knowing
// which worker we need to contact during the backward pass.
Expand Down Expand Up @@ -83,6 +88,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {

// Tensors part of the wrappedRpc that need to be considered for autograd.
std::vector<torch::Tensor> tensors_;

// Device mapping for tensors that are sent across an RPC to another node.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
};

} // namespace autograd
Expand Down
14 changes: 9 additions & 5 deletions torch/csrc/distributed/autograd/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void addSendRpcBackward(
ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId) {
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
auto autogradContext =
Expand All @@ -61,7 +62,7 @@ ContextPtr addRecvRpcBackward(
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
// Attach the tensors as inputs to the autograd function.
auto grad_fn = std::make_shared<RecvRpcBackward>(
autogradMetadata, autogradContext, fromWorkerId);
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
for (auto& tensor : tensors) {
if (tensor.requires_grad()) {
torch::autograd::set_history(tensor, grad_fn);
Expand Down Expand Up @@ -102,7 +103,8 @@ Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording) {
bool forceGradRecording,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
auto& autogradContainer = DistAutogradContainer::getInstance();

// If there is no valid context and no tensor requires grads, send original
Expand All @@ -125,7 +127,8 @@ Message getMessageWithAutograd(
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
msgType,
autogradMetadata,
std::move(wrappedRpcMsg));
std::move(wrappedRpcMsg),
deviceMap);

if (tensorsRequireGrad) {
// Record autograd information for 'send'.
Expand All @@ -149,7 +152,8 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording);
forceGradRecording,
agent.getDeviceMap(dst));

std::shared_ptr<FutureMessage> fut;
// If profiler is enabled, wrap this message with profiling metadata that will
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/distributed/autograd/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ TORCH_API void addSendRpcBackward(
TORCH_API ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId);
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap);

// This method is a wrapper utility used internally to wrap autograd info
// and attach autograd function for each type of rpc call if it has valid
Expand All @@ -42,7 +43,9 @@ TORCH_API rpc::Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
rpc::Message&& wrappedRpcMsg,
rpc::MessageType msgType,
bool forceGradRecording = false);
bool forceGradRecording = false,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{});

// Send message after autograd checking
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage>
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/rpc/process_group_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ void ProcessGroupAgent::shutdownImpl() {
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds) {
const float rpcTimeoutSeconds,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
// Throw if we previously encountered an exception in ::listenLoop.
{
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/distributed/rpc/process_group_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) override;

// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/distributed/rpc/request_callback_no_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,20 @@ void RequestCallbackNoPython::processForwardAutogradReq(
const std::shared_ptr<FutureMessage>& responseFuture) const {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

// Need to reverse the device map for the backward pass of distributed
// autograd.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> reverseDeviceMap;
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
}


// Attach 'recv' autograd function.
auto autogradContext = addRecvRpcBackward(
rpcWithAutograd.autogradMetadata(),
rpcWithAutograd.tensors(),
rpcWithAutograd.fromWorkerId());
rpcWithAutograd.fromWorkerId(),
reverseDeviceMap);
// For this recv thread on server side, before processRpc(),
// set current_context_id_ to be context_id passed from client.
// In this way, if there is nested rpc call in python rpc call, original
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/distributed/rpc/rpc_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ bool RpcAgent::isGILProfilingEnabled() {
return profilingEnabled_.load();
}

std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> RpcAgent::getDeviceMap(
const WorkerInfo& dest) {
// Default implementation has no device map.
return {};
}

std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
/* This would later include more info other than metrics for eg: may include
stack traces for the threads owned by the agent */
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/distributed/rpc/rpc_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class TORCH_API RpcAgent {
virtual std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0;
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had agreed that in this initial version of CUDA support we would not allow to specify a per-RPC-call mapping but would instead always use the constant global one. It's true that this not being exposed at the Python layer, but introducing such an ability on the agent would add complexity (we'd probably need to attach the map to the message in case the receiver wants to access it and reverse it) and should probably be discussed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @mrshenli, hadn't we said that we should use c10::Device rather than c10::DeviceIndex as the latter is implicitly limiting us to CUDA and won't allow (one day) to have host-to-device maps or handle AMD GPUs...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only user facing API we support today is the Python one. The RpcAgent interface can be thought of as an internal API that we have complete control over. In this PR we do attach this map to the message and actually reverse it for distributed autograd.

I went with DeviceIndex here to be consistent with the rest of the device mapping code. I agree with Shen that this should be Device, but that is a much more involved change for 1.7. We control this interface and all its implementations, so it shouldn't be a big deal to change this parameter slightly in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we are confident that we will soon add support for per-RPC device map arguments? If that's the case, adding it to recv backward LGTM. If we don't see that coming in the near future, I am not sure if it would be worthy to introduce the additional complexity. But since the device map will be a beta feature anyway, I think it should be fine either way from the perf's perspective. If we decided to keep the current version, in order to address code complexity concerns, we can create an issue/reminder to revisit this and see whether a global map would be enough before 1.9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is tied with whether or not we want to support per-RPC device map arguments. This is not the public API that users see and is a private one for now. If we do end up building a C++ API, at that point we can evaluate what to do with this extra argument.

Regarding complexity I'm not sure if there is a simpler way to address this issue holistically. A global map for the backward wouldn't work in all cases. For example if nodes 1 and 2 perform RPCs on node3 with separate global device maps for the forward pass, there can't be a global backward map defined on node3 to handle this. It seems like we do need to do this at a per RPC level to handle this in a generic way.

{}) = 0;

// Retries sending the message up to maxRetries times until an ACK is
// receieved. The duration between consecutive sends is increased over
Expand Down Expand Up @@ -259,6 +261,10 @@ class TORCH_API RpcAgent {
// Get the type resolver
std::shared_ptr<TypeResolver> getTypeResolver();

// Retrieves the device map for the provided destination worker.
virtual std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> getDeviceMap(
const WorkerInfo& dest);

protected:
const WorkerInfo workerInfo_;
const std::unique_ptr<RequestCallback> cb_;
Expand Down