8
8
import cuda .core .experimental
9
9
10
10
import dask
11
- from distributed .diagnostics .nvml import get_device_index_and_uuid , has_cuda_context
11
+ from distributed .diagnostics .nvml import (
12
+ CudaDeviceInfo ,
13
+ get_device_index_and_uuid ,
14
+ has_cuda_context ,
15
+ )
12
16
13
17
from .utils import get_ucx_config
14
18
15
19
logger = logging .getLogger (__name__ )
16
20
17
21
22
+ pre_existing_cuda_context = None
23
+ cuda_context_created = None
24
+
25
+
26
+ _warning_suffix = (
27
+ "This is often the result of a CUDA-enabled library calling a CUDA runtime "
28
+ "function before Dask-CUDA can spawn worker processes. Please make sure any such "
29
+ "function calls don't happen at import time or in the global scope of a program."
30
+ )
31
+
32
+
33
+ def _get_device_and_uuid_str (device_info : CudaDeviceInfo ) -> str :
34
+ return f"{ device_info .device_index } ({ str (device_info .uuid )} )"
35
+
36
+
37
+ def _warn_existing_cuda_context (device_info : CudaDeviceInfo , pid : int ) -> None :
38
+ device_uuid_str = _get_device_and_uuid_str (device_info )
39
+ logger .warning (
40
+ f"A CUDA context for device { device_uuid_str } already exists "
41
+ f"on process ID { pid } . { _warning_suffix } "
42
+ )
43
+
44
+
45
+ def _warn_cuda_context_wrong_device (
46
+ device_info_expected : CudaDeviceInfo , device_info_actual : CudaDeviceInfo , pid : int
47
+ ) -> None :
48
+ expected_device_uuid_str = _get_device_and_uuid_str (device_info_expected )
49
+ actual_device_uuid_str = _get_device_and_uuid_str (device_info_actual )
50
+ logger .warning (
51
+ f"Worker with process ID { pid } should have a CUDA context assigned to device "
52
+ f"{ expected_device_uuid_str } , but instead the CUDA context is on device "
53
+ f"{ actual_device_uuid_str } . { _warning_suffix } "
54
+ )
55
+
56
+
18
57
def _create_cuda_context_handler ():
19
58
if int (os .environ .get ("DASK_CUDA_TEST_SINGLE_GPU" , "0" )) != 0 :
20
59
try :
@@ -25,98 +64,76 @@ def _create_cuda_context_handler():
25
64
cuda .core .experimental .Device ().set_current ()
26
65
27
66
28
- def _warn_generic ():
29
- try :
30
- # TODO: update when UCX-Py is removed, see
31
- # https://github.com/rapidsai/dask-cuda/issues/1517
32
- import distributed .comm .ucx
67
+ def _create_cuda_context_and_warn ():
68
+ """Create CUDA context and warn depending on certain conditions.
33
69
34
- # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
35
- # context directly from the UCX module, thus avoiding a similar warning there.
36
- cuda_visible_device = get_device_index_and_uuid (
37
- os .environ .get ("CUDA_VISIBLE_DEVICES" , "0" ).split ("," )[0 ]
38
- )
39
- ctx = has_cuda_context ()
40
- if (
41
- ctx .has_context
42
- and not distributed .comm .ucx .cuda_context_created .has_context
43
- ):
44
- distributed .comm .ucx ._warn_existing_cuda_context (ctx , os .getpid ())
45
-
46
- _create_cuda_context_handler ()
47
-
48
- if (
49
- distributed .comm .ucx .cuda_context_created is False
50
- or distributed .comm .ucx .cuda_context_created .has_context
51
- ):
52
- ctx = has_cuda_context ()
53
- if ctx .has_context and ctx .device_info != cuda_visible_device :
54
- distributed .comm .ucx ._warn_cuda_context_wrong_device (
55
- cuda_visible_device , ctx .device_info , os .getpid ()
56
- )
70
+ Warns if a pre-existing CUDA context already existed or if the resulting CUDA
71
+ context was created in the wrong device.
57
72
58
- except Exception :
59
- logger .error ("Unable to start CUDA Context" , exc_info = True )
73
+ This function is almost an identical duplicate from
74
+ `distributed_ucxx.ucxx.init_once`, the duplication is necessary because Dask-CUDA
75
+ needs to support `protocol="tcp"` as well, even when distributed-ucxx is not
76
+ installed, but this here runs _after_ comms have started, which is fine for TCP
77
+ because the time when CUDA context is created is not important. The code needs to
78
+ live also in distributed-ucxx because there the time when a CUDA context is created
79
+ matters, and it needs to happen _before_ UCX is initialized, but comms in
80
+ Distributed is initialized before preload, and thus only after this function
81
+ executes.
60
82
83
+ Raises
84
+ ------
85
+ Exception
86
+ If anything wrong happened during context initialization.
61
87
62
- def _initialize_ucx ():
63
- try :
64
- import distributed .comm .ucx
88
+ Returns
89
+ -------
90
+ None
91
+ """
92
+ global pre_existing_cuda_context , cuda_context_created
65
93
66
- distributed .comm .ucx .init_once ()
67
- except ModuleNotFoundError :
68
- # UCX initialization has to be delegated to Distributed, it will take care
69
- # of setting correct environment variables and importing `ucp` after that.
70
- # Therefore if ``import ucp`` fails we can just continue here.
71
- pass
94
+ cuda_visible_device = get_device_index_and_uuid (
95
+ os .environ .get ("CUDA_VISIBLE_DEVICES" , "0" ).split ("," )[0 ]
96
+ )
97
+ pre_existing_cuda_context = has_cuda_context ()
98
+ if pre_existing_cuda_context .has_context :
99
+ _warn_existing_cuda_context (pre_existing_cuda_context .device_info , os .getpid ())
100
+
101
+ _create_cuda_context_handler ()
102
+
103
+ cuda_context_created = has_cuda_context ()
104
+ if (
105
+ cuda_context_created .has_context
106
+ and cuda_context_created .device_info .uuid != cuda_visible_device .uuid
107
+ ):
108
+ _warn_cuda_context_wrong_device (
109
+ cuda_visible_device , cuda_context_created .device_info , os .getpid ()
110
+ )
72
111
73
112
74
- def _initialize_ucxx ():
113
+ def _create_cuda_context ():
75
114
try :
76
115
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
77
116
# context directly from the UCX module, thus avoiding a similar warning there.
78
117
import distributed_ucxx .ucxx
118
+ except ImportError :
119
+ pass
120
+ else :
121
+ if distributed_ucxx .ucxx .ucxx is not None :
122
+ # UCXX has already initialized (and warned if necessary)
123
+ return
79
124
80
- distributed_ucxx .ucxx .init_once ()
81
-
82
- cuda_visible_device = get_device_index_and_uuid (
83
- os .environ .get ("CUDA_VISIBLE_DEVICES" , "0" ).split ("," )[0 ]
84
- )
85
- ctx = has_cuda_context ()
86
- if (
87
- ctx .has_context
88
- and not distributed_ucxx .ucxx .cuda_context_created .has_context
89
- ):
90
- distributed_ucxx .ucxx ._warn_existing_cuda_context (ctx , os .getpid ())
91
-
92
- _create_cuda_context_handler ()
93
-
94
- if not distributed_ucxx .ucxx .cuda_context_created .has_context :
95
- ctx = has_cuda_context ()
96
- if ctx .has_context and ctx .device_info != cuda_visible_device :
97
- distributed_ucxx .ucxx ._warn_cuda_context_wrong_device (
98
- cuda_visible_device , ctx .device_info , os .getpid ()
99
- )
100
-
125
+ try :
126
+ _create_cuda_context_and_warn ()
101
127
except Exception :
102
128
logger .error ("Unable to start CUDA Context" , exc_info = True )
103
129
104
130
105
- def _create_cuda_context (protocol = "ucx" ):
106
- if protocol in ("ucx" , "ucxx" ):
107
- _initialize_ucxx ()
108
- else :
109
- # Not a UCX protocol, just raise CUDA context warnings if needed.
110
- _warn_generic ()
111
-
112
-
113
131
def initialize (
114
132
create_cuda_context = True ,
115
133
enable_tcp_over_ucx = None ,
116
134
enable_infiniband = None ,
117
135
enable_nvlink = None ,
118
136
enable_rdmacm = None ,
119
- protocol = "ucx" ,
120
137
):
121
138
"""Create CUDA context and initialize UCXX configuration.
122
139
@@ -167,12 +184,11 @@ def initialize(
167
184
enable_infiniband = enable_infiniband ,
168
185
enable_nvlink = enable_nvlink ,
169
186
enable_rdmacm = enable_rdmacm ,
170
- protocol = protocol ,
171
187
)
172
188
dask .config .set ({"distributed-ucxx" : ucx_config })
173
189
174
190
if create_cuda_context :
175
- _create_cuda_context (protocol = protocol )
191
+ _create_cuda_context ()
176
192
177
193
178
194
@click .command ()
@@ -185,6 +201,5 @@ def dask_setup(
185
201
worker ,
186
202
create_cuda_context ,
187
203
):
188
- protocol = worker ._protocol .split ("://" )[0 ]
189
204
if create_cuda_context :
190
- _create_cuda_context (protocol = protocol )
205
+ _create_cuda_context ()
0 commit comments