-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[CUDA graphs] Cuda RNG-safe graph capture and replay bindings #48875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
652f7e2
8b64bb6
26b07da
820c753
31a26ee
dbb24ba
160b42e
9ae4fd9
1386e1b
e0bb708
abe0289
f61bbe2
768123e
85f767f
fbb4c43
346279a
158223b
6980b84
fbf0089
dd604c5
c7c87d4
83f0e81
bd44b17
49d8a19
fa1ac41
1d9a903
12ac954
b374e76
1d69b27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
#include <ATen/cuda/Exceptions.h> | ||
#include <ATen/cuda/CUDAGraph.h> | ||
#include <ATen/Functions.h> | ||
#include <c10/cuda/CUDAFunctions.h> | ||
|
||
namespace at { | ||
namespace cuda { | ||
|
||
/** | ||
* Note [CUDA Graph Wrapper Class] | ||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
* Q: Why do we need graph capture and launch bindings in Pytorch? | ||
* Why can't they live in a user extension, for example? | ||
* | ||
* A1: Convenience. | ||
* A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with | ||
* CPU statefulness) need cooperation from the capture and replay bindings | ||
* (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h). | ||
* | ||
* We can't expect users to know about this cooperation. If users write capture | ||
* bindings naively in an extension, they likely won't interact with the native | ||
* ops properly. Their graphs would yield invalid numerics on replay. | ||
*/ | ||
|
||
CUDAGraph::CUDAGraph() | ||
// CUDAStreams may not be default-constructed. | ||
: capture_stream_(at::cuda::getCurrentCUDAStream()) { | ||
#if CUDA_VERSION < 11000 | ||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); | ||
#endif | ||
} | ||
|
||
void CUDAGraph::capture_begin() { | ||
#if CUDA_VERSION >= 11000 | ||
TORCH_CHECK(!has_graph_exec_, | ||
"This CUDAGraph instance already owns a captured graph. " | ||
"To capture a new graph, create a new instance."); | ||
|
||
// For now, a CUDAGraph instance only accommodates the default generator on the device that's | ||
// current when capture begins. If any op in the captured region uses a non-default generator, | ||
// or a generator on another device, the offending generator will throw an error. | ||
// These restrictions simplify CUDAGraph, but could be relaxed in the future: | ||
// in principle, the underlying Cuda calls do permit cross-device ops to be captured. | ||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>( | ||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); | ||
|
||
auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong); | ||
offset_extragraph_ = at::empty({1}, options); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you sure you want empty here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it doesn't matter if we run on garbage, all we're trying to do is bake in the right pointers. If there's data dependent control flow in the graphed region, we have no business graphing it anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, deliberately beginning capture with garbage here will stress any unwanted/unexpected data dependence, and help catch failures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, half the time the garbage is going to be zeros, if you really want to find unexpected data dependence, should fill it with some sentinel. (But duly taken) |
||
|
||
gen->capture_prologue(offset_extragraph_.data_ptr<int64_t>()); | ||
|
||
auto stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(), | ||
"CUDA graphs must be captured on a non-default stream. " | ||
"(However, after capture, it's ok to replay them on the " | ||
"default stream.)"); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
capture_stream_ = stream; | ||
capture_gen_ = gen; | ||
|
||
// cudaStreamCaptureModeGlobal is the most conservative option to | ||
// prevent potentially unsafe CUDA API calls during capture. See | ||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 | ||
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); | ||
|
||
// Stashes the current graph's uuid. | ||
cudaStreamCaptureStatus status; | ||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_)); | ||
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); | ||
#else | ||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); | ||
#endif | ||
} | ||
|
||
void CUDAGraph::capture_end() { | ||
#if CUDA_VERSION >= 11000 | ||
auto stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
TORCH_CHECK(stream == capture_stream_, | ||
"Capture must end on the same stream it began on."); | ||
|
||
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); | ||
TORCH_CHECK(graph_ != NULL, "Invalid capture."); | ||
has_graph_ = true; | ||
|
||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, | ||
// who prefer not to report error message through these arguments moving forward | ||
// (they prefer return value, or errors on api calls internal to the capture) | ||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); | ||
has_graph_exec_ = true; | ||
|
||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>( | ||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); | ||
TORCH_CHECK(gen == capture_gen_, | ||
"Default CUDA RNG generator on current device at capture end " | ||
"is different from default generator on current device " | ||
"when capture began"); | ||
wholegraph_increment_ = gen->capture_epilogue(); | ||
|
||
// Now that we've instantiated graph_ into graph_exec_, | ||
// we don't need graph_ anymore. | ||
AT_CUDA_CHECK(cudaGraphDestroy(graph_)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe wrap graph into some RAII guard? Otherwise this will be leaking if some of the functions above error out. Or is it handled in CUDAGraph destructor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also handled in the CUDAGraph destructor, but I don't like that (and neither does the compiler) because it can throw an exception. Local RAII guard here for But There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uh oh, using an RAII guard for graph_ means I call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cuda 11_1 build refused to compile with exceptions in destructors. Moved graph cleanup to a reset() method called by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, just ignore the error code, we do this in a few other places too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? Do you mean I should go back to cleanup in the destructor, rather than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, other places where we are ignoring the error happen during final cleanup when the process exits (cublas handles and the like), so who cares, whereas graph desctructor should be called during runtime more or less regularly, so I agree ignoring the code is not ideal. I vote for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per slack discussion, moved cleanup back to C++ destructor, but with warnings instead of throwing checks. |
||
has_graph_ = false; | ||
#else | ||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); | ||
#endif | ||
} | ||
|
||
void CUDAGraph::replay() { | ||
#if CUDA_VERSION >= 11000 | ||
TORCH_CHECK(has_graph_exec_, | ||
"Called CUDAGraph::replay without a preceding successful capture."); | ||
|
||
{ | ||
c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; | ||
|
||
// Just like any RNG consumer kernel! | ||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>( | ||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); | ||
PhiloxCudaState rng_engine_inputs; | ||
{ | ||
std::lock_guard<std::mutex> lock(gen->mutex_); | ||
rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_); | ||
} | ||
offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val)); | ||
|
||
// graph_exec_ may be replayed in any stream. | ||
AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream())); | ||
} | ||
#else | ||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); | ||
#endif | ||
} | ||
|
||
void CUDAGraph::reset() { | ||
#if CUDA_VERSION >= 11000 | ||
// I'd prefer these checks throw exceptions, not print warnings, | ||
// but the destructor calls reset(), and at least one CI build | ||
// refuses to compile with a throwing destructor. | ||
// | ||
// Instead of calling reset() in the destructor to clean up, I could | ||
// call reset() in the __del__ method of a thin Python wrapper, | ||
// in which case reset would be allowed to throw exceptions. | ||
// But Stackoverflow does not like user-defined __del__. | ||
// __del__ prevents Graph instances from EVER being garbage collected | ||
// if they participate in a reference cycle. | ||
// And exceptions thrown in __del__ only print a warning anyway. | ||
// | ||
// Calling reset() in the C++ destructor, with warnings instead of exceptions | ||
// if calls fail, is the compromise we chose. | ||
if (has_graph_) { | ||
C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_)); | ||
} | ||
if (has_graph_exec_) { | ||
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_)); | ||
} | ||
#else | ||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); | ||
#endif | ||
} | ||
|
||
CUDAGraph::~CUDAGraph() { | ||
reset(); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <ATen/Tensor.h> | ||
#include <c10/core/Device.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <ATen/CUDAGeneratorImpl.h> | ||
|
||
namespace at { | ||
namespace cuda { | ||
|
||
struct TORCH_CUDA_API CUDAGraph { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably should have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
CUDAGraph(); | ||
~CUDAGraph(); | ||
|
||
void capture_begin(); | ||
void capture_end(); | ||
void replay(); | ||
void reset(); | ||
|
||
protected: | ||
#if CUDA_VERSION >= 11000 | ||
cudaGraph_t graph_ = NULL; | ||
cudaGraphExec_t graph_exec_ = NULL; | ||
#endif | ||
|
||
// internal states for error checking | ||
bool has_graph_ = false; | ||
bool has_graph_exec_ = false; | ||
|
||
// uuid, retrieved from Cuda | ||
unsigned long long id_; | ||
|
||
// Stream on which capture began | ||
at::cuda::CUDAStream capture_stream_; | ||
|
||
// Default generator on device where capture began | ||
at::CUDAGeneratorImpl* capture_gen_; | ||
|
||
// RNG state trackers | ||
at::Tensor offset_extragraph_; | ||
uint64_t wholegraph_increment_; | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace at |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you have multiple captures going on at the same time?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point.
I think the best approach is to makenvm my idea is nonsense. There are ways I can try to make this safe which we can discuss in more detail, but I can't think of a simple one. The simplest answer is to disallow that usage: only one capture may be underway at a time.graph_expects_this_gen_
an argument to those consumer-called members.