Skip to content

Commit

Permalink
Add with_ctx keyword argument to darq.task decorator. It tells worker…
Browse files Browse the repository at this point in the history
… to pass its ctx as a first argument to task function
  • Loading branch information
kindermax authored and seedofjoy committed May 29, 2022
1 parent 0bfebc8 commit bac6e12
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
4 changes: 4 additions & 0 deletions darq/app.py
Expand Up @@ -185,6 +185,7 @@ def task(
max_tries: t.Optional[int] = None,
queue: t.Optional[str] = None,
expires: t.Union[None, AnyTimedelta, UNSET_ARG] = unset_arg,
with_ctx: bool = False,
) -> t.Callable[
[WrappingFunc],
DarqTask[WrappingFunc],
Expand All @@ -200,6 +201,7 @@ def task( # type: ignore[no-untyped-def]
max_tries=None,
queue=None,
expires=unset_arg,
with_ctx=False,
):
"""
:param func: coroutine function
Expand All @@ -212,6 +214,7 @@ def task( # type: ignore[no-untyped-def]
different queue
:param expires: if the task still hasn't started after this
duration, do not run it
:param with_ctx: pass context to the task as first argument
"""
task_queue = queue
task_expires = expires
Expand Down Expand Up @@ -292,6 +295,7 @@ async def delay(*args: t.Any, **kwargs: t.Any) -> t.Optional[Job]:
self.registry.add(Task.new(
coroutine=function, name=name,
keep_result=keep_result, timeout=timeout, max_tries=max_tries,
with_ctx=with_ctx
))

return function
Expand Down
10 changes: 9 additions & 1 deletion darq/worker.py
Expand Up @@ -53,6 +53,7 @@ class Task(t.NamedTuple):
timeout_s: t.Optional[float]
keep_result_s: t.Optional[float]
max_tries: t.Optional[int]
with_ctx: bool

@classmethod
def new(
Expand All @@ -62,10 +63,12 @@ def new(
timeout: t.Optional[SecondsTimedelta] = None,
keep_result: t.Optional[SecondsTimedelta] = None,
max_tries: t.Optional[int] = None,
with_ctx: bool = False,
) -> 'Task':
return cls(
name=name, coroutine=coroutine, timeout_s=to_seconds(timeout),
keep_result_s=to_seconds(keep_result), max_tries=max_tries,
with_ctx=with_ctx
)


Expand Down Expand Up @@ -473,11 +476,16 @@ async def job_failed(exc: Exception) -> None:
'%6.2fs → %s(%s)%s',
(start_ms - enqueue_time_ms) / 1000, ref, s, extra,
)

coro = function.coroutine
if function.with_ctx:
coro = partial(coro, ctx)

# run repr(result) and extra inside try/except as they can
# raise exceptions
try:
async with async_timeout.timeout(timeout_s):
result = await function.coroutine(*args, **kwargs)
result = await coro(*args, **kwargs)
except Exception as e:
exc_extra = getattr(e, 'extra', None)
if callable(exc_extra):
Expand Down
50 changes: 49 additions & 1 deletion tests/test_app.py
Expand Up @@ -110,9 +110,10 @@ async def test_task_parametrized(darq):
keep_result = 0.5
max_tries = 92
queue = 'my_queue'
with_ctx = True
foobar_task = darq.task(
keep_result=keep_result, timeout=timeout,
max_tries=max_tries, queue=queue,
max_tries=max_tries, queue=queue, with_ctx=with_ctx
)(foobar)

task_name = 'tests.test_app.foobar'
Expand All @@ -123,6 +124,7 @@ async def test_task_parametrized(darq):
assert task.coroutine == foobar_task
assert task.timeout_s == timeout
assert task.keep_result_s == keep_result
assert task.with_ctx == with_ctx


async def test_task_self_enqueue(darq, caplog, worker_factory):
Expand Down Expand Up @@ -385,3 +387,49 @@ async def test_expires_param(
**expected_kwargs,
)
await darq.disconnect()


async def foobar_with_ctx(ctx, a: int) -> int:
return 42 + a + ctx['b']


@pytest.mark.parametrize('func_args,func_kwargs,func_ctx,result', [
((1,), {}, {'b': 1}, 44),
((), {'a': 2}, {'b': 1}, 45),
])
async def test_run_task_with_ctx(
func_args, func_kwargs, func_ctx, result,
arq_redis, caplog, worker_factory
):
caplog.set_level(logging.INFO)

async def on_worker_startup(ctx):
ctx.update(func_ctx)

darq = Darq(
redis_settings=redis_settings,
burst=True,
on_startup=on_worker_startup,
)

foobar_with_ctx_task = darq.task(foobar_with_ctx, with_ctx=True)

await darq.connect()

job_id = 'testing'
function_name = 'tests.test_app.foobar_with_ctx'
await foobar_with_ctx_task.apply_async(
func_args, func_kwargs, job_id=job_id
)

worker = worker_factory(darq)
await worker.main()

assert_worker_job_finished(
records=caplog.records,
job_id=job_id,
function_name=function_name,
result=result,
args=func_args,
kwargs=func_kwargs,
)

0 comments on commit bac6e12

Please sign in to comment.