From f74cce954b54a6a3d5b9e12c640b899b5905a615 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 16:38:17 +0300 Subject: [PATCH 1/4] Refactoring the NNG support solution v1 --- taskiq/brokers/nng/__init__.py | 24 + taskiq/brokers/nng/broker.py | 328 ++++++++++++++ taskiq/brokers/nng/hub.py | 482 +++++++++++++++++++++ taskiq/brokers/nng/protocol.py | 159 +++++++ taskiq/brokers/nng/storage.py | 722 +++++++++++++++++++++++++++++++ taskiq/brokers/nng_broker.py | 48 -- tests/brokers/test_nng_broker.py | 576 ++++++++++++++++++++++++ 7 files changed, 2291 insertions(+), 48 deletions(-) create mode 100644 taskiq/brokers/nng/__init__.py create mode 100644 taskiq/brokers/nng/broker.py create mode 100644 taskiq/brokers/nng/hub.py create mode 100644 taskiq/brokers/nng/protocol.py create mode 100644 taskiq/brokers/nng/storage.py delete mode 100644 taskiq/brokers/nng_broker.py create mode 100644 tests/brokers/test_nng_broker.py diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py new file mode 100644 index 00000000..0d0a2946 --- /dev/null +++ b/taskiq/brokers/nng/__init__.py @@ -0,0 +1,24 @@ +from hub import HubConfig, NNGHub +from protocol import ( + ControlMessage, + ControlResponse, + MessageKind, + TaskEnvelope, + WorkerState, + WorkerStatus, +) +from storage import QueueFullError, SQLiteJournal, StoreConfig + +__all__ = [ + 'HubConfig', + 'NNGHub', + 'ControlMessage', + 'ControlResponse', + 'MessageKind', + 'TaskEnvelope', + 'WorkerState', + 'WorkerStatus', + 'QueueFullError', + 'SQLiteJournal', + 'StoreConfig', +] diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py new file mode 100644 index 00000000..6961cbeb --- /dev/null +++ b/taskiq/brokers/nng/broker.py @@ -0,0 +1,328 @@ +"""NNG broker for taskiq — backed by a standalone :class:`NNGHub`.""" +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +import tempfile +import time +import uuid +from collections.abc import AsyncGenerator, Callable +from contextlib import suppress +from typing import Any, TypeVar + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.acks import AckableMessage +from taskiq.message import BrokerMessage + +from protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, + WorkerStatus, +) + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +_T = TypeVar("_T") + +logger = logging.getLogger(__name__) + + +def _ipc_addr(prefix: str = "taskiq-nng") -> str: + name = f"{prefix}-{os.getpid()}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +class NNGBroker(AsyncBroker): + """ + Taskiq broker backed by a standalone :class:`~taskiq.brokers.nng_hub.NNGHub`. + + The hub must be running before workers or clients start. Launch it with:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc + + **Client mode** (``is_worker_process = False``) + Only the control socket is opened. :meth:`kick` submits tasks to the + hub via a Req0 → Rep0 round-trip. + + **Worker mode** (``is_worker_process = True``) + In addition to the control socket the broker opens a unique Pull0 + socket, registers with the hub, and runs a heartbeat loop. + :meth:`listen` yields :class:`~taskiq.acks.AckableMessage` instances + whose ``ack`` callback sends the correct ``lease_id`` back to the hub. + + Thread / coroutine safety + ───────────────────────── + ``Req0`` is strictly serial (one request in-flight per socket). + ``_ctrl_lock`` serialises all :meth:`_send_control` calls so that + concurrent coroutines (heartbeat + ack + kick) never interleave frames. + + Ack correctness + ─────────────── + The hub embeds the dispatch-generated ``lease_id`` inside every + :class:`~taskiq.brokers.nng_protocol.TaskEnvelope`. The ack closure + captures it directly, so validation on the hub side always succeeds for + genuine acks and correctly rejects late/duplicate ones. + """ + + def __init__( + self, + control_addr: str, + *, + result_backend: "AsyncResultBackend[_T] | None" = None, + task_id_generator: Callable[[], str] | None = None, + worker_task_addr: str | None = None, + worker_id: str | None = None, + heartbeat_interval: float = 5.0, + lease_timeout: float = 20.0, + capacity: int = 1, + max_retries: int = 0, + retry_backoff: float = 1.0, + retry_jitter: float = 0.0, + recv_timeout_ms: int = 5_000, + send_timeout_ms: int = 5_000, + ) -> None: + """ + Initialise the NNG broker. + + :param control_addr: NNG address of the hub's Rep0 control socket. + :param result_backend: optional result backend. + :param task_id_generator: optional task ID generator. + :param worker_task_addr: NNG address this worker's Pull0 listens on. + Defaults to a unique per-process IPC path. + :param worker_id: stable identifier for this worker process. + Defaults to ``-``. + :param heartbeat_interval: seconds between heartbeat messages to hub. + :param lease_timeout: seconds a dispatched task lease remains valid. + :param capacity: max concurrent tasks this worker will accept. + :param max_retries: default max retries for submitted tasks. + :param retry_backoff: base seconds for exponential backoff. + :param retry_jitter: jitter multiplier added to backoff (0 = no jitter). + :param recv_timeout_ms: Req0 recv timeout in milliseconds. + :param send_timeout_ms: Req0 send timeout in milliseconds. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGBroker. " + "Install it with: pip install taskiq[nng]", + ) + super().__init__( + result_backend=result_backend, + task_id_generator=task_id_generator, + ) + self.control_addr = control_addr + self.worker_task_addr = worker_task_addr or _ipc_addr() + self.worker_id = worker_id or f"{os.getpid()}-{uuid.uuid4().hex[:12]}" + self.heartbeat_interval = heartbeat_interval + self.lease_timeout = lease_timeout + self.capacity = capacity + self.max_retries = max_retries + self.retry_backoff = retry_backoff + self.retry_jitter = retry_jitter + self.recv_timeout_ms = recv_timeout_ms + self.send_timeout_ms = send_timeout_ms + + self._ctrl_sock: Any = None # pynng.Req0 + self._task_sock: Any = None # pynng.Pull0 (worker mode only) + self._heartbeat_task: asyncio.Task[None] | None = None + # Req0 allows exactly one request in-flight; this lock enforces that. + self._ctrl_lock = asyncio.Lock() + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def startup(self) -> None: + """Open sockets, register with hub (worker mode), and start heartbeat.""" + self._ctrl_sock = pynng.Req0( + dial=self.control_addr, + recv_timeout=self.recv_timeout_ms, + send_timeout=self.send_timeout_ms, + ) + if self.is_worker_process: + # recv_buffer_size lets the hub pre-queue up to `capacity` task + # messages in NNG's recv buffer before listen() calls arecv(). + self._task_sock = pynng.Pull0( + listen=self.worker_task_addr, + recv_buffer_size=self.capacity, + ) + resp = await self._send_control( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.worker_task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "taskiq-nng", + }, + ) + if not resp.ok: + raise RuntimeError(f"Worker registration failed: {resp.error}") + logger.info( + "Worker %s registered with hub at %s", + self.worker_id, + self.control_addr, + ) + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), + name=f"nng-hb-{self.worker_id[:8]}", + ) + await super().startup() + + async def shutdown(self) -> None: + """Drain, unregister, cancel heartbeat, and close all sockets.""" + if self.is_worker_process: + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + with suppress(asyncio.CancelledError): + await self._heartbeat_task + if self._ctrl_sock is not None: + with suppress(Exception): + await self._send_control( + "drain", {"worker_id": self.worker_id}, + ) + await self._send_control( + "unregister", {"worker_id": self.worker_id}, + ) + if self._task_sock is not None: + with suppress(Exception): + self._task_sock.close() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + await super().shutdown() + + # ── internal helpers ────────────────────────────────────────────────────── + + async def _send_control( + self, kind: str, payload: dict[str, Any], + ) -> ControlResponse: + if self._ctrl_sock is None: + raise RuntimeError("Control socket is not open (call startup() first)") + async with self._ctrl_lock: + await self._ctrl_sock.asend( + ControlMessage(kind=kind, payload=payload).to_bytes(), + ) + raw = await self._ctrl_sock.arecv() + return ControlResponse.from_bytes(raw) + + async def _heartbeat_loop(self) -> None: + while True: + try: + await asyncio.sleep(self.heartbeat_interval) + resp = await self._send_control( + "heartbeat", {"worker_id": self.worker_id}, + ) + if not resp.ok: + logger.warning("Heartbeat rejected by hub: %s", resp.error) + except asyncio.CancelledError: + raise + except Exception as exc: + # Hub may be temporarily unreachable; log and keep trying. + logger.warning("Heartbeat failed: %s", exc) + + # ── AsyncBroker API ─────────────────────────────────────────────────────── + + async def kick(self, message: BrokerMessage) -> None: + """ + Submit a task to the hub for dispatch. + + :param message: broker message to submit. + :raises RuntimeError: if the broker has not been started or the hub + rejects the submission (e.g. queue full). + """ + if self._ctrl_sock is None: + raise RuntimeError("Broker is not started") + payload: dict[str, Any] = { + "task_id": message.task_id, + "task_name": message.task_name, + "payload_b64": base64.b64encode(message.message).decode("ascii"), + "labels": message.labels, + "lease_id": "", # hub assigns the real lease_id at dispatch time + "attempts": int(message.labels.get("attempts", 0)), + "max_retries": int( + message.labels.get("max_retries", self.max_retries), + ), + "retry_backoff": float( + message.labels.get("retry_backoff", self.retry_backoff), + ), + "retry_jitter": float( + message.labels.get("retry_jitter", self.retry_jitter), + ), + "priority": int(message.labels.get("priority", 0)), + "created_at": time.time(), + } + resp = await self._send_control("submit", payload) + if not resp.ok: + raise RuntimeError(resp.error or "task submission failed") + + async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: + """ + Yield incoming tasks as :class:`~taskiq.acks.AckableMessage` instances. + + Each message's ``ack`` callback sends the hub-issued ``lease_id`` back + so the hub can validate the ack and reject any late/duplicate ones. + + :raises RuntimeError: if called outside worker mode or before startup. + :yields: ackable task messages. + """ + if not self.is_worker_process: + raise RuntimeError("listen() is only valid in worker mode") + if self._task_sock is None: + raise RuntimeError("Task socket is not open (call startup() first)") + + while True: + try: + raw = await self._task_sock.arecv() + except pynng.Closed: + logger.info("Task socket closed; stopping listen()") + return + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Task arecv error: %s", exc) + continue + + try: + envelope = TaskEnvelope.from_bytes(raw) + except Exception as exc: + logger.error("Malformed task envelope discarded: %s", exc) + continue + + task_id = envelope.task_id + worker_id = self.worker_id + lease_id = envelope.lease_id # hub-assigned; correct by construction + + async def _ack( + _task_id: str = task_id, + _worker_id: str = worker_id, + _lease_id: str = lease_id, + ) -> None: + try: + resp = await self._send_control( + "ack", + { + "task_id": _task_id, + "worker_id": _worker_id, + "lease_id": _lease_id, + }, + ) + if not resp.ok: + logger.debug( + "Ack rejected for %s (late/duplicate): %s", + _task_id, resp.error, + ) + except Exception as exc: + logger.warning("Ack send failed for %s: %s", _task_id, exc) + + yield AckableMessage(data=envelope.payload, ack=_ack) diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py new file mode 100644 index 00000000..844055c5 --- /dev/null +++ b/taskiq/brokers/nng/hub.py @@ -0,0 +1,482 @@ +""" +NNG hub: central control plane, task dispatcher, and lease manager. + +Run as a standalone process:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\ + --task-db /var/lib/taskiq/tasks.db + +Or embed it in an application for testing:: + + hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:")) + await hub.start() + ... + await hub.stop() +""" +from __future__ import annotations + +import argparse +import asyncio +import base64 +import json +import logging +import os +import signal +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +from protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, +) +from storage import QueueFullError, SQLiteJournal, StoreConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class HubConfig: + """Configuration for :class:`NNGHub`.""" + + control_addr: str + task_db: str + max_pending: int = 10_000 + heartbeat_timeout: float = 15.0 + lease_timeout: float = 20.0 + dispatch_interval: float = 0.05 + reaper_interval: float = 0.5 + routing_policy: str = "least_loaded" + backoff_cap: float = 60.0 + # Number of concurrent Rep0 contexts. Each context handles one req/rep + # pair independently; N contexts ≈ N simultaneous control-plane clients. + control_concurrency: int = 16 + dispatch_batch: int = 50 + # Per-context recv timeout in ms. Allows the stop event to be checked + # even when there are no incoming messages. + recv_timeout_ms: int = 1_000 + + +class NNGHub: + """ + Stateful central hub: control plane, task dispatcher, and lease manager. + + Architecture + ──────────── + **Control plane** — ``Rep0`` socket with ``control_concurrency`` + independent ``nng_ctx`` contexts running concurrently. Each context + handles one request-reply at a time, so N workers can + register/heartbeat/ack simultaneously without queuing behind each other. + This is the key fix over the single-context (serial) Rep0 in v2. + + **Data plane** — One ``Push0`` socket per registered worker, dialed to + the worker's own ``Pull0`` listen address. The hub explicitly targets + the least-loaded worker instead of relying on NNG round-robin, giving + us load-aware routing. + + **Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in + WAL mode. All storage calls are executed on a single-threaded + ``ThreadPoolExecutor`` so the asyncio event loop is never blocked and + SQLite write serialisation is guaranteed. + + **Recovery** — On startup, tasks leased to workers that died during the + previous hub session are automatically requeued. + """ + + def __init__(self, config: HubConfig) -> None: + """ + Initialise the hub with the given configuration. + + :param config: hub configuration. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGHub. " + "Install it with: pip install taskiq[nng]" + ) + self.config = config + self.store = SQLiteJournal( + StoreConfig( + path=config.task_db, + max_pending=config.max_pending, + lease_timeout=config.lease_timeout, + backoff_cap=config.backoff_cap, + ), + ) + self._stop = asyncio.Event() + self._ctrl_sock: Any = None # pynng.Rep0 + self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 + self._tasks: list[asyncio.Task[None]] = [] + # Single-threaded executor: serialises all SQLite calls on one OS thread. + self._db_exec = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="nng-db" + ) + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def start(self) -> None: + """Start the hub: recover orphaned tasks, open sockets, spawn loops.""" + await self._db(self.store.recover_dead_workers, self.config.heartbeat_timeout) + + self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr) + self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms + + self._tasks = [ + asyncio.create_task(self._dispatch_loop(), name="hub-dispatch"), + asyncio.create_task(self._reaper_loop(), name="hub-reaper"), + ] + for i in range(self.config.control_concurrency): + ctx = self._ctrl_sock.new_context() + self._tasks.append( + asyncio.create_task( + self._control_handler(ctx), name=f"hub-ctrl-{i}" + ), + ) + logger.info( + "NNG hub started on %s (db=%s)", + self.config.control_addr, + self.config.task_db, + ) + + async def stop(self) -> None: + """Gracefully stop all hub loops and close sockets.""" + logger.info("NNG hub stopping…") + self._stop.set() + for t in self._tasks: + t.cancel() + with suppress(asyncio.CancelledError): + await t + for sock in self._worker_push.values(): + with suppress(Exception): + sock.close() + self._worker_push.clear() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + self._db_exec.shutdown(wait=True) + logger.info("NNG hub stopped") + + # ── DB helper ───────────────────────────────────────────────────────────── + + async def _db(self, fn: Any, *args: Any, **kwargs: Any) -> Any: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self._db_exec, lambda: fn(*args, **kwargs) + ) + + # ── control plane ───────────────────────────────────────────────────────── + + async def _control_handler(self, ctx: Any) -> None: + """Run one Rep0 context: receive → dispatch → reply, in a loop.""" + while not self._stop.is_set(): + try: + raw = await ctx.arecv() + except pynng.Timeout: + continue + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control recv error: %s", exc) + continue + + try: + response = await self._handle(raw) + except Exception as exc: + logger.exception("Unhandled error in control handler") + response = ControlResponse(ok=False, error=str(exc)) + + try: + await ctx.asend(response.to_bytes()) + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control send error: %s", exc) + + async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 + """Dispatch a raw control message to the appropriate handler.""" + msg = ControlMessage.from_bytes(raw) + + if msg.kind == "ping": + return ControlResponse(ok=True, payload={"pong": True}) + + if msg.kind == "submit": + return await self._handle_submit(msg.payload) + + if msg.kind == "register": + return await self._handle_register(msg.payload) + + if msg.kind == "heartbeat": + await self._db(self.store.heartbeat, msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"ok": True}) + + if msg.kind == "unregister": + return await self._handle_unregister(msg.payload["worker_id"]) + + if msg.kind == "drain": + await self._db(self.store.mark_draining, msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"draining": True}) + + if msg.kind == "ack": + ok = await self._db( + self.store.ack, + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + ) + return ControlResponse(ok=ok, payload={"acked": ok}) + + if msg.kind == "nack": + ok = await self._db( + self.store.nack, + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + msg.payload.get("error", "unknown error"), + ) + return ControlResponse(ok=ok, payload={"nacked": ok}) + + if msg.kind == "status": + task = await self._db(self.store.get_task, msg.payload["task_id"]) + return ControlResponse(ok=bool(task), payload=dict(task) if task else {}) + + if msg.kind == "stats": + s = await self._db(self.store.stats) + return ControlResponse(ok=True, payload=s) + + return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}") + + async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse: + envelope = TaskEnvelope(**payload) + try: + await self._db(self.store.submit, envelope) + return ControlResponse(ok=True, payload={"task_id": envelope.task_id}) + except QueueFullError: + return ControlResponse(ok=False, error="queue full") + + async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: + worker = WorkerState(**payload) + await self._db(self.store.register_worker, worker) + if worker.worker_id not in self._worker_push: + try: + sock = pynng.Push0(dial=worker.task_addr) + self._worker_push[worker.worker_id] = sock + except Exception as exc: + logger.error( + "Failed to dial worker %s at %s: %s", + worker.worker_id, worker.task_addr, exc, + ) + return ControlResponse(ok=False, error=f"dial failed: {exc}") + return ControlResponse(ok=True, payload={"registered": True}) + + async def _handle_unregister(self, worker_id: str) -> ControlResponse: + await self._db(self.store.unregister_worker, worker_id) + sock = self._worker_push.pop(worker_id, None) + if sock is not None: + with suppress(Exception): + sock.close() + return ControlResponse(ok=True, payload={"unregistered": True}) + + # ── dispatch loop ───────────────────────────────────────────────────────── + + async def _dispatch_loop(self) -> None: + while not self._stop.is_set(): + try: + sent = await self._dispatch_once() + if not sent: + await asyncio.sleep(self.config.dispatch_interval) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Dispatch loop error") + await asyncio.sleep(self.config.dispatch_interval) + + async def _dispatch_once(self) -> bool: + """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" + due = await self._db(self.store.due_tasks, self.config.dispatch_batch) + if not due: + return False + sent_any = False + for row in due: + worker = await self._db( + self.store.choose_worker, + self.config.routing_policy, + heartbeat_timeout=self.config.heartbeat_timeout, + ) + if worker is None: + return sent_any # no capacity; leave remaining tasks in queue + + worker_id = worker["worker_id"] + lease_id = uuid.uuid4().hex + lease_until = time.time() + self.config.lease_timeout + + if not await self._db( + self.store.mark_leased, + row["task_id"], worker_id, lease_id, lease_until, + ): + continue # concurrent dispatch race; task already taken + + sock = self._worker_push.get(worker_id) + if sock is None: + logger.warning( + "No push socket for worker %s, requeueing %s", + worker_id, row["task_id"], + ) + await self._db( + self.store.nack, + row["task_id"], worker_id, lease_id, "no socket", + ) + continue + + # Include the hub-generated lease_id so the worker can ack with + # the exact token. Omitting it was the core correctness bug in v2. + envelope = TaskEnvelope( + task_id=row["task_id"], + task_name=row["task_name"], + payload_b64=base64.b64encode(row["payload"]).decode("ascii"), + labels=json.loads(row["labels_json"]), + lease_id=lease_id, + attempts=int(row["attempts"]) + 1, + max_retries=int(row["max_retries"]), + retry_backoff=float(row["retry_backoff"]), + retry_jitter=float(row["retry_jitter"]), + priority=int(row["priority"]), + created_at=float(row["created_at"]), + ) + try: + await sock.asend(envelope.to_bytes()) + sent_any = True + except Exception as exc: + logger.warning( + "Failed to deliver %s to worker %s: %s", + row["task_id"], worker_id, exc, + ) + await self._db( + self.store.nack, + row["task_id"], worker_id, lease_id, + f"dispatch send failed: {exc}", + ) + return sent_any + + # ── reaper loop ─────────────────────────────────────────────────────────── + + async def _reaper_loop(self) -> None: + while not self._stop.is_set(): + try: + await asyncio.sleep(self.config.reaper_interval) + reaped = await self._db(self.store.reap_expired_leases) + if reaped: + logger.debug("Reaped %d expired leases", reaped) + recovered = await self._db( + self.store.recover_dead_workers, + self.config.heartbeat_timeout, + ) + if recovered: + logger.info("Requeued %d tasks from dead workers", recovered) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Reaper loop error") + + +# ── standalone CLI entry point ──────────────────────────────────────────────── + +def _build_config() -> HubConfig: + p = argparse.ArgumentParser( + description="taskiq-nng-hub — NNG task router, dispatcher, and lease manager", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--control-addr", + default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"), + help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR", + ) + p.add_argument( + "--task-db", + default=os.getenv("NNG_TASK_DB", "/tmp/taskiq-nng-tasks.db"), # noqa: S108 + help="Path to the SQLite WAL task journal. Env: NNG_TASK_DB", + ) + p.add_argument( + "--max-pending", + type=int, + default=int(os.getenv("NNG_MAX_PENDING", "10000")), + ) + p.add_argument( + "--heartbeat-timeout", + type=float, + default=float(os.getenv("NNG_HEARTBEAT_TIMEOUT", "15.0")), + help="Seconds of silence before a worker is declared dead.", + ) + p.add_argument( + "--lease-timeout", + type=float, + default=float(os.getenv("NNG_LEASE_TIMEOUT", "20.0")), + help="Seconds before an unacked task lease is reaped.", + ) + p.add_argument( + "--routing-policy", + choices=["least_loaded", "p2c"], + default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"), + ) + p.add_argument( + "--control-concurrency", + type=int, + default=int(os.getenv("NNG_CONTROL_CONCURRENCY", "16")), + help="Number of concurrent Rep0 contexts.", + ) + p.add_argument( + "--log-level", + default=os.getenv("NNG_LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + args = p.parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s %(name)-24s %(levelname)-8s %(message)s", + ) + return HubConfig( + control_addr=args.control_addr, + task_db=args.task_db, + max_pending=args.max_pending, + heartbeat_timeout=args.heartbeat_timeout, + lease_timeout=args.lease_timeout, + routing_policy=args.routing_policy, + control_concurrency=args.control_concurrency, + ) + + +async def _run(config: HubConfig) -> None: + hub = NNGHub(config) + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _on_signal() -> None: + logger.info("Shutdown signal received") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, _on_signal) + + await hub.start() + try: + await stop_event.wait() + finally: + await hub.stop() + + +def main() -> None: + """Entry point for the ``taskiq-nng-hub`` CLI command.""" + config = _build_config() + try: + asyncio.run(_run(config)) + except KeyboardInterrupt: + pass diff --git a/taskiq/brokers/nng/protocol.py b/taskiq/brokers/nng/protocol.py new file mode 100644 index 00000000..9b0b4d8e --- /dev/null +++ b/taskiq/brokers/nng/protocol.py @@ -0,0 +1,159 @@ +"""Wire protocol types for the NNG broker.""" +from __future__ import annotations + +import base64 +import enum +import json +from dataclasses import asdict, dataclass, field +from typing import Any + + +class _StrValue(str, enum.Enum): + """Base for string enums whose str() returns the plain value (Python 3.10+).""" + + def __str__(self) -> str: + return self.value + + +class MessageKind(_StrValue): + """Kinds of control-plane messages sent between broker/client and hub.""" + + SUBMIT = "submit" + REGISTER = "register" + HEARTBEAT = "heartbeat" + UNREGISTER = "unregister" + DRAIN = "drain" + ACK = "ack" + NACK = "nack" + STATUS = "status" + STATS = "stats" + PING = "ping" + + +class TaskState(_StrValue): + """Lifecycle state of a task in the hub store.""" + + READY = "ready" + LEASED = "leased" + DONE = "done" + FAILED = "failed" + + +class WorkerStatus(_StrValue): + """Lifecycle status of a registered worker.""" + + STARTING = "starting" + LISTENING = "listening" + DRAINING = "draining" + OFFLINE = "offline" + DEAD = "dead" + + +@dataclass +class TaskEnvelope: + """ + Task payload sent from hub to worker over the data plane. + + ``lease_id`` is the UUID assigned by the hub at dispatch time. + Workers must echo it back in the ACK so the hub can validate + that the ack is not stale (e.g. after lease expiry and requeue). + """ + + task_id: str + task_name: str + payload_b64: str + labels: dict[str, Any] = field(default_factory=dict) + lease_id: str = "" + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = 0.0 + + @property + def payload(self) -> bytes: + """Decode the base-64 task payload.""" + return base64.b64decode(self.payload_b64.encode("ascii")) + + @classmethod + def from_bytes(cls, raw: bytes) -> TaskEnvelope: + """Deserialise from JSON bytes.""" + return cls(**json.loads(raw.decode("utf-8"))) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + asdict(self), separators=(",", ":"), ensure_ascii=False + ).encode("utf-8") + + +@dataclass +class ControlMessage: + """Request sent over the control plane (Req0 → Rep0).""" + + kind: str + payload: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlMessage: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls(kind=data["kind"], payload=data.get("payload", {})) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"kind": self.kind, "payload": self.payload}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class ControlResponse: + """Response sent back over the control plane (Rep0 → Req0).""" + + ok: bool + payload: dict[str, Any] = field(default_factory=dict) + error: str | None = None + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlResponse: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls( + ok=data["ok"], + payload=data.get("payload", {}), + error=data.get("error"), + ) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"ok": self.ok, "payload": self.payload, "error": self.error}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class WorkerState: + """Snapshot of a worker's identity and capacity at registration time.""" + + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: WorkerStatus = WorkerStatus.STARTING + version: str = "unknown" + + def to_dict(self) -> dict[str, Any]: + """Convert to a plain dict, serialising the status enum to its string value.""" + d = asdict(self) + d["status"] = str(self.status) + return d diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py new file mode 100644 index 00000000..410a7064 --- /dev/null +++ b/taskiq/brokers/nng/storage.py @@ -0,0 +1,722 @@ +"""Durable WAL-mode SQLite task journal for the NNG hub.""" +from __future__ import annotations + +import json +import random +import sqlite3 +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generator + +from protocol import TaskEnvelope, WorkerState, WorkerStatus + + +@dataclass +class StoreConfig: + """Configuration for the SQLite task journal.""" + + path: str + max_pending: int = 10_000 + lease_timeout: float = 30.0 + backoff_base: float = 1.0 + backoff_cap: float = 60.0 + + +class QueueFullError(RuntimeError): + """Raised when a submission is attempted on a full queue.""" + + +class SQLiteJournal: + """ + Thread-safe, WAL-mode SQLite task store. + + Design notes + ──────────── + * Every method opens and closes its own connection. WAL allows concurrent + readers without blocking; SQLite serialises concurrent writers internally, + and the Python-level ``_submit_lock`` prevents the TOCTOU race in + :meth:`submit`. + * The hub runs every call through a single-threaded + ``ThreadPoolExecutor`` so, in practice, writes never contend at the + OS level either. + * ``PRAGMA`` settings (WAL, synchronous, busy_timeout) are applied per + connection because each ``sqlite3.connect()`` call starts with defaults. + """ + + def __init__(self, config: StoreConfig) -> None: + """Initialise the journal and create schema if not present.""" + self.config = config + # Guards only the pending_count check + INSERT pair in submit() to + # prevent concurrent callers from racing past max_pending. + self._submit_lock = threading.Lock() + self._init() + + # ── connection ──────────────────────────────────────────────────────────── + + @contextmanager + def _conn(self) -> Generator[sqlite3.Connection, None, None]: + conn = sqlite3.connect( + self.config.path, + timeout=10.0, + check_same_thread=False, + isolation_level=None, # we manage transactions explicitly + ) + conn.row_factory = sqlite3.Row + # Must be set per-connection, not just once at schema creation. + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") # safe with WAL; faster than FULL + conn.execute("PRAGMA busy_timeout=5000") # wait up to 5s before SQLITE_BUSY + conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache + try: + yield conn + finally: + conn.close() + + def _init(self) -> None: + Path(self.config.path).parent.mkdir(parents=True, exist_ok=True) + with self._conn() as conn: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS tasks ( + task_id TEXT PRIMARY KEY, + task_name TEXT NOT NULL, + payload BLOB NOT NULL, + labels_json TEXT NOT NULL DEFAULT '{}', + state TEXT NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 0, + retry_backoff REAL NOT NULL DEFAULT 1.0, + retry_jitter REAL NOT NULL DEFAULT 0.0, + priority INTEGER NOT NULL DEFAULT 0, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + next_run_at REAL NOT NULL, + lease_id TEXT, + leased_worker_id TEXT, + lease_until REAL, + last_error TEXT + ); + + CREATE TABLE IF NOT EXISTS workers ( + worker_id TEXT PRIMARY KEY, + task_addr TEXT NOT NULL, + capacity INTEGER NOT NULL, + inflight INTEGER NOT NULL DEFAULT 0, + last_seen REAL NOT NULL DEFAULT 0, + heartbeat_interval REAL NOT NULL DEFAULT 5.0, + lease_timeout REAL NOT NULL DEFAULT 15.0, + draining INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL, + version TEXT NOT NULL DEFAULT 'unknown' + ); + + CREATE TABLE IF NOT EXISTS journal ( + seq INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL NOT NULL, + kind TEXT NOT NULL, + payload_json TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_tasks_dispatch + ON tasks (state, next_run_at, priority DESC); + CREATE INDEX IF NOT EXISTS idx_tasks_lease + ON tasks (state, lease_until); + CREATE INDEX IF NOT EXISTS idx_workers_active + ON workers (status, draining, last_seen); + """) + + # ── helpers ─────────────────────────────────────────────────────────────── + + def _journal( + self, + conn: sqlite3.Connection, + kind: str, + payload: dict[str, Any], + ) -> None: + conn.execute( + "INSERT INTO journal (ts, kind, payload_json) VALUES (?, ?, ?)", + ( + time.time(), + kind, + json.dumps(payload, separators=(",", ":"), ensure_ascii=False), + ), + ) + + def _backoff(self, attempts: int, backoff_base: float) -> float: + return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + + # ── task lifecycle ──────────────────────────────────────────────────────── + + def pending_count(self) -> int: + """Return the number of ready + leased tasks.""" + with self._conn() as conn: + return int( + conn.execute( + "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", + ).fetchone()[0], + ) + + def submit(self, envelope: TaskEnvelope) -> None: + """ + Persist a new task in 'ready' state. + + :param envelope: task envelope to store. + :raises QueueFullError: when ``max_pending`` is reached. + """ + now = time.time() + with self._submit_lock, self._conn() as conn: + count = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", + ).fetchone()[0] + if count >= self.config.max_pending: + raise QueueFullError("Task queue is full.") + conn.execute("BEGIN") + conn.execute( + """ + INSERT INTO tasks ( + task_id, task_name, payload, labels_json, state, + attempts, max_retries, retry_backoff, retry_jitter, + priority, created_at, updated_at, next_run_at + ) VALUES (?, ?, ?, ?, 'ready', 0, ?, ?, ?, ?, ?, ?, ?) + """, + ( + envelope.task_id, + envelope.task_name, + envelope.payload, + json.dumps( + envelope.labels, separators=(",", ":"), ensure_ascii=False + ), + envelope.max_retries, + envelope.retry_backoff, + envelope.retry_jitter, + envelope.priority, + envelope.created_at or now, + now, + now, + ), + ) + self._journal( + conn, + "task_submitted", + {"task_id": envelope.task_id, "task_name": envelope.task_name}, + ) + conn.execute("COMMIT") + + def due_tasks(self, limit: int = 50) -> list[sqlite3.Row]: + """ + Return ready tasks whose ``next_run_at`` is in the past. + + Results are ordered by priority (descending) then creation time. + + :param limit: maximum number of rows to return. + :return: list of task rows. + """ + now = time.time() + with self._conn() as conn: + return list( + conn.execute( + """ + SELECT * FROM tasks + WHERE state = 'ready' AND next_run_at <= ? + ORDER BY priority DESC, created_at ASC + LIMIT ? + """, + (now, limit), + ), + ) + + def mark_leased( + self, + task_id: str, + worker_id: str, + lease_id: str, + lease_until: float, + ) -> bool: + """ + Atomically transition a task from 'ready' to 'leased'. + + :param task_id: task to lease. + :param worker_id: worker receiving the task. + :param lease_id: unique token for this dispatch attempt. + :param lease_until: absolute epoch deadline for the lease. + :return: True if the transition succeeded; False if the task was + already taken (concurrent dispatch race). + """ + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT state FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + if not row or row["state"] != "ready": + return False + conn.execute("BEGIN") + conn.execute( + """ + UPDATE tasks + SET state = 'leased', + leased_worker_id = ?, lease_id = ?, lease_until = ?, + attempts = attempts + 1, updated_at = ? + WHERE task_id = ? + """, + (worker_id, lease_id, lease_until, now, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = inflight + 1 WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_leased", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + }, + ) + conn.execute("COMMIT") + return True + + def ack(self, task_id: str, worker_id: str, lease_id: str) -> bool: + """ + Mark a task as successfully completed. + + Late or duplicate acks (mismatched ``lease_id`` or state ≠ 'leased') + are silently rejected and return False. + + :param task_id: task being acknowledged. + :param worker_id: worker sending the ack. + :param lease_id: lease token that was issued at dispatch. + :return: True if the ack was accepted. + """ + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT state, lease_id, leased_worker_id FROM tasks WHERE task_id = ?", + (task_id,), + ).fetchone() + if not row or row["state"] != "leased": + return False + if row["lease_id"] != lease_id or row["leased_worker_id"] != worker_id: + return False + conn.execute("BEGIN") + conn.execute( + """ + UPDATE tasks + SET state = 'done', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL + WHERE task_id = ? + """, + (now, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_acked", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + }, + ) + conn.execute("COMMIT") + return True + + def nack( + self, task_id: str, worker_id: str, lease_id: str, error: str + ) -> bool: + """ + Explicitly fail a task, triggering retry or permanent failure. + + :param task_id: task being nacked. + :param worker_id: worker sending the nack. + :param lease_id: lease token issued at dispatch. + :param error: human-readable reason for the failure. + :return: True if the nack was accepted. + """ + return self._requeue_or_fail(task_id, worker_id, lease_id, error) + + def _requeue_or_fail( + self, task_id: str, worker_id: str, lease_id: str, error: str + ) -> bool: + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT * FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + if ( + not row + or row["state"] != "leased" + or row["lease_id"] != lease_id + or row["leased_worker_id"] != worker_id + ): + return False + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + conn.execute("BEGIN") + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = ? + WHERE task_id = ? + """, + (now, error, task_id), + ) + else: + backoff = self._backoff(attempts, float(row["retry_backoff"])) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = ? + WHERE task_id = ? + """, + (now, now + backoff, error, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_nacked", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + "error": error, + "requeued": attempts <= max_retries, + }, + ) + conn.execute("COMMIT") + return True + + # ── reaper / recovery ───────────────────────────────────────────────────── + + def reap_expired_leases(self) -> int: + """ + Find leases past their deadline and requeue or permanently fail them. + + :return: number of tasks reaped. + """ + now = time.time() + with self._conn() as conn: + expired = list( + conn.execute( + """ + SELECT * FROM tasks + WHERE state = 'leased' + AND lease_until IS NOT NULL + AND lease_until < ? + """, + (now,), + ), + ) + if not expired: + return 0 + conn.execute("BEGIN") + count = 0 + for row in expired: + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + worker_id = row["leased_worker_id"] + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'lease expired' + WHERE task_id = ? + """, + (now, row["task_id"]), + ) + else: + backoff = self._backoff(attempts, float(row["retry_backoff"])) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'lease expired' + WHERE task_id = ? + """, + (now, now + backoff, row["task_id"]), + ) + if worker_id: + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "lease_reaped", + { + "task_id": row["task_id"], + "worker_id": worker_id, + "lease_id": row["lease_id"], + }, + ) + count += 1 + conn.execute("COMMIT") + return count + + def recover_dead_workers(self, heartbeat_timeout: float) -> int: + """ + Mark workers that missed their heartbeat deadline as DEAD. + + All tasks leased to dead workers are requeued (or permanently failed + if retries are exhausted). + + :param heartbeat_timeout: seconds of silence before a worker is dead. + :return: number of tasks requeued. + """ + now = time.time() + cutoff = now - heartbeat_timeout + with self._conn() as conn: + dead = list( + conn.execute( + "SELECT * FROM workers WHERE last_seen < ? AND status != 'dead'", + (cutoff,), + ), + ) + if not dead: + return 0 + conn.execute("BEGIN") + requeued = 0 + for worker in dead: + worker_id = worker["worker_id"] + conn.execute( + "UPDATE workers SET status = 'dead', draining = 1 WHERE worker_id = ?", + (worker_id,), + ) + leased = list( + conn.execute( + "SELECT * FROM tasks WHERE state = 'leased' AND leased_worker_id = ?", + (worker_id,), + ), + ) + for row in leased: + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'worker died' + WHERE task_id = ? + """, + (now, row["task_id"]), + ) + else: + backoff = self._backoff( + attempts, float(row["retry_backoff"]) + ) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'worker died' + WHERE task_id = ? + """, + (now, now + backoff, row["task_id"]), + ) + self._journal( + conn, + "worker_dead_requeue", + {"worker_id": worker_id, "task_id": row["task_id"]}, + ) + requeued += 1 + conn.execute("COMMIT") + return requeued + + # ── worker lifecycle ────────────────────────────────────────────────────── + + def register_worker(self, worker: WorkerState) -> None: + """ + Upsert a worker record. + + Re-registering an existing worker (e.g. after hub restart) resets + its draining flag and updates its metadata. + + :param worker: worker state snapshot from the registration message. + """ + now = time.time() + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + """ + INSERT INTO workers ( + worker_id, task_addr, capacity, inflight, last_seen, + heartbeat_interval, lease_timeout, draining, status, version + ) VALUES (?, ?, ?, 0, ?, ?, ?, 0, ?, ?) + ON CONFLICT(worker_id) DO UPDATE SET + task_addr = excluded.task_addr, + capacity = excluded.capacity, + last_seen = excluded.last_seen, + heartbeat_interval = excluded.heartbeat_interval, + lease_timeout = excluded.lease_timeout, + draining = 0, + status = excluded.status, + version = excluded.version + """, + ( + worker.worker_id, + worker.task_addr, + worker.capacity, + now, + worker.heartbeat_interval, + worker.lease_timeout, + str(WorkerStatus.LISTENING), + worker.version, + ), + ) + self._journal( + conn, + "worker_register", + {"worker_id": worker.worker_id, "task_addr": worker.task_addr}, + ) + conn.execute("COMMIT") + + def heartbeat(self, worker_id: str) -> None: + """ + Record a heartbeat from a worker, resetting its last_seen timestamp. + + :param worker_id: ID of the worker sending the heartbeat. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + "UPDATE workers SET last_seen = ?, status = ? WHERE worker_id = ?", + (time.time(), str(WorkerStatus.LISTENING), worker_id), + ) + self._journal(conn, "heartbeat", {"worker_id": worker_id}) + conn.execute("COMMIT") + + def unregister_worker(self, worker_id: str) -> None: + """ + Remove a worker from the registry (graceful shutdown path). + + :param worker_id: ID of the worker unregistering. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) + self._journal(conn, "worker_unregister", {"worker_id": worker_id}) + conn.execute("COMMIT") + + def mark_draining(self, worker_id: str) -> None: + """ + Mark a worker as draining so the hub stops dispatching new tasks to it. + + :param worker_id: ID of the worker entering drain mode. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + "UPDATE workers SET draining = 1, status = ? WHERE worker_id = ?", + (str(WorkerStatus.DRAINING), worker_id), + ) + self._journal(conn, "worker_drain", {"worker_id": worker_id}) + conn.execute("COMMIT") + + # ── routing ─────────────────────────────────────────────────────────────── + + def choose_worker( + self, + routing_policy: str = "least_loaded", + *, + heartbeat_timeout: float = 15.0, + ) -> sqlite3.Row | None: + """ + Select the best available worker according to ``routing_policy``. + + ``'least_loaded'`` picks the worker with the lowest ``inflight / + capacity`` ratio. ``'p2c'`` (Power-of-Two-Choices) samples two + workers at random and picks the less loaded one, reducing hot-spot + probability under high concurrency. + + :param routing_policy: ``'least_loaded'`` or ``'p2c'``. + :param heartbeat_timeout: seconds before a worker is considered stale. + :return: chosen worker row, or None if no worker has capacity. + """ + cutoff = time.time() - heartbeat_timeout + with self._conn() as conn: + rows = list( + conn.execute( + """ + SELECT * FROM workers + WHERE status IN ('starting', 'listening') + AND draining = 0 + AND last_seen >= ? + """, + (cutoff,), + ), + ) + available = [ + w for w in rows if int(w["inflight"]) < int(w["capacity"]) + ] + if not available: + return None + if routing_policy == "p2c" and len(available) >= 2: + a, b = random.sample(available, k=2) # noqa: S311 + load_a = int(a["inflight"]) / max(int(a["capacity"]), 1) + load_b = int(b["inflight"]) / max(int(b["capacity"]), 1) + return a if load_a <= load_b else b + return min( + available, + key=lambda w: int(w["inflight"]) / max(int(w["capacity"]), 1), + ) + + # ── management / observability ──────────────────────────────────────────── + + def get_task(self, task_id: str) -> sqlite3.Row | None: + """ + Fetch a single task row by ID. + + :param task_id: ID of the task to look up. + :return: row or None if not found. + """ + with self._conn() as conn: + return conn.execute( + "SELECT * FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + + def list_workers(self) -> list[sqlite3.Row]: + """Return all registered workers ordered by most-recently-seen.""" + with self._conn() as conn: + return list( + conn.execute("SELECT * FROM workers ORDER BY last_seen DESC"), + ) + + def stats(self) -> dict[str, int]: + """Return a summary dict with task state counts and active worker count.""" + with self._conn() as conn: + rows = conn.execute( + "SELECT state, COUNT(*) AS n FROM tasks GROUP BY state", + ).fetchall() + counts = {r["state"]: r["n"] for r in rows} + worker_count = conn.execute( + """ + SELECT COUNT(*) FROM workers + WHERE status IN ('starting', 'listening') AND draining = 0 + """, + ).fetchone()[0] + return { + "ready": counts.get("ready", 0), + "leased": counts.get("leased", 0), + "done": counts.get("done", 0), + "failed": counts.get("failed", 0), + "active_workers": worker_count, + } diff --git a/taskiq/brokers/nng_broker.py b/taskiq/brokers/nng_broker.py deleted file mode 100644 index 15ab3aaa..00000000 --- a/taskiq/brokers/nng_broker.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import AsyncGenerator - -import pynng - -from taskiq.abc.broker import AsyncBroker -from taskiq.message import BrokerMessage - - -class NNGBroker(AsyncBroker): - """ - NanoMSG next generation broker. - - This broker is very much alike to the ZMQ broker, - It has a similar Idea, but slightly different - implementation. - """ - - def __init__(self, addr: str) -> None: - """ - Initialize the broker. - - :param addr: address which is used by both worker and client. - """ - super().__init__() - self.socket = pynng.Pair1(polyamorous=True) - self.addr = addr - - async def startup(self) -> None: - """Start the socket.""" - await super().startup() - if self.is_worker_process: - self.socket.listen(self.addr) - else: - self.socket.dial(self.addr, block=True) - - async def shutdown(self) -> None: - """Close the socket.""" - await super().shutdown() - self.socket.close() - - async def kick(self, message: BrokerMessage) -> None: - """Send a message.""" - await self.socket.ascend(message.message) - - async def listen(self) -> AsyncGenerator[bytes, None]: - """Infinite loop that receives messages.""" - while True: - yield await self.socket.arecv() diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py new file mode 100644 index 00000000..f128e757 --- /dev/null +++ b/tests/brokers/test_nng_broker.py @@ -0,0 +1,576 @@ +""" +Tests for the NNG broker, hub, storage, and protocol. + +The test suite is split into three layers: + +1. **Protocol** — pure serialisation roundtrips; no NNG sockets needed. +2. **Storage** — SQLiteJournal unit tests; no NNG sockets needed. +3. **Integration** — real NNG sockets, real SQLite, single asyncio event loop. + Uses ``FakeWorker`` / ``FakeClient`` helpers that speak the wire protocol + directly so we can inject faults precisely (crash before ack, late ack, etc.). + +All NNG tests are skipped when ``pynng`` is not installed. +""" +from __future__ import annotations + +import asyncio +import os +import sqlite3 +import tempfile +import time +import uuid + +import pytest + +pynng = pytest.importorskip("pynng") + +from taskiq.brokers.nng import ( + HubConfig, + NNGHub, + ControlMessage, + ControlResponse, + MessageKind, + TaskEnvelope, + WorkerState, + WorkerStatus, + QueueFullError, + SQLiteJournal, + StoreConfig, +) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _ipc(tag: str = "") -> str: + name = f"nng-test-{tag}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +def _envelope(**kwargs: object) -> TaskEnvelope: + defaults: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + defaults.update(kwargs) + return TaskEnvelope(**defaults) # type: ignore[arg-type] + + +def _worker_state( + worker_id: str | None = None, + task_addr: str | None = None, + capacity: int = 2, +) -> WorkerState: + wid = worker_id or uuid.uuid4().hex + return WorkerState( + worker_id=wid, + task_addr=task_addr or f"ipc:///tmp/{wid}.ipc", + capacity=capacity, + heartbeat_interval=5.0, + lease_timeout=10.0, + ) + + +def _hub(control_addr: str, db_path: str, **kwargs: object) -> NNGHub: + cfg = HubConfig( + control_addr=control_addr, + task_db=db_path, + max_pending=100, + heartbeat_timeout=2.0, + lease_timeout=2.0, + dispatch_interval=0.02, + reaper_interval=0.1, + control_concurrency=4, + **kwargs, # type: ignore[arg-type] + ) + return NNGHub(cfg) + + +@pytest.fixture +def db_path(tmp_path: object) -> str: + import pathlib + return str(pathlib.Path(str(tmp_path)) / "hub.db") # type: ignore[arg-type] + + +@pytest.fixture +def ctrl_addr() -> str: + return _ipc("ctrl") + + +class FakeWorker: + """Minimal NNG worker that speaks the control + task protocol.""" + + def __init__( + self, + control_addr: str, + task_addr: str | None = None, + capacity: int = 1, + ) -> None: + self.worker_id = uuid.uuid4().hex[:8] + self.task_addr = task_addr or _ipc("worker") + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._pull = pynng.Pull0(listen=self.task_addr, recv_timeout=3000) + self._lock = asyncio.Lock() + self.capacity = capacity + + async def ctrl(self, kind: str, payload: dict[str, object]) -> ControlResponse: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind=kind, payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw) + + async def register(self) -> None: + resp = await self.ctrl( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": 1.0, + "lease_timeout": 2.0, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "test", + }, + ) + assert resp.ok, f"register failed: {resp.error}" + + async def recv_task(self, timeout: float = 3.0) -> TaskEnvelope: + raw = await asyncio.wait_for(self._pull.arecv(), timeout=timeout) + return TaskEnvelope.from_bytes(raw) + + async def ack(self, task_id: str, lease_id: str) -> bool: + resp = await self.ctrl( + "ack", + { + "task_id": task_id, + "worker_id": self.worker_id, + "lease_id": lease_id, + }, + ) + return resp.ok + + async def heartbeat(self) -> None: + await self.ctrl("heartbeat", {"worker_id": self.worker_id}) + + async def drain_and_unregister(self) -> None: + await self.ctrl("drain", {"worker_id": self.worker_id}) + await self.ctrl("unregister", {"worker_id": self.worker_id}) + + def close(self) -> None: + self._ctrl.close() + self._pull.close() + + +class FakeClient: + """Minimal NNG client that can submit tasks and query hub status.""" + + def __init__(self, control_addr: str) -> None: + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._lock = asyncio.Lock() + + async def submit(self, **labels: object) -> str: + tid = uuid.uuid4().hex + payload: dict[str, object] = { + "task_id": tid, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": labels.pop("max_retries", 0), + "retry_backoff": labels.pop("retry_backoff", 1.0), + "retry_jitter": 0.0, + "priority": labels.pop("priority", 0), + "created_at": time.time(), + } + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert resp.ok, f"submit failed: {resp.error}" + return tid + + async def ping(self) -> bool: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="ping", payload={}).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw).ok + + def close(self) -> None: + self._ctrl.close() + + +# ── 1. Protocol tests ───────────────────────────────────────────────────────── + + +def test_control_message_roundtrip() -> None: + msg = ControlMessage(kind=MessageKind.HEARTBEAT, payload={"worker_id": "w1"}) + assert ControlMessage.from_bytes(msg.to_bytes()) == msg + + +def test_control_response_roundtrip() -> None: + resp = ControlResponse(ok=True, payload={"task_id": "abc"}, error=None) + assert ControlResponse.from_bytes(resp.to_bytes()) == resp + + +def test_task_envelope_lease_id_preserved() -> None: + """Regression: v2 omitted lease_id from the envelope, breaking ack validation.""" + env = TaskEnvelope( + task_id="x", task_name="m:f", payload_b64="YQ==", lease_id="abc123" + ) + rt = TaskEnvelope.from_bytes(env.to_bytes()) + assert rt.lease_id == "abc123" + + +def test_task_envelope_payload_decode() -> None: + env = _envelope(payload_b64="dGVzdA==") + assert env.payload == b"test" + + +# ── 2. Storage tests ────────────────────────────────────────────────────────── + + +@pytest.fixture +def store(db_path: str) -> SQLiteJournal: + return SQLiteJournal(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) + + +def test_submit_and_pending(store: SQLiteJournal) -> None: + store.submit(_envelope()) + assert store.pending_count() == 1 + + +def test_submit_queue_full(db_path: str) -> None: + s = SQLiteJournal(StoreConfig(path=db_path, max_pending=2)) + s.submit(_envelope()) + s.submit(_envelope()) + with pytest.raises(QueueFullError): + s.submit(_envelope()) + + +def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=10)) + due = store.due_tasks(limit=10) + assert due[0]["task_id"] == "hi" + assert due[1]["task_id"] == "lo" + + +def test_ack_happy_path(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + assert store.mark_leased(env.task_id, w.worker_id, "L1", time.time() + 60) + assert store.ack(env.task_id, w.worker_id, "L1") + assert store.get_task(env.task_id)["state"] == "done" + + +def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "real", time.time() + 60) + assert not store.ack(env.task_id, w.worker_id, "wrong") + + +def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L2", time.time() - 1) + assert store.reap_expired_leases() == 1 + assert not store.ack(env.task_id, w.worker_id, "L2") + + +def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: + env = _envelope(max_retries=2, retry_backoff=1.0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L3", time.time() + 60) + assert store.nack(env.task_id, w.worker_id, "L3", "boom") + task = store.get_task(env.task_id) + assert task["state"] == "ready" + assert float(task["next_run_at"]) > time.time() + + +def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: + env = _envelope(max_retries=0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L4", time.time() + 60) + store.nack(env.task_id, w.worker_id, "L4", "error") + assert store.get_task(env.task_id)["state"] == "failed" + + +def test_dead_worker_tasks_requeued(store: SQLiteJournal, db_path: str) -> None: + w = _worker_state() + store.register_worker(w) + env = _envelope(max_retries=3) + store.submit(env) + store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) + conn = sqlite3.connect(db_path) + conn.execute("UPDATE workers SET last_seen=0 WHERE worker_id=?", (w.worker_id,)) + conn.commit() + conn.close() + assert store.recover_dead_workers(heartbeat_timeout=1.0) == 1 + assert store.get_task(env.task_id)["state"] == "ready" + + +def test_choose_worker_least_loaded(store: SQLiteJournal, db_path: str) -> None: + w1 = _worker_state(worker_id="w1", capacity=4) + w2 = _worker_state(worker_id="w2", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + conn = sqlite3.connect(db_path) + conn.execute("UPDATE workers SET inflight=3 WHERE worker_id='w1'") + conn.commit() + conn.close() + chosen = store.choose_worker("least_loaded", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "w2" + + +def test_stats(store: SQLiteJournal) -> None: + w = _worker_state() + store.register_worker(w) + store.submit(_envelope()) + s = store.stats() + assert s["ready"] == 1 + assert s["active_workers"] == 1 + + +# ── 3. Integration tests ────────────────────────────────────────────────────── + + +async def test_ping(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + try: + assert await client.ping() + finally: + client.close() + await hub.stop() + + +async def test_submit_dispatch_ack(ctrl_addr: str, db_path: str) -> None: + """Golden path: one task, one worker, full round-trip.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await worker.register() + tid = await client.submit() + env = await worker.recv_task(timeout=3.0) + assert env.task_id == tid + assert env.lease_id != "", "Hub must populate lease_id in envelope" + assert await worker.ack(env.task_id, env.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + worker.close() + client.close() + await hub.stop() + + +async def test_multiple_workers_load_balanced(ctrl_addr: str, db_path: str) -> None: + """Both workers must receive at least one task — no single hot-spot.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=4) + w2 = FakeWorker(ctrl_addr, capacity=4) + client = FakeClient(ctrl_addr) + try: + await w1.register() + await w2.register() + task_ids = [await client.submit() for _ in range(6)] + received: dict[str, list[str]] = {w1.worker_id: [], w2.worker_id: []} + pending = set(task_ids) + + async def drain(w: FakeWorker) -> None: + while pending: + try: + env = await w.recv_task(timeout=0.5) + received[w.worker_id].append(env.task_id) + pending.discard(env.task_id) + await w.ack(env.task_id, env.lease_id) + except asyncio.TimeoutError: + break + + await asyncio.gather(drain(w1), drain(w2)) + assert not pending, f"Tasks not delivered: {pending}" + assert len(received[w1.worker_id]) > 0 + assert len(received[w2.worker_id]) > 0 + finally: + w1.close() + w2.close() + client.close() + await hub.stop() + + +async def test_worker_crash_before_ack_task_requeued( + ctrl_addr: str, db_path: str +) -> None: + """ + Worker receives a task but dies before acking. + After lease expiry the hub must requeue it for a second worker. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + assert env1.task_id == tid + w1.close() # simulate crash without acking + + await asyncio.sleep(3.5) # lease_timeout=2s + reaper_interval=0.1s + + assert hub.store.get_task(tid)["state"] == "ready" + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + assert env2.task_id == tid + assert env2.lease_id != env1.lease_id + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + client.close() + await hub.stop() + + +async def test_late_ack_after_requeue_rejected( + ctrl_addr: str, db_path: str +) -> None: + """ + Sequence: dispatch to w1 → lease expires → requeue → dispatch to w2. + w1's late ack must be rejected; w2's ack must succeed. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + await asyncio.sleep(3.5) # let lease expire + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + + # w1's stale ack must be rejected + assert not await w1.ack(env1.task_id, env1.lease_id) + # w2's valid ack succeeds + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + w1.close() + client.close() + await hub.stop() + + +async def test_hub_restart_recovers_orphaned_tasks( + ctrl_addr: str, db_path: str +) -> None: + """ + Tasks leased at hub shutdown must be requeued when a new hub starts + with the same database. + """ + hub1 = _hub(ctrl_addr, db_path) + await hub1.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + await w1.register() + tid = await client.submit(max_retries=3) + env = await w1.recv_task(timeout=3.0) + assert env.task_id == tid + # "kill" hub1 without giving worker a chance to ack + await hub1.stop() + w1.close() + client.close() + + # Task is still leased in the DB + assert hub1.store.get_task(tid)["state"] == "leased" + + hub2 = _hub(ctrl_addr, db_path) + await hub2.start() + await asyncio.sleep(0.3) # allow startup recovery + try: + assert hub2.store.get_task(tid)["state"] == "ready" + finally: + await hub2.stop() + + +async def test_concurrent_heartbeats(ctrl_addr: str, db_path: str) -> None: + """ + N workers heartbeat simultaneously. With concurrent Rep0 contexts all + must succeed without serialisation stalls. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + workers = [FakeWorker(ctrl_addr, capacity=2) for _ in range(8)] + try: + await asyncio.gather(*[w.register() for w in workers]) + results = await asyncio.gather( + *[w.heartbeat() for w in workers], + return_exceptions=True, + ) + errors = [r for r in results if isinstance(r, Exception)] + assert not errors, f"Concurrent heartbeats failed: {errors}" + finally: + for w in workers: + w.close() + await hub.stop() + + +async def test_graceful_drain_and_unregister(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=2) + try: + await worker.register() + assert len(hub.store.list_workers()) == 1 + await worker.drain_and_unregister() + await asyncio.sleep(0.1) + assert len(hub.store.list_workers()) == 0 + finally: + worker.close() + await hub.stop() From 7fdb8eb19cd6085419556d850466214eb4799855 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 17:10:02 +0300 Subject: [PATCH 2/4] Refactoring the NNG support solution v1.5: Simplify store --- taskiq/brokers/nng/__init__.py | 29 +- taskiq/brokers/nng/broker.py | 2 +- taskiq/brokers/nng/hub.py | 107 ++-- taskiq/brokers/nng/storage.py | 822 +++++++++++-------------------- tests/brokers/test_nng_broker.py | 74 +-- 5 files changed, 349 insertions(+), 685 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 0d0a2946..8e2a7f4a 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -1,5 +1,6 @@ -from hub import HubConfig, NNGHub -from protocol import ( +"""NNG broker package for taskiq.""" +from .hub import HubConfig, NNGHub +from .protocol import ( ControlMessage, ControlResponse, MessageKind, @@ -7,18 +8,18 @@ WorkerState, WorkerStatus, ) -from storage import QueueFullError, SQLiteJournal, StoreConfig +from .storage import InMemoryStore, QueueFullError, StoreConfig __all__ = [ - 'HubConfig', - 'NNGHub', - 'ControlMessage', - 'ControlResponse', - 'MessageKind', - 'TaskEnvelope', - 'WorkerState', - 'WorkerStatus', - 'QueueFullError', - 'SQLiteJournal', - 'StoreConfig', + "HubConfig", + "NNGHub", + "ControlMessage", + "ControlResponse", + "MessageKind", + "TaskEnvelope", + "WorkerState", + "WorkerStatus", + "QueueFullError", + "InMemoryStore", + "StoreConfig", ] diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py index 6961cbeb..a6273e41 100644 --- a/taskiq/brokers/nng/broker.py +++ b/taskiq/brokers/nng/broker.py @@ -17,7 +17,7 @@ from taskiq.acks import AckableMessage from taskiq.message import BrokerMessage -from protocol import ( +from .protocol import ( ControlMessage, ControlResponse, TaskEnvelope, diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index 844055c5..c58857bf 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -3,12 +3,11 @@ Run as a standalone process:: - taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\ - --task-db /var/lib/taskiq/tasks.db + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc Or embed it in an application for testing:: - hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:")) + hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc")) await hub.start() ... await hub.stop() @@ -18,13 +17,11 @@ import argparse import asyncio import base64 -import json import logging import os import signal import time import uuid -from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from dataclasses import dataclass, field from typing import Any @@ -34,13 +31,13 @@ except ImportError: pynng = None # type: ignore[assignment] -from protocol import ( +from .protocol import ( ControlMessage, ControlResponse, TaskEnvelope, WorkerState, ) -from storage import QueueFullError, SQLiteJournal, StoreConfig +from .storage import InMemoryStore, QueueFullError, StoreConfig logger = logging.getLogger(__name__) @@ -50,7 +47,7 @@ class HubConfig: """Configuration for :class:`NNGHub`.""" control_addr: str - task_db: str + task_db: str = "" # kept for API compat; ignored by in-memory store max_pending: int = 10_000 heartbeat_timeout: float = 15.0 lease_timeout: float = 20.0 @@ -77,20 +74,18 @@ class NNGHub: independent ``nng_ctx`` contexts running concurrently. Each context handles one request-reply at a time, so N workers can register/heartbeat/ack simultaneously without queuing behind each other. - This is the key fix over the single-context (serial) Rep0 in v2. **Data plane** — One ``Push0`` socket per registered worker, dialed to the worker's own ``Pull0`` listen address. The hub explicitly targets - the least-loaded worker instead of relying on NNG round-robin, giving - us load-aware routing. + the least-loaded worker instead of relying on NNG round-robin. - **Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in - WAL mode. All storage calls are executed on a single-threaded - ``ThreadPoolExecutor`` so the asyncio event loop is never blocked and - SQLite write serialisation is guaranteed. + **State** — :class:`~taskiq.brokers.nng.storage.InMemoryStore`. All + store operations are synchronous and execute directly on the asyncio event + loop without blocking (no I/O, no syscalls). - **Recovery** — On startup, tasks leased to workers that died during the - previous hub session are automatically requeued. + **Recovery** — On startup, any tasks that were leased before the hub last + stopped (within the same process lifetime) are automatically requeued by + :meth:`~InMemoryStore.recover_dead_workers`. """ def __init__(self, config: HubConfig) -> None: @@ -105,9 +100,8 @@ def __init__(self, config: HubConfig) -> None: "Install it with: pip install taskiq[nng]" ) self.config = config - self.store = SQLiteJournal( + self.store = InMemoryStore( StoreConfig( - path=config.task_db, max_pending=config.max_pending, lease_timeout=config.lease_timeout, backoff_cap=config.backoff_cap, @@ -117,16 +111,12 @@ def __init__(self, config: HubConfig) -> None: self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 self._tasks: list[asyncio.Task[None]] = [] - # Single-threaded executor: serialises all SQLite calls on one OS thread. - self._db_exec = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="nng-db" - ) # ── lifecycle ───────────────────────────────────────────────────────────── async def start(self) -> None: """Start the hub: recover orphaned tasks, open sockets, spawn loops.""" - await self._db(self.store.recover_dead_workers, self.config.heartbeat_timeout) + self.store.recover_dead_workers(self.config.heartbeat_timeout) self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr) self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms @@ -142,11 +132,7 @@ async def start(self) -> None: self._control_handler(ctx), name=f"hub-ctrl-{i}" ), ) - logger.info( - "NNG hub started on %s (db=%s)", - self.config.control_addr, - self.config.task_db, - ) + logger.info("NNG hub started on %s", self.config.control_addr) async def stop(self) -> None: """Gracefully stop all hub loops and close sockets.""" @@ -163,17 +149,8 @@ async def stop(self) -> None: if self._ctrl_sock is not None: with suppress(Exception): self._ctrl_sock.close() - self._db_exec.shutdown(wait=True) logger.info("NNG hub stopped") - # ── DB helper ───────────────────────────────────────────────────────────── - - async def _db(self, fn: Any, *args: Any, **kwargs: Any) -> Any: - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - self._db_exec, lambda: fn(*args, **kwargs) - ) - # ── control plane ───────────────────────────────────────────────────────── async def _control_handler(self, ctx: Any) -> None: @@ -216,19 +193,18 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return await self._handle_register(msg.payload) if msg.kind == "heartbeat": - await self._db(self.store.heartbeat, msg.payload["worker_id"]) + self.store.heartbeat(msg.payload["worker_id"]) return ControlResponse(ok=True, payload={"ok": True}) if msg.kind == "unregister": return await self._handle_unregister(msg.payload["worker_id"]) if msg.kind == "drain": - await self._db(self.store.mark_draining, msg.payload["worker_id"]) + self.store.mark_draining(msg.payload["worker_id"]) return ControlResponse(ok=True, payload={"draining": True}) if msg.kind == "ack": - ok = await self._db( - self.store.ack, + ok = self.store.ack( msg.payload["task_id"], msg.payload["worker_id"], msg.payload["lease_id"], @@ -236,8 +212,7 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return ControlResponse(ok=ok, payload={"acked": ok}) if msg.kind == "nack": - ok = await self._db( - self.store.nack, + ok = self.store.nack( msg.payload["task_id"], msg.payload["worker_id"], msg.payload["lease_id"], @@ -246,26 +221,25 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return ControlResponse(ok=ok, payload={"nacked": ok}) if msg.kind == "status": - task = await self._db(self.store.get_task, msg.payload["task_id"]) - return ControlResponse(ok=bool(task), payload=dict(task) if task else {}) + task = self.store.get_task(msg.payload["task_id"]) + return ControlResponse(ok=bool(task), payload=task or {}) if msg.kind == "stats": - s = await self._db(self.store.stats) - return ControlResponse(ok=True, payload=s) + return ControlResponse(ok=True, payload=self.store.stats()) return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}") async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse: envelope = TaskEnvelope(**payload) try: - await self._db(self.store.submit, envelope) + self.store.submit(envelope) return ControlResponse(ok=True, payload={"task_id": envelope.task_id}) except QueueFullError: return ControlResponse(ok=False, error="queue full") async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: worker = WorkerState(**payload) - await self._db(self.store.register_worker, worker) + self.store.register_worker(worker) if worker.worker_id not in self._worker_push: try: sock = pynng.Push0(dial=worker.task_addr) @@ -279,7 +253,7 @@ async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: return ControlResponse(ok=True, payload={"registered": True}) async def _handle_unregister(self, worker_id: str) -> ControlResponse: - await self._db(self.store.unregister_worker, worker_id) + self.store.unregister_worker(worker_id) sock = self._worker_push.pop(worker_id, None) if sock is not None: with suppress(Exception): @@ -302,13 +276,12 @@ async def _dispatch_loop(self) -> None: async def _dispatch_once(self) -> bool: """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" - due = await self._db(self.store.due_tasks, self.config.dispatch_batch) + due = self.store.due_tasks(self.config.dispatch_batch) if not due: return False sent_any = False for row in due: - worker = await self._db( - self.store.choose_worker, + worker = self.store.choose_worker( self.config.routing_policy, heartbeat_timeout=self.config.heartbeat_timeout, ) @@ -319,8 +292,7 @@ async def _dispatch_once(self) -> bool: lease_id = uuid.uuid4().hex lease_until = time.time() + self.config.lease_timeout - if not await self._db( - self.store.mark_leased, + if not self.store.mark_leased( row["task_id"], worker_id, lease_id, lease_until, ): continue # concurrent dispatch race; task already taken @@ -331,19 +303,14 @@ async def _dispatch_once(self) -> bool: "No push socket for worker %s, requeueing %s", worker_id, row["task_id"], ) - await self._db( - self.store.nack, - row["task_id"], worker_id, lease_id, "no socket", - ) + self.store.nack(row["task_id"], worker_id, lease_id, "no socket") continue - # Include the hub-generated lease_id so the worker can ack with - # the exact token. Omitting it was the core correctness bug in v2. envelope = TaskEnvelope( task_id=row["task_id"], task_name=row["task_name"], payload_b64=base64.b64encode(row["payload"]).decode("ascii"), - labels=json.loads(row["labels_json"]), + labels=row["labels"], lease_id=lease_id, attempts=int(row["attempts"]) + 1, max_retries=int(row["max_retries"]), @@ -360,8 +327,7 @@ async def _dispatch_once(self) -> bool: "Failed to deliver %s to worker %s: %s", row["task_id"], worker_id, exc, ) - await self._db( - self.store.nack, + self.store.nack( row["task_id"], worker_id, lease_id, f"dispatch send failed: {exc}", ) @@ -373,11 +339,10 @@ async def _reaper_loop(self) -> None: while not self._stop.is_set(): try: await asyncio.sleep(self.config.reaper_interval) - reaped = await self._db(self.store.reap_expired_leases) + reaped = self.store.reap_expired_leases() if reaped: logger.debug("Reaped %d expired leases", reaped) - recovered = await self._db( - self.store.recover_dead_workers, + recovered = self.store.recover_dead_workers( self.config.heartbeat_timeout, ) if recovered: @@ -400,11 +365,6 @@ def _build_config() -> HubConfig: default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"), help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR", ) - p.add_argument( - "--task-db", - default=os.getenv("NNG_TASK_DB", "/tmp/taskiq-nng-tasks.db"), # noqa: S108 - help="Path to the SQLite WAL task journal. Env: NNG_TASK_DB", - ) p.add_argument( "--max-pending", type=int, @@ -445,7 +405,6 @@ def _build_config() -> HubConfig: ) return HubConfig( control_addr=args.control_addr, - task_db=args.task_db, max_pending=args.max_pending, heartbeat_timeout=args.heartbeat_timeout, lease_timeout=args.lease_timeout, diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 410a7064..87970adf 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -1,24 +1,20 @@ -"""Durable WAL-mode SQLite task journal for the NNG hub.""" +"""Pure in-memory task store for the NNG hub — no external dependencies.""" from __future__ import annotations -import json import random -import sqlite3 -import threading import time -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Generator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any -from protocol import TaskEnvelope, WorkerState, WorkerStatus +if TYPE_CHECKING: + from .protocol import TaskEnvelope, WorkerState @dataclass class StoreConfig: - """Configuration for the SQLite task journal.""" + """Configuration for :class:`InMemoryStore`.""" - path: str + path: str = "" # kept for API compat; not used max_pending: int = 10_000 lease_timeout: float = 30.0 backoff_base: float = 1.0 @@ -29,203 +25,172 @@ class QueueFullError(RuntimeError): """Raised when a submission is attempted on a full queue.""" -class SQLiteJournal: +@dataclass +class _Task: + task_id: str + task_name: str + payload: bytes + labels: dict[str, Any] + state: str # ready / leased / done / failed + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + next_run_at: float = field(default_factory=time.time) + lease_id: str | None = None + leased_worker_id: str | None = None + lease_until: float | None = None + last_error: str | None = None + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this task record.""" + return { + "task_id": self.task_id, + "task_name": self.task_name, + "payload": self.payload, + "labels": self.labels, + "state": self.state, + "attempts": self.attempts, + "max_retries": self.max_retries, + "retry_backoff": self.retry_backoff, + "retry_jitter": self.retry_jitter, + "priority": self.priority, + "created_at": self.created_at, + "updated_at": self.updated_at, + "next_run_at": self.next_run_at, + "lease_id": self.lease_id, + "leased_worker_id": self.leased_worker_id, + "lease_until": self.lease_until, + "last_error": self.last_error, + } + + def as_status_dict(self) -> dict[str, Any]: + """Return a JSON-safe dict (no raw bytes) for control-plane status responses.""" + d = self.as_dict() + d.pop("payload", None) + return d + + +@dataclass +class _Worker: + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: str = "starting" + version: str = "unknown" + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this worker record.""" + return { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": self.inflight, + "last_seen": self.last_seen, + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": self.draining, + "status": self.status, + "version": self.version, + } + + +class InMemoryStore: """ - Thread-safe, WAL-mode SQLite task store. - - Design notes - ──────────── - * Every method opens and closes its own connection. WAL allows concurrent - readers without blocking; SQLite serialises concurrent writers internally, - and the Python-level ``_submit_lock`` prevents the TOCTOU race in - :meth:`submit`. - * The hub runs every call through a single-threaded - ``ThreadPoolExecutor`` so, in practice, writes never contend at the - OS level either. - * ``PRAGMA`` settings (WAL, synchronous, busy_timeout) are applied per - connection because each ``sqlite3.connect()`` call starts with defaults. + Pure in-memory task store for the NNG hub. + + All methods are synchronous and safe to call from a single asyncio event + loop — asyncio's cooperative scheduling makes them effectively atomic (no + ``await`` between reads and writes). + + State is lost when the process exits. For persistent task queues use a + dedicated result backend; the NNG broker is designed for low-latency + in-process delivery, not durable storage. """ def __init__(self, config: StoreConfig) -> None: - """Initialise the journal and create schema if not present.""" + """Initialise an empty store with the given configuration.""" self.config = config - # Guards only the pending_count check + INSERT pair in submit() to - # prevent concurrent callers from racing past max_pending. - self._submit_lock = threading.Lock() - self._init() - - # ── connection ──────────────────────────────────────────────────────────── - - @contextmanager - def _conn(self) -> Generator[sqlite3.Connection, None, None]: - conn = sqlite3.connect( - self.config.path, - timeout=10.0, - check_same_thread=False, - isolation_level=None, # we manage transactions explicitly - ) - conn.row_factory = sqlite3.Row - # Must be set per-connection, not just once at schema creation. - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") # safe with WAL; faster than FULL - conn.execute("PRAGMA busy_timeout=5000") # wait up to 5s before SQLITE_BUSY - conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache - try: - yield conn - finally: - conn.close() - - def _init(self) -> None: - Path(self.config.path).parent.mkdir(parents=True, exist_ok=True) - with self._conn() as conn: - conn.executescript(""" - CREATE TABLE IF NOT EXISTS tasks ( - task_id TEXT PRIMARY KEY, - task_name TEXT NOT NULL, - payload BLOB NOT NULL, - labels_json TEXT NOT NULL DEFAULT '{}', - state TEXT NOT NULL, - attempts INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 0, - retry_backoff REAL NOT NULL DEFAULT 1.0, - retry_jitter REAL NOT NULL DEFAULT 0.0, - priority INTEGER NOT NULL DEFAULT 0, - created_at REAL NOT NULL, - updated_at REAL NOT NULL, - next_run_at REAL NOT NULL, - lease_id TEXT, - leased_worker_id TEXT, - lease_until REAL, - last_error TEXT - ); - - CREATE TABLE IF NOT EXISTS workers ( - worker_id TEXT PRIMARY KEY, - task_addr TEXT NOT NULL, - capacity INTEGER NOT NULL, - inflight INTEGER NOT NULL DEFAULT 0, - last_seen REAL NOT NULL DEFAULT 0, - heartbeat_interval REAL NOT NULL DEFAULT 5.0, - lease_timeout REAL NOT NULL DEFAULT 15.0, - draining INTEGER NOT NULL DEFAULT 0, - status TEXT NOT NULL, - version TEXT NOT NULL DEFAULT 'unknown' - ); - - CREATE TABLE IF NOT EXISTS journal ( - seq INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL NOT NULL, - kind TEXT NOT NULL, - payload_json TEXT NOT NULL - ); - - CREATE INDEX IF NOT EXISTS idx_tasks_dispatch - ON tasks (state, next_run_at, priority DESC); - CREATE INDEX IF NOT EXISTS idx_tasks_lease - ON tasks (state, lease_until); - CREATE INDEX IF NOT EXISTS idx_workers_active - ON workers (status, draining, last_seen); - """) + self._tasks: dict[str, _Task] = {} + self._workers: dict[str, _Worker] = {} # ── helpers ─────────────────────────────────────────────────────────────── - def _journal( - self, - conn: sqlite3.Connection, - kind: str, - payload: dict[str, Any], - ) -> None: - conn.execute( - "INSERT INTO journal (ts, kind, payload_json) VALUES (?, ?, ?)", - ( - time.time(), - kind, - json.dumps(payload, separators=(",", ":"), ensure_ascii=False), - ), - ) - def _backoff(self, attempts: int, backoff_base: float) -> float: return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + def _requeue_or_fail(self, task: _Task, worker_id: str, error: str) -> bool: + now = time.time() + if task.attempts > task.max_retries: + task.state = "failed" + else: + task.state = "ready" + task.next_run_at = now + self._backoff(task.attempts, task.retry_backoff) + task.last_error = error + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True + # ── task lifecycle ──────────────────────────────────────────────────────── def pending_count(self) -> int: - """Return the number of ready + leased tasks.""" - with self._conn() as conn: - return int( - conn.execute( - "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", - ).fetchone()[0], - ) + """Return the count of ready and leased tasks.""" + return sum(1 for t in self._tasks.values() if t.state in ("ready", "leased")) def submit(self, envelope: TaskEnvelope) -> None: """ - Persist a new task in 'ready' state. + Accept a new task into the store. :param envelope: task envelope to store. :raises QueueFullError: when ``max_pending`` is reached. """ + if self.pending_count() >= self.config.max_pending: + raise QueueFullError("Task queue is full.") now = time.time() - with self._submit_lock, self._conn() as conn: - count = conn.execute( - "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", - ).fetchone()[0] - if count >= self.config.max_pending: - raise QueueFullError("Task queue is full.") - conn.execute("BEGIN") - conn.execute( - """ - INSERT INTO tasks ( - task_id, task_name, payload, labels_json, state, - attempts, max_retries, retry_backoff, retry_jitter, - priority, created_at, updated_at, next_run_at - ) VALUES (?, ?, ?, ?, 'ready', 0, ?, ?, ?, ?, ?, ?, ?) - """, - ( - envelope.task_id, - envelope.task_name, - envelope.payload, - json.dumps( - envelope.labels, separators=(",", ":"), ensure_ascii=False - ), - envelope.max_retries, - envelope.retry_backoff, - envelope.retry_jitter, - envelope.priority, - envelope.created_at or now, - now, - now, - ), - ) - self._journal( - conn, - "task_submitted", - {"task_id": envelope.task_id, "task_name": envelope.task_name}, - ) - conn.execute("COMMIT") + self._tasks[envelope.task_id] = _Task( + task_id=envelope.task_id, + task_name=envelope.task_name, + payload=envelope.payload, + labels=envelope.labels, + state="ready", + max_retries=envelope.max_retries, + retry_backoff=envelope.retry_backoff, + retry_jitter=envelope.retry_jitter, + priority=envelope.priority, + created_at=envelope.created_at or now, + updated_at=now, + next_run_at=now, + ) - def due_tasks(self, limit: int = 50) -> list[sqlite3.Row]: + def due_tasks(self, limit: int = 50) -> list[dict[str, Any]]: """ Return ready tasks whose ``next_run_at`` is in the past. Results are ordered by priority (descending) then creation time. :param limit: maximum number of rows to return. - :return: list of task rows. + :return: list of task dicts. """ now = time.time() - with self._conn() as conn: - return list( - conn.execute( - """ - SELECT * FROM tasks - WHERE state = 'ready' AND next_run_at <= ? - ORDER BY priority DESC, created_at ASC - LIMIT ? - """, - (now, limit), - ), - ) + ready = [ + t for t in self._tasks.values() + if t.state == "ready" and t.next_run_at <= now + ] + ready.sort(key=lambda t: (-t.priority, t.created_at)) + return [t.as_dict() for t in ready[:limit]] def mark_leased( self, @@ -241,90 +206,50 @@ def mark_leased( :param worker_id: worker receiving the task. :param lease_id: unique token for this dispatch attempt. :param lease_until: absolute epoch deadline for the lease. - :return: True if the transition succeeded; False if the task was - already taken (concurrent dispatch race). + :return: True on success; False if the task is not in 'ready' state. """ + task = self._tasks.get(task_id) + if task is None or task.state != "ready": + return False now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT state FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() - if not row or row["state"] != "ready": - return False - conn.execute("BEGIN") - conn.execute( - """ - UPDATE tasks - SET state = 'leased', - leased_worker_id = ?, lease_id = ?, lease_until = ?, - attempts = attempts + 1, updated_at = ? - WHERE task_id = ? - """, - (worker_id, lease_id, lease_until, now, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = inflight + 1 WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_leased", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - }, - ) - conn.execute("COMMIT") - return True + task.state = "leased" + task.leased_worker_id = worker_id + task.lease_id = lease_id + task.lease_until = lease_until + task.attempts += 1 + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight += 1 + return True def ack(self, task_id: str, worker_id: str, lease_id: str) -> bool: """ Mark a task as successfully completed. Late or duplicate acks (mismatched ``lease_id`` or state ≠ 'leased') - are silently rejected and return False. + are silently rejected. :param task_id: task being acknowledged. :param worker_id: worker sending the ack. - :param lease_id: lease token that was issued at dispatch. + :param lease_id: lease token issued at dispatch. :return: True if the ack was accepted. """ + task = self._tasks.get(task_id) + if task is None or task.state != "leased": + return False + if task.lease_id != lease_id or task.leased_worker_id != worker_id: + return False now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT state, lease_id, leased_worker_id FROM tasks WHERE task_id = ?", - (task_id,), - ).fetchone() - if not row or row["state"] != "leased": - return False - if row["lease_id"] != lease_id or row["leased_worker_id"] != worker_id: - return False - conn.execute("BEGIN") - conn.execute( - """ - UPDATE tasks - SET state = 'done', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL - WHERE task_id = ? - """, - (now, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_acked", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - }, - ) - conn.execute("COMMIT") - return True + task.state = "done" + task.updated_at = now + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True def nack( self, task_id: str, worker_id: str, lease_id: str, error: str @@ -335,274 +260,106 @@ def nack( :param task_id: task being nacked. :param worker_id: worker sending the nack. :param lease_id: lease token issued at dispatch. - :param error: human-readable reason for the failure. + :param error: human-readable failure reason. :return: True if the nack was accepted. """ - return self._requeue_or_fail(task_id, worker_id, lease_id, error) - - def _requeue_or_fail( - self, task_id: str, worker_id: str, lease_id: str, error: str - ) -> bool: - now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT * FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() - if ( - not row - or row["state"] != "leased" - or row["lease_id"] != lease_id - or row["leased_worker_id"] != worker_id - ): - return False - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - conn.execute("BEGIN") - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = ? - WHERE task_id = ? - """, - (now, error, task_id), - ) - else: - backoff = self._backoff(attempts, float(row["retry_backoff"])) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = ? - WHERE task_id = ? - """, - (now, now + backoff, error, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_nacked", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - "error": error, - "requeued": attempts <= max_retries, - }, - ) - conn.execute("COMMIT") - return True + task = self._tasks.get(task_id) + if ( + task is None + or task.state != "leased" + or task.lease_id != lease_id + or task.leased_worker_id != worker_id + ): + return False + return self._requeue_or_fail(task, worker_id, error) # ── reaper / recovery ───────────────────────────────────────────────────── def reap_expired_leases(self) -> int: """ - Find leases past their deadline and requeue or permanently fail them. + Requeue or permanently fail tasks whose lease deadline has passed. :return: number of tasks reaped. """ now = time.time() - with self._conn() as conn: - expired = list( - conn.execute( - """ - SELECT * FROM tasks - WHERE state = 'leased' - AND lease_until IS NOT NULL - AND lease_until < ? - """, - (now,), - ), - ) - if not expired: - return 0 - conn.execute("BEGIN") - count = 0 - for row in expired: - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - worker_id = row["leased_worker_id"] - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'lease expired' - WHERE task_id = ? - """, - (now, row["task_id"]), - ) - else: - backoff = self._backoff(attempts, float(row["retry_backoff"])) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'lease expired' - WHERE task_id = ? - """, - (now, now + backoff, row["task_id"]), - ) - if worker_id: - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "lease_reaped", - { - "task_id": row["task_id"], - "worker_id": worker_id, - "lease_id": row["lease_id"], - }, - ) - count += 1 - conn.execute("COMMIT") - return count + expired = [ + t for t in self._tasks.values() + if t.state == "leased" + and t.lease_until is not None + and t.lease_until < now + ] + for task in expired: + self._requeue_or_fail(task, task.leased_worker_id or "", "lease expired") + return len(expired) def recover_dead_workers(self, heartbeat_timeout: float) -> int: """ - Mark workers that missed their heartbeat deadline as DEAD. - - All tasks leased to dead workers are requeued (or permanently failed - if retries are exhausted). + Mark workers that missed their heartbeat deadline as dead and requeue their tasks. :param heartbeat_timeout: seconds of silence before a worker is dead. :return: number of tasks requeued. """ - now = time.time() - cutoff = now - heartbeat_timeout - with self._conn() as conn: - dead = list( - conn.execute( - "SELECT * FROM workers WHERE last_seen < ? AND status != 'dead'", - (cutoff,), - ), - ) - if not dead: - return 0 - conn.execute("BEGIN") - requeued = 0 - for worker in dead: - worker_id = worker["worker_id"] - conn.execute( - "UPDATE workers SET status = 'dead', draining = 1 WHERE worker_id = ?", - (worker_id,), - ) - leased = list( - conn.execute( - "SELECT * FROM tasks WHERE state = 'leased' AND leased_worker_id = ?", - (worker_id,), - ), - ) - for row in leased: - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'worker died' - WHERE task_id = ? - """, - (now, row["task_id"]), - ) - else: - backoff = self._backoff( - attempts, float(row["retry_backoff"]) - ) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'worker died' - WHERE task_id = ? - """, - (now, now + backoff, row["task_id"]), - ) - self._journal( - conn, - "worker_dead_requeue", - {"worker_id": worker_id, "task_id": row["task_id"]}, - ) - requeued += 1 - conn.execute("COMMIT") - return requeued + cutoff = time.time() - heartbeat_timeout + dead = [ + w for w in self._workers.values() + if w.last_seen < cutoff and w.status != "dead" + ] + requeued = 0 + for worker in dead: + worker.status = "dead" + worker.draining = True + leased = [ + t for t in self._tasks.values() + if t.state == "leased" and t.leased_worker_id == worker.worker_id + ] + for task in leased: + self._requeue_or_fail(task, worker.worker_id, "worker died") + requeued += 1 + return requeued # ── worker lifecycle ────────────────────────────────────────────────────── def register_worker(self, worker: WorkerState) -> None: """ - Upsert a worker record. - - Re-registering an existing worker (e.g. after hub restart) resets - its draining flag and updates its metadata. + Upsert a worker record, resetting drain state on re-registration. :param worker: worker state snapshot from the registration message. """ now = time.time() - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - """ - INSERT INTO workers ( - worker_id, task_addr, capacity, inflight, last_seen, - heartbeat_interval, lease_timeout, draining, status, version - ) VALUES (?, ?, ?, 0, ?, ?, ?, 0, ?, ?) - ON CONFLICT(worker_id) DO UPDATE SET - task_addr = excluded.task_addr, - capacity = excluded.capacity, - last_seen = excluded.last_seen, - heartbeat_interval = excluded.heartbeat_interval, - lease_timeout = excluded.lease_timeout, - draining = 0, - status = excluded.status, - version = excluded.version - """, - ( - worker.worker_id, - worker.task_addr, - worker.capacity, - now, - worker.heartbeat_interval, - worker.lease_timeout, - str(WorkerStatus.LISTENING), - worker.version, - ), - ) - self._journal( - conn, - "worker_register", - {"worker_id": worker.worker_id, "task_addr": worker.task_addr}, + existing = self._workers.get(worker.worker_id) + if existing is not None: + existing.task_addr = worker.task_addr + existing.capacity = worker.capacity + existing.last_seen = now + existing.heartbeat_interval = worker.heartbeat_interval + existing.lease_timeout = worker.lease_timeout + existing.draining = False + existing.status = "listening" + existing.version = worker.version + else: + self._workers[worker.worker_id] = _Worker( + worker_id=worker.worker_id, + task_addr=worker.task_addr, + capacity=worker.capacity, + inflight=0, + last_seen=now, + heartbeat_interval=worker.heartbeat_interval, + lease_timeout=worker.lease_timeout, + draining=False, + status="listening", + version=worker.version, ) - conn.execute("COMMIT") def heartbeat(self, worker_id: str) -> None: """ - Record a heartbeat from a worker, resetting its last_seen timestamp. + Record a heartbeat, resetting the worker's last_seen timestamp. :param worker_id: ID of the worker sending the heartbeat. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - "UPDATE workers SET last_seen = ?, status = ? WHERE worker_id = ?", - (time.time(), str(WorkerStatus.LISTENING), worker_id), - ) - self._journal(conn, "heartbeat", {"worker_id": worker_id}) - conn.execute("COMMIT") + worker = self._workers.get(worker_id) + if worker is not None: + worker.last_seen = time.time() + worker.status = "listening" def unregister_worker(self, worker_id: str) -> None: """ @@ -610,11 +367,7 @@ def unregister_worker(self, worker_id: str) -> None: :param worker_id: ID of the worker unregistering. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) - self._journal(conn, "worker_unregister", {"worker_id": worker_id}) - conn.execute("COMMIT") + self._workers.pop(worker_id, None) def mark_draining(self, worker_id: str) -> None: """ @@ -622,14 +375,10 @@ def mark_draining(self, worker_id: str) -> None: :param worker_id: ID of the worker entering drain mode. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - "UPDATE workers SET draining = 1, status = ? WHERE worker_id = ?", - (str(WorkerStatus.DRAINING), worker_id), - ) - self._journal(conn, "worker_drain", {"worker_id": worker_id}) - conn.execute("COMMIT") + worker = self._workers.get(worker_id) + if worker is not None: + worker.draining = True + worker.status = "draining" # ── routing ─────────────────────────────────────────────────────────────── @@ -638,85 +387,70 @@ def choose_worker( routing_policy: str = "least_loaded", *, heartbeat_timeout: float = 15.0, - ) -> sqlite3.Row | None: + ) -> dict[str, Any] | None: """ Select the best available worker according to ``routing_policy``. - ``'least_loaded'`` picks the worker with the lowest ``inflight / - capacity`` ratio. ``'p2c'`` (Power-of-Two-Choices) samples two - workers at random and picks the less loaded one, reducing hot-spot - probability under high concurrency. + ``'least_loaded'`` picks the worker with the lowest inflight/capacity + ratio. ``'p2c'`` samples two workers and picks the less loaded one. :param routing_policy: ``'least_loaded'`` or ``'p2c'``. :param heartbeat_timeout: seconds before a worker is considered stale. - :return: chosen worker row, or None if no worker has capacity. + :return: chosen worker dict, or None if no worker has capacity. """ cutoff = time.time() - heartbeat_timeout - with self._conn() as conn: - rows = list( - conn.execute( - """ - SELECT * FROM workers - WHERE status IN ('starting', 'listening') - AND draining = 0 - AND last_seen >= ? - """, - (cutoff,), - ), - ) available = [ - w for w in rows if int(w["inflight"]) < int(w["capacity"]) + w for w in self._workers.values() + if w.status in ("starting", "listening") + and not w.draining + and w.last_seen >= cutoff + and w.inflight < w.capacity ] if not available: return None if routing_policy == "p2c" and len(available) >= 2: a, b = random.sample(available, k=2) # noqa: S311 - load_a = int(a["inflight"]) / max(int(a["capacity"]), 1) - load_b = int(b["inflight"]) / max(int(b["capacity"]), 1) - return a if load_a <= load_b else b - return min( - available, - key=lambda w: int(w["inflight"]) / max(int(w["capacity"]), 1), - ) + load_a = a.inflight / max(a.capacity, 1) + load_b = b.inflight / max(b.capacity, 1) + chosen = a if load_a <= load_b else b + else: + chosen = min(available, key=lambda w: w.inflight / max(w.capacity, 1)) + return chosen.as_dict() - # ── management / observability ──────────────────────────────────────────── + # ── observability ───────────────────────────────────────────────────────── - def get_task(self, task_id: str) -> sqlite3.Row | None: + def get_task(self, task_id: str) -> dict[str, Any] | None: """ - Fetch a single task row by ID. + Fetch task status by ID (no raw bytes in result). :param task_id: ID of the task to look up. - :return: row or None if not found. + :return: status dict or None if not found. """ - with self._conn() as conn: - return conn.execute( - "SELECT * FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() + task = self._tasks.get(task_id) + return task.as_status_dict() if task is not None else None - def list_workers(self) -> list[sqlite3.Row]: + def list_workers(self) -> list[dict[str, Any]]: """Return all registered workers ordered by most-recently-seen.""" - with self._conn() as conn: - return list( - conn.execute("SELECT * FROM workers ORDER BY last_seen DESC"), + return [ + w.as_dict() + for w in sorted( + self._workers.values(), key=lambda w: w.last_seen, reverse=True ) + ] def stats(self) -> dict[str, int]: """Return a summary dict with task state counts and active worker count.""" - with self._conn() as conn: - rows = conn.execute( - "SELECT state, COUNT(*) AS n FROM tasks GROUP BY state", - ).fetchall() - counts = {r["state"]: r["n"] for r in rows} - worker_count = conn.execute( - """ - SELECT COUNT(*) FROM workers - WHERE status IN ('starting', 'listening') AND draining = 0 - """, - ).fetchone()[0] + counts: dict[str, int] = {} + for t in self._tasks.values(): + counts[t.state] = counts.get(t.state, 0) + 1 + active = sum( + 1 for w in self._workers.values() + if w.status in ("starting", "listening") and not w.draining + ) return { "ready": counts.get("ready", 0), "leased": counts.get("leased", 0), "done": counts.get("done", 0), "failed": counts.get("failed", 0), - "active_workers": worker_count, + "active_workers": active, } diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index f128e757..4bb9c4b6 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -4,8 +4,8 @@ The test suite is split into three layers: 1. **Protocol** — pure serialisation roundtrips; no NNG sockets needed. -2. **Storage** — SQLiteJournal unit tests; no NNG sockets needed. -3. **Integration** — real NNG sockets, real SQLite, single asyncio event loop. +2. **Storage** — InMemoryStore unit tests; no NNG sockets needed. +3. **Integration** — real NNG sockets, single asyncio event loop. Uses ``FakeWorker`` / ``FakeClient`` helpers that speak the wire protocol directly so we can inject faults precisely (crash before ack, late ack, etc.). @@ -15,7 +15,6 @@ import asyncio import os -import sqlite3 import tempfile import time import uuid @@ -34,7 +33,7 @@ WorkerState, WorkerStatus, QueueFullError, - SQLiteJournal, + InMemoryStore, StoreConfig, ) @@ -253,24 +252,24 @@ def test_task_envelope_payload_decode() -> None: @pytest.fixture -def store(db_path: str) -> SQLiteJournal: - return SQLiteJournal(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) +def store(db_path: str) -> InMemoryStore: + return InMemoryStore(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) -def test_submit_and_pending(store: SQLiteJournal) -> None: +def test_submit_and_pending(store: InMemoryStore) -> None: store.submit(_envelope()) assert store.pending_count() == 1 def test_submit_queue_full(db_path: str) -> None: - s = SQLiteJournal(StoreConfig(path=db_path, max_pending=2)) + s = InMemoryStore(StoreConfig(path=db_path, max_pending=2)) s.submit(_envelope()) s.submit(_envelope()) with pytest.raises(QueueFullError): s.submit(_envelope()) -def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: +def test_due_tasks_ordered_by_priority(store: InMemoryStore) -> None: store.submit(_envelope(task_id="lo", priority=0)) store.submit(_envelope(task_id="hi", priority=10)) due = store.due_tasks(limit=10) @@ -278,7 +277,7 @@ def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: assert due[1]["task_id"] == "lo" -def test_ack_happy_path(store: SQLiteJournal) -> None: +def test_ack_happy_path(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -288,7 +287,7 @@ def test_ack_happy_path(store: SQLiteJournal) -> None: assert store.get_task(env.task_id)["state"] == "done" -def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: +def test_ack_wrong_lease_rejected(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -297,7 +296,7 @@ def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: assert not store.ack(env.task_id, w.worker_id, "wrong") -def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: +def test_late_ack_after_requeue_ignored(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -307,7 +306,7 @@ def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: assert not store.ack(env.task_id, w.worker_id, "L2") -def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: +def test_nack_requeues_with_backoff(store: InMemoryStore) -> None: env = _envelope(max_retries=2, retry_backoff=1.0) store.submit(env) w = _worker_state() @@ -319,7 +318,7 @@ def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: assert float(task["next_run_at"]) > time.time() -def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: +def test_nack_exceeds_retries_fails(store: InMemoryStore) -> None: env = _envelope(max_retries=0) store.submit(env) w = _worker_state() @@ -329,35 +328,29 @@ def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: assert store.get_task(env.task_id)["state"] == "failed" -def test_dead_worker_tasks_requeued(store: SQLiteJournal, db_path: str) -> None: +def test_dead_worker_tasks_requeued(store: InMemoryStore) -> None: w = _worker_state() store.register_worker(w) env = _envelope(max_retries=3) store.submit(env) store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) - conn = sqlite3.connect(db_path) - conn.execute("UPDATE workers SET last_seen=0 WHERE worker_id=?", (w.worker_id,)) - conn.commit() - conn.close() + store._workers[w.worker_id].last_seen = 0 # simulate missed heartbeats assert store.recover_dead_workers(heartbeat_timeout=1.0) == 1 assert store.get_task(env.task_id)["state"] == "ready" -def test_choose_worker_least_loaded(store: SQLiteJournal, db_path: str) -> None: +def test_choose_worker_least_loaded(store: InMemoryStore) -> None: w1 = _worker_state(worker_id="w1", capacity=4) w2 = _worker_state(worker_id="w2", capacity=4) store.register_worker(w1) store.register_worker(w2) - conn = sqlite3.connect(db_path) - conn.execute("UPDATE workers SET inflight=3 WHERE worker_id='w1'") - conn.commit() - conn.close() + store._workers["w1"].inflight = 3 # w1 heavily loaded chosen = store.choose_worker("least_loaded", heartbeat_timeout=30.0) assert chosen is not None assert chosen["worker_id"] == "w2" -def test_stats(store: SQLiteJournal) -> None: +def test_stats(store: InMemoryStore) -> None: w = _worker_state() store.register_worker(w) store.submit(_envelope()) @@ -507,36 +500,13 @@ async def test_late_ack_after_requeue_rejected( await hub.stop() +@pytest.mark.skip( + reason="In-memory store has no persistence; restart recovery requires a durable backend." +) async def test_hub_restart_recovers_orphaned_tasks( ctrl_addr: str, db_path: str ) -> None: - """ - Tasks leased at hub shutdown must be requeued when a new hub starts - with the same database. - """ - hub1 = _hub(ctrl_addr, db_path) - await hub1.start() - w1 = FakeWorker(ctrl_addr, capacity=1) - client = FakeClient(ctrl_addr) - await w1.register() - tid = await client.submit(max_retries=3) - env = await w1.recv_task(timeout=3.0) - assert env.task_id == tid - # "kill" hub1 without giving worker a chance to ack - await hub1.stop() - w1.close() - client.close() - - # Task is still leased in the DB - assert hub1.store.get_task(tid)["state"] == "leased" - - hub2 = _hub(ctrl_addr, db_path) - await hub2.start() - await asyncio.sleep(0.3) # allow startup recovery - try: - assert hub2.store.get_task(tid)["state"] == "ready" - finally: - await hub2.stop() + """Persistence across restarts is not supported by the in-memory store.""" async def test_concurrent_heartbeats(ctrl_addr: str, db_path: str) -> None: From b8c72127c53918ef63120d3755b1c9148695c971 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 18:37:06 +0300 Subject: [PATCH 3/4] Refactoring the NNG support solution v2: Update routing policy --- taskiq/brokers/nng/__init__.py | 21 ++++- taskiq/brokers/nng/hub.py | 17 +++- taskiq/brokers/nng/storage.py | 138 +++++++++++++++++++++++++++--- tests/brokers/test_nng_broker.py | 141 ++++++++++++++++++++++++++++++- 4 files changed, 295 insertions(+), 22 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 8e2a7f4a..75ba0d95 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -8,18 +8,37 @@ WorkerState, WorkerStatus, ) -from .storage import InMemoryStore, QueueFullError, StoreConfig +from .storage import ( + InMemoryStore, + LeastLoaded, + PowerOfTwoChoices, + QueueFullError, + RoutingPolicy, + RoundRobin, + StoreConfig, + WorkerView, + make_routing_policy, +) __all__ = [ "HubConfig", "NNGHub", + # protocol "ControlMessage", "ControlResponse", "MessageKind", "TaskEnvelope", "WorkerState", "WorkerStatus", + # store "QueueFullError", "InMemoryStore", "StoreConfig", + # routing + "WorkerView", + "RoutingPolicy", + "LeastLoaded", + "PowerOfTwoChoices", + "RoundRobin", + "make_routing_policy", ] diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index c58857bf..37a935b9 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -37,7 +37,13 @@ TaskEnvelope, WorkerState, ) -from .storage import InMemoryStore, QueueFullError, StoreConfig +from .storage import ( + InMemoryStore, + QueueFullError, + RoutingPolicy, + StoreConfig, + make_routing_policy, +) logger = logging.getLogger(__name__) @@ -53,7 +59,7 @@ class HubConfig: lease_timeout: float = 20.0 dispatch_interval: float = 0.05 reaper_interval: float = 0.5 - routing_policy: str = "least_loaded" + routing_policy: RoutingPolicy | str = "least_loaded" backoff_cap: float = 60.0 # Number of concurrent Rep0 contexts. Each context handles one req/rep # pair independently; N contexts ≈ N simultaneous control-plane clients. @@ -107,6 +113,9 @@ def __init__(self, config: HubConfig) -> None: backoff_cap=config.backoff_cap, ), ) + # Resolve once at construction so RoundRobin and similar stateful + # policies maintain their counter across dispatch calls. + self._routing: RoutingPolicy = make_routing_policy(config.routing_policy) self._stop = asyncio.Event() self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 @@ -282,7 +291,7 @@ async def _dispatch_once(self) -> bool: sent_any = False for row in due: worker = self.store.choose_worker( - self.config.routing_policy, + self._routing, heartbeat_timeout=self.config.heartbeat_timeout, ) if worker is None: @@ -384,7 +393,7 @@ def _build_config() -> HubConfig: ) p.add_argument( "--routing-policy", - choices=["least_loaded", "p2c"], + choices=["least_loaded", "p2c", "round_robin"], default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"), ) p.add_argument( diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 87970adf..81ab1899 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -4,7 +4,7 @@ import random import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from .protocol import TaskEnvelope, WorkerState @@ -103,6 +103,113 @@ def as_dict(self) -> dict[str, Any]: } +# ── routing policy abstraction ──────────────────────────────────────────────── + + +@dataclass(frozen=True) +class WorkerView: + """Immutable worker snapshot passed to :class:`RoutingPolicy` implementations.""" + + worker_id: str + inflight: int + capacity: int + + @property + def load(self) -> float: + """Fractional load: 0.0 idle → 1.0 at capacity.""" + return self.inflight / max(self.capacity, 1) + + +@runtime_checkable +class RoutingPolicy(Protocol): + """Strategy interface for selecting a dispatch target from available workers.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the chosen worker, or None to hold off dispatch.""" + ... + + +class LeastLoaded: + """Pick the worker with the lowest inflight / capacity ratio.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the least-loaded worker.""" + if not workers: + return None + return min(workers, key=lambda w: w.load) + + +class PowerOfTwoChoices: + """ + Power-of-two-choices routing. + + Samples two workers uniformly at random and returns the less loaded one. + Reduces hot-spot probability under high concurrency compared to pure random. + """ + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the less loaded of two randomly sampled workers.""" + if not workers: + return None + if len(workers) == 1: + return workers[0] + a, b = random.sample(workers, k=2) # noqa: S311 + return a if a.load <= b.load else b + + +class RoundRobin: + """ + Round-robin routing — cycle through workers in alphabetical ID order. + + Ignores load; useful when tasks are homogeneous and worker capacity is equal. + The counter is per-instance, so each :class:`NNGHub` maintains its own cycle. + """ + + def __init__(self) -> None: + """Initialise the cycle counter.""" + self._idx: int = 0 + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the next worker in the cycle.""" + if not workers: + return None + w = workers[self._idx % len(workers)] + self._idx += 1 + return w + + +# Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub +# processes. Users needing isolated counters should pass their own instance. +_BUILTIN_POLICIES: dict[str, RoutingPolicy] = { + "least_loaded": LeastLoaded(), + "p2c": PowerOfTwoChoices(), + "round_robin": RoundRobin(), +} + + +def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy: + """ + Resolve a routing policy name or pass through an instance. + + :param policy: ``'least_loaded'``, ``'p2c'``, ``'round_robin'``, or a + :class:`RoutingPolicy` instance. + :return: concrete routing policy. + :raises ValueError: for unknown string names. + """ + if isinstance(policy, str): + resolved = _BUILTIN_POLICIES.get(policy) + if resolved is None: + raise ValueError( + f"Unknown routing policy {policy!r}; " + f"available: {sorted(_BUILTIN_POLICIES)}" + ) + return resolved + return policy + + +# ── store ───────────────────────────────────────────────────────────────────── + + class InMemoryStore: """ Pure in-memory task store for the NNG hub. @@ -384,17 +491,17 @@ def mark_draining(self, worker_id: str) -> None: def choose_worker( self, - routing_policy: str = "least_loaded", + policy: "RoutingPolicy | str" = "least_loaded", *, heartbeat_timeout: float = 15.0, ) -> dict[str, Any] | None: """ - Select the best available worker according to ``routing_policy``. + Select the best available worker using a routing policy. - ``'least_loaded'`` picks the worker with the lowest inflight/capacity - ratio. ``'p2c'`` samples two workers and picks the less loaded one. + Accepts a :class:`RoutingPolicy` instance or a string name + (``'least_loaded'``, ``'p2c'``, ``'round_robin'``). - :param routing_policy: ``'least_loaded'`` or ``'p2c'``. + :param policy: routing policy or name. :param heartbeat_timeout: seconds before a worker is considered stale. :return: chosen worker dict, or None if no worker has capacity. """ @@ -408,14 +515,17 @@ def choose_worker( ] if not available: return None - if routing_policy == "p2c" and len(available) >= 2: - a, b = random.sample(available, k=2) # noqa: S311 - load_a = a.inflight / max(a.capacity, 1) - load_b = b.inflight / max(b.capacity, 1) - chosen = a if load_a <= load_b else b - else: - chosen = min(available, key=lambda w: w.inflight / max(w.capacity, 1)) - return chosen.as_dict() + # Stable sort so RoundRobin cycles in a predictable, deterministic order. + views = sorted( + [WorkerView(w.worker_id, w.inflight, w.capacity) for w in available], + key=lambda v: v.worker_id, + ) + routing = make_routing_policy(policy) + chosen = routing.choose(views) + if chosen is None: + return None + worker = self._workers.get(chosen.worker_id) + return worker.as_dict() if worker is not None else None # ── observability ───────────────────────────────────────────────────────── diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index 4bb9c4b6..0b8bcc43 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -28,13 +28,19 @@ NNGHub, ControlMessage, ControlResponse, + InMemoryStore, + LeastLoaded, MessageKind, + PowerOfTwoChoices, + QueueFullError, + RoutingPolicy, + RoundRobin, + StoreConfig, TaskEnvelope, WorkerState, WorkerStatus, - QueueFullError, - InMemoryStore, - StoreConfig, + WorkerView, + make_routing_policy, ) @@ -544,3 +550,132 @@ async def test_graceful_drain_and_unregister(ctrl_addr: str, db_path: str) -> No finally: worker.close() await hub.stop() + + +# ── 2b. Routing policy unit tests ───────────────────────────────────────────── + + +def test_least_loaded_picks_idle_worker() -> None: + policy = LeastLoaded() + workers = [WorkerView("w1", inflight=3, capacity=4), WorkerView("w2", inflight=0, capacity=4)] + assert policy.choose(workers).worker_id == "w2" # type: ignore[union-attr] + + +def test_least_loaded_empty_returns_none() -> None: + assert LeastLoaded().choose([]) is None + + +def test_p2c_returns_a_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("w1", 1, 4), WorkerView("w2", 2, 4), WorkerView("w3", 0, 4)] + chosen = policy.choose(workers) + assert chosen is not None + assert chosen.worker_id in {"w1", "w2", "w3"} + + +def test_p2c_single_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("only", 0, 4)] + assert policy.choose(workers).worker_id == "only" # type: ignore[union-attr] + + +def test_round_robin_cycles() -> None: + policy = RoundRobin() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4), WorkerView("w3", 0, 4)] + ids = [policy.choose(workers).worker_id for _ in range(6)] # type: ignore[union-attr] + assert ids == ["w1", "w2", "w3", "w1", "w2", "w3"] + + +def test_make_routing_policy_string() -> None: + assert isinstance(make_routing_policy("least_loaded"), LeastLoaded) + assert isinstance(make_routing_policy("p2c"), PowerOfTwoChoices) + assert isinstance(make_routing_policy("round_robin"), RoundRobin) + + +def test_make_routing_policy_instance_passthrough() -> None: + policy = LeastLoaded() + assert make_routing_policy(policy) is policy + + +def test_make_routing_policy_unknown_raises() -> None: + with pytest.raises(ValueError, match="Unknown routing policy"): + make_routing_policy("no_such_policy") + + +def test_custom_routing_policy_accepted(store: InMemoryStore) -> None: + """Users can pass a RoutingPolicy instance directly to choose_worker.""" + + class AlwaysFirstPolicy: + """Trivial policy: always pick the worker with the lexicographically smallest ID.""" + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + return min(workers, key=lambda w: w.worker_id) if workers else None + + policy = AlwaysFirstPolicy() + # Verify it satisfies the Protocol at runtime. + assert isinstance(policy, RoutingPolicy) + + w1 = _worker_state(worker_id="aaa", capacity=4) + w2 = _worker_state(worker_id="zzz", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + chosen = store.choose_worker(policy, heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "aaa" + + +def test_choose_worker_p2c(store: InMemoryStore) -> None: + """P2C routing returns one of the registered workers.""" + for i in range(4): + store.register_worker(_worker_state(worker_id=f"w{i}", capacity=4)) + chosen = store.choose_worker("p2c", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] in {f"w{i}" for i in range(4)} + + +def test_hub_accepts_policy_instance(ctrl_addr: str, db_path: str) -> None: + """HubConfig.routing_policy accepts a RoutingPolicy instance.""" + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + routing_policy=RoundRobin(), + max_pending=100, + )) + assert isinstance(hub._routing, RoundRobin) + + +# ── 3b. Backpressure integration test ──────────────────────────────────────── + + +async def test_backpressure_hub_rejects_when_full( + ctrl_addr: str, db_path: str +) -> None: + """Hub returns error=queue full when max_pending is reached.""" + hub = _hub(ctrl_addr, db_path, max_pending=1) + await hub.start() + client = FakeClient(ctrl_addr) + try: + await client.submit() # fills the one slot (no worker → stays queued) + # Second submission must be rejected + payload: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + async with client._lock: + await client._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await client._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert not resp.ok + assert resp.error == "queue full" + finally: + client.close() + await hub.stop() From fe3b0a5883d59a075d792e0ea1c7c91541bf7c11 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 23:36:47 +0300 Subject: [PATCH 4/4] Refactoring the NNG support solution v2.5: Update routing policy, add affinity policy, and scheduler abstraction. --- taskiq/brokers/nng/__init__.py | 9 ++ taskiq/brokers/nng/hub.py | 15 ++- taskiq/brokers/nng/storage.py | 108 +++++++++++++++- tests/brokers/test_nng_broker.py | 216 ++++++++++++++++++++++++++++++- 4 files changed, 336 insertions(+), 12 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 75ba0d95..1e0bdcea 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -9,13 +9,17 @@ WorkerStatus, ) from .storage import ( + AffinityPolicy, InMemoryStore, LeastLoaded, PowerOfTwoChoices, + PriorityScheduler, QueueFullError, RoutingPolicy, RoundRobin, + Scheduler, StoreConfig, + TaskContext, WorkerView, make_routing_policy, ) @@ -35,10 +39,15 @@ "InMemoryStore", "StoreConfig", # routing + "TaskContext", "WorkerView", "RoutingPolicy", + "AffinityPolicy", "LeastLoaded", "PowerOfTwoChoices", "RoundRobin", "make_routing_policy", + # scheduler + "Scheduler", + "PriorityScheduler", ] diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index 37a935b9..f3920c6f 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -39,9 +39,12 @@ ) from .storage import ( InMemoryStore, + PriorityScheduler, QueueFullError, RoutingPolicy, + Scheduler, StoreConfig, + TaskContext, make_routing_policy, ) @@ -60,6 +63,7 @@ class HubConfig: dispatch_interval: float = 0.05 reaper_interval: float = 0.5 routing_policy: RoutingPolicy | str = "least_loaded" + scheduler: Scheduler | None = None backoff_cap: float = 60.0 # Number of concurrent Rep0 contexts. Each context handles one req/rep # pair independently; N contexts ≈ N simultaneous control-plane clients. @@ -116,6 +120,7 @@ def __init__(self, config: HubConfig) -> None: # Resolve once at construction so RoundRobin and similar stateful # policies maintain their counter across dispatch calls. self._routing: RoutingPolicy = make_routing_policy(config.routing_policy) + self._scheduler: Scheduler = config.scheduler or PriorityScheduler() self._stop = asyncio.Event() self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 @@ -285,14 +290,22 @@ async def _dispatch_loop(self) -> None: async def _dispatch_once(self) -> bool: """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" - due = self.store.due_tasks(self.config.dispatch_batch) + due = self._scheduler.select(self.store, self.config.dispatch_batch) if not due: return False sent_any = False for row in due: + task_ctx = TaskContext( + task_id=row["task_id"], + task_name=row["task_name"], + labels=row["labels"], + priority=int(row["priority"]), + attempts=int(row["attempts"]), + ) worker = self.store.choose_worker( self._routing, heartbeat_timeout=self.config.heartbeat_timeout, + task=task_ctx, ) if worker is None: return sent_any # no capacity; leave remaining tasks in queue diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 81ab1899..b804400a 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -1,6 +1,8 @@ """Pure in-memory task store for the NNG hub — no external dependencies.""" from __future__ import annotations +import functools +import inspect import random import time from dataclasses import dataclass, field @@ -103,6 +105,20 @@ def as_dict(self) -> dict[str, Any]: } +# ── task context ───────────────────────────────────────────────────────────── + + +@dataclass +class TaskContext: + """Task metadata passed to context-aware routing policies (e.g. affinity).""" + + task_id: str + task_name: str + labels: dict[str, Any] + priority: int = 0 + attempts: int = 0 + + # ── routing policy abstraction ──────────────────────────────────────────────── @@ -178,12 +194,70 @@ def choose(self, workers: list[WorkerView]) -> WorkerView | None: return w -# Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub -# processes. Users needing isolated counters should pass their own instance. +class AffinityPolicy: + """ + Sticky routing: tasks with the same ``affinity_key`` label always go to the + same worker. Falls back to least-loaded when the preferred worker is gone. + + The affinity table is per-instance and lives only in memory. + """ + + def __init__(self) -> None: + """Initialise an empty affinity table.""" + self._table: dict[str, str] = {} # affinity_key → worker_id + + def choose( + self, + workers: list[WorkerView], + task: "TaskContext | None" = None, + ) -> WorkerView | None: + """Return the sticky worker for the task's affinity key, or least-loaded.""" + if not workers: + return None + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key and key in self._table: + match = next( + (w for w in workers if w.worker_id == self._table[key]), None + ) + if match is not None: + return match + chosen = min(workers, key=lambda w: w.load) + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key: + self._table[key] = chosen.worker_id + return chosen + + +@functools.lru_cache(maxsize=None) +def _policy_accepts_task(policy_cls: type) -> bool: + """Return True if policy.choose accepts a ``task`` keyword argument.""" + try: + return "task" in inspect.signature(policy_cls.choose).parameters + except (ValueError, TypeError): + return False + + +def _choose_with_context( + policy: RoutingPolicy, + views: list[WorkerView], + task: "TaskContext | None", +) -> "WorkerView | None": + """Call policy.choose, passing ``task`` only when the policy supports it.""" + if task is not None and _policy_accepts_task(type(policy)): + return policy.choose(views, task=task) # type: ignore[call-arg] + return policy.choose(views) + + +# Singletons for stateless built-ins; RoundRobin/AffinityPolicy singletons are +# fine for single-hub processes. Users needing isolated state should pass their +# own instance. _BUILTIN_POLICIES: dict[str, RoutingPolicy] = { "least_loaded": LeastLoaded(), "p2c": PowerOfTwoChoices(), "round_robin": RoundRobin(), + "affinity": AffinityPolicy(), # type: ignore[dict-item] } @@ -207,6 +281,26 @@ def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy: return policy +# ── scheduler abstraction ───────────────────────────────────────────────────── + + +@runtime_checkable +class Scheduler(Protocol): + """Strategy interface for selecting which tasks to dispatch next.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Return up to ``limit`` tasks ready for dispatch.""" + ... + + +class PriorityScheduler: + """Default scheduler: highest-priority due tasks first.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Delegate to :meth:`InMemoryStore.due_tasks`.""" + return store.due_tasks(limit) + + # ── store ───────────────────────────────────────────────────────────────────── @@ -494,15 +588,21 @@ def choose_worker( policy: "RoutingPolicy | str" = "least_loaded", *, heartbeat_timeout: float = 15.0, + task: "TaskContext | None" = None, ) -> dict[str, Any] | None: """ Select the best available worker using a routing policy. Accepts a :class:`RoutingPolicy` instance or a string name - (``'least_loaded'``, ``'p2c'``, ``'round_robin'``). + (``'least_loaded'``, ``'p2c'``, ``'round_robin'``, ``'affinity'``). + + Context-aware policies (e.g. :class:`AffinityPolicy`) receive the + optional ``task`` argument when they declare it in their ``choose`` + signature. :param policy: routing policy or name. :param heartbeat_timeout: seconds before a worker is considered stale. + :param task: optional task context for context-aware policies. :return: chosen worker dict, or None if no worker has capacity. """ cutoff = time.time() - heartbeat_timeout @@ -521,7 +621,7 @@ def choose_worker( key=lambda v: v.worker_id, ) routing = make_routing_policy(policy) - chosen = routing.choose(views) + chosen = _choose_with_context(routing, views, task) if chosen is None: return None worker = self._workers.get(chosen.worker_id) diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index 0b8bcc43..7bb7b2db 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -15,7 +15,9 @@ import asyncio import os +import sys import tempfile +import textwrap import time import uuid @@ -24,6 +26,7 @@ pynng = pytest.importorskip("pynng") from taskiq.brokers.nng import ( + AffinityPolicy, HubConfig, NNGHub, ControlMessage, @@ -32,10 +35,13 @@ LeastLoaded, MessageKind, PowerOfTwoChoices, + PriorityScheduler, QueueFullError, RoutingPolicy, RoundRobin, + Scheduler, StoreConfig, + TaskContext, TaskEnvelope, WorkerState, WorkerStatus, @@ -86,16 +92,19 @@ def _worker_state( def _hub(control_addr: str, db_path: str, **kwargs: object) -> NNGHub: + defaults: dict[str, object] = { + "max_pending": 100, + "heartbeat_timeout": 2.0, + "lease_timeout": 2.0, + "dispatch_interval": 0.02, + "reaper_interval": 0.1, + "control_concurrency": 4, + } + defaults.update(kwargs) cfg = HubConfig( control_addr=control_addr, task_db=db_path, - max_pending=100, - heartbeat_timeout=2.0, - lease_timeout=2.0, - dispatch_interval=0.02, - reaper_interval=0.1, - control_concurrency=4, - **kwargs, # type: ignore[arg-type] + **defaults, # type: ignore[arg-type] ) return NNGHub(cfg) @@ -679,3 +688,196 @@ async def test_backpressure_hub_rejects_when_full( finally: client.close() await hub.stop() + + +# ── 2c. AffinityPolicy unit tests ──────────────────────────────────────────── + + +def test_affinity_policy_sticks_to_worker() -> None: + """Same affinity_key must route to the same worker across calls.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "user-42"}) + first = policy.choose(workers, task=task) + assert first is not None + for _ in range(10): + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == first.worker_id + + +def test_affinity_policy_falls_back_when_worker_gone() -> None: + """When the sticky worker is no longer available, fall back to least-loaded.""" + policy = AffinityPolicy() + workers_full = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "key-x"}) + first = policy.choose(workers_full, task=task) + assert first is not None + # Remove the sticky worker — only the other one remains. + remaining = [w for w in workers_full if w.worker_id != first.worker_id] + fallback = policy.choose(remaining, task=task) + assert fallback is not None + assert fallback.worker_id != first.worker_id + + +def test_affinity_policy_no_key_uses_least_loaded() -> None: + """Tasks without affinity_key get least-loaded routing.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 3, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {}) + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == "w2" + + +def test_affinity_policy_is_routing_policy() -> None: + assert isinstance(AffinityPolicy(), RoutingPolicy) + + +def test_choose_worker_affinity_string(store: InMemoryStore) -> None: + """String 'affinity' resolves to the singleton AffinityPolicy via choose_worker.""" + for wid in ("a1", "a2"): + store.register_worker(_worker_state(worker_id=wid, capacity=4)) + task = TaskContext("t1", "fn", {"affinity_key": "session-1"}) + first = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert first is not None + for _ in range(5): + chosen = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert chosen is not None + assert chosen["worker_id"] == first["worker_id"] + + +# ── 2d. Scheduler unit tests ───────────────────────────────────────────────── + + +def test_priority_scheduler_delegates_to_due_tasks(store: InMemoryStore) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=5)) + sched = PriorityScheduler() + rows = sched.select(store, limit=10) + assert rows[0]["task_id"] == "hi" + + +def test_priority_scheduler_is_scheduler() -> None: + assert isinstance(PriorityScheduler(), Scheduler) + + +def test_custom_scheduler_used_by_hub(ctrl_addr: str, db_path: str) -> None: + """HubConfig.scheduler accepts a custom Scheduler instance.""" + + class NoopScheduler: + """Never returns tasks — useful for verifying it is actually called.""" + called = False + + def select( + self, store: InMemoryStore, limit: int + ) -> list[dict[str, object]]: + NoopScheduler.called = True + return [] + + scheduler = NoopScheduler() + assert isinstance(scheduler, Scheduler) + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + scheduler=scheduler, + max_pending=10, + )) + assert hub._scheduler is scheduler + + +# ── 4. Multiprocess integration test ───────────────────────────────────────── + +_WORKER_SCRIPT = textwrap.dedent("""\ + import asyncio, sys, os + sys.path.insert(0, {root!r}) + try: + import pynng # noqa: F401 + from taskiq.brokers.nng.broker import NNGBroker + except Exception as exc: + sys.stdout.write(f"SKIP:{{exc}}\\n") + sys.stdout.flush() + sys.exit(0) + + async def main() -> None: + broker = NNGBroker( + {ctrl_addr!r}, + worker_task_addr={task_addr!r}, + worker_id={worker_id!r}, + capacity=1, + heartbeat_interval=1.0, + recv_timeout_ms=3000, + send_timeout_ms=3000, + ) + broker.is_worker_process = True + await broker.startup() + sys.stdout.write("READY\\n") + sys.stdout.flush() + async for msg in broker.listen(): + sys.stdout.write(f"TASK:{{msg.data.decode()}}\\n") + sys.stdout.flush() + await msg.ack() + break + await broker.shutdown() + + asyncio.run(main()) +""") + + +async def test_multiprocess_worker_receives_task( + ctrl_addr: str, db_path: str +) -> None: + """A real subprocess worker (separate OS process) receives and acks a task.""" + repo_root = str( + __import__("pathlib").Path(__file__).parent.parent.parent.resolve() + ) + task_addr = _ipc("mp-worker") + worker_id = f"mp-{uuid.uuid4().hex[:8]}" + + script = _WORKER_SCRIPT.format( + root=repo_root, + ctrl_addr=ctrl_addr, + task_addr=task_addr, + worker_id=worker_id, + ) + + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + + proc = await asyncio.create_subprocess_exec( + sys.executable, "-c", script, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + async def _read_line(timeout: float = 10.0) -> str: + assert proc.stdout is not None + line = await asyncio.wait_for(proc.stdout.readline(), timeout=timeout) + return line.decode().strip() + + try: + first_line = await _read_line(timeout=10.0) + if first_line.startswith("SKIP:"): + pytest.skip(f"Worker subprocess skipped: {first_line[5:]}") + + assert first_line == "READY", f"Expected READY, got: {first_line!r}" + + # Submit a task now that the worker is registered and listening. + tid = await client.submit() + + task_line = await _read_line(timeout=10.0) + assert task_line.startswith("TASK:"), f"Expected TASK:..., got: {task_line!r}" + + await proc.wait() + + # Give hub's reaper a tick to process the ack. + await asyncio.sleep(0.2) + state = hub.store.get_task(tid) + assert state is not None + assert state["state"] == "done", f"Expected done, got {state['state']!r}" + finally: + if proc.returncode is None: + proc.terminate() + await proc.wait() + client.close() + await hub.stop()