Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
652f7e2
possibly best approach
Nov 24, 2020
8b64bb6
Merge remote-tracking branch 'upstream/master' into cudagraphs_bindings
Nov 24, 2020
26b07da
Tentative files i'll need
Nov 25, 2020
820c753
Works on my machine (for minimal example)
Nov 26, 2020
31a26ee
Tentative tests
Dec 1, 2020
dbb24ba
Merge remote-tracking branch 'upstream/master' into cudagraphs_bindings
Dec 1, 2020
160b42e
steps in the right direction
Dec 1, 2020
9ae4fd9
compiles
Dec 3, 2020
1386e1b
Tests work!
Dec 3, 2020
e0bb708
Distributions test passes
Dec 4, 2020
abe0289
merging master
Dec 5, 2020
f61bbe2
no need to include cuda.h in CUDAGeneratorImpl.cpp
Dec 5, 2020
768123e
Error-check gen, skip tests for Cuda < 11
Dec 8, 2020
85f767f
Merge remote-tracking branch 'upstream/master' into cudagraphs_bindings
Dec 8, 2020
fbb4c43
Only import torch.cuda.Graph for cuda>=11, maybe this clears some tes…
Dec 8, 2020
346279a
two ed comments
Dec 9, 2020
158223b
clever hansing CudaGraphBase import failures in CI
Dec 9, 2020
6980b84
more clever hans-ing: move test_graph* to TestCuda
Dec 9, 2020
fbf0089
clever hans round 3
Dec 9, 2020
dd604c5
clever hans round 4
Dec 9, 2020
c7c87d4
clever hans round 5: add reset() to __init__.pyi.in
Dec 10, 2020
83f0e81
reset takes self
Dec 10, 2020
bd44b17
Merge remote-tracking branch 'upstream/master' into cudagraphs_bindings
Dec 10, 2020
49d8a19
Merge remote-tracking branch 'upstream/master' into cudagraphs_bindings
Dec 10, 2020
fa1ac41
looks like skipIfRocm did not, in fact, skip if rocm
Dec 10, 2020
1d9a903
explain __del__
Dec 10, 2020
12ac954
graph destruction in destructor with warnings
Dec 10, 2020
b374e76
typo
Dec 11, 2020
1d69b27
s/module/module_
Dec 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ filegroup(
"aten/src/ATen/cuda/CUDABlas.cpp",
"aten/src/ATen/cuda/CUDASolver.cpp",
"aten/src/ATen/cuda/CUDAContext.cpp",
"aten/src/ATen/cuda/CUDAGenerator.cpp",
"aten/src/ATen/cuda/CUDAGeneratorImpl.cpp",
"aten/src/ATen/cuda/CUDAGraph.cpp",
"aten/src/ATen/cuda/CuSparseHandlePool.cpp",
"aten/src/ATen/cuda/CublasHandlePool.cpp",
"aten/src/ATen/cuda/CusolverDnHandlePool.cpp",
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/CUDAGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
uint64_t seed() override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread();
void graph_prologue(int64_t* offset_extragraph);
uint64_t graph_epilogue();
void capture_prologue(int64_t* offset_extragraph);
uint64_t capture_epilogue();
PhiloxCudaState philox_cuda_state(uint64_t increment);

// Temporarily accommodates call sites that use philox_engine_inputs.
Expand All @@ -147,6 +147,7 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
uint64_t philox_offset_per_thread_ = 0;
int64_t* offset_extragraph_;
uint32_t offset_intragraph_ = 0;
bool graph_expects_this_gen_ = false;
};

