Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CUDAFuture remember and restore current device in callback #48789

Closed
wants to merge 4 commits into from

Conversation

lw
Copy link
Contributor

@lw lw commented Dec 3, 2020

Stack from ghstack:

CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior worse. I'd love to hear people's opinions on this.

Differential Revision: D25210335

CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.

Differential Revision: [D25210335](https://our.internmc.facebook.com/intern/diff/D25210335/)

[ghstack-poisoned]
@codecov
Copy link

codecov bot commented Dec 3, 2020

Codecov Report

Merging #48789 (ae4e7f3) into gh/lw/101/base (b726a1b) will increase coverage by 0.02%.
The diff coverage is 64.86%.

@@                Coverage Diff                 @@
##           gh/lw/101/base   #48789      +/-   ##
==================================================
+ Coverage           80.79%   80.81%   +0.02%     
==================================================
  Files                1865     1863       -2     
  Lines              201074   200922     -152     
==================================================
- Hits               162456   162383      -73     
+ Misses              38618    38539      -79     

…lback"

CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.

Differential Revision: [D25210335](https://our.internmc.facebook.com/intern/diff/D25210335/)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Dec 4, 2020

💊 CI failures summary and remediations

As of commit 4f76484 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_keypoint_rcnn FAILED [ 55%]
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_index_select_scaler_index PASSED [ 53%] 
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_inplace_arithmetic PASSED [ 54%] 
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_inplace_fill PASSED [ 54%] 
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_inplace_list PASSED [ 54%] 
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_inplace_zero PASSED [ 54%] 
Dec 09 17:55:11 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_interpolate_adaptive_pooling_error PASSED [ 54%] 
Dec 09 17:55:12 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_interpolate_downsample PASSED [ 55%] 
Dec 09 17:55:12 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_interpolate_function_substitution PASSED [ 55%] 
Dec 09 17:55:12 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_interpolate_no_shape PASSED [ 55%] 
Dec 09 17:55:13 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_interpolate_upsample PASSED [ 55%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_keypoint_rcnn FAILED [ 55%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_kldiv_loss PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_l1_norm PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_l2_norm PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_layer_norm PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_le PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_le_scalar PASSED [ 56%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_len PASSED [ 57%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_len_list PASSED [ 57%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_list PASSED [ 57%] 
Dec 09 17:55:31 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference::test_list_pass PASSED [ 57%] 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 7 times.

…lback"

CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.

Differential Revision: [D25210335](https://our.internmc.facebook.com/intern/diff/D25210335/)

[ghstack-poisoned]
@@ -31,6 +31,8 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
}

void postMarkCompletedHook(const at::IValue& value) override {
currentDevice_ = c10::cuda::current_device();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why we are recording the device when the future is marked as completed, instead of remember the device when the callback was inserted (through then/addCallback)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess those are two different approaches to this. The "philosophy" I was following is this: then() and addCallback() are used to perform a computation after another one is complete. If we were dealing with with sync operations, one would do this:

do_sth_sync()
do_sth_later()

but with async ops this needs to change and become

fut = do_sth_async()
fut.then(do_sth_later)

In the sync scenario, do_sth_later() runs in the same "environment"/"context" as do_sth_sync() (same current device, same current streams, ...). In this diff I was trying to recreate this in the async case, by "recording" the environment's state at the end of the async operation, and then "recreating" it in the callback.

Note that the approach we have for streams is somewhat similar. We do not run the callback in the streams that were current when the callback was inserted. (In honesty, we also do not run it in the streams that were current when the async op finished (as that would mean running computations in the I/O streams), but we do synchronize the "fresh" streams with those streams).

Note also that I am not opposed to changing this behavior, and replicate the "environment" that was current when the user inserted the callback. That has its own sets of advantages. For streams, for example, it allows the user to very precisely control what streams the callback will use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For streams, for example, it allows the user to very precisely control what streams the callback will use.

Yep, same for device as well, especially when the callback fn is an imported function and users cannot easily change it. The current behavior should be sufficient to unblock RPC use cases. So I am OK to land this PR and modify it later if necessary.

…lback"

CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.

Differential Revision: [D25210335](https://our.internmc.facebook.com/intern/diff/D25210335/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 5ab90b2.

@facebook-github-bot facebook-github-bot deleted the gh/lw/101/head branch December 14, 2020 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants