Skip to content

Commit 037713b

Browse files
authored
Update to support CUDA context warnings for all protocols (#1548)
Update to support CUDA context warnings for all protocols, even after removal of `distributed.comm.ucx`. Depends on #1546. Partially addresses #1517. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #1548
1 parent ecedb34 commit 037713b

File tree

6 files changed

+291
-93
lines changed

6 files changed

+291
-93
lines changed

dask_cuda/initialize.py

Lines changed: 90 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,52 @@
88
import cuda.core.experimental
99

1010
import dask
11-
from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
11+
from distributed.diagnostics.nvml import (
12+
CudaDeviceInfo,
13+
get_device_index_and_uuid,
14+
has_cuda_context,
15+
)
1216

1317
from .utils import get_ucx_config
1418

1519
logger = logging.getLogger(__name__)
1620

1721

22+
pre_existing_cuda_context = None
23+
cuda_context_created = None
24+
25+
26+
_warning_suffix = (
27+
"This is often the result of a CUDA-enabled library calling a CUDA runtime "
28+
"function before Dask-CUDA can spawn worker processes. Please make sure any such "
29+
"function calls don't happen at import time or in the global scope of a program."
30+
)
31+
32+
33+
def _get_device_and_uuid_str(device_info: CudaDeviceInfo) -> str:
34+
return f"{device_info.device_index} ({str(device_info.uuid)})"
35+
36+
37+
def _warn_existing_cuda_context(device_info: CudaDeviceInfo, pid: int) -> None:
38+
device_uuid_str = _get_device_and_uuid_str(device_info)
39+
logger.warning(
40+
f"A CUDA context for device {device_uuid_str} already exists "
41+
f"on process ID {pid}. {_warning_suffix}"
42+
)
43+
44+
45+
def _warn_cuda_context_wrong_device(
46+
device_info_expected: CudaDeviceInfo, device_info_actual: CudaDeviceInfo, pid: int
47+
) -> None:
48+
expected_device_uuid_str = _get_device_and_uuid_str(device_info_expected)
49+
actual_device_uuid_str = _get_device_and_uuid_str(device_info_actual)
50+
logger.warning(
51+
f"Worker with process ID {pid} should have a CUDA context assigned to device "
52+
f"{expected_device_uuid_str}, but instead the CUDA context is on device "
53+
f"{actual_device_uuid_str}. {_warning_suffix}"
54+
)
55+
56+
1857
def _create_cuda_context_handler():
1958
if int(os.environ.get("DASK_CUDA_TEST_SINGLE_GPU", "0")) != 0:
2059
try:
@@ -25,98 +64,76 @@ def _create_cuda_context_handler():
2564
cuda.core.experimental.Device().set_current()
2665

2766

28-
def _warn_generic():
29-
try:
30-
# TODO: update when UCX-Py is removed, see
31-
# https://github.com/rapidsai/dask-cuda/issues/1517
32-
import distributed.comm.ucx
67+
def _create_cuda_context_and_warn():
68+
"""Create CUDA context and warn depending on certain conditions.
3369
34-
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
35-
# context directly from the UCX module, thus avoiding a similar warning there.
36-
cuda_visible_device = get_device_index_and_uuid(
37-
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
38-
)
39-
ctx = has_cuda_context()
40-
if (
41-
ctx.has_context
42-
and not distributed.comm.ucx.cuda_context_created.has_context
43-
):
44-
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
45-
46-
_create_cuda_context_handler()
47-
48-
if (
49-
distributed.comm.ucx.cuda_context_created is False
50-
or distributed.comm.ucx.cuda_context_created.has_context
51-
):
52-
ctx = has_cuda_context()
53-
if ctx.has_context and ctx.device_info != cuda_visible_device:
54-
distributed.comm.ucx._warn_cuda_context_wrong_device(
55-
cuda_visible_device, ctx.device_info, os.getpid()
56-
)
70+
Warns if a pre-existing CUDA context already existed or if the resulting CUDA
71+
context was created in the wrong device.
5772
58-
except Exception:
59-
logger.error("Unable to start CUDA Context", exc_info=True)
73+
This function is almost an identical duplicate from
74+
`distributed_ucxx.ucxx.init_once`, the duplication is necessary because Dask-CUDA
75+
needs to support `protocol="tcp"` as well, even when distributed-ucxx is not
76+
installed, but this here runs _after_ comms have started, which is fine for TCP
77+
because the time when CUDA context is created is not important. The code needs to
78+
live also in distributed-ucxx because there the time when a CUDA context is created
79+
matters, and it needs to happen _before_ UCX is initialized, but comms in
80+
Distributed is initialized before preload, and thus only after this function
81+
executes.
6082
83+
Raises
84+
------
85+
Exception
86+
If anything wrong happened during context initialization.
6187
62-
def _initialize_ucx():
63-
try:
64-
import distributed.comm.ucx
88+
Returns
89+
-------
90+
None
91+
"""
92+
global pre_existing_cuda_context, cuda_context_created
6593

66-
distributed.comm.ucx.init_once()
67-
except ModuleNotFoundError:
68-
# UCX initialization has to be delegated to Distributed, it will take care
69-
# of setting correct environment variables and importing `ucp` after that.
70-
# Therefore if ``import ucp`` fails we can just continue here.
71-
pass
94+
cuda_visible_device = get_device_index_and_uuid(
95+
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
96+
)
97+
pre_existing_cuda_context = has_cuda_context()
98+
if pre_existing_cuda_context.has_context:
99+
_warn_existing_cuda_context(pre_existing_cuda_context.device_info, os.getpid())
100+
101+
_create_cuda_context_handler()
102+
103+
cuda_context_created = has_cuda_context()
104+
if (
105+
cuda_context_created.has_context
106+
and cuda_context_created.device_info.uuid != cuda_visible_device.uuid
107+
):
108+
_warn_cuda_context_wrong_device(
109+
cuda_visible_device, cuda_context_created.device_info, os.getpid()
110+
)
72111

73112

74-
def _initialize_ucxx():
113+
def _create_cuda_context():
75114
try:
76115
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
77116
# context directly from the UCX module, thus avoiding a similar warning there.
78117
import distributed_ucxx.ucxx
118+
except ImportError:
119+
pass
120+
else:
121+
if distributed_ucxx.ucxx.ucxx is not None:
122+
# UCXX has already initialized (and warned if necessary)
123+
return
79124

80-
distributed_ucxx.ucxx.init_once()
81-
82-
cuda_visible_device = get_device_index_and_uuid(
83-
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
84-
)
85-
ctx = has_cuda_context()
86-
if (
87-
ctx.has_context
88-
and not distributed_ucxx.ucxx.cuda_context_created.has_context
89-
):
90-
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
91-
92-
_create_cuda_context_handler()
93-
94-
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
95-
ctx = has_cuda_context()
96-
if ctx.has_context and ctx.device_info != cuda_visible_device:
97-
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
98-
cuda_visible_device, ctx.device_info, os.getpid()
99-
)
100-
125+
try:
126+
_create_cuda_context_and_warn()
101127
except Exception:
102128
logger.error("Unable to start CUDA Context", exc_info=True)
103129

