diff --git a/tests/abc/test_broker.py b/tests/abc/test_broker.py index 8d39ea50..636f9576 100644 --- a/tests/abc/test_broker.py +++ b/tests/abc/test_broker.py @@ -83,103 +83,77 @@ async def test_task() -> None: ... @pytest.mark.anyio -async def test_async_context_manager_enter() -> None: - """Test that __aenter__ calls startup.""" +@pytest.mark.parametrize( + ("is_worker_process", "startup", "shutdown"), + [ + (True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN), + (False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN), + ], +) +async def test_async_context_manager_enter( + *, + is_worker_process: bool, + startup: TaskiqEvents, + shutdown: TaskiqEvents, +) -> None: + """Test that `__aenter__` and `__aexit__` calls work.""" broker = _TestBroker() + broker.is_worker_process = is_worker_process startup_called = False + shutdown_called = False - @broker.on_event(TaskiqEvents.CLIENT_STARTUP) + @broker.on_event(startup) async def track_startup(state: TaskiqState) -> None: nonlocal startup_called startup_called = True - async with broker: - assert startup_called is True - - -@pytest.mark.anyio -async def test_async_context_manager_exit() -> None: - """Test that __aexit__ calls shutdown.""" - broker = _TestBroker() - shutdown_called = False - - @broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN) + @broker.on_event(shutdown) async def track_shutdown(state: TaskiqState) -> None: nonlocal shutdown_called shutdown_called = True - async with broker: - pass + async with broker as ctx: + assert ctx is None + assert startup_called is True + assert shutdown_called is False assert shutdown_called is True @pytest.mark.anyio -async def test_async_context_manager_enter_worker() -> None: - """Test that __aenter__ calls worker startup when is_worker_process is True.""" +@pytest.mark.parametrize( + ("is_worker_process", "startup", "shutdown"), + [ + (True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN), + (False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN), + ], +) +async def test_async_context_manager_exit_on_exception( + *, + is_worker_process: bool, + startup: TaskiqEvents, + shutdown: TaskiqEvents, +) -> None: + """Test that __aexit__ calls shutdown even if exception is raised.""" broker = _TestBroker() - broker.is_worker_process = True + broker.is_worker_process = is_worker_process startup_called = False + shutdown_called = False - @broker.on_event(TaskiqEvents.WORKER_STARTUP) + @broker.on_event(startup) async def track_startup(state: TaskiqState) -> None: nonlocal startup_called startup_called = True - async with broker: - assert startup_called is True - - -@pytest.mark.anyio -async def test_async_context_manager_exit_worker() -> None: - """Test that __aexit__ calls worker shutdown when is_worker_process is True.""" - broker = _TestBroker() - broker.is_worker_process = True - shutdown_called = False - - @broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) - async def track_shutdown(state: TaskiqState) -> None: - nonlocal shutdown_called - shutdown_called = True - - async with broker: - pass - - assert shutdown_called is True - - -@pytest.mark.anyio -async def test_async_context_manager_exit_on_exception() -> None: - """Test that __aexit__ calls shutdown even if exception is raised.""" - broker = _TestBroker() - shutdown_called = False - - @broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN) - async def track_shutdown(state: TaskiqState) -> None: - nonlocal shutdown_called - shutdown_called = True - - with pytest.raises(ValueError, match="Test exception"): - async with broker: - raise ValueError("Test exception") - - assert shutdown_called is True - - -@pytest.mark.anyio -async def test_async_context_manager_exit_worker_on_exception() -> None: - """Test that __aexit__ calls worker shutdown even if exception is raised.""" - broker = _TestBroker() - broker.is_worker_process = True - shutdown_called = False - - @broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) + @broker.on_event(shutdown) async def track_shutdown(state: TaskiqState) -> None: nonlocal shutdown_called shutdown_called = True with pytest.raises(ValueError, match="Test exception"): async with broker: + assert startup_called is True + assert shutdown_called is False raise ValueError("Test exception") assert shutdown_called is True