forked from dask/distributed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ucx.py
331 lines (267 loc) · 10.1 KB
/
ucx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
:ref:`UCX`_ based communications for distributed.
See :ref:`communications` for more.
.. _UCX: https://github.com/openucx/ucx
"""
import ucp
import logging
import concurrent
import dask
import numpy as np
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, Connector, Listener, CommClosedError
from .registry import Backend, backends
from .utils import ensure_concrete_host, to_frames, from_frames
from ..utils import ensure_ip, get_ip, get_ipv6, nbytes, log_errors
import dask
import numpy as np
logger = logging.getLogger(__name__)
# In order to avoid double init when forking/spawning new processes (multiprocess),
# we make sure only to import and initialize UCX once at first use.
ucp = None
cuda_array = None
def init_once():
global ucp, cuda_array
if ucp is not None:
return
import ucp as _ucp
ucp = _ucp
ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True)
# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm
cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8)
except ImportError:
try:
import numba.cuda
cuda_array = lambda n: numba.cuda.device_array((n,), dtype=np.uint8)
except ImportError:
def cuda_array(n):
raise RuntimeError(
"In order to send/recv CUDA arrays, Numba or RMM is required"
)
class UCX(Comm):
"""Comm object using UCP.
Parameters
----------
ep : ucp.Endpoint
The UCP endpoint.
address : str
The address, prefixed with `ucx://` to use.
deserialize : bool, default True
Whether to deserialize data in :meth:`distributed.protocol.loads`
Notes
-----
The read-write cycle uses the following pattern:
Each msg is serialized into a number of "data" frames. We prepend these
real frames with two additional frames
1. is_gpus: Boolean indicator for whether the frame should be
received into GPU memory. Packed in '?' format. Unpack with
``<n_frames>?`` format.
2. frame_size : Unsigned int describing the size of frame (in bytes)
to receive. Packed in 'Q' format, so a length-0 frame is equivalent
to an unsized frame. Unpacked with ``<n_frames>Q``.
The expected read cycle is
1. Read the frame describing number of frames
2. Read the frame describing whether each data frame is gpu-bound
3. Read the frame describing whether each data frame is sized
4. Read all the data frames.
"""
def __init__(self, ep, local_addr: str, peer_addr: str, deserialize=True):
Comm.__init__(self)
self._ep = ep
if local_addr:
assert local_addr.startswith("ucx")
assert peer_addr.startswith("ucx")
self._local_addr = local_addr
self._peer_addr = peer_addr
self.deserialize = deserialize
self.comm_flag = None
logger.debug("UCX.__init__ %s", self)
@property
def local_address(self) -> str:
return self._local_addr
@property
def peer_address(self) -> str:
return self._peer_addr
async def write(
self,
msg: dict,
serializers=("cuda", "dask", "pickle", "error"),
on_error: str = "message",
):
with log_errors():
if self.closed():
raise CommClosedError("Endpoint is closed -- unable to send message")
try:
if serializers is None:
serializers = ("cuda", "dask", "pickle", "error")
# msg can also be a list of dicts when sending batched messages
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
)
# Send meta data
await self.ep.send(np.array([len(frames)], dtype=np.uint64))
await self.ep.send(
np.array(
[hasattr(f, "__cuda_array_interface__") for f in frames],
dtype=np.bool,
)
)
await self.ep.send(
np.array([nbytes(f) for f in frames], dtype=np.uint64)
)
# Send frames
for frame in frames:
if nbytes(frame) > 0:
await self.ep.send(frame)
return sum(map(nbytes, frames))
except (ucp.exceptions.UCXBaseException):
self.abort()
raise CommClosedError("While writing, the connection was closed")
async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
with log_errors():
if self.closed():
raise CommClosedError("Endpoint is closed -- unable to read message")
if deserializers is None:
deserializers = ("cuda", "dask", "pickle", "error")
try:
# Recv meta data
nframes = np.empty(1, dtype=np.uint64)
await self.ep.recv(nframes)
is_cudas = np.empty(nframes[0], dtype=np.bool)
await self.ep.recv(is_cudas)
sizes = np.empty(nframes[0], dtype=np.uint64)
await self.ep.recv(sizes)
except (
ucp.exceptions.UCXBaseException,
concurrent.futures._base.CancelledError,
):
self.abort()
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = []
for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()):
if size > 0:
if is_cuda:
frame = cuda_array(size)
else:
frame = np.empty(size, dtype=np.uint8)
await self.ep.recv(frame)
frames.append(frame)
else:
if is_cuda:
frames.append(cuda_array(size))
else:
frames.append(b"")
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
return msg
async def close(self):
if self._ep is not None:
await self._ep.close()
self._ep = None
def abort(self):
if self._ep is not None:
self._ep.abort()
self._ep = None
@property
def ep(self):
if self._ep is not None:
return self._ep
else:
raise CommClosedError("UCX Endpoint is closed")
def closed(self):
return self._ep is None
class UCXConnector(Connector):
prefix = "ucx://"
comm_class = UCX
encrypted = False
async def connect(self, address: str, deserialize=True, **connection_args) -> UCX:
logger.debug("UCXConnector.connect: %s", address)
ip, port = parse_host_port(address)
init_once()
ep = await ucp.create_endpoint(ip, port)
return self.comm_class(
ep,
local_addr=None,
peer_addr=self.prefix + address,
deserialize=deserialize,
)
class UCXListener(Listener):
prefix = UCXConnector.prefix
comm_class = UCXConnector.comm_class
encrypted = UCXConnector.encrypted
def __init__(
self, address: str, comm_handler: None, deserialize=False, **connection_args
):
if not address.startswith("ucx"):
address = "ucx://" + address
self.ip, self._input_port = parse_host_port(address, default_port=0)
self.comm_handler = comm_handler
self.deserialize = deserialize
self._ep = None # type: ucp.Endpoint
self.ucp_server = None
self.connection_args = connection_args
@property
def port(self):
return self.ucp_server.port
@property
def address(self):
return "ucx://" + self.ip + ":" + str(self.port)
def start(self):
async def serve_forever(client_ep):
ucx = UCX(
client_ep,
local_addr=self.address,
peer_addr=self.address,
deserialize=self.deserialize,
)
if self.comm_handler:
await self.comm_handler(ucx)
init_once()
self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port)
def stop(self):
self.ucp_server = None
def get_host_port(self):
# TODO: TCP raises if this hasn't started yet.
return self.ip, self.port
@property
def listen_address(self):
return self.prefix + unparse_host_port(*self.get_host_port())
@property
def contact_address(self):
host, port = self.get_host_port()
host = ensure_concrete_host(host) # TODO: ensure_concrete_host
return self.prefix + unparse_host_port(host, port)
@property
def bound_address(self):
# TODO: Does this become part of the base API? Kinda hazy, since
# we exclude in for inproc.
return self.get_host_port()
class UCXBackend(Backend):
# I / O
def get_connector(self):
return UCXConnector()
def get_listener(self, loc, handle_comm, deserialize, **connection_args):
return UCXListener(loc, handle_comm, deserialize, **connection_args)
# Address handling
# This duplicates BaseTCPBackend
def get_address_host(self, loc):
return parse_host_port(loc)[0]
def get_address_host_port(self, loc):
return parse_host_port(loc)
def resolve_address(self, loc):
host, port = parse_host_port(loc)
return unparse_host_port(ensure_ip(host), port)
def get_local_address_for(self, loc):
host, port = parse_host_port(loc)
host = ensure_ip(host)
if ":" in host:
local_host = get_ipv6(host)
else:
local_host = get_ip(host)
return unparse_host_port(local_host, None)
backends["ucx"] = UCXBackend()