-
Notifications
You must be signed in to change notification settings - Fork 25k
Add autograd hook for python rpc call #27576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context ghstack-source-id: 91555159 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
@@ -52,6 +59,33 @@ DistAutogradContext* addRecvRpcBackward( | |||
return nullptr; | |||
} | |||
|
|||
std::shared_ptr<FutureMessage> sendMessageWithAutograd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this function doesn't always send a message with autograd (if there is no valid context). If we want to name this function sendMessageWithAutograd
, should we refactor to remove this check and do the check before invoking sendMessageWithAutograd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall changes look great! I have mostly comments around additional tests and some code structure.
test/dist_autograd_test.py
Outdated
|
||
# Get send function. | ||
self._verify_send_recv_functions_in_client(context_id, t1, t2, ret) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should call this verify_current_rank_context
test/dist_autograd_test.py
Outdated
# Now verify the autograd graph. | ||
ctx = dist_autograd._retrieve_context(prev_rank_context_id) | ||
|
||
self._verify_send_recv_functions_in_tensor_run(ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be called verify_prev_rank_context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but it is not prev_rank, for nested call, it is for prev_prev_rank, let me think about the naming more
test/dist_autograd_test.py
Outdated
self.assertEqual(t2, next_funcs[1][0].variable) | ||
self.assertEqual(0, next_funcs[1][1]) | ||
@dist_init | ||
def test_autograd_functions_for_python_nested_call(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test where the nested RPC calls itself? Basically A->B->A?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can we add a test for more than 1 layer of nested calls? (could extend this test itself too). Something like A->B->C->D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add them
test/dist_autograd_test.py
Outdated
with self.assertRaises(RuntimeError): | ||
with dist_autograd.context() as context_id_1: | ||
with dist_autograd.context() as context_id_2: | ||
a = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can use pass
if you don't want to do anything in a block.
@@ -85,6 +85,9 @@ DistAutogradContext& DistAutogradContainer::getOrCreateContext( | |||
|
|||
const DistAutogradContext& DistAutogradContainer::newContext() { | |||
std::lock_guard<std::mutex> guard(autograd_context_lock_); | |||
TORCH_CHECK( | |||
!hasValidContext(), | |||
"Next context can be created only when there is no valid context."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: New instead of Next
@@ -146,6 +149,10 @@ int64_t DistAutogradContainer::getMaxId() { | |||
return max_id_; | |||
} | |||
|
|||
void DistAutogradContainer::set_current_context_id(int64_t context_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a TORCH_INTERNAL_ASSERT
here such that the current_context_id_ isn't set already?
// After processRpc() is done, clean up current_context_id_ to be -1, | ||
// for a recv thread, current_context_id_ should always be invalid after | ||
// processRpc() is done. | ||
if (autogradContext != nullptr) { | ||
autogradContainer.set_current_context_id(-1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should clear the context id in request_callback.cpp
in the operator() function (since we're guaranteed every RPC would go through here). We can do something like this:
Message RequestCallback::operator()(Message& request) const {
ClearAutogradContextGuard guard;
try {
return processMessage(request);
} catch (std::exception& e) {
LOG(ERROR) << "Received error while processing request type "
<< request.type() << ": " << e.what();
return createException(request, e);
}
}
Basically in the destructor of ClearAutogradContextGuard
we set the current context id to -1. Also, we shouldn't pass -1 here like this. The value -1 is an internal detail of the DistAutogradContext
class. We should add a method called clearCurrentContext()
which sets the value to -1.
// Wrap the response with autograd, need a new autograd message id for | ||
// each send/recv pair. | ||
auto& autogradContainer = DistAutogradContainer::getInstance(); | ||
AutogradMetadata responseAutogradMetadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use sendMessageWithAutograd
here.
// passed in the chain calls. | ||
auto& autogradContainer = DistAutogradContainer::getInstance(); | ||
if (autogradContext != nullptr) { | ||
autogradContainer.set_current_context_id(autogradContext->context_id()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What prevents another thread setting the context id to a different value immediately after this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current_context_id
is a thread local variable.
// rpc call in python rpc call, original context_id from client can be | ||
// passed in the chain calls. | ||
auto& autogradContainer = DistAutogradContainer::getInstance(); | ||
if (autogradContext != nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it expected to see MESSAGE_WITH_AUTOGRAD_REQ
if there is no valid context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
autogradContext == nullptr here means: no valid context or tensors do not require grads
test/dist_autograd_test.py
Outdated
|
||
# Wait for the prev rank to be done with rpc. | ||
while not prev_rank_rpc_done: | ||
time.sleep(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add a timeout for this?
test/dist_autograd_test.py
Outdated
self.assertEqual(1, len(recv_functions)) | ||
self.assertEqual(ret.grad_fn, list(recv_functions.values())[0]) | ||
|
||
# Host receives tensors and actually runs tensor operations, return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate a bit more in the comments? What does _in_tensor_run
mean? Isn't this function verifying the autograd graph structure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the naming is bad, I will change it
next_funcs = list(send_functions.values())[0].next_functions | ||
self.assertEqual(1, len(next_funcs)) | ||
add_backward_fn = next_funcs[0][0] | ||
self.assertEqual("AddBackward0", add_backward_fn.name()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes a very specific autograd graph structure, but the name of this function seems to suggest it could work for any ctx
. Let's add more descriptions in the function comments to describe what is the expected structure, and what functions are called in the forward pass.
test/dist_autograd_test.py
Outdated
next_funcs = list(send_functions.values())[1].next_functions | ||
self.assertEqual(1, len(next_funcs)) | ||
self.assertEqual( | ||
"torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This does not need to be done in this PR)
It seems autograd function checks are scattered in several places, making it a little difficult to track. Is it possible to implement one general autograd function checking utility method which takes a ctx
and a nested list/tuple/dict (or we could just add our own graph structure) expected_graph
. Then a test could do the forward pass, then construct expected_graph
right next the forward pass (so that it will be much easier to verify they match), then call the generic autograd function checking method to verify correctness.
test/dist_autograd_test.py
Outdated
|
||
@dist_init | ||
def test_nested_contex(self): | ||
with self.assertRaises(RuntimeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the RuntimeError
also come with a string message? Let's also verify it is the exact RuntimeError we are expecting.
Q: is this just a temporary limitation that we do not support nested context or will it always be in that way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is always like this, it did not work for nested context right now, as current_context_id will be cleared when inner context exits, I'm just adding checks to make it more explicit
@@ -146,6 +149,10 @@ int64_t DistAutogradContainer::getMaxId() { | |||
return max_id_; | |||
} | |||
|
|||
void DistAutogradContainer::set_current_context_id(int64_t context_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use camelCase
everywhere.
test/dist_autograd_test.py
Outdated
pass | ||
|
||
# Now verify the autograd graph. | ||
ctx = dist_autograd._retrieve_context(prev_rank_context_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not going to succeed with the changes in #27951, because it is possible that the prev_rank
(i.e. the rank with context_id prev_rank_context_id
) would have exited the autograd context manager in their thread, and thus destroyed the autograd context. Instead of using time.sleep()
as above, I think we need to do something like rpc.sync_rpc()
, though @pritamdamania87 mentioned that this is going away. Is there any other to wait for all outstanding RPCs to complete ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for heads up, if your diff landed first, I will rebase on it. otherwise I will leave this implementation as it is right now, and your diff can rebase on mine and add the change properly. does it sound good to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
test/dist_autograd_test.py
Outdated
@@ -160,3 +268,11 @@ def test_rpc_complex_args(self): | |||
self.assertEqual(tensors[i], next_funcs[i][0].variable) | |||
else: | |||
self.assertIsNone(next_funcs[i][0]) | |||
|
|||
@dist_init | |||
def test_nested_contex(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: test_nested_context
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context ghstack-source-id: 92041048 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context ghstack-source-id: 92067068 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
1. add autograd send/recv functions for python rpc call 2. make changes to support nested python rpc calls 4. disallow nested dist autograd context Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) ghstack-source-id: 92090804 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
did more refactoring, addressed comments and added more unit tests. Summary is also updated. It is ready for another round of review now. one pending issue is: although local tests passed, CI tests still have link issue. Looking into it. |
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) ghstack-source-id: 92123039 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
added TORCH_API for newly added functions to fix link issue on CI tests |
test/dist_autograd_test.py
Outdated
# respectively. | ||
# [RankDistance.PREV_PREV] represents for prev of prev rank. | ||
# [RankDistance.PREV_PREV_PREV] represents for prev of prev of prev rank. | ||
class RankDistance(IntEnum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we just use integers for the distance instead of an enum like this? That way we can extend this to multiple nested levels later on easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought enum is more descriptive, but yes, I can change it back to integers for add more levels easily
test/dist_autograd_test.py
Outdated
return torch.add(t1, t2) | ||
|
||
|
||
def my_py_nested_call(t1, t2, dst, world_size, ttl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: instead of ttl
use hops
rpc::MessageType msgType); | ||
|
||
// Send message after autograd checking | ||
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage> sendMessage( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: call this sendMessageWithAutograd
since we might attach autograd information here.
// and attach autograd function for each type of rpc call if it has valid | ||
// context and tensors require grads, in this case, return RpcWithAutograd | ||
// message; otherwise return original rpc message. | ||
TORCH_API rpc::Message getMessageWithAutogradCheck( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove the Check
at the end.
@@ -20,9 +20,27 @@ Message createException(const Message& request, const std::exception& e) { | |||
request.id()); | |||
} | |||
|
|||
struct ClearAutogradContextGuard { | |||
ClearAutogradContextGuard() { | |||
clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do this in the constructor? Destructor is enough right?
std::unique_ptr<RpcCommandBase> rpcCommand) { | ||
Message getMessageWithAutogradCheck( | ||
const rpc::worker_id_t dstId, | ||
torch::distributed::rpc::Message&& wrappedRpcMsg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method should take the wrappedRpc and not wrappedMsg
std::shared_ptr<FutureMessage> sendMessage( | ||
RpcAgent& agent, | ||
const WorkerInfo& dst, | ||
torch::distributed::rpc::Message&& wrappedRpcMsg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pass in wrappedRpc
in both sendMessage() and getMessageWithAutogradCheck
with above and inside getMessageWithAutogradCheck
call toMessage
on the rpc? That seems cleaner from an API standpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in request_callback_impl.cpp, getMessageWithAutogradCheck is called, where we can only pass wrappedRpcResponse message to it.
RpcAgent& agent, | ||
const WorkerInfo& dst, | ||
torch::distributed::rpc::Message&& wrappedRpcMsg, | ||
MessageType msgType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have msgType
here? Isn't it always FORWARD_AUTOGRAD_REQ
?
self._verify_graph_for_rpc_call_exec(list(send_functions.values())[0]) | ||
|
||
# Rank0->Rank1->Rank0 | ||
@dist_init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a test where none of the tensors require grad and verify that we don't attach send/recv functions anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, and need to verify no prev context is passed over when there is no tensors requiring grads.
MessageType msgType) { | ||
auto& autogradContainer = DistAutogradContainer::getInstance(); | ||
|
||
if (!autogradContainer.hasValidContext() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Add a small comment here explaining why we do this.
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
addressed comments |
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/) [ghstack-poisoned]
Pull Request resolved: #27576 1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) ghstack-source-id: 92154535 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good from my side all my comments are nits or followups.
prev_rank_rpc_done = False | ||
prev_rank_context_id = 0 | ||
# Right now we test up to 3-layer nested rpc calls. | ||
# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rpc_done[0] is the current rank?
self.assertEqual(ret.grad_fn, recv_function) | ||
|
||
# For a context passed from previous nested chain calls, this rank | ||
# recevied two tensors t1 and t2, execute torch.add(t1, t2) and send result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
receives, executes, sends
self.assertEqual(next_funcs[0][0], next_funcs[1][0]) | ||
|
||
# For a context passed from previous nested chain calls, this rank | ||
# recevied two tensors t1 and t2, forwarding t1 and t2 tensors using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recevie -> receive
receives, and forwards
# nested rpc call to next dst. In return route, receive result tensor t3 | ||
# from next dst and forwarding t3 back to previous calls. | ||
# For this context in this rank, it expects graph like this: | ||
# send and recv functions while recevive and forward t1 and t2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recevive -> receive
while recevive and forward t1 and t2 -> for receving and forwarding t1 and t2?
# rpcSendBackward | ||
# / \ | ||
# t1.recvRpcBackward t2.recvRpcBackward | ||
# send and recv functions while receive and forward t3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
def my_py_nested_call(t1, t2, dst, world_size, hops): | ||
next_dst = (dst + 1) % world_size | ||
if hops > 0: | ||
return rpc.rpc_sync("worker{}".format(next_dst), my_py_nested_call, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Does not need to be in this PR). Let's also test async rpc calls, e.g., making multiple async calls, collect all futures in a list, and wait on all list in the end.
} // anonymous namespace | ||
|
||
Message RequestCallback::operator()(Message& request) const { | ||
// For a rev thread, current context id should be invalid outside |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rev -> recv
} // anonymous namespace | ||
|
||
Message RequestCallback::operator()(Message& request) const { | ||
// For a rev thread, current context id should be invalid outside | ||
// processMessage(). | ||
ClearAutogradContextGuard guard; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(does not need to be done in this PR) It looks a little weird that we have a guard that only clears context but not set it. I wonder if it would be better if we let the guard to govern both set and clear context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli this is possibly hard to set context here, as we can only set it after addRecvBackward() in FORWARD_AUTOGRAD_REQ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failure looks irrelevant:
02:16:34 Failed to reproduce exception. Expected:
02:16:34 Traceback (most recent call last):
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 669, in evaluate_test_data
02:16:34 result = self.execute(data, collect=True)
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 584, in execute
02:16:34 result = self.test_runner(data, run)
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/executors.py", line 58, in default_new_style_executor
02:16:34 return function(data)
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 580, in run
02:16:34 return test(*args, **kwargs)
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/operator_test/unique_ops_test.py", line 39, in test_unique_op
02:16:34 X=hu.tensor1d(
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 524, in test
02:16:34 result = self.test(*args, **kwargs)
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/operator_test/unique_ops_test.py", line 62, in test_unique_op
02:16:34 outputs_to_check=[0, 1] if return_remapping else [0]
02:16:34 File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/hypothesis_test_util.py", line 417, in assertDeviceChecks
02:16:34 dc.CheckSimple(op, inputs, outputs_to_check, input_device_options)
02:16:34 File "/usr/lib64/python2.7/unittest/case.py", line 462, in assertTrue
02:16:34 raise self.failureException(msg)
02:16:34 AssertionError: False is not true
@@ -62,12 +62,15 @@ DistAutogradContext* addRecvRpcBackward( | |||
return &autogradContext; | |||
} | |||
|
|||
Message getMessageWithAutogradCheck( | |||
Message getMessageWithAutograd( | |||
const rpc::worker_id_t dstId, | |||
torch::distributed::rpc::Message&& wrappedRpcMsg, | |||
MessageType msgType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we pass msgType here? We just need to directly set FORWARD_AUTOGRAD_REQ
on line 87
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because it could be passed with FORWARD_AUTOGRAD_REQ or FORWARD_AUTOGRAD_RESP
yeah, the test failures are not relevant, we skipped rocm tests for distributed as well |
# prev context id is not passed over as tensors do not require grads | ||
with self.assertRaises(RuntimeError): | ||
ctx = dist_autograd._retrieve_context(ctx_ids[1]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need a dist.barrier()
here, since this test is modified to test for state on a process that is set by an RPC from another process (previously this was just testing local state). Without dist.barrier()
, for example, worker 0 could run the 2nd portion of this test (where the tensors do require gradients), and create the context on worker 1. Then worker 1 can run in the 1st portion of the test, and the assert would fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, my bad, forgot to call "self._check_rpc_done(1)" before calling _retrieve_context(ctx_ids[1])
This pull request has been merged in 56c4215. |
Summary: Pull Request resolved: pytorch#27576 1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in pytorch#27022) ghstack-source-id: 92154535 Test Plan: unit tests Differential Revision: D17819153 fbshipit-source-id: 37d8a85855bf591f2f2da48d475a06e870a30ea1
Stack from ghstack:
This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads
meanwhile create a utiliy to attach autograd info and functions as needed
add autograd send/recv functions for python rpc call
make changes to support nested python rpc calls
disallow nested dist autograd context (was landed in Distributed Autograd - FAST mode backward pass implementation. #27022)
Differential Revision: D17819153