11#include < ATen/cuda/CUDAGreenContext.h>
22
3- namespace at ::cuda {
4- GreenContext::GreenContext (uint32_t device_id, uint32_t num_sms) {
5- #if CUDA_HAS_GREEN_CONTEXT
6- int driver_version;
7- C10_CUDA_CHECK (cudaDriverGetVersion (&driver_version));
8- TORCH_CHECK (
9- driver_version >= 12080 , " cuda driver too old to use green context!" );
10- CUcontext pctx = nullptr ;
11- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuCtxGetCurrent_ (&pctx));
12- if (C10_UNLIKELY (!pctx)) {
13- TORCH_WARN (
14- " Attempted to create a green context but"
15- " there was no primary context! Creating a primary context..." );
16-
17- cudaFree (0 );
18- }
19-
20- CUdevice device;
21- device_id_ = device_id;
22- C10_CUDA_DRIVER_CHECK (
23- c10::cuda::DriverAPI::get ()->cuDeviceGet_ (&device, device_id));
24-
25- // Get device resources
26- CUdevResource device_resource;
27- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuDeviceGetDevResource_ (
28- device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
29-
30- // Split resources
31- std::vector<CUdevResource> result (1 );
32- auto result_data = result.data ();
33- unsigned int nb_groups = 1 ;
34- CUdevResource remaining;
3+ #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
4+ #include < c10/cuda/driver_api.h>
5+ #include < stdexcept>
6+ #include < vector>
7+ #define HAS_CUDA_GREEN_CONTEXT () 1
8+ #else
9+ #define HAS_CUDA_GREEN_CONTEXT () 0
10+ // Suppress unsued private field warnings as this class is not supposed to be called
11+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED (" -Wunused-private-field" )
12+ #endif
3513
36- C10_CUDA_DRIVER_CHECK (
37- c10::cuda::DriverAPI::get ()->cuDevSmResourceSplitByCount_ (
38- result_data,
39- &nb_groups,
40- &device_resource,
41- &remaining,
42- 0 , // default flags
43- num_sms));
44-
45- TORCH_CHECK (nb_groups == 1 , " Failed to create single resource group" );
46-
47- // Generate resource descriptor
48- CUdevResourceDesc desc;
49- C10_CUDA_DRIVER_CHECK (
50- c10::cuda::DriverAPI::get ()->cuDevResourceGenerateDesc_ (
51- &desc, result_data, 1 ));
14+ namespace at ::cuda {
5215
53- // Create green context
54- // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
55- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
56- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuGreenCtxCreate_ (
57- &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
16+ GreenContext::GreenContext (uint32_t device_id, uint32_t num_sms) {
17+ #if HAS_CUDA_GREEN_CONTEXT()
18+ int driver_version;
19+ C10_CUDA_CHECK (cudaDriverGetVersion (&driver_version));
20+ TORCH_CHECK (
21+ driver_version >= 12080 , " cuda driver too old to use green context!" );
22+ CUcontext pctx = nullptr ;
23+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuCtxGetCurrent_ (&pctx));
24+ if (C10_UNLIKELY (!pctx)) {
25+ TORCH_WARN (
26+ " Attempted to create a green context but"
27+ " there was no primary context! Creating a primary context..." );
28+
29+ cudaFree (0 );
30+ }
5831
59- // Convert to regular context
60- C10_CUDA_DRIVER_CHECK (
61- c10::cuda::DriverAPI::get ()->cuCtxFromGreenCtx_ (&context_, green_ctx_));
62- TORCH_CHECK (context_, " Green ctx conversion to regular ctx failed!" );
32+ CUdevice device;
33+ device_id_ = device_id;
34+ C10_CUDA_DRIVER_CHECK (
35+ c10::cuda::DriverAPI::get ()->cuDeviceGet_ (&device, device_id));
36+
37+ // Get device resources
38+ CUdevResource device_resource;
39+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuDeviceGetDevResource_ (
40+ device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
41+
42+ // Split resources
43+ std::vector<CUdevResource> result (1 );
44+ auto result_data = result.data ();
45+ unsigned int nb_groups = 1 ;
46+ CUdevResource remaining;
47+
48+ C10_CUDA_DRIVER_CHECK (
49+ c10::cuda::DriverAPI::get ()->cuDevSmResourceSplitByCount_ (
50+ result_data,
51+ &nb_groups,
52+ &device_resource,
53+ &remaining,
54+ 0 , // default flags
55+ num_sms));
56+
57+ TORCH_CHECK (nb_groups == 1 , " Failed to create single resource group" );
58+
59+ // Generate resource descriptor
60+ CUdevResourceDesc desc;
61+ C10_CUDA_DRIVER_CHECK (
62+ c10::cuda::DriverAPI::get ()->cuDevResourceGenerateDesc_ (
63+ &desc, result_data, 1 ));
64+
65+ // Create green context
66+ // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
67+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
68+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuGreenCtxCreate_ (
69+ &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
70+
71+ // Convert to regular context
72+ C10_CUDA_DRIVER_CHECK (
73+ c10::cuda::DriverAPI::get ()->cuCtxFromGreenCtx_ (&context_, green_ctx_));
74+ TORCH_CHECK (context_, " Green ctx conversion to regular ctx failed!" );
6375#else
64- TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
76+ TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
6577#endif
6678 }
6779
6880 std::unique_ptr<GreenContext> GreenContext::create (
6981 uint32_t num_sms,
7082 std::optional<uint32_t > device_id) {
71- #if CUDA_HAS_GREEN_CONTEXT
83+ #if HAS_CUDA_GREEN_CONTEXT()
7284 if (!device_id.has_value ()) {
7385 device_id = at::cuda::current_device ();
7486 }
75- return std::make_unique <GreenContext>(device_id.value (), num_sms);
87+ return std::unique_ptr <GreenContext>(new GreenContext ( device_id.value (), num_sms) );
7688#else
7789 TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
7890#endif
7991 }
8092
8193 // Implement move operations
8294 GreenContext::GreenContext (GreenContext&& other) noexcept {
83- #if CUDA_HAS_GREEN_CONTEXT
95+ #if HAS_CUDA_GREEN_CONTEXT()
8496 device_id_ = std::exchange (other.device_id_ , -1 );
8597 green_ctx_ = std::exchange (other.green_ctx_ , nullptr );
8698 context_ = std::exchange (other.context_ , nullptr );
@@ -91,7 +103,7 @@ namespace at::cuda {
91103 }
92104
93105 GreenContext& GreenContext::operator =(GreenContext&& other) noexcept {
94- #if CUDA_HAS_GREEN_CONTEXT
106+ #if HAS_CUDA_GREEN_CONTEXT()
95107 if (this != &other) {
96108 // Clean up current resources
97109 if (green_ctx_) {
@@ -120,33 +132,17 @@ namespace at::cuda {
120132 }
121133
122134 GreenContext::~GreenContext () noexcept {
123- #if CUDA_HAS_GREEN_CONTEXT
135+ #if HAS_CUDA_GREEN_CONTEXT()
124136 C10_CUDA_DRIVER_CHECK (
125137 c10::cuda::DriverAPI::get ()->cuGreenCtxDestroy_ (green_ctx_));
126138#else
127139 TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
128140#endif
129141 }
130142
131- // Get the underlying CUDA context
132- CUcontext GreenContext::getContext () const {
133- #if CUDA_HAS_GREEN_CONTEXT
134- return context_;
135- #else
136- TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
137- #endif
138- }
139-
140- // Get the underlying green context
141- #if CUDA_HAS_GREEN_CONTEXT
142- CUgreenCtx GreenContext::getGreenContext () const {
143- return green_ctx_;
144- }
145- #endif
146-
147143 // Make this context current
148144 void GreenContext::setContext () {
149- #if CUDA_HAS_GREEN_CONTEXT
145+ #if HAS_CUDA_GREEN_CONTEXT()
150146 auto current_stream = c10::cuda::getCurrentCUDAStream ();
151147 parent_stream_ = current_stream.stream ();
152148
@@ -175,7 +171,7 @@ namespace at::cuda {
175171 }
176172
177173 void GreenContext::popContext () {
178- #if CUDA_HAS_GREEN_CONTEXT
174+ #if HAS_CUDA_GREEN_CONTEXT()
179175 // see above note about stream being hardcoded to the default stream
180176 at::cuda::CUDAEvent ev;
181177 ev.record (c10::cuda::getCurrentCUDAStream ());
0 commit comments