Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 42 additions & 68 deletions tests/abc/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,103 +83,77 @@ async def test_task() -> None: ...


@pytest.mark.anyio
async def test_async_context_manager_enter() -> None:
"""Test that __aenter__ calls startup."""
@pytest.mark.parametrize(
("is_worker_process", "startup", "shutdown"),
[
(True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN),
(False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN),
],
)
async def test_async_context_manager_enter(
*,
is_worker_process: bool,
startup: TaskiqEvents,
shutdown: TaskiqEvents,
) -> None:
"""Test that `__aenter__` and `__aexit__` calls work."""
broker = _TestBroker()
broker.is_worker_process = is_worker_process
startup_called = False
shutdown_called = False

@broker.on_event(TaskiqEvents.CLIENT_STARTUP)
@broker.on_event(startup)
async def track_startup(state: TaskiqState) -> None:
nonlocal startup_called
startup_called = True

async with broker:
assert startup_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit() -> None:
"""Test that __aexit__ calls shutdown."""
broker = _TestBroker()
shutdown_called = False

@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
@broker.on_event(shutdown)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

async with broker:
pass
async with broker as ctx:
assert ctx is None
assert startup_called is True
assert shutdown_called is False

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_enter_worker() -> None:
"""Test that __aenter__ calls worker startup when is_worker_process is True."""
@pytest.mark.parametrize(
("is_worker_process", "startup", "shutdown"),
[
(True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN),
(False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN),
],
)
async def test_async_context_manager_exit_on_exception(
*,
is_worker_process: bool,
startup: TaskiqEvents,
shutdown: TaskiqEvents,
) -> None:
"""Test that __aexit__ calls shutdown even if exception is raised."""
broker = _TestBroker()
broker.is_worker_process = True
broker.is_worker_process = is_worker_process
startup_called = False
shutdown_called = False

@broker.on_event(TaskiqEvents.WORKER_STARTUP)
@broker.on_event(startup)
async def track_startup(state: TaskiqState) -> None:
nonlocal startup_called
startup_called = True

async with broker:
assert startup_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_worker() -> None:
"""Test that __aexit__ calls worker shutdown when is_worker_process is True."""
broker = _TestBroker()
broker.is_worker_process = True
shutdown_called = False

@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

async with broker:
pass

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_on_exception() -> None:
"""Test that __aexit__ calls shutdown even if exception is raised."""
broker = _TestBroker()
shutdown_called = False

@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

with pytest.raises(ValueError, match="Test exception"):
async with broker:
raise ValueError("Test exception")

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_worker_on_exception() -> None:
"""Test that __aexit__ calls worker shutdown even if exception is raised."""
broker = _TestBroker()
broker.is_worker_process = True
shutdown_called = False

@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
@broker.on_event(shutdown)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

with pytest.raises(ValueError, match="Test exception"):
async with broker:
assert startup_called is True
assert shutdown_called is False
raise ValueError("Test exception")

assert shutdown_called is True
Loading