104130

105-
def _create_cuda_context(protocol="ucx"):
106-
if protocol in ("ucx", "ucxx"):
107-
_initialize_ucxx()
108-
else:
109-
# Not a UCX protocol, just raise CUDA context warnings if needed.
110-
_warn_generic()
111-
112-
113131
def initialize(
114132
create_cuda_context=True,
115133
enable_tcp_over_ucx=None,
116134
enable_infiniband=None,
117135
enable_nvlink=None,
118136
enable_rdmacm=None,
119-
protocol="ucx",
120137
):
121138
"""Create CUDA context and initialize UCXX configuration.
122139
@@ -167,12 +184,11 @@ def initialize(
167184
enable_infiniband=enable_infiniband,
168185
enable_nvlink=enable_nvlink,
169186
enable_rdmacm=enable_rdmacm,
170-
protocol=protocol,
171187
)
172188
dask.config.set({"distributed-ucxx": ucx_config})
173189

174190
if create_cuda_context:
175-
_create_cuda_context(protocol=protocol)
191+
_create_cuda_context()
176192

177193

178194
@click.command()
@@ -185,6 +201,5 @@ def dask_setup(
185201
worker,
186202
create_cuda_context,
187203
):
188-
protocol = worker._protocol.split("://")[0]
189204
if create_cuda_context:
190-
_create_cuda_context(protocol=protocol)
205+
_create_cuda_context()

dask_cuda/tests/test_dask_setup.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
from dask_cuda.utils import wait_workers
1818

1919

20-
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
21-
def test_dask_setup_function_with_mock_worker(protocol):
20+
def test_dask_setup_function_with_mock_worker():
2221
"""Test the dask_setup function directly with mock worker."""
2322
# Create a mock worker object
2423
mock_worker = Mock()
25-
mock_worker._protocol = protocol
2624

2725
with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context:
2826
# Test with create_cuda_context=True
@@ -33,7 +31,7 @@ def test_dask_setup_function_with_mock_worker(protocol):
3331
create_cuda_context=True,
3432
)
3533

36-
mock_create_context.assert_called_once_with(protocol=protocol)
34+
mock_create_context.assert_called_once_with()
3735

3836
mock_create_context.reset_mock()
3937

dask_cuda/tests/test_dgx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def _test_ucx_infiniband_nvlink(
152152
cm_tls_priority = ["tcp"]
153153

154154
initialize(
155-
protocol="ucx",
156155
enable_tcp_over_ucx=enable_tcp_over_ucx,
157156
enable_infiniband=enable_infiniband,
158157
enable_nvlink=enable_nvlink,

0 commit comments

Comments
 (0)