New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA graphs] Cuda RNG-safe graph capture and replay bindings #48875
Conversation
💊 CI failures summary and remediationsAs of commit 1d69b27 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_py3_6_gcc5_4_test (1/1)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
Lotta test failures lol |
namespace at { | ||
namespace cuda { | ||
|
||
struct TORCH_CUDA_API CUDAGraph { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should have a // See Note [CUDA Graph...
reference somewhere here; probably framed as "why does PyTorch need its own CUDAGraph rep"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); | ||
|
||
auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong); | ||
offset_extragraph_ = at::empty({1}, options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you sure you want empty here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it doesn't matter if we run on garbage, all we're trying to do is bake in the right pointers. If there's data dependent control flow in the graphed region, we have no business graphing it anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, deliberately beginning capture with garbage here will stress any unwanted/unexpected data dependence, and help catch failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, half the time the garbage is going to be zeros, if you really want to find unexpected data dependence, should fill it with some sentinel. (But duly taken)
offset_extragraph_ = offset_extragraph; | ||
offset_intragraph_ = 0; | ||
graph_expects_this_gen_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you have multiple captures going on at the same time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I think the best approach is to make nvm my idea is nonsense. There are ways I can try to make this safe which we can discuss in more detail, but I can't think of a simple one. The simplest answer is to disallow that usage: only one capture may be underway at a time.graph_expects_this_gen_
an argument to those consumer-called members.
test/test_cuda.py
Outdated
a = torch.zeros((1000,), device="cuda") | ||
a += 1 | ||
g = torch.cuda.Graph() | ||
g.capture_begin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you sure you don't want a context manager for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manager could also take care of getting onto a non-default stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think eventually a context manager is the right exposure, but right now, this won't be documented, and exists to serve experiments. In experiments, I don't want to exit gracefully from graphed regions if capture fails. I want all hell to break loose so we see it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm hacking on a local UX that wraps the exposures here and takes care of non-default stream. The goal of this PR is to introduce a minimal capture exposure to enable flexible UX hacking.
test/test_cuda.py
Outdated
self.assertTrue(a.sum().item() == 3000.) | ||
|
||
def test_graph_rng_functional(self): | ||
# The caching allocator isn't yet graph-safe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, so what's the plan for this anyway? Thread local allocator override?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://www.youtube.com/watch?v=Lf6aMcnR9WQ
Long term (many months) I'm working with cudaMallocAsync people to ensure cudaMallocAsync plays well with cuda graphs. Hopefully after we integrate cudaMallocAsync, capturing allocations will "just work."
Short term, Arslan Zulfiqar, one of our people, hacked the Pytorch allocator to request and reserve a stream-silo for the graph(s): graph capture gets its own memory silo/pool in which it can reuse memory, and it won't be affected by other eager allocations. My very next task after this PR is to look into upstreaming his changes or some variation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, are you saying, when cudaMallocAsync becomes a thing, we get rid of PyTorch's CUDA caching allocator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that would be a sad waste, not to mention loss of control for corner cases when you need it. I think the perfect world would be, cudaMallocAsync and the Pytorch allocator are alternative user-selectable backends for some common allocator interface (which may also accept external allocators that support the same interface #43144). I do think cudaMallocAsync should be the default.
The implementation bits seem fine. But the UX may need some smithing. Is there a plan on record here? (@ngimel?) |
Codecov Report
@@ Coverage Diff @@
## master #48875 +/- ##
=======================================
Coverage 80.69% 80.70%
=======================================
Files 1871 1871
Lines 202062 202064 +2
=======================================
+ Hits 163064 163071 +7
+ Misses 38998 38993 -5 |
test/test_cuda.py
Outdated
with torch.cuda.stream(s1): | ||
a = torch.zeros((1000,), device="cuda") | ||
a += 1 | ||
g = torch.cuda.Graph() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API name don't look very private to me! 🤣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤫
Alright, I can not import it as Graph for now, so usage (for experimenters) would be
g = torch._C._CudaGraphBase()
g.capture_begin/capture_end/replay()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just make it torch.cuda._Graph
for now, rename when we have a better idea what we want to do with it.
I guess I'm OK with merging this on an experimental basis, but right now it's too public looking for my taste. |
this seems ok. @ngimel do you want to land it? |
Let me take a last look and I will. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Implements #51075 (comment) and additions discussed offline with ezyang ngimel . (Calling it "simple" is charitable but it's not too bad). [High level strategy](https://github.com/pytorch/pytorch/pull/51436/files#diff-acc6337586bf9cdcf0a684380779300ec171897d05b8569bf439820dc8c93bd5R57-R82) The current design aggregates stats from private pools with the ordinary pools, which may or may not be what we want. Instead of adding PrivatePools as an internal feature of DeviceAllocator, I could inherit from DeviceAllocator (eg `DevicePrivateAllocator : public DeviceAllocator`) and create separate per-graph instances of the inherited class. I'm not sure if that would be better. Graph bindings in Python are almost unchanged from #48875: ```python # Same bindings as 48875, but now implicitly grabs a private mempool graph1.capture_begin() graph1.capture_end() # pool=... is new. It hints that allocations during graph2's capture may share graph1's mempool graph2.capture_begin(pool=graph1.pool()) graph2.capture_end() # graph3 also implicitly creates its own mempool graph3.capture_begin() graph3.capture_end() ``` Test plan (other suggestions appreciated): - [x] Stop maintaining manual references for all the tensors in my existing graphs+RNG tests. If private pools somehow give bad allocations, they should start failing intermittently. They run eager ops and eager allocations mixed with graph replays, so they may expose if eager ops and replays corrupt each other. - [x] `test_graph_two_successive`: Capture successive graphs, with the second graph using the first graph's result. Try with and without sharing a pool. Check results, also check memory stats to confirm sharing a pool saves memory. - [x] `test_graph_concurrent_replay`: Capture some graphs in separate private pools, replay them concurrently in different streams, check the results to make sure they don't corrupt each other's memory. Capture some graphs with a shared pool, replay them concurrently in different streams, check results, confirm they DO corrupt each other's memory. - [x] `test_graph_three_successive`: A three-graph case, checking the safe and unsafe replay patterns in [Restrictions of the Strawman API](#51075)). - [x] `test_graph_memory_stats_and_use_result_after_destroy_graph`: Comprehensively check torch.cuda.memory_stats() changes that result from graph capture and delete. Check that a tensor ref created during capture and held after graph delete stays valid until the tensor itself is deleted. Pull Request resolved: #51436 Reviewed By: mruberry Differential Revision: D26993790 Pulled By: ngimel fbshipit-source-id: a992eaee1b8c23628e7b388a5a3c26e0f80e54da
Summary: Implements pytorch#51075 (comment) and additions discussed offline with ezyang ngimel . (Calling it "simple" is charitable but it's not too bad). [High level strategy](https://github.com/pytorch/pytorch/pull/51436/files#diff-acc6337586bf9cdcf0a684380779300ec171897d05b8569bf439820dc8c93bd5R57-R82) The current design aggregates stats from private pools with the ordinary pools, which may or may not be what we want. Instead of adding PrivatePools as an internal feature of DeviceAllocator, I could inherit from DeviceAllocator (eg `DevicePrivateAllocator : public DeviceAllocator`) and create separate per-graph instances of the inherited class. I'm not sure if that would be better. Graph bindings in Python are almost unchanged from pytorch#48875: ```python # Same bindings as 48875, but now implicitly grabs a private mempool graph1.capture_begin() graph1.capture_end() # pool=... is new. It hints that allocations during graph2's capture may share graph1's mempool graph2.capture_begin(pool=graph1.pool()) graph2.capture_end() # graph3 also implicitly creates its own mempool graph3.capture_begin() graph3.capture_end() ``` Test plan (other suggestions appreciated): - [x] Stop maintaining manual references for all the tensors in my existing graphs+RNG tests. If private pools somehow give bad allocations, they should start failing intermittently. They run eager ops and eager allocations mixed with graph replays, so they may expose if eager ops and replays corrupt each other. - [x] `test_graph_two_successive`: Capture successive graphs, with the second graph using the first graph's result. Try with and without sharing a pool. Check results, also check memory stats to confirm sharing a pool saves memory. - [x] `test_graph_concurrent_replay`: Capture some graphs in separate private pools, replay them concurrently in different streams, check the results to make sure they don't corrupt each other's memory. Capture some graphs with a shared pool, replay them concurrently in different streams, check results, confirm they DO corrupt each other's memory. - [x] `test_graph_three_successive`: A three-graph case, checking the safe and unsafe replay patterns in [Restrictions of the Strawman API](pytorch#51075)). - [x] `test_graph_memory_stats_and_use_result_after_destroy_graph`: Comprehensively check torch.cuda.memory_stats() changes that result from graph capture and delete. Check that a tensor ref created during capture and held after graph delete stays valid until the tensor itself is deleted. Pull Request resolved: pytorch#51436 Reviewed By: mruberry Differential Revision: D26993790 Pulled By: ngimel fbshipit-source-id: a992eaee1b8c23628e7b388a5a3c26e0f80e54da
Part 2 of #46148 refactor. (part 1 was #48694.)
Contains
Diffs compile and tests pass on my machine (ubuntu 20.04, cuda 11.0) but it needs finetuning for many CI builds.
See Note [CUDA Graph-safe RNG states] for the strategy, based on #46148 (comment).