Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import sys
from collections.abc import Callable
from concurrent.futures import Executor
from concurrent.futures import Executor, ProcessPoolExecutor
from logging import getLogger
from time import time
from typing import Any, get_type_hints
Expand All @@ -28,6 +28,24 @@
QUEUE_DONE = b"-1"


def _execute_sync_task_in_executor(
target: Callable[..., Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""Execute a sync task.

This is a wrapper to ensure we pass the target function directly
to the executor, avoiding issues with pickling bound methods like ctx.run.

:param target: function to execute
:param args: positional arguments
:param kwargs: keyword arguments
:return: result of the function call
"""
return target(*args, **kwargs)


class Receiver:
"""Class that uses as a callback handler."""

Expand Down Expand Up @@ -69,6 +87,7 @@ def __init__(
"can result in undefined behavior",
)
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
self.is_process_pool = isinstance(executor, ProcessPoolExecutor)

async def callback( # noqa: C901, PLR0912
self,
Expand Down Expand Up @@ -245,15 +264,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
target_future = target(*message.args, **kwargs)
else:
is_coroutine = False
# If this is a synchronous function, we
# run it in executor and preserve the context.
ctx = contextvars.copy_context()
func = functools.partial(target, *message.args, **kwargs)
target_future = loop.run_in_executor(
self.executor,
ctx.run,
func,
)
if self.is_process_pool:
# For ProcessPoolExecutor, we can't use ctx.run because it contains
# a reference to contextvars.Context which cannot be pickled.
# Instead, we call the target function directly in the executor.
# Each worker process starts with its own context, so we don't need
# to preserve the parent context.
target_future = loop.run_in_executor(
self.executor,
_execute_sync_task_in_executor,
target,
tuple(message.args),
kwargs,
)
else:
# For ThreadPoolExecutor, we can use ctx.run with functools.partial
ctx = contextvars.copy_context()
func = functools.partial(target, *message.args, **kwargs)
target_future = loop.run_in_executor(
self.executor,
ctx.run,
func,
)
timeout = message.labels.get("timeout")
if timeout is not None:
if not is_coroutine:
Expand Down
Loading