From 7acda4e19514c67df755e5b47fa9124281f99338 Mon Sep 17 00:00:00 2001 From: Mark Gordon Date: Thu, 10 Jun 2021 23:18:39 -0700 Subject: [PATCH] Added support for manipulating context in asyncio tasks Signed-off-by: Mark Gordon --- Doc/library/asyncio-task.rst | 7 +-- Lib/asyncio/runners.py | 4 +- Lib/asyncio/tasks.py | 33 +++++++++++-- Lib/test/test_asyncio/test_runners.py | 26 ++++++++++ Lib/test/test_asyncio/test_tasks.py | 70 ++++++++++++++++++++++++++- Modules/_asynciomodule.c | 44 +++++++++++++++-- Modules/clinic/_asynciomodule.c.h | 56 ++++++++++++++++++--- 7 files changed, 218 insertions(+), 22 deletions(-) diff --git a/Doc/library/asyncio-task.rst b/Doc/library/asyncio-task.rst index bbdef3345a4d42f..e8d162ed6d5841e 100644 --- a/Doc/library/asyncio-task.rst +++ b/Doc/library/asyncio-task.rst @@ -847,9 +847,10 @@ Task Object APIs except :meth:`Future.set_result` and :meth:`Future.set_exception`. - Tasks support the :mod:`contextvars` module. When a Task - is created it copies the current context and later runs its - coroutine in the copied context. + Tasks support the :mod:`contextvars` module. Tasks can be run under + any context, defaulting to a copy of the context that created them. This + context will later be used to run its coroutines. The context associated + with a task can be modified using `:meth:`asyncio.run_in_context`. .. versionchanged:: 3.7 Added support for the :mod:`contextvars` module. diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py index 9a5e9a48479ef78..6417e38b5cbd300 100644 --- a/Lib/asyncio/runners.py +++ b/Lib/asyncio/runners.py @@ -5,7 +5,7 @@ from . import tasks -def run(main, *, debug=None): +def run(main, *, debug=None, **task_kwargs): """Execute the coroutine and return the result. This function runs the passed coroutine, taking care of @@ -41,7 +41,7 @@ async def main(): events.set_event_loop(loop) if debug is not None: loop.set_debug(debug) - return loop.run_until_complete(main) + return loop.run_until_complete(tasks.Task(main, loop=loop, **task_kwargs)) finally: try: _cancel_all_tasks(loop) diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 9a9d0d6e3cc269d..39b0c138a59fe0b 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -5,7 +5,7 @@ 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe', - 'current_task', 'all_tasks', + 'current_task', 'all_tasks', 'run_in_context', '_register_task', '_unregister_task', '_enter_task', '_leave_task', ) @@ -71,6 +71,24 @@ def _set_task_name(task, name): set_name(name) +async def run_in_context(context, coro): + """Run the coroutine coro in the passed context. + + This method can be used to run coro in an alternate context within the + calling Task. This is the asyncio analog of contextvars.Context.run. + """ + task = current_task() + if task is None: + raise RuntimeError("No running task") + prev_context = task._set_context(context) + await __sleep0() + try: + return await coro + finally: + task._set_context(prev_context) + await __sleep0() + + class Task(futures._PyFuture): # Inherit Python Task implementation # from a Python Future implementation. @@ -89,7 +107,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation # status is still pending _log_destroy_pending = True - def __init__(self, coro, *, loop=None, name=None): + def __init__(self, coro, *, loop=None, name=None, context=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -107,7 +125,11 @@ def __init__(self, coro, *, loop=None, name=None): self._must_cancel = False self._fut_waiter = None self._coro = coro - self._context = contextvars.copy_context() + + if context is None: + self._context = contextvars.copy_context() + else: + self._context = context self._loop.call_soon(self.__step, context=self._context) _register_task(self) @@ -129,6 +151,11 @@ def __class_getitem__(cls, type): def _repr_info(self): return base_tasks._task_repr_info(self) + def _set_context(self, context): + prev_context = self._context + self._context = context + return prev_context + def get_coro(self): return self._coro diff --git a/Lib/test/test_asyncio/test_runners.py b/Lib/test/test_asyncio/test_runners.py index b9ae02dc3c04e09..8e42ca3f6f9874f 100644 --- a/Lib/test/test_asyncio/test_runners.py +++ b/Lib/test/test_asyncio/test_runners.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import unittest from unittest import mock @@ -180,3 +181,28 @@ async def main(): self.assertIsNone(spinner.ag_frame) self.assertFalse(spinner.ag_running) + + def test_asyncio_run_task_creation(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + context = contextvars.Context() + context.run(cvar.set, 'maybe') + + async def check_explicit(name_expect, var_expect, var_update): + self.assertEqual(name_expect, asyncio.current_task().get_name()) + self.assertEqual(var_expect, cvar.get()) + cvar.set(var_update) + + # Verify name and context passed to task + asyncio.run(check_explicit('my-task', 'maybe', 'sometimes'), name='my-task', context=context) + self.assertEqual(context.run(cvar.get), 'sometimes') + + async def check_default(var_expect, var_update): + self.assertTrue(asyncio.current_task().get_name().startswith("Task-")) + self.assertEqual(var_expect, cvar.get()) + cvar.set(var_update) + + # Verify default name and context (copy of current context) used otherwise + cvar.set('seldom') + asyncio.run(check_default('seldom', 'often')) + self.assertEqual(cvar.get(), 'seldom') diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index a9e4cf53566ca90..546f40af20bff94 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -108,8 +108,8 @@ class BaseTaskTests: Task = None Future = None - def new_task(self, loop, coro, name='TestTask'): - return self.__class__.Task(coro, loop=loop, name=name) + def new_task(self, loop, coro, name='TestTask', context=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context) def new_future(self, loop): return self.__class__.Future(loop=loop) @@ -2860,6 +2860,72 @@ async def main(): self.assertEqual(cvar.get(), -1) + def test_context_4(self): + # Test specifying context + cvar = contextvars.ContextVar('cvar', default='nope') + + context = contextvars.Context() + context.run(cvar.set, 'maybe') + + async def sub(expect, update): + self.assertEqual(cvar.get(), expect) + cvar.set(update) + + async def main(): + self.assertEqual(cvar.get(), 'maybe') + + await sub('maybe', 'always') + self.assertEqual(cvar.get(), 'always') + + await self.new_task(loop, sub('always', 'never')) + self.assertEqual(cvar.get(), 'always') + + await self.new_task(loop, sub('always', 'never'), context=context) + self.assertEqual(cvar.get(), 'never') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main(), context=context) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + self.assertEqual(context.run(cvar.get), 'never') + + def test_run_in_context(self): + # Test run_in_context behavior + cvar = contextvars.ContextVar('cvar', default='nope') + + context = contextvars.Context() + context.run(cvar.set, 'maybe') + + async def sub(update, parent_task): + self.assertIs(parent_task, asyncio.current_task()) + value = cvar.get() + cvar.set(update) + return value + + + async def main(): + self.assertEqual(cvar.get(), 'maybe') + sub_context = context.copy() + + cvar.set('never') + self.assertEqual(await asyncio.run_in_context(sub_context, sub('always', asyncio.current_task())), 'maybe') + self.assertEqual(cvar.get(), 'never') + self.assertEqual(sub_context.run(cvar.get), 'always') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main(), context=context) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + self.assertEqual(context.run(cvar.get), 'never') + def test_get_coro(self): loop = asyncio.new_event_loop() coro = coroutine_function() diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index a4d5d4551e9b0af..1715169bd10a655 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -2009,14 +2009,15 @@ _asyncio.Task.__init__ * loop: object = None name: object = None + context: object = None A coroutine wrapped in a Future. [clinic start generated code]*/ static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name) -/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/ + PyObject *name, PyObject *context) +/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/ { if (future_init((FutureObj*)self, loop)) { return -1; @@ -2034,9 +2035,14 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, return -1; } - Py_XSETREF(self->task_context, PyContext_CopyCurrent()); - if (self->task_context == NULL) { - return -1; + if (context == Py_None) { + Py_XSETREF(self->task_context, PyContext_CopyCurrent()); + if (self->task_context == NULL) { + return -1; + } + } else { + Py_INCREF(context); + Py_XSETREF(self->task_context, context); } Py_CLEAR(self->task_fut_waiter); @@ -2379,6 +2385,33 @@ _asyncio_Task_set_name(TaskObj *self, PyObject *value) Py_RETURN_NONE; } +/*[clinic input] +_asyncio.Task._set_context + + context: object + +Set the context associated with the task. + +This does not change the current thread context and only affects the thread +context of later callbacks. Returns the previously context attached to the task. +[clinic start generated code]*/ + +static PyObject * +_asyncio_Task__set_context_impl(TaskObj *self, PyObject *context) +/*[clinic end generated code: output=46ac5ea28ccfac3b input=36a0652ac2b5f671]*/ +{ + if (context == Py_None) { + PyErr_SetString( + PyExc_RuntimeError, "expected valid context"); + return NULL; + } + + PyObject *prev_context = self->task_context; + Py_INCREF(context); + self->task_context = context; + return prev_context; +} + static void TaskObj_finalize(TaskObj *task) { @@ -2471,6 +2504,7 @@ static PyMethodDef TaskType_methods[] = { _ASYNCIO_TASK__REPR_INFO_METHODDEF _ASYNCIO_TASK_GET_NAME_METHODDEF _ASYNCIO_TASK_SET_NAME_METHODDEF + _ASYNCIO_TASK__SET_CONTEXT_METHODDEF _ASYNCIO_TASK_GET_CORO_METHODDEF {"__class_getitem__", task_cls_getitem, METH_O|METH_CLASS, NULL}, {NULL, NULL} /* Sentinel */ diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index c472e652fb7c566..bd437201eb360b4 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -310,28 +310,29 @@ _asyncio_Future__repr_info(FutureObj *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(_asyncio_Task___init____doc__, -"Task(coro, *, loop=None, name=None)\n" +"Task(coro, *, loop=None, name=None, context=None)\n" "--\n" "\n" "A coroutine wrapped in a Future."); static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name); + PyObject *name, PyObject *context); static int _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) { int return_value = -1; - static const char * const _keywords[] = {"coro", "loop", "name", NULL}; + static const char * const _keywords[] = {"coro", "loop", "name", "context", NULL}; static _PyArg_Parser _parser = {NULL, _keywords, "Task", 0}; - PyObject *argsbuf[3]; + PyObject *argsbuf[4]; PyObject * const *fastargs; Py_ssize_t nargs = PyTuple_GET_SIZE(args); Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1; PyObject *coro; PyObject *loop = Py_None; PyObject *name = Py_None; + PyObject *context = Py_None; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 1, 0, argsbuf); if (!fastargs) { @@ -347,9 +348,15 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) goto skip_optional_kwonly; } } - name = fastargs[2]; + if (fastargs[2]) { + name = fastargs[2]; + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + context = fastargs[3]; skip_optional_kwonly: - return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name); + return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context); exit: return return_value; @@ -611,6 +618,41 @@ PyDoc_STRVAR(_asyncio_Task_set_name__doc__, #define _ASYNCIO_TASK_SET_NAME_METHODDEF \ {"set_name", (PyCFunction)_asyncio_Task_set_name, METH_O, _asyncio_Task_set_name__doc__}, +PyDoc_STRVAR(_asyncio_Task__set_context__doc__, +"_set_context($self, /, context)\n" +"--\n" +"\n" +"Set the context associated with the task.\n" +"\n" +"This does not change the current thread context and only affects the thread\n" +"context of later callbacks. Returns the previously context attached to the task."); + +#define _ASYNCIO_TASK__SET_CONTEXT_METHODDEF \ + {"_set_context", (PyCFunction)(void(*)(void))_asyncio_Task__set_context, METH_FASTCALL|METH_KEYWORDS, _asyncio_Task__set_context__doc__}, + +static PyObject * +_asyncio_Task__set_context_impl(TaskObj *self, PyObject *context); + +static PyObject * +_asyncio_Task__set_context(TaskObj *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + static const char * const _keywords[] = {"context", NULL}; + static _PyArg_Parser _parser = {NULL, _keywords, "_set_context", 0}; + PyObject *argsbuf[1]; + PyObject *context; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf); + if (!args) { + goto exit; + } + context = args[0]; + return_value = _asyncio_Task__set_context_impl(self, context); + +exit: + return return_value; +} + PyDoc_STRVAR(_asyncio__get_running_loop__doc__, "_get_running_loop($module, /)\n" "--\n" @@ -871,4 +913,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=0d127162ac92e0c0 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=02e21cc39188c337 input=a9049054013a1b77]*/