diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 99298af2..f54f2259 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -4,7 +4,7 @@ import inspect import sys from collections.abc import Callable -from concurrent.futures import Executor +from concurrent.futures import Executor, ProcessPoolExecutor from logging import getLogger from time import time from typing import Any, get_type_hints @@ -28,6 +28,24 @@ QUEUE_DONE = b"-1" +def _execute_sync_task_in_executor( + target: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """Execute a sync task. + + This is a wrapper to ensure we pass the target function directly + to the executor, avoiding issues with pickling bound methods like ctx.run. + + :param target: function to execute + :param args: positional arguments + :param kwargs: keyword arguments + :return: result of the function call + """ + return target(*args, **kwargs) + + class Receiver: """Class that uses as a callback handler.""" @@ -69,6 +87,7 @@ def __init__( "can result in undefined behavior", ) self.sem_prefetch = asyncio.Semaphore(max_prefetch) + self.is_process_pool = isinstance(executor, ProcessPoolExecutor) async def callback( # noqa: C901, PLR0912 self, @@ -245,15 +264,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 target_future = target(*message.args, **kwargs) else: is_coroutine = False - # If this is a synchronous function, we - # run it in executor and preserve the context. - ctx = contextvars.copy_context() - func = functools.partial(target, *message.args, **kwargs) - target_future = loop.run_in_executor( - self.executor, - ctx.run, - func, - ) + if self.is_process_pool: + # For ProcessPoolExecutor, we can't use ctx.run because it contains + # a reference to contextvars.Context which cannot be pickled. + # Instead, we call the target function directly in the executor. + # Each worker process starts with its own context, so we don't need + # to preserve the parent context. + target_future = loop.run_in_executor( + self.executor, + _execute_sync_task_in_executor, + target, + tuple(message.args), + kwargs, + ) + else: + # For ThreadPoolExecutor, we can use ctx.run with functools.partial + ctx = contextvars.copy_context() + func = functools.partial(target, *message.args, **kwargs) + target_future = loop.run_in_executor( + self.executor, + ctx.run, + func, + ) timeout = message.labels.get("timeout") if timeout is not None: if not is_coroutine: