forked from dask/distributed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_semaphore.py
445 lines (341 loc) · 13.3 KB
/
test_semaphore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import pickle
import dask
import pytest
from dask.distributed import Client
from time import sleep
from distributed import Semaphore
from distributed.comm import Comm
from distributed.core import ConnectionPool
from distributed.metrics import time
from distributed.utils_test import ( # noqa: F401
client,
cluster,
async_wait_for,
captured_logger,
cluster_fixture,
gen_cluster,
slowidentity,
loop,
)
@gen_cluster(client=True)
async def test_semaphore_trivial(c, s, a, b):
semaphore = await Semaphore(max_leases=2, name="resource_we_want_to_limit")
result = await semaphore.acquire() # allowed_leases: 2 - 1 -> 1
assert result is True
second = await semaphore.acquire() # allowed_leases: 1 - 1 -> 0
assert second is True
start = time()
result = await semaphore.acquire(timeout=0.025) # allowed_leases: 0 -> False
stop = time()
assert stop - start < 0.2
assert result is False
@gen_cluster(client=True)
async def test_serializable(c, s, a, b):
sem = await Semaphore(max_leases=2, name="x")
res = await sem.acquire()
assert len(s.extensions["semaphores"].leases["x"]) == 1
assert res
sem2 = pickle.loads(pickle.dumps(sem))
assert sem2.name == sem.name
assert sem2.client.scheduler.address == sem.client.scheduler.address
# actual leases didn't change
assert len(s.extensions["semaphores"].leases["x"]) == 1
res = await sem2.acquire()
assert res
assert len(s.extensions["semaphores"].leases["x"]) == 2
# Ensure that both objects access the same semaphore
res = await sem.acquire(timeout=0.025)
assert not res
res = await sem2.acquire(timeout=0.025)
assert not res
@gen_cluster(client=True)
async def test_release_simple(c, s, a, b):
def f(x, semaphore):
with semaphore:
assert semaphore.name == "x"
return x + 1
sem = await Semaphore(max_leases=2, name="x")
futures = c.map(f, list(range(10)), semaphore=sem)
await c.gather(futures)
@gen_cluster(client=True)
async def test_acquires_with_timeout(c, s, a, b):
sem = await Semaphore(1, "x")
assert await sem.acquire(timeout=0.025)
assert not await sem.acquire(timeout=0.025)
await sem.release()
assert await sem.acquire(timeout=0.025)
await sem.release()
def test_timeout_sync(client):
s = Semaphore(name="x")
# Using the context manager already acquires a lease, so the line below won't be able to acquire another one
with s:
assert s.acquire(timeout=0.025) is False
@gen_cluster(
client=True,
timeout=20,
config={
"distributed.scheduler.locks.lease-validation-interval": "500ms",
"distributed.scheduler.locks.lease-timeout": "500ms",
},
)
async def test_release_semaphore_after_timeout(c, s, a, b):
sem = await Semaphore(name="x", max_leases=2)
await sem.acquire() # leases: 2 - 1 = 1
semY = await Semaphore(name="y")
async with Client(s.address, asynchronous=True, name="ClientB") as clientB:
semB = await Semaphore(name="x", max_leases=2, client=clientB)
semYB = await Semaphore(name="y", client=clientB)
assert await semB.acquire() # leases: 1 - 1 = 0
assert await semYB.acquire()
assert not (await sem.acquire(timeout=0.01))
assert not (await semB.acquire(timeout=0.01))
assert not (await semYB.acquire(timeout=0.01))
# `ClientB` goes out of scope, leases should be released
# At this point, we should be able to acquire x and y once
assert await sem.acquire()
assert await semY.acquire()
assert not (await semY.acquire(timeout=0.5))
assert not (await sem.acquire(timeout=0.5))
@gen_cluster()
async def test_async_ctx(s, a, b):
sem = await Semaphore(name="x")
async with sem:
assert not await sem.acquire(timeout=0.025)
assert await sem.acquire()
@pytest.mark.slow
def test_worker_dies():
with cluster(disconnect_timeout=10) as (scheduler, workers):
with Client(scheduler["address"]) as client:
sem = Semaphore(name="x", max_leases=1)
def f(x, sem, kill_address):
with sem:
from distributed.worker import get_worker
worker = get_worker()
if worker.address == kill_address:
import os
os.kill(os.getpid(), 15)
return x
futures = client.map(
f, range(100), sem=sem, kill_address=workers[0]["address"]
)
results = client.gather(futures)
assert sorted(results) == list(range(100))
@gen_cluster(client=True)
async def test_access_semaphore_by_name(c, s, a, b):
def f(x, release=True):
sem = Semaphore(name="x")
if not sem.acquire(timeout=0.1):
return False
if release:
sem.release()
return True
sem = await Semaphore(name="x")
futures = c.map(f, list(range(10)))
assert all(await c.gather(futures))
# Clean-up the state, otherwise we would get the same result when calling `f` with the same arguments
del futures
assert len(s.extensions["semaphores"].leases["x"]) == 0
assert await sem.acquire()
assert len(s.extensions["semaphores"].leases["x"]) == 1
futures = c.map(f, list(range(10)))
assert not any(await c.gather(futures))
await sem.release()
del futures
futures = c.map(f, list(range(10)), release=False)
result = await c.gather(futures)
assert result.count(True) == 1
assert result.count(False) == 9
@gen_cluster(client=True)
async def test_close_async(c, s, a, b):
sem = await Semaphore(name="test")
assert await sem.acquire()
with pytest.warns(
RuntimeWarning,
match="Closing semaphore .* but there remain unreleased leases .*",
):
await sem.close()
with pytest.raises(
RuntimeError, match="Semaphore `test` not known or already closed."
):
await sem.acquire()
semaphore_object = s.extensions["semaphores"]
assert not semaphore_object.max_leases
assert not semaphore_object.leases
assert not semaphore_object.events
def test_close_sync(client):
sem = Semaphore()
sem.close()
with pytest.raises(RuntimeError, match="Semaphore .* not known or already closed."):
sem.acquire()
@gen_cluster(client=True)
async def test_release_once_too_many(c, s, a, b):
sem = await Semaphore(name="x")
assert await sem.acquire()
await sem.release()
with pytest.raises(RuntimeError, match="Released too often"):
await sem.release()
assert await sem.acquire()
await sem.release()
@gen_cluster(client=True)
async def test_release_once_too_many_resilience(c, s, a, b):
def f(x, sem):
sem.acquire()
sem.release()
with pytest.raises(RuntimeError, match="Released too often"):
sem.release()
return x
sem = await Semaphore(max_leases=3, name="x")
inpt = list(range(20))
futures = c.map(f, inpt, sem=sem)
assert sorted(await c.gather(futures)) == inpt
assert not s.extensions["semaphores"].leases["x"]
await sem.acquire()
assert len(s.extensions["semaphores"].leases["x"]) == 1
class BrokenComm(Comm):
peer_address = None
local_address = None
def close(self):
pass
def closed(self):
return True
def abort(self):
pass
def read(self, deserializers=None):
raise EnvironmentError
def write(self, msg, serializers=None, on_error=None):
raise EnvironmentError
class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connections=0, **kwargs):
self.cnn_count = 0
self.failing_connections = failing_connections
self._flaky_active = False
super().__init__(*args, **kwargs)
def activate(self):
self._flaky_active = True
async def connect(self, *args, **kwargs):
if self.cnn_count >= self.failing_connections or not self._flaky_active:
return await super().connect(*args, **kwargs)
else:
self.cnn_count += 1
return BrokenComm()
@gen_cluster(client=True)
async def test_retry_acquire(c, s, a, b):
with dask.config.set({"distributed.comm.retry.count": 1}):
pool = await FlakyConnectionPool(failing_connections=1)
rpc = pool(s.address)
c.scheduler = rpc
semaphore = await Semaphore(
max_leases=2, name="resource_we_want_to_limit", client=c
)
pool.activate()
result = await semaphore.acquire()
assert result is True
second = await semaphore.acquire()
assert second is True
start = time()
result = await semaphore.acquire(timeout=0.025)
stop = time()
assert stop - start < 0.2
assert result is False
@gen_cluster(
client=True,
config={
"distributed.scheduler.locks.lease-timeout": "100ms",
"distributed.scheduler.locks.lease-validation-interval": "100ms",
},
)
async def test_oversubscribing_leases(c, s, a, b):
"""
This test ensures that we detect oversubscription scenarios and will not
accept new leases as long as the semaphore is oversubscribed.
Oversubscription may occur if tasks hold the GIL for a longer time than the
lease-timeout is configured causing the lease refresh to go stale and timeout.
We cannot protect ourselves entirely from this but we can ensure that while
a task with a timed out lease is still running, we block further
acquisitions until we return to normal.
An example would be a task which continuously locks the GIL for a longer
time than the lease timeout but this continuous lock only makes up a
fraction of the tasks runtime.
"""
# GH3705
from distributed.worker import Worker, get_client
# Using the metadata as a crude "asyncio.Event" since the proper event
# implementation cannot be serialized. For the purpose of this test a
# metadata check with a sleep loop is not elegant but practical.
await c.set_metadata("release", False)
sem = await Semaphore()
sem.refresh_callback.stop()
def guaranteed_lease_timeout(x, sem):
"""
This function simulates a payload computation with some GIL
locking in the beginning.
To simulate this we will manually disable the refresh callback, i.e.
all leases will eventually timeout. The function will only
release/return once the "Event" is set, i.e. our observer is done.
"""
sem.refresh_leases = False
client = get_client()
with sem:
# This simulates a task which holds the GIL for longer than the
# lease-timeout. This is twice the lease timeout to ensurre that the
# leases are actually timed out
slowidentity(delay=0.2)
assert sem._leases
# Now the GIL is free again, i.e. we enable the callback again
sem.refresh_leases = True
sleep(0.1)
# This is the poormans Event.wait()
while client.get_metadata("release") is not True:
sleep(0.05)
assert sem.get_value() >= 1
return x
def observe_state(sem):
"""
This function is 100% artificial and acts as an observer to verify
our assumptions. The function will wait until both payload tasks are
executing, i.e. we're in an oversubscription scenario. It will then
try to acquire and hopefully fail showing that the semaphore is
protected if the oversubscription is recognized.
"""
sem.refresh_callback.stop()
# We wait until we're in an oversubscribed state, i.e. both tasks
# are executed although there should only be one allowed
while not sem.get_value() > 1:
sleep(0.2)
# Once we're in an oversubscribed state, we must not be able to
# acquire a lease.
assert not sem.acquire(timeout=0)
client = get_client()
client.set_metadata("release", True)
observer = await Worker(s.address)
futures = c.map(
guaranteed_lease_timeout, range(2), sem=sem, workers=[a.address, b.address]
)
fut_observe = c.submit(observe_state, sem=sem, workers=[observer.address])
with captured_logger("distributed.semaphore") as caplog:
payload, observer = await c.gather([futures, fut_observe])
logs = caplog.getvalue().split("\n")
timeouts = [log for log in logs if "timed out" in log]
refresh_unknown = [log for log in logs if "Refreshing an unknown lease ID" in log]
assert len(timeouts) == 2
assert len(refresh_unknown) == 2
assert sorted(payload) == [0, 1]
# Back to normal
assert await sem.get_value() == 0
@gen_cluster(client=True,)
async def test_timeout_zero(c, s, a, b):
# Depending on the internals a timeout zero cannot work, e.g. when the
# initial try already includes a wait. Since some test cases use this, it is
# worth testing against.
sem = await Semaphore()
assert await sem.acquire(timeout=0)
assert not await sem.acquire(timeout=0)
await sem.release()
@gen_cluster(client=True)
async def test_getvalue(c, s, a, b):
sem = await Semaphore()
assert await sem.get_value() == 0
await sem.acquire()
assert await sem.get_value() == 1
await sem.release()
assert await sem.get_value() == 0