namespace cuda {
Expand Down
45 changes: 21 additions & 24 deletions aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Generator createCUDAGenerator(DeviceIndex device_index) {
*/
CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
: c10::GeneratorImpl{Device(DeviceType::CUDA, device_index),
DispatchKeySet(c10::DispatchKey::CUDA)} {
DispatchKeySet(c10::DispatchKey::CUDA)} {
at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl");
}

Expand All @@ -101,20 +101,18 @@ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
}

#define CAPTURE_DEFAULT_GENS_MSG \
"Non-default (user-constructed) CUDA RNG generators cannot be used " \
"in regions captured by CUDA graphs. " \
"If you need a non-default CUDA generator in a captured region, " \
"please file an issue."
"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \
"generator on the device that's current when capture begins. " \
"If you need a non-default (user-supplied) generator, or a generator on another " \
"device, please file an issue."

/**
* Gets the current seed of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::current_seed() const {
TORCH_CHECK((at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None) ||
((void*)this ==
(void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index())),
CAPTURE_DEFAULT_GENS_MSG);
// Debatable if current_seed() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
return seed_;
}

Expand Down Expand Up @@ -151,25 +149,21 @@ uint64_t CUDAGeneratorImpl::philox_offset_per_thread() {
}

/**
* Prepares this instance for a cuda graph capture region.
* Called by CUDAGraph to prepare this instance for a graph capture region.
* offset_extragraph is the initial offset at the start of the graphed region.
* offset_intragraph tracks the offset in the graphed region.
*/
void CUDAGeneratorImpl::graph_prologue(int64_t* offset_extragraph) {
TORCH_CHECK((void*)this ==
(void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()),
CAPTURE_DEFAULT_GENS_MSG);
void CUDAGeneratorImpl::capture_prologue(int64_t* offset_extragraph) {
offset_extragraph_ = offset_extragraph;
offset_intragraph_ = 0;
graph_expects_this_gen_ = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you have multiple captures going on at the same time?

Copy link
Collaborator Author

@mcarilli mcarilli Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think the best approach is to make graph_expects_this_gen_ an argument to those consumer-called members. nvm my idea is nonsense. There are ways I can try to make this safe which we can discuss in more detail, but I can't think of a simple one. The simplest answer is to disallow that usage: only one capture may be underway at a time.

}

/**
* Finalizes a cuda graph capture region for this instance.
* Called by CUDAGraph to finalize a graph capture region for this instance.
*/
uint64_t CUDAGeneratorImpl::graph_epilogue() {
TORCH_CHECK((void*)this ==
(void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()),
CAPTURE_DEFAULT_GENS_MSG);
uint64_t CUDAGeneratorImpl::capture_epilogue() {
graph_expects_this_gen_ = false;
return offset_intragraph_;
}

Expand All @@ -187,7 +181,7 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() {
* it intends to generate.
*
* Increment should be at least the number of curand() random numbers used in
* each thread. It is the user's responsibility to make sure that the increment
* each thread. It is the user's responsibility to make sure the increment
* for philox is never smaller than the number of curand() calls. Increment
* value > the number of curand() calls won't harm but anything less would mean
* that you would be reusing random values from previous calls.
Expand All @@ -196,17 +190,20 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() {
*/
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
TORCH_CHECK((void*)this ==
(void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()),
TORCH_CHECK(graph_expects_this_gen_,
"philox_cuda_state for an unexpected CUDA generator used during capture. "
CAPTURE_DEFAULT_GENS_MSG);
uint32_t offset = this->offset_intragraph_;
TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <=
std::numeric_limits<uint32_t>::max() - increment);
std::numeric_limits<uint32_t>::max() - increment);
this->offset_intragraph_ += increment;
return PhiloxCudaState(this->seed_,
this->offset_extragraph_,
offset);
} else {
TORCH_CHECK(!graph_expects_this_gen_,
"CUDA generator expects graph capture to be underway, "
"but the current stream is not capturing.");
uint64_t offset = this->philox_offset_per_thread_;
this->philox_offset_per_thread_ += increment;
return PhiloxCudaState(this->seed_, offset);
Expand Down
168 changes: 168 additions & 0 deletions aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDAFunctions.h>

namespace at {
namespace cuda {

/**
* Note [CUDA Graph Wrapper Class]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Q: Why do we need graph capture and launch bindings in Pytorch?
* Why can't they live in a user extension, for example?
*
* A1: Convenience.
* A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with
* CPU statefulness) need cooperation from the capture and replay bindings
* (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h).
*
* We can't expect users to know about this cooperation. If users write capture
* bindings naively in an extension, they likely won't interact with the native
* ops properly. Their graphs would yield invalid numerics on replay.
*/

CUDAGraph::CUDAGraph()
// CUDAStreams may not be default-constructed.
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
#if CUDA_VERSION < 11000
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
}

void CUDAGraph::capture_begin() {
#if CUDA_VERSION >= 11000
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
"To capture a new graph, create a new instance.");

// For now, a CUDAGraph instance only accommodates the default generator on the device that's
// current when capture begins. If any op in the captured region uses a non-default generator,
// or a generator on another device, the offending generator will throw an error.
// These restrictions simplify CUDAGraph, but could be relaxed in the future:
// in principle, the underlying Cuda calls do permit cross-device ops to be captured.
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());

auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong);
offset_extragraph_ = at::empty({1}, options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you sure you want empty here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it doesn't matter if we run on garbage, all we're trying to do is bake in the right pointers. If there's data dependent control flow in the graphed region, we have no business graphing it anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, deliberately beginning capture with garbage here will stress any unwanted/unexpected data dependence, and help catch failures.

Copy link
Contributor

@ezyang ezyang Dec 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, half the time the garbage is going to be zeros, if you really want to find unexpected data dependence, should fill it with some sentinel. (But duly taken)


gen->capture_prologue(offset_extragraph_.data_ptr<int64_t>());

auto stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(),
"CUDA graphs must be captured on a non-default stream. "
"(However, after capture, it's ok to replay them on the "
"default stream.)");

capture_stream_ = stream;
capture_gen_ = gen;

// cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));

// Stashes the current graph's uuid.
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_));
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
}

void CUDAGraph::capture_end() {
#if CUDA_VERSION >= 11000
auto stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(stream == capture_stream_,
"Capture must end on the same stream it began on.");

AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
TORCH_CHECK(graph_ != NULL, "Invalid capture.");
has_graph_ = true;

// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
has_graph_exec_ = true;

auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
TORCH_CHECK(gen == capture_gen_,
"Default CUDA RNG generator on current device at capture end "
"is different from default generator on current device "
"when capture began");
wholegraph_increment_ = gen->capture_epilogue();

// Now that we've instantiated graph_ into graph_exec_,
// we don't need graph_ anymore.
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe wrap graph into some RAII guard? Otherwise this will be leaking if some of the functions above error out. Or is it handled in CUDAGraph destructor?

Copy link
Collaborator Author

@mcarilli mcarilli Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also handled in the CUDAGraph destructor, but I don't like that (and neither does the compiler) because it can throw an exception.

Local RAII guard here for graph_ is a good idea and lets me take cudaGraphDestroy out of the destructor. I'll do it.

But graph_exec_ is a different story. I don't see how to clean up graph_exec_ without cudaGraphExecDestroy in either the destructor or a del method in Python, and neither is a great option.

Copy link
Collaborator Author

@mcarilli mcarilli Dec 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh oh, using an RAII guard for graph_ means I call AT_CUDA_CHECK(cudaGraphDestroy(graph_)); in the guard destructor instead of the CUDAGraph destructor. I can't win here 😂 🔫

Copy link
Collaborator Author

@mcarilli mcarilli Dec 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cuda 11_1 build refused to compile with exceptions in destructors. Moved graph cleanup to a reset() method called by __del__ in a thin Python wrapper (__del__ may or may not be a lesser evil than a throwing destructor, but it's the evil CI allows).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just ignore the error code, we do this in a few other places too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? Do you mean I should go back to cleanup in the destructor, rather than __del__? If so, how do I tell the CI build (I believe it was pytorch_libtorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build, possibly others) not to error on the throwing destructor? Or should I suppress the exception with a try catch? I don't want to do the latter: silently ignoring errors is worse than catastrophically failing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, other places where we are ignoring the error happen during final cleanup when the process exits (cublas handles and the like), so who cares, whereas graph desctructor should be called during runtime more or less regularly, so I agree ignoring the code is not ideal. I vote for __del__ unless there's a better solution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per slack discussion, moved cleanup back to C++ destructor, but with warnings instead of throwing checks.

has_graph_ = false;
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
}

void CUDAGraph::replay() {
#if CUDA_VERSION >= 11000
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::replay without a preceding successful capture.");

{
c10::OptionalDeviceGuard device_guard{capture_stream_.device()};

// Just like any RNG consumer kernel!
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
PhiloxCudaState rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_);
}
offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val));

// graph_exec_ may be replayed in any stream.
AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream()));
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
}

void CUDAGraph::reset() {
#if CUDA_VERSION >= 11000
// I'd prefer these checks throw exceptions, not print warnings,
// but the destructor calls reset(), and at least one CI build
// refuses to compile with a throwing destructor.
//
// Instead of calling reset() in the destructor to clean up, I could
// call reset() in the __del__ method of a thin Python wrapper,
// in which case reset would be allowed to throw exceptions.
// But Stackoverflow does not like user-defined __del__.
// __del__ prevents Graph instances from EVER being garbage collected
// if they participate in a reference cycle.
// And exceptions thrown in __del__ only print a warning anyway.
//
// Calling reset() in the C++ destructor, with warnings instead of exceptions
// if calls fail, is the compromise we chose.
if (has_graph_) {
C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_));
}
if (has_graph_exec_) {
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
}

CUDAGraph::~CUDAGraph() {
reset();
}

} // namespace cuda
} // namespace at
43 changes: 43 additions & 0 deletions aten/src/ATen/cuda/CUDAGraph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/CUDAGeneratorImpl.h>

namespace at {
namespace cuda {

struct TORCH_CUDA_API CUDAGraph {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should have a // See Note [CUDA Graph... reference somewhere here; probably framed as "why does PyTorch need its own CUDAGraph rep"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDAGraph();
~CUDAGraph();

void capture_begin();
void capture_end();
void replay();
void reset();

protected:
#if CUDA_VERSION >= 11000
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
#endif

// internal states for error checking
bool has_graph_ = false;
bool has_graph_exec_ = false;

// uuid, retrieved from Cuda
unsigned long long id_;

// Stream on which capture began
at::cuda::CUDAStream capture_stream_;

// Default generator on device where capture began
at::CUDAGeneratorImpl* capture_gen_;

// RNG state trackers
at::Tensor offset_extragraph_;
uint64_t wholegraph_increment_;
};

} // namespace cuda
} // namespace at
4 changes: 3 additions & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
int ctx_is_active;
// In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
// (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
int ctx_is_active = 0;
AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
return ctx_is_active == 1;
}
Expand Down
Loading