|
9 | 9 |
|
10 | 10 | from taskiq.abc.broker import AckableMessage, AsyncBroker
|
11 | 11 | from taskiq.abc.middleware import TaskiqMiddleware
|
| 12 | +from taskiq.acks import AckableMessageWithDeliveryCount |
12 | 13 | from taskiq.brokers.inmemory_broker import InMemoryBroker
|
13 | 14 | from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
|
14 |
| -from taskiq.message import TaskiqMessage |
| 15 | +from taskiq.message import DeliveryCountMessage, TaskiqMessage |
15 | 16 | from taskiq.receiver import Receiver
|
16 | 17 | from taskiq.result import TaskiqResult
|
17 | 18 | from tests.utils import AsyncQueueBroker
|
@@ -359,6 +360,127 @@ async def test_callback_unknown_task() -> None:
|
359 | 360 | await receiver.callback(broker_message.message)
|
360 | 361 |
|
361 | 362 |
|
| 363 | +@pytest.mark.anyio |
| 364 | +@pytest.mark.parametrize("delivery_count", [2, None]) |
| 365 | +async def test_callback_max_attempts_at_message_not_exceeded( |
| 366 | + delivery_count: Optional[int], |
| 367 | +) -> None: |
| 368 | + """ |
| 369 | + Test that callback function calls the task if `max_attempts_at_message` |
| 370 | + is not exceeded. |
| 371 | + """ |
| 372 | + broker = InMemoryBroker() |
| 373 | + called_times = 0 |
| 374 | + |
| 375 | + @broker.task |
| 376 | + async def my_task() -> int: |
| 377 | + nonlocal called_times |
| 378 | + called_times += 1 |
| 379 | + return 1 |
| 380 | + |
| 381 | + receiver = get_receiver(broker) |
| 382 | + receiver.max_attempts_at_message = 3 |
| 383 | + |
| 384 | + broker_message = broker.formatter.dumps( |
| 385 | + TaskiqMessage( |
| 386 | + task_id="task_id", |
| 387 | + task_name=my_task.task_name, |
| 388 | + labels={}, |
| 389 | + args=[], |
| 390 | + kwargs={}, |
| 391 | + ), |
| 392 | + ) |
| 393 | + |
| 394 | + await receiver.callback( |
| 395 | + DeliveryCountMessage( |
| 396 | + data=broker_message.message, |
| 397 | + delivery_count=delivery_count, |
| 398 | + ), |
| 399 | + ) |
| 400 | + assert called_times == 1 |
| 401 | + |
| 402 | + |
| 403 | +@pytest.mark.anyio |
| 404 | +async def test_callback_max_attempts_at_message_exceeded() -> None: |
| 405 | + """ |
| 406 | + Test that callback function does not call the task if `max_attempts_at_message` |
| 407 | + is exceeded. |
| 408 | + """ |
| 409 | + broker = InMemoryBroker() |
| 410 | + called_times = 0 |
| 411 | + |
| 412 | + @broker.task |
| 413 | + async def my_task() -> int: |
| 414 | + nonlocal called_times |
| 415 | + called_times += 1 |
| 416 | + return 1 |
| 417 | + |
| 418 | + receiver = get_receiver(broker) |
| 419 | + receiver.max_attempts_at_message = 3 |
| 420 | + |
| 421 | + broker_message = broker.formatter.dumps( |
| 422 | + TaskiqMessage( |
| 423 | + task_id="task_id", |
| 424 | + task_name=my_task.task_name, |
| 425 | + labels={}, |
| 426 | + args=[], |
| 427 | + kwargs={}, |
| 428 | + ), |
| 429 | + ) |
| 430 | + |
| 431 | + await receiver.callback( |
| 432 | + DeliveryCountMessage( |
| 433 | + data=broker_message.message, |
| 434 | + delivery_count=3, |
| 435 | + ), |
| 436 | + ) |
| 437 | + assert called_times == 0 |
| 438 | + |
| 439 | + |
| 440 | +@pytest.mark.anyio |
| 441 | +async def test_callback_max_attempts_at_message_exceeded_ackable() -> None: |
| 442 | + """ |
| 443 | + Test that callback function does not call the task if `max_attempts_at_message` |
| 444 | + is exceeded and acks the message. |
| 445 | + """ |
| 446 | + broker = InMemoryBroker() |
| 447 | + called_times = 0 |
| 448 | + acked = False |
| 449 | + |
| 450 | + @broker.task |
| 451 | + async def my_task() -> int: |
| 452 | + nonlocal called_times |
| 453 | + called_times += 1 |
| 454 | + return 1 |
| 455 | + |
| 456 | + async def ack_callback() -> None: |
| 457 | + nonlocal acked |
| 458 | + acked = True |
| 459 | + |
| 460 | + receiver = get_receiver(broker) |
| 461 | + receiver.max_attempts_at_message = 3 |
| 462 | + |
| 463 | + broker_message = broker.formatter.dumps( |
| 464 | + TaskiqMessage( |
| 465 | + task_id="task_id", |
| 466 | + task_name=my_task.task_name, |
| 467 | + labels={}, |
| 468 | + args=[], |
| 469 | + kwargs={}, |
| 470 | + ), |
| 471 | + ) |
| 472 | + |
| 473 | + await receiver.callback( |
| 474 | + AckableMessageWithDeliveryCount( |
| 475 | + data=broker_message.message, |
| 476 | + delivery_count=3, |
| 477 | + ack=ack_callback, |
| 478 | + ), |
| 479 | + ) |
| 480 | + assert called_times == 0 |
| 481 | + assert acked |
| 482 | + |
| 483 | + |
362 | 484 | @pytest.mark.anyio
|
363 | 485 | async def test_custom_ctx() -> None:
|
364 | 486 | """Tests that run_task can run sync tasks."""
|
|
0 commit comments