Skip to content

Commit 1adb5a5

Browse files
committedDec 29, 2024
Add max_attempts_at_message
1 parent 49c0408 commit 1adb5a5

File tree

12 files changed

+185
-10
lines changed

12 files changed

+185
-10
lines changed
 

‎pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ ignore = [
159159
"ANN401", # typing.Any are disallowed in `**kwargs
160160
"PLR0913", # Too many arguments for function call
161161
"D106", # Missing docstring in public nested class
162+
"D205", # 1 blank line required between summary line and description
162163
]
163164
exclude = [".venv/"]
164165
mccabe = { max-complexity = 10 }

‎taskiq/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Distributed task manager."""
2+
23
from importlib.metadata import version
34

45
from taskiq_dependencies import Depends as TaskiqDepends
@@ -8,7 +9,7 @@
89
from taskiq.abc.middleware import TaskiqMiddleware
910
from taskiq.abc.result_backend import AsyncResultBackend
1011
from taskiq.abc.schedule_source import ScheduleSource
11-
from taskiq.acks import AckableMessage
12+
from taskiq.acks import AckableMessage, AckableMessageWithDeliveryCount
1213
from taskiq.brokers.inmemory_broker import InMemoryBroker
1314
from taskiq.brokers.shared_broker import async_shared_broker
1415
from taskiq.brokers.zmq_broker import ZeroMQBroker
@@ -24,7 +25,7 @@
2425
TaskiqResultTimeoutError,
2526
)
2627
from taskiq.funcs import gather
27-
from taskiq.message import BrokerMessage, TaskiqMessage
28+
from taskiq.message import BrokerMessage, DeliveryCountMessage, TaskiqMessage
2829
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
2930
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
3031
from taskiq.result import TaskiqResult
@@ -53,6 +54,8 @@
5354
"NoResultError",
5455
"SendTaskError",
5556
"AckableMessage",
57+
"DeliveryCountMessage",
58+
"AckableMessageWithDeliveryCount",
5659
"InMemoryBroker",
5760
"ScheduleSource",
5861
"TaskiqScheduler",

‎taskiq/abc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Abstract classes for taskiq."""
2+
23
from taskiq.abc.broker import AsyncBroker
34
from taskiq.abc.result_backend import AsyncResultBackend
45

‎taskiq/abc/broker.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self,
7878
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
7979
task_id_generator: Optional[Callable[[], str]] = None,
80+
max_attempts_at_message: Optional[int] = None,
8081
) -> None:
8182
if result_backend is None:
8283
result_backend = DummyResultBackend()
@@ -113,6 +114,7 @@ def __init__(
113114
self.state = TaskiqState()
114115
self.custom_dependency_context: Dict[Any, Any] = {}
115116
self.dependency_overrides: Dict[Any, Any] = {}
117+
self.max_attempts_at_message = max_attempts_at_message
116118
# True only if broker runs in worker process.
117119
self.is_worker_process: bool = False
118120
# True only if broker runs in scheduler process.

‎taskiq/acks.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
from typing import Awaitable, Callable, Union
33

4-
from pydantic import BaseModel
4+
from taskiq.message import DeliveryCountMessage, WrappedMessage
55

66

77
@enum.unique
@@ -20,7 +20,7 @@ class AcknowledgeType(str, enum.Enum):
2020
WHEN_SAVED = "when_saved"
2121

2222

23-
class AckableMessage(BaseModel):
23+
class AckableMessage(WrappedMessage):
2424
"""
2525
Message that can be acknowledged.
2626
@@ -33,5 +33,8 @@ class AckableMessage(BaseModel):
3333
as a whole.
3434
"""
3535

36-
data: bytes
3736
ack: Callable[[], Union[None, Awaitable[None]]]
37+
38+
39+
class AckableMessageWithDeliveryCount(AckableMessage, DeliveryCountMessage):
40+
"""Message that can be acknowledged and has a delivery count."""

‎taskiq/cli/worker/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
143143
ack_type=args.ack_type,
144144
max_tasks_to_execute=args.max_tasks_per_child,
145145
wait_tasks_timeout=args.wait_tasks_timeout,
146+
max_attempts_at_message=broker.max_attempts_at_message,
146147
**receiver_kwargs, # type: ignore
147148
)
148149
loop.run_until_complete(receiver.listen())

‎taskiq/message.py

+12
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ class BrokerMessage(BaseModel):
4242
task_name: str
4343
message: bytes
4444
labels: Dict[str, Any]
45+
46+
47+
class WrappedMessage(BaseModel):
48+
"""Abstraction for an incoming message in a wrapper."""
49+
50+
data: bytes
51+
52+
53+
class DeliveryCountMessage(WrappedMessage):
54+
"""Message with a present delivery count."""
55+
56+
delivery_count: Optional[int] = None

‎taskiq/receiver/receiver.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from taskiq.acks import AcknowledgeType
1414
from taskiq.context import Context
1515
from taskiq.exceptions import NoResultError
16-
from taskiq.message import TaskiqMessage
16+
from taskiq.message import DeliveryCountMessage, TaskiqMessage, WrappedMessage
1717
from taskiq.receiver.params_parser import parse_params
1818
from taskiq.result import TaskiqResult
1919
from taskiq.state import TaskiqState
@@ -58,6 +58,7 @@ def __init__(
5858
on_exit: Optional[Callable[["Receiver"], None]] = None,
5959
max_tasks_to_execute: Optional[int] = None,
6060
wait_tasks_timeout: Optional[float] = None,
61+
max_attempts_at_message: Optional[int] = None,
6162
) -> None:
6263
self.broker = broker
6364
self.executor = executor
@@ -72,6 +73,7 @@ def __init__(
7273
self.known_tasks: Set[str] = set()
7374
self.max_tasks_to_execute = max_tasks_to_execute
7475
self.wait_tasks_timeout = wait_tasks_timeout
76+
self.max_attempts_at_message = max_attempts_at_message
7577
for task in self.broker.get_all_tasks().values():
7678
self._prepare_task(task.task_name, task.original_func)
7779
self.sem: "Optional[asyncio.Semaphore]" = None
@@ -86,7 +88,7 @@ def __init__(
8688

8789
async def callback( # noqa: C901, PLR0912
8890
self,
89-
message: Union[bytes, AckableMessage],
91+
message: Union[bytes, WrappedMessage],
9092
raise_err: bool = False,
9193
) -> None:
9294
"""
@@ -101,7 +103,31 @@ async def callback( # noqa: C901, PLR0912
101103
:param raise_err: raise an error if cannot save result in
102104
result_backend.
103105
"""
104-
message_data = message.data if isinstance(message, AckableMessage) else message
106+
message_data = message.data if isinstance(message, WrappedMessage) else message
107+
108+
delivery_count = (
109+
message.delivery_count
110+
if isinstance(message, DeliveryCountMessage)
111+
else None
112+
)
113+
if (
114+
delivery_count
115+
and self.max_attempts_at_message
116+
and delivery_count >= self.max_attempts_at_message
117+
):
118+
logger.error(
119+
"Permitted number of attempts at processing message %s "
120+
"has been exhausted after %s attempts.",
121+
message_data,
122+
self.max_attempts_at_message,
123+
)
124+
if isinstance(
125+
message,
126+
AckableMessage,
127+
):
128+
await maybe_awaitable(message.ack())
129+
return
130+
105131
try:
106132
taskiq_msg = self.broker.formatter.loads(message=message_data)
107133
taskiq_msg.parse_labels()

‎taskiq/schedule_sources/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Package for schedule sources."""
2+
23
from taskiq.schedule_sources.label_based import LabelScheduleSource
34

45
__all__ = [

‎taskiq/scheduler/created_schedule.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ async def kiq(
3232
...
3333

3434
@overload
35-
async def kiq(self: "CreatedSchedule[_ReturnType]") -> AsyncTaskiqTask[_ReturnType]:
35+
async def kiq(
36+
self: "CreatedSchedule[_ReturnType]",
37+
) -> AsyncTaskiqTask[_ReturnType]:
3638
...
3739

3840
async def kiq(self) -> Any:

‎taskiq/serializers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Taskiq serializers."""
2+
23
from .cbor_serializer import CBORSerializer
34
from .json_serializer import JSONSerializer
45
from .msgpack_serializer import MSGPackSerializer

‎tests/receiver/test_receiver.py

+123-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
from taskiq.abc.broker import AckableMessage, AsyncBroker
1111
from taskiq.abc.middleware import TaskiqMiddleware
12+
from taskiq.acks import AckableMessageWithDeliveryCount
1213
from taskiq.brokers.inmemory_broker import InMemoryBroker
1314
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
14-
from taskiq.message import TaskiqMessage
15+
from taskiq.message import DeliveryCountMessage, TaskiqMessage
1516
from taskiq.receiver import Receiver
1617
from taskiq.result import TaskiqResult
1718
from tests.utils import AsyncQueueBroker
@@ -359,6 +360,127 @@ async def test_callback_unknown_task() -> None:
359360
await receiver.callback(broker_message.message)
360361

361362

363+
@pytest.mark.anyio
364+
@pytest.mark.parametrize("delivery_count", [2, None])
365+
async def test_callback_max_attempts_at_message_not_exceeded(
366+
delivery_count: Optional[int],
367+
) -> None:
368+
"""
369+
Test that callback function calls the task if `max_attempts_at_message`
370+
is not exceeded.
371+
"""
372+
broker = InMemoryBroker()
373+
called_times = 0
374+
375+
@broker.task
376+
async def my_task() -> int:
377+
nonlocal called_times
378+
called_times += 1
379+
return 1
380+
381+
receiver = get_receiver(broker)
382+
receiver.max_attempts_at_message = 3
383+
384+
broker_message = broker.formatter.dumps(
385+
TaskiqMessage(
386+
task_id="task_id",
387+
task_name=my_task.task_name,
388+
labels={},
389+
args=[],
390+
kwargs={},
391+
),
392+
)
393+
394+
await receiver.callback(
395+
DeliveryCountMessage(
396+
data=broker_message.message,
397+
delivery_count=delivery_count,
398+
),
399+
)
400+
assert called_times == 1
401+
402+
403+
@pytest.mark.anyio
404+
async def test_callback_max_attempts_at_message_exceeded() -> None:
405+
"""
406+
Test that callback function does not call the task if `max_attempts_at_message`
407+
is exceeded.
408+
"""
409+
broker = InMemoryBroker()
410+
called_times = 0
411+
412+
@broker.task
413+
async def my_task() -> int:
414+
nonlocal called_times
415+
called_times += 1
416+
return 1
417+
418+
receiver = get_receiver(broker)
419+
receiver.max_attempts_at_message = 3
420+
421+
broker_message = broker.formatter.dumps(
422+
TaskiqMessage(
423+
task_id="task_id",
424+
task_name=my_task.task_name,
425+
labels={},
426+
args=[],
427+
kwargs={},
428+
),
429+
)
430+
431+
await receiver.callback(
432+
DeliveryCountMessage(
433+
data=broker_message.message,
434+
delivery_count=3,
435+
),
436+
)
437+
assert called_times == 0
438+
439+
440+
@pytest.mark.anyio
441+
async def test_callback_max_attempts_at_message_exceeded_ackable() -> None:
442+
"""
443+
Test that callback function does not call the task if `max_attempts_at_message`
444+
is exceeded and acks the message.
445+
"""
446+
broker = InMemoryBroker()
447+
called_times = 0
448+
acked = False
449+
450+
@broker.task
451+
async def my_task() -> int:
452+
nonlocal called_times
453+
called_times += 1
454+
return 1
455+
456+
async def ack_callback() -> None:
457+
nonlocal acked
458+
acked = True
459+
460+
receiver = get_receiver(broker)
461+
receiver.max_attempts_at_message = 3
462+
463+
broker_message = broker.formatter.dumps(
464+
TaskiqMessage(
465+
task_id="task_id",
466+
task_name=my_task.task_name,
467+
labels={},
468+
args=[],
469+
kwargs={},
470+
),
471+
)
472+
473+
await receiver.callback(
474+
AckableMessageWithDeliveryCount(
475+
data=broker_message.message,
476+
delivery_count=3,
477+
ack=ack_callback,
478+
),
479+
)
480+
assert called_times == 0
481+
assert acked
482+
483+
362484
@pytest.mark.anyio
363485
async def test_custom_ctx() -> None:
364486
"""Tests that run_task can run sync tasks."""

0 commit comments

Comments
 (0)
Failed to load comments.