Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 45 additions & 70 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import time
import threading

import anyio

from pytest import mark, raises
Expand All @@ -12,13 +15,19 @@ class GPUPromise(BaseGPUPromise):
# Subclass with each own set of unresolved promise instances
_UNRESOLVED = set()

def _sync_wait(self):
# Same implementation as the wgpu_native backend.
# If we have a test that has not polling thread, and sync_wait() is called
# when the promise is still pending, this will hang.
self._thread_event.wait()


class SillyLoop:
def __init__(self):
self._pending_calls = []
self.errors = []

def call_soon(self, f, *args):
def call_soon_threadsafe(self, f, *args):
self._pending_calls.append((f, args))

def process_events(self):
Expand Down Expand Up @@ -73,23 +82,18 @@ def test_promise_basics():
# %%%%% Promise using sync_wait


def test_promise_sync_need_poll():
promise = GPUPromise("test", None)

with raises(RuntimeError): # cannot poll without poll function
promise.sync_wait()
def run_in_thread(callable):
t = threading.Thread(target=callable)
t.start()


def test_promise_sync_simple():
count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", None, poller=poller)
promise = GPUPromise("test", None)

result = promise.sync_wait()
assert result == 42
Expand All @@ -99,15 +103,12 @@ def test_promise_sync_normal():
def handler(input):
return input * 2

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

result = promise.sync_wait()
assert result == 84
Expand All @@ -117,15 +118,12 @@ def test_promise_sync_fail1():
def handler(input):
return input * 2

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_error(ZeroDivisionError())
time.sleep(0.1)
promise._wgpu_set_error(ZeroDivisionError())

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

with raises(ZeroDivisionError):
promise.sync_wait()
Expand All @@ -135,15 +133,12 @@ def test_promise_sync_fail2():
def handler(input):
return input / 0

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

with raises(ZeroDivisionError):
promise.sync_wait()
Expand All @@ -152,25 +147,14 @@ def poller():
# %% Promise using await with poll and loop


@mark.anyio
async def test_promise_async_need_poll_or_loop():
promise = GPUPromise("test", None)

with raises(RuntimeError): # cannot poll without poll function
await promise


@mark.anyio
async def test_promise_async_poll_simple():
count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", None, poller=poller)
promise = GPUPromise("test", None)

result = await promise
assert result == 42
Expand All @@ -181,15 +165,12 @@ async def test_promise_async_poll_normal():
def handler(input):
return input * 2

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

result = await promise
assert result == 84
Expand All @@ -200,15 +181,12 @@ async def test_promise_async_poll_fail1():
def handler(input):
return input * 2

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_error(ZeroDivisionError())
time.sleep(0.1)
promise._wgpu_set_error(ZeroDivisionError())

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

with raises(ZeroDivisionError):
await promise
Expand All @@ -219,15 +197,12 @@ async def test_promise_async_poll_fail2():
def handler(input):
return input / 0

count = 0

@run_in_thread
def poller():
nonlocal count
count += 1
if count > 5:
promise._wgpu_set_input(42)
time.sleep(0.1)
promise._wgpu_set_input(42)

promise = GPUPromise("test", handler, poller=poller)
promise = GPUPromise("test", handler)

with raises(ZeroDivisionError):
await promise
Expand Down
188 changes: 188 additions & 0 deletions tests/test_wgpu_native_poller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import gc
import time
import queue

import wgpu
from wgpu.backends.wgpu_native._poller import PollThread, PollToken

from testutils import can_use_wgpu_lib, run_tests, is_pypy
from pytest import mark


def test_poll_thread():
# A timeout to give polling thread time to progress. The GIL switches
# threads about every 5ms, but in this cases likely faster, because it also switches
# when it goes to sleep on a blocking call. So 50ms seems plenty.
timeout = 0.05

count = 0
gpu_work_done_queue = queue.SimpleQueue()

def reset():
nonlocal count
ref_count = count
# Make sure the poller is not waiting in poll_func
gpu_work_done_queue.put(None)
gpu_work_done_queue.put(None)
# Give it time
time.sleep(timeout)
# Check that it did not enter again, i.e. is waiting for tokens
assert count == ref_count, "Looks like a token is still active"
# Reset
count = 0
while True:
try:
gpu_work_done_queue.get(False)
except queue.Empty:
break

def finish_tokens(*tokens):
# This mimics the GPU finishing an async task, and invoking its
# callback that sets the token to done.
gpu_work_done_queue.put(None)
for token in tokens:
assert not token.is_done()
token.set_done()

def poll_func(block):
# This mimics the wgpuDevicePoll.
nonlocal count
count += 1
if block:
gpu_work_done_queue.get() # blocking
else:
try:
gpu_work_done_queue.get(False)
except queue.Empty:
pass

# Start the poller
t = PollThread(poll_func)
t.start()

reset()

# == Normal behavior

token = t.get_token()
assert isinstance(token, PollToken)
time.sleep(timeout)
assert count == 2

finish_tokens(token)

time.sleep(timeout)
assert count == 2

reset()

# == Always at least one poll

token = t.get_token()
token.set_done()
time.sleep(timeout)
assert count in (1, 2) # typically 1, but can sometimes be 2

reset()

# == Mark done through deletion

token = t.get_token()
time.sleep(timeout)
assert count == 2

finish_tokens()

time.sleep(timeout)
assert count == 3

finish_tokens()

time.sleep(timeout)
assert count == 4

del token
gc.collect()
gc.collect()

finish_tokens()

time.sleep(timeout)
assert count == 4

reset()

# More tasks

token1 = t.get_token()
time.sleep(timeout)
assert count == 2

token2 = t.get_token()
time.sleep(timeout)
assert count == 2

token3 = t.get_token()
token4 = t.get_token()
time.sleep(timeout)
assert count == 2

finish_tokens(token1)
time.sleep(timeout)
assert count == 3

finish_tokens(token2, token3)
time.sleep(timeout)
assert count == 4

finish_tokens() # can actually bump more unrelated works
finish_tokens()
time.sleep(timeout)
assert count == 6

token5 = t.get_token()
finish_tokens(token4)
time.sleep(timeout)
assert count == 7

finish_tokens(token5)
time.sleep(timeout)
assert count == 8

reset()

# Shut it down

t.stop()
time.sleep(0.1)
assert not t.is_alive()


@mark.skipif(not can_use_wgpu_lib, reason="Needs wgpu lib")
def test_poller_stops_when_device_gone():
device = wgpu.gpu.request_adapter_sync().request_device_sync()

t = device._poller
assert t.is_alive()
device.__del__()
time.sleep(0.1)

assert not t.is_alive()

device = wgpu.gpu.request_adapter_sync().request_device_sync()

t = device._poller
assert t.is_alive()
del device
gc.collect()
gc.collect()
if is_pypy:
gc.collect()
gc.collect()
time.sleep(0.1)

assert not t.is_alive()


if __name__ == "__main__":
run_tests(globals())
Loading
Loading