From 065331f0eaa5e2d216c6f3fea6b99cf956e9d2a1 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Tue, 6 Apr 2021 17:01:36 -0400 Subject: [PATCH 01/15] Add another test for anext --- Lib/test/test_asyncgen.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 99464e3d0929fd..0521a9c585049e 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -388,6 +388,26 @@ async def consume(): with self.assertRaises(StopAsyncIteration): self.loop.run_until_complete(consume()) + async def test_2(): + g1 = gen() + self.assertEqual(await anext(g1), 1) + self.assertEqual(await anext(g1), 2) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + + g2 = gen() + self.assertEqual(await anext(g2, "default"), 1) + self.assertEqual(await anext(g2, "default"), 2) + self.assertEqual(await anext(g2, "default"), "default") + self.assertEqual(await anext(g2, "default"), "default") + + return "completed" + + result = self.loop.run_until_complete(test_2()) + self.assertEqual(result, "completed") + def test_async_gen_aiter(self): async def gen(): yield 1 From ddeac2d32e629e80f067b14e00042ef044852a4a Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Tue, 6 Apr 2021 17:20:35 -0400 Subject: [PATCH 02/15] Fix anext(ait, default) returning None --- Objects/iterobject.c | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index f0c6b799176804..42cfc4ea22f0fb 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -316,7 +316,10 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg) static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { - PyObject *result = PyIter_Next(obj->wrapped); + assert(obj->wrapped != NULL); + unaryfunc getter = Py_TYPE(obj->wrapped)->tp_iternext; + assert(getter != NULL); + PyObject *result = getter(obj->wrapped); if (result != NULL) { return result; } From 134449b2d9b7b079abe2b767dc0c4d9e5ab7126a Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Tue, 6 Apr 2021 20:16:19 -0400 Subject: [PATCH 03/15] Add null check for getter --- Objects/iterobject.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 42cfc4ea22f0fb..5f0d1ab89920c0 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -318,7 +318,11 @@ anextawaitable_iternext(anextawaitableobject *obj) { assert(obj->wrapped != NULL); unaryfunc getter = Py_TYPE(obj->wrapped)->tp_iternext; - assert(getter != NULL); + if (getter == NULL) { + PyErr_SetString(PyExc_TypeError, + "anext() argument was not async iterable."); + return NULL; + } PyObject *result = getter(obj->wrapped); if (result != NULL) { return result; From bfd57ada5be42b627b5259792cc8524b2d8b6630 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Tue, 6 Apr 2021 20:26:12 -0400 Subject: [PATCH 04/15] Test anext() on generator and python-implemented async iterator class --- Lib/test/test_asyncgen.py | 77 +++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 0521a9c585049e..9d5f9721137f02 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -373,40 +373,55 @@ def tearDown(self): asyncio.set_event_loop_policy(None) def test_async_gen_anext(self): - async def gen(): + async def agen(): yield 1 yield 2 - g = gen() - async def consume(): - results = [] - results.append(await anext(g)) - results.append(await anext(g)) - results.append(await anext(g, 'buckle my shoe')) - return results - res = self.loop.run_until_complete(consume()) - self.assertEqual(res, [1, 2, 'buckle my shoe']) - with self.assertRaises(StopAsyncIteration): - self.loop.run_until_complete(consume()) - - async def test_2(): - g1 = gen() - self.assertEqual(await anext(g1), 1) - self.assertEqual(await anext(g1), 2) - with self.assertRaises(StopAsyncIteration): - await anext(g1) - with self.assertRaises(StopAsyncIteration): - await anext(g1) - - g2 = gen() - self.assertEqual(await anext(g2, "default"), 1) - self.assertEqual(await anext(g2, "default"), 2) - self.assertEqual(await anext(g2, "default"), "default") - self.assertEqual(await anext(g2, "default"), "default") - - return "completed" - result = self.loop.run_until_complete(test_2()) - self.assertEqual(result, "completed") + class MyAsyncIter: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + async def __anext__(self): + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + + for gen in (agen, MyAsyncIter): + g = gen() + async def consume(): + results = [] + results.append(await anext(g)) + results.append(await anext(g)) + results.append(await anext(g, 'buckle my shoe')) + return results + res = self.loop.run_until_complete(consume()) + self.assertEqual(res, [1, 2, 'buckle my shoe']) + with self.assertRaises(StopAsyncIteration): + self.loop.run_until_complete(consume()) + + async def test_2(): + g1 = gen() + self.assertEqual(await anext(g1), 1) + self.assertEqual(await anext(g1), 2) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + + g2 = gen() + self.assertEqual(await anext(g2, "default"), 1) + self.assertEqual(await anext(g2, "default"), 2) + self.assertEqual(await anext(g2, "default"), "default") + self.assertEqual(await anext(g2, "default"), "default") + + return "completed" + + result = self.loop.run_until_complete(test_2()) + self.assertEqual(result, "completed") def test_async_gen_aiter(self): async def gen(): From d8c56ffec50586f83eb721167ee09632d2044391 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Wed, 7 Apr 2021 13:56:45 -0400 Subject: [PATCH 05/15] fix whitespace --- Lib/test/test_asyncgen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 9d5f9721137f02..8a2707711d6fe7 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -389,7 +389,7 @@ async def __anext__(self): else: self.yielded += 1 return self.yielded - + for gen in (agen, MyAsyncIter): g = gen() async def consume(): From ab54eceb3c7dc2ec37899caa77fd803d6c33123d Mon Sep 17 00:00:00 2001 From: "blurb-it[bot]" <43283697+blurb-it[bot]@users.noreply.github.com> Date: Wed, 7 Apr 2021 18:00:08 +0000 Subject: [PATCH 06/15] =?UTF-8?q?=F0=9F=93=9C=F0=9F=A4=96=20Added=20by=20b?= =?UTF-8?q?lurb=5Fit.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst new file mode 100644 index 00000000000000..75951ae794d106 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst @@ -0,0 +1 @@ +Fixed a bug where ``anext(ait, default)`` would erroneously return None. \ No newline at end of file From 056d9c61a4b8ac92fb8ea87c33bfa2a54ed1403a Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Wed, 7 Apr 2021 18:01:18 -0400 Subject: [PATCH 07/15] Explicity call am_await() when applicable. --- Objects/iterobject.c | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 5f0d1ab89920c0..0b30fd86609dce 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -317,10 +317,41 @@ static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { assert(obj->wrapped != NULL); - unaryfunc getter = Py_TYPE(obj->wrapped)->tp_iternext; + PyTypeObject *type = Py_TYPE(obj->wrapped); + /* Consider the following class: + * + * class A: + * async def __anext__(self): + * ... + * a = A() + * + * Then anext(a) should call + * a.__anext__().__await__().__next__() + * + * On the other hand, given + * + * async def agen(): + * yield 1 + * yield 2 + * gen = agen() + * + * Then anext(g) can just call + * g.__anext__().__next__() + */ + if (type->tp_as_async && type->tp_as_async->am_await) { + unaryfunc await_getter = type->tp_as_async->am_await; + PyObject *result = await_getter(obj->wrapped); + if (result == NULL) { + return NULL; + } + type = Py_TYPE(result); + Py_SETREF(obj->wrapped, result); + } + + unaryfunc getter = type->tp_iternext; if (getter == NULL) { PyErr_SetString(PyExc_TypeError, - "anext() argument was not async iterable."); + "__await__() did not return an iterator."); return NULL; } PyObject *result = getter(obj->wrapped); From 58a56fb223d17686223ec5aa24eb6a54a22791a1 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Thu, 8 Apr 2021 12:53:44 -0400 Subject: [PATCH 08/15] Refactor tests --- Lib/test/test_asyncgen.py | 71 ++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 8a2707711d6fe7..509bfb142f41d2 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -372,11 +372,46 @@ def tearDown(self): self.loop = None asyncio.set_event_loop_policy(None) - def test_async_gen_anext(self): + def check_async_iterator_anext(self, ait_class): + g = ait_class() + async def consume(): + results = [] + results.append(await anext(g)) + results.append(await anext(g)) + results.append(await anext(g, 'buckle my shoe')) + return results + res = self.loop.run_until_complete(consume()) + self.assertEqual(res, [1, 2, 'buckle my shoe']) + with self.assertRaises(StopAsyncIteration): + self.loop.run_until_complete(consume()) + + async def test_2(): + g1 = ait_class() + self.assertEqual(await anext(g1), 1) + self.assertEqual(await anext(g1), 2) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + + g2 = ait_class() + self.assertEqual(await anext(g2, "default"), 1) + self.assertEqual(await anext(g2, "default"), 2) + self.assertEqual(await anext(g2, "default"), "default") + self.assertEqual(await anext(g2, "default"), "default") + + return "completed" + + result = self.loop.run_until_complete(test_2()) + self.assertEqual(result, "completed") + + def test_async_generator_anext(self): async def agen(): yield 1 yield 2 + self.check_async_iterator_anext(agen) + def test_python_async_iterator_anext(self): class MyAsyncIter: """Asynchronously yield 1, then 2.""" def __init__(self): @@ -389,39 +424,7 @@ async def __anext__(self): else: self.yielded += 1 return self.yielded - - for gen in (agen, MyAsyncIter): - g = gen() - async def consume(): - results = [] - results.append(await anext(g)) - results.append(await anext(g)) - results.append(await anext(g, 'buckle my shoe')) - return results - res = self.loop.run_until_complete(consume()) - self.assertEqual(res, [1, 2, 'buckle my shoe']) - with self.assertRaises(StopAsyncIteration): - self.loop.run_until_complete(consume()) - - async def test_2(): - g1 = gen() - self.assertEqual(await anext(g1), 1) - self.assertEqual(await anext(g1), 2) - with self.assertRaises(StopAsyncIteration): - await anext(g1) - with self.assertRaises(StopAsyncIteration): - await anext(g1) - - g2 = gen() - self.assertEqual(await anext(g2, "default"), 1) - self.assertEqual(await anext(g2, "default"), 2) - self.assertEqual(await anext(g2, "default"), "default") - self.assertEqual(await anext(g2, "default"), "default") - - return "completed" - - result = self.loop.run_until_complete(test_2()) - self.assertEqual(result, "completed") + self.check_async_iterator_anext(MyAsyncIter) def test_async_gen_aiter(self): async def gen(): From a113e19ec4165b73fe8652259c4351452fbd760b Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Thu, 8 Apr 2021 13:15:15 -0400 Subject: [PATCH 09/15] Add test for async generator with __anext__ = types.coroutine(...) --- Lib/test/test_asyncgen.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 509bfb142f41d2..ac3d52d26bde7b 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -426,6 +426,25 @@ async def __anext__(self): return self.yielded self.check_async_iterator_anext(MyAsyncIter) + def test_python_async_iterator_types_coroutine_anext(self): + import types + class MyAsyncIterWithTypesCoro: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + @types.coroutine + def __anext__(self): + if False: + yield "this is a generator-based coroutine" + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + self.check_async_iterator_anext(MyAsyncIterWithTypesCoro) + def test_async_gen_aiter(self): async def gen(): yield 1 From 403917057b68840e400848c908d71fdd97fba75b Mon Sep 17 00:00:00 2001 From: Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> Date: Thu, 8 Apr 2021 20:33:53 -0400 Subject: [PATCH 10/15] Update Objects/iterobject.c Fix inconsistent variable names in comment Co-authored-by: Joshua Bronson --- Objects/iterobject.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 0b30fd86609dce..c1bd184d1d59b9 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -335,8 +335,8 @@ anextawaitable_iternext(anextawaitableobject *obj) * yield 2 * gen = agen() * - * Then anext(g) can just call - * g.__anext__().__next__() + * Then anext(gen) can just call + * gen.__anext__().__next__() */ if (type->tp_as_async && type->tp_as_async->am_await) { unaryfunc await_getter = type->tp_as_async->am_await; From 0130217476aa3e41876c36a7320b8b6a0603a137 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Fri, 9 Apr 2021 23:00:23 -0400 Subject: [PATCH 11/15] fix minor issues; add a test case for bad awaitables --- Lib/test/test_asyncgen.py | 21 +++++++++++++++++++++ Objects/iterobject.c | 7 +++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index ac3d52d26bde7b..a20ebc23657d0c 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -495,6 +495,27 @@ async def call_with_wrong_type_args(): with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_wrong_type_args()) + class BadAwaitable: + def __await__(self): + return 42 + class MyAsyncIter: + def __aiter__(self): + return self + def __anext__(self): + return BadAwaitable() + + async def call_not_awaitable(): + regex = r"__await__.*iterator" + awaitable = anext(MyAsyncIter(), "default") + with self.assertRaisesRegex(TypeError, regex): + await awaitable + awaitable = anext(MyAsyncIter()) + with self.assertRaisesRegex(TypeError, regex): + await awaitable + return "completed" + result = self.loop.run_until_complete(call_not_awaitable()) + self.assertEqual(result, "completed") + def test_aiter_bad_args(self): async def gen(): yield 1 diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 0b30fd86609dce..a6337532314a7d 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -325,7 +325,7 @@ anextawaitable_iternext(anextawaitableobject *obj) * ... * a = A() * - * Then anext(a) should call + * Then `await anext(a)` should call * a.__anext__().__await__().__next__() * * On the other hand, given @@ -335,12 +335,11 @@ anextawaitable_iternext(anextawaitableobject *obj) * yield 2 * gen = agen() * - * Then anext(g) can just call + * Then `await anext(g)` can just call * g.__anext__().__next__() */ if (type->tp_as_async && type->tp_as_async->am_await) { - unaryfunc await_getter = type->tp_as_async->am_await; - PyObject *result = await_getter(obj->wrapped); + PyObject *result = type->tp_as_async->am_await(obj->wrapped); if (result == NULL) { return NULL; } From 5607d839eed58e20dd6b4dc8ba50b34f0afefce7 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Sat, 10 Apr 2021 00:39:33 -0400 Subject: [PATCH 12/15] Add more tests for corner cases --- Lib/test/test_asyncgen.py | 76 ++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 12 deletions(-) diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index a20ebc23657d0c..77c15c02bc8914 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -488,23 +488,27 @@ async def call_with_too_many_args(): await anext(gen(), 1, 3) async def call_with_wrong_type_args(): await anext(1, gen()) + async def call_with_kwarg(): + await anext(aiterator=gen()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_few_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_many_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_wrong_type_args()) - - class BadAwaitable: - def __await__(self): - return 42 - class MyAsyncIter: - def __aiter__(self): - return self - def __anext__(self): - return BadAwaitable() - - async def call_not_awaitable(): + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_kwarg()) + + def test_anext_bad_await(self): + async def bad_awaitable(): + class BadAwaitable: + def __await__(self): + return 42 + class MyAsyncIter: + def __aiter__(self): + return self + def __anext__(self): + return BadAwaitable() regex = r"__await__.*iterator" awaitable = anext(MyAsyncIter(), "default") with self.assertRaisesRegex(TypeError, regex): @@ -513,7 +517,55 @@ async def call_not_awaitable(): with self.assertRaisesRegex(TypeError, regex): await awaitable return "completed" - result = self.loop.run_until_complete(call_not_awaitable()) + result = self.loop.run_until_complete(bad_awaitable()) + self.assertEqual(result, "completed") + + async def check_anext_returning_iterator(self, aiter_class): + awaitable = anext(aiter_class(), "default") + with self.assertRaises(TypeError): + await awaitable + awaitable = anext(aiter_class()) + with self.assertRaises(TypeError): + await awaitable + return "completed" + + def test_anext_return_iterator(self): + class WithIterAnext: + def __aiter__(self): + return self + def __anext__(self): + return iter("abc") + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext)) + self.assertEqual(result, "completed") + + def test_anext_return_generator(self): + class WithGenAnext: + def __aiter__(self): + return self + def __anext__(self): + yield + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext)) + self.assertEqual(result, "completed") + + def test_anext_await_raises(self): + class RaisingAwaitable: + def __await__(self): + raise ZeroDivisionError() + yield + class WithRaisingAwaitableAnext: + def __aiter__(self): + return self + def __anext__(self): + return RaisingAwaitable() + async def do_test(): + awaitable = anext(WithRaisingAwaitableAnext()) + with self.assertRaises(ZeroDivisionError): + await awaitable + awaitable = anext(WithRaisingAwaitableAnext(), "default") + with self.assertRaises(ZeroDivisionError): + await awaitable + return "completed" + result = self.loop.run_until_complete(do_test()) self.assertEqual(result, "completed") def test_aiter_bad_args(self): From 11fe9c2a53990d91dc5e41a39c16a06a08bd5eee Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Sat, 10 Apr 2021 01:56:55 -0400 Subject: [PATCH 13/15] make the corner-case tests pass --- Objects/iterobject.c | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index e467a514a3c4b4..dc0a85bf76c49c 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -316,8 +316,6 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg) static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { - assert(obj->wrapped != NULL); - PyTypeObject *type = Py_TYPE(obj->wrapped); /* Consider the following class: * * class A: @@ -338,22 +336,29 @@ anextawaitable_iternext(anextawaitableobject *obj) * Then `await anext(gen)` can just call * gen.__anext__().__next__() */ - if (type->tp_as_async && type->tp_as_async->am_await) { - PyObject *result = type->tp_as_async->am_await(obj->wrapped); - if (result == NULL) { + assert(obj->wrapped != NULL); + PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped); + if (awaitable == NULL) { + return NULL; + } + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + if (Py_TYPE(awaitable)->tp_as_async == NULL || + Py_TYPE(awaitable)->tp_as_async->am_await == NULL) + { + PyErr_SetString(PyExc_TypeError, + "__anext__ returned a non-awaitable."); + return NULL; + } + unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; + PyObject *new_awaitable = getter(awaitable); + Py_SETREF(awaitable, new_awaitable); + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + PyErr_SetString(PyExc_TypeError, + "__await__ returned a non-iterable"); return NULL; } - type = Py_TYPE(result); - Py_SETREF(obj->wrapped, result); - } - - unaryfunc getter = type->tp_iternext; - if (getter == NULL) { - PyErr_SetString(PyExc_TypeError, - "__await__() did not return an iterator."); - return NULL; } - PyObject *result = getter(obj->wrapped); + PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); if (result != NULL) { return result; } From 6a6fa66b6917d0d08f20df7a214b8053b36d75a1 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Sat, 10 Apr 2021 21:15:14 -0400 Subject: [PATCH 14/15] Add null check for new_awaitable; remove null check for am_await because awaitable is always a coroutine in that case. --- Objects/iterobject.c | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index dc0a85bf76c49c..62bb1ac72d9931 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -342,15 +342,16 @@ anextawaitable_iternext(anextawaitableobject *obj) return NULL; } if (Py_TYPE(awaitable)->tp_iternext == NULL) { - if (Py_TYPE(awaitable)->tp_as_async == NULL || - Py_TYPE(awaitable)->tp_as_async->am_await == NULL) - { - PyErr_SetString(PyExc_TypeError, - "__anext__ returned a non-awaitable."); - return NULL; - } + /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator, + * or an iterator. Of these, only coroutines lack tp_iternext. + */ + assert(PyCoro_CheckExact(awaitable)); unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; PyObject *new_awaitable = getter(awaitable); + if (new_awaitable == NULL) { + Py_DECREF(awaitable); + return NULL; + } Py_SETREF(awaitable, new_awaitable); if (Py_TYPE(awaitable)->tp_iternext == NULL) { PyErr_SetString(PyExc_TypeError, From 857a2b0c605897e284115197263a70a1d57fb031 Mon Sep 17 00:00:00 2001 From: sweeneyde Date: Sat, 10 Apr 2021 23:03:39 -0400 Subject: [PATCH 15/15] Avoid leaking awaitable --- Objects/iterobject.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 62bb1ac72d9931..c316de4e32c6b2 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -356,10 +356,12 @@ anextawaitable_iternext(anextawaitableobject *obj) if (Py_TYPE(awaitable)->tp_iternext == NULL) { PyErr_SetString(PyExc_TypeError, "__await__ returned a non-iterable"); + Py_DECREF(awaitable); return NULL; } } PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); + Py_DECREF(awaitable); if (result != NULL) { return result; }