Skip to content

Add autograd hook for python rpc call #27576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

zhaojuanmao
Copy link
Contributor

@zhaojuanmao zhaojuanmao commented Oct 8, 2019

Stack from ghstack:

  1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal.

This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

  1. meanwhile create a utiliy to attach autograd info and functions as needed

  2. add autograd send/recv functions for python rpc call

  3. make changes to support nested python rpc calls

  4. disallow nested dist autograd context (was landed in Distributed Autograd - FAST mode backward pass implementation. #27022)

Differential Revision: D17819153

1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
@pytorchbot pytorchbot added oncall: distributed Add this issue/PR to distributed oncall triage queue module: pybind Related to our Python bindings / interactions with other Python libraries labels Oct 8, 2019
1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 8, 2019
Pull Request resolved: #27576

1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context
ghstack-source-id: 91555159

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
@@ -52,6 +59,33 @@ DistAutogradContext* addRecvRpcBackward(
return nullptr;
}

std::shared_ptr<FutureMessage> sendMessageWithAutograd(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this function doesn't always send a message with autograd (if there is no valid context). If we want to name this function sendMessageWithAutograd , should we refactor to remove this check and do the check before invoking sendMessageWithAutograd?

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall changes look great! I have mostly comments around additional tests and some code structure.


# Get send function.
self._verify_send_recv_functions_in_client(context_id, t1, t2, ret)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call this verify_current_rank_context

# Now verify the autograd graph.
ctx = dist_autograd._retrieve_context(prev_rank_context_id)

self._verify_send_recv_functions_in_tensor_run(ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be called verify_prev_rank_context

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it is not prev_rank, for nested call, it is for prev_prev_rank, let me think about the naming more

self.assertEqual(t2, next_funcs[1][0].variable)
self.assertEqual(0, next_funcs[1][1])
@dist_init
def test_autograd_functions_for_python_nested_call(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test where the nested RPC calls itself? Basically A->B->A?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we add a test for more than 1 layer of nested calls? (could extend this test itself too). Something like A->B->C->D.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add them

with self.assertRaises(RuntimeError):
with dist_autograd.context() as context_id_1:
with dist_autograd.context() as context_id_2:
a = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use pass if you don't want to do anything in a block.

@@ -85,6 +85,9 @@ DistAutogradContext& DistAutogradContainer::getOrCreateContext(

const DistAutogradContext& DistAutogradContainer::newContext() {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(
!hasValidContext(),
"Next context can be created only when there is no valid context.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: New instead of Next

@@ -146,6 +149,10 @@ int64_t DistAutogradContainer::getMaxId() {
return max_id_;
}

void DistAutogradContainer::set_current_context_id(int64_t context_id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a TORCH_INTERNAL_ASSERT here such that the current_context_id_ isn't set already?

Comment on lines 154 to 159
// After processRpc() is done, clean up current_context_id_ to be -1,
// for a recv thread, current_context_id_ should always be invalid after
// processRpc() is done.
if (autogradContext != nullptr) {
autogradContainer.set_current_context_id(-1);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should clear the context id in request_callback.cpp in the operator() function (since we're guaranteed every RPC would go through here). We can do something like this:

Message RequestCallback::operator()(Message& request) const {
  ClearAutogradContextGuard guard;
  try {
    return processMessage(request);
  } catch (std::exception& e) {
    LOG(ERROR) << "Received error while processing request type "
               << request.type() << ": " << e.what();
    return createException(request, e);
  }
}

Basically in the destructor of ClearAutogradContextGuard we set the current context id to -1. Also, we shouldn't pass -1 here like this. The value -1 is an internal detail of the DistAutogradContext class. We should add a method called clearCurrentContext() which sets the value to -1.

// Wrap the response with autograd, need a new autograd message id for
// each send/recv pair.
auto& autogradContainer = DistAutogradContainer::getInstance();
AutogradMetadata responseAutogradMetadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use sendMessageWithAutograd here.

// passed in the chain calls.
auto& autogradContainer = DistAutogradContainer::getInstance();
if (autogradContext != nullptr) {
autogradContainer.set_current_context_id(autogradContext->context_id());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What prevents another thread setting the context id to a different value immediately after this line?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_context_id is a thread local variable.

// rpc call in python rpc call, original context_id from client can be
// passed in the chain calls.
auto& autogradContainer = DistAutogradContainer::getInstance();
if (autogradContext != nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected to see MESSAGE_WITH_AUTOGRAD_REQ if there is no valid context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autogradContext == nullptr here means: no valid context or tensors do not require grads


# Wait for the prev rank to be done with rpc.
while not prev_rank_rpc_done:
time.sleep(0.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a timeout for this?

self.assertEqual(1, len(recv_functions))
self.assertEqual(ret.grad_fn, list(recv_functions.values())[0])

# Host receives tensors and actually runs tensor operations, return tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit more in the comments? What does _in_tensor_run mean? Isn't this function verifying the autograd graph structure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the naming is bad, I will change it

next_funcs = list(send_functions.values())[0].next_functions
self.assertEqual(1, len(next_funcs))
add_backward_fn = next_funcs[0][0]
self.assertEqual("AddBackward0", add_backward_fn.name())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes a very specific autograd graph structure, but the name of this function seems to suggest it could work for any ctx. Let's add more descriptions in the function comments to describe what is the expected structure, and what functions are called in the forward pass.

next_funcs = list(send_functions.values())[1].next_functions
self.assertEqual(1, len(next_funcs))
self.assertEqual(
"torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This does not need to be done in this PR)

It seems autograd function checks are scattered in several places, making it a little difficult to track. Is it possible to implement one general autograd function checking utility method which takes a ctx and a nested list/tuple/dict (or we could just add our own graph structure) expected_graph. Then a test could do the forward pass, then construct expected_graph right next the forward pass (so that it will be much easier to verify they match), then call the generic autograd function checking method to verify correctness.


@dist_init
def test_nested_contex(self):
with self.assertRaises(RuntimeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the RuntimeError also come with a string message? Let's also verify it is the exact RuntimeError we are expecting.

Q: is this just a temporary limitation that we do not support nested context or will it always be in that way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is always like this, it did not work for nested context right now, as current_context_id will be cleared when inner context exits, I'm just adding checks to make it more explicit

@@ -146,6 +149,10 @@ int64_t DistAutogradContainer::getMaxId() {
return max_id_;
}

void DistAutogradContainer::set_current_context_id(int64_t context_id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use camelCase everywhere.

pass

# Now verify the autograd graph.
ctx = dist_autograd._retrieve_context(prev_rank_context_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not going to succeed with the changes in #27951, because it is possible that the prev_rank (i.e. the rank with context_id prev_rank_context_id) would have exited the autograd context manager in their thread, and thus destroyed the autograd context. Instead of using time.sleep() as above, I think we need to do something like rpc.sync_rpc(), though @pritamdamania87 mentioned that this is going away. Is there any other to wait for all outstanding RPCs to complete ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for heads up, if your diff landed first, I will rebase on it. otherwise I will leave this implementation as it is right now, and your diff can rebase on mine and add the change properly. does it sound good to you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@@ -160,3 +268,11 @@ def test_rpc_complex_args(self):
self.assertEqual(tensors[i], next_funcs[i][0].variable)
else:
self.assertIsNone(next_funcs[i][0])

@dist_init
def test_nested_contex(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: test_nested_context

1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 16, 2019
Pull Request resolved: #27576

1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context
ghstack-source-id: 92041048

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 17, 2019
Pull Request resolved: #27576

1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context
ghstack-source-id: 92067068

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
1. add autograd send/recv functions for python rpc call
2. make changes to support nested python rpc calls
4. disallow nested dist autograd context

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 17, 2019
Pull Request resolved: #27576

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached.
it still send rpc with autograd meta. This is not ideal.
This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)
ghstack-source-id: 92090804

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
@zhaojuanmao
Copy link
Contributor Author

did more refactoring, addressed comments and added more unit tests. Summary is also updated.

It is ready for another round of review now.

one pending issue is: although local tests passed, CI tests still have link issue. Looking into it.

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal.

This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 17, 2019
Pull Request resolved: #27576

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached.
it still send rpc with autograd meta. This is not ideal.
This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)
ghstack-source-id: 92123039

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
@zhaojuanmao
Copy link
Contributor Author

added TORCH_API for newly added functions to fix link issue on CI tests

# respectively.
# [RankDistance.PREV_PREV] represents for prev of prev rank.
# [RankDistance.PREV_PREV_PREV] represents for prev of prev of prev rank.
class RankDistance(IntEnum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just use integers for the distance instead of an enum like this? That way we can extend this to multiple nested levels later on easily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought enum is more descriptive, but yes, I can change it back to integers for add more levels easily

return torch.add(t1, t2)


def my_py_nested_call(t1, t2, dst, world_size, ttl):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: instead of ttl use hops

rpc::MessageType msgType);

// Send message after autograd checking
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage> sendMessage(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this sendMessageWithAutograd since we might attach autograd information here.

// and attach autograd function for each type of rpc call if it has valid
// context and tensors require grads, in this case, return RpcWithAutograd
// message; otherwise return original rpc message.
TORCH_API rpc::Message getMessageWithAutogradCheck(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove the Check at the end.

@@ -20,9 +20,27 @@ Message createException(const Message& request, const std::exception& e) {
request.id());
}

struct ClearAutogradContextGuard {
ClearAutogradContextGuard() {
clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this in the constructor? Destructor is enough right?

std::unique_ptr<RpcCommandBase> rpcCommand) {
Message getMessageWithAutogradCheck(
const rpc::worker_id_t dstId,
torch::distributed::rpc::Message&& wrappedRpcMsg,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should take the wrappedRpc and not wrappedMsg

std::shared_ptr<FutureMessage> sendMessage(
RpcAgent& agent,
const WorkerInfo& dst,
torch::distributed::rpc::Message&& wrappedRpcMsg,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass in wrappedRpc in both sendMessage() and getMessageWithAutogradCheck with above and inside getMessageWithAutogradCheck call toMessage on the rpc? That seems cleaner from an API standpoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in request_callback_impl.cpp, getMessageWithAutogradCheck is called, where we can only pass wrappedRpcResponse message to it.

RpcAgent& agent,
const WorkerInfo& dst,
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have msgType here? Isn't it always FORWARD_AUTOGRAD_REQ?

self._verify_graph_for_rpc_call_exec(list(send_functions.values())[0])

# Rank0->Rank1->Rank0
@dist_init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test where none of the tensors require grad and verify that we don't attach send/recv functions anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, and need to verify no prev context is passed over when there is no tensors requiring grads.

MessageType msgType) {
auto& autogradContainer = DistAutogradContainer::getInstance();

if (!autogradContainer.hasValidContext() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a small comment here explaining why we do this.

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal.

This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
@zhaojuanmao
Copy link
Contributor Author

addressed comments

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal.

This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal.

This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)

[ghstack-poisoned]
zhaojuanmao added a commit that referenced this pull request Oct 17, 2019
Pull Request resolved: #27576

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached.
it still send rpc with autograd meta. This is not ideal.
This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in #27022)
ghstack-source-id: 92154535

Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from my side all my comments are nits or followups.

prev_rank_rpc_done = False
prev_rank_context_id = 0
# Right now we test up to 3-layer nested rpc calls.
# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rpc_done[0] is the current rank?

self.assertEqual(ret.grad_fn, recv_function)

# For a context passed from previous nested chain calls, this rank
# recevied two tensors t1 and t2, execute torch.add(t1, t2) and send result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

receives, executes, sends

self.assertEqual(next_funcs[0][0], next_funcs[1][0])

# For a context passed from previous nested chain calls, this rank
# recevied two tensors t1 and t2, forwarding t1 and t2 tensors using
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recevie -> receive

receives, and forwards

# nested rpc call to next dst. In return route, receive result tensor t3
# from next dst and forwarding t3 back to previous calls.
# For this context in this rank, it expects graph like this:
# send and recv functions while recevive and forward t1 and t2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recevive -> receive

while recevive and forward t1 and t2 -> for receving and forwarding t1 and t2?

# rpcSendBackward
# / \
# t1.recvRpcBackward t2.recvRpcBackward
# send and recv functions while receive and forward t3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def my_py_nested_call(t1, t2, dst, world_size, hops):
next_dst = (dst + 1) % world_size
if hops > 0:
return rpc.rpc_sync("worker{}".format(next_dst), my_py_nested_call,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Does not need to be in this PR). Let's also test async rpc calls, e.g., making multiple async calls, collect all futures in a list, and wait on all list in the end.

} // anonymous namespace

Message RequestCallback::operator()(Message& request) const {
// For a rev thread, current context id should be invalid outside
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rev -> recv

} // anonymous namespace

Message RequestCallback::operator()(Message& request) const {
// For a rev thread, current context id should be invalid outside
// processMessage().
ClearAutogradContextGuard guard;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(does not need to be done in this PR) It looks a little weird that we have a guard that only clears context but not set it. I wonder if it would be better if we let the guard to govern both set and clear context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli this is possibly hard to set context here, as we can only set it after addRecvBackward() in FORWARD_AUTOGRAD_REQ

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failure looks irrelevant:

02:16:34 Failed to reproduce exception. Expected: 
02:16:34 Traceback (most recent call last):
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 669, in evaluate_test_data
02:16:34     result = self.execute(data, collect=True)
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 584, in execute
02:16:34     result = self.test_runner(data, run)
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/executors.py", line 58, in default_new_style_executor
02:16:34     return function(data)
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 580, in run
02:16:34     return test(*args, **kwargs)
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/operator_test/unique_ops_test.py", line 39, in test_unique_op
02:16:34     X=hu.tensor1d(
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/hypothesis/core.py", line 524, in test
02:16:34     result = self.test(*args, **kwargs)
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/operator_test/unique_ops_test.py", line 62, in test_unique_op
02:16:34     outputs_to_check=[0, 1] if return_remapping else [0]
02:16:34   File "/var/lib/jenkins/.local/lib/python2.7/site-packages/caffe2/python/hypothesis_test_util.py", line 417, in assertDeviceChecks
02:16:34     dc.CheckSimple(op, inputs, outputs_to_check, input_device_options)
02:16:34   File "/usr/lib64/python2.7/unittest/case.py", line 462, in assertTrue
02:16:34     raise self.failureException(msg)
02:16:34 AssertionError: False is not true

@@ -62,12 +62,15 @@ DistAutogradContext* addRecvRpcBackward(
return &autogradContext;
}

Message getMessageWithAutogradCheck(
Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we pass msgType here? We just need to directly set FORWARD_AUTOGRAD_REQ on line 87

Copy link
Contributor Author

@zhaojuanmao zhaojuanmao Oct 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it could be passed with FORWARD_AUTOGRAD_REQ or FORWARD_AUTOGRAD_RESP

@zhaojuanmao
Copy link
Contributor Author

yeah, the test failures are not relevant, we skipped rocm tests for distributed as well

# prev context id is not passed over as tensors do not require grads
with self.assertRaises(RuntimeError):
ctx = dist_autograd._retrieve_context(ctx_ids[1])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a dist.barrier() here, since this test is modified to test for state on a process that is set by an RPC from another process (previously this was just testing local state). Without dist.barrier(), for example, worker 0 could run the 2nd portion of this test (where the tensors do require gradients), and create the context on worker 1. Then worker 1 can run in the 1st portion of the test, and the assert would fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, my bad, forgot to call "self._check_rpc_done(1)" before calling _retrieve_context(ctx_ids[1])

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 56c4215.

@facebook-github-bot facebook-github-bot deleted the gh/zhaojuanmao/10/head branch October 28, 2019 22:23
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
Summary:
Pull Request resolved: pytorch#27576

1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached.
it still send rpc with autograd meta. This is not ideal.
This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads

2. meanwhile create a utiliy to attach autograd info and functions as needed

3. add autograd send/recv functions for python rpc call

4. make changes to support nested python rpc calls

5. disallow nested dist autograd context (was landed in pytorch#27022)
ghstack-source-id: 92154535

Test Plan: unit tests

Differential Revision: D17819153

fbshipit-source-id: 37d8a85855bf591f2f2da48d475a06e870a30ea1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: pybind Related to our Python bindings / interactions with other Python libraries oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants