Skip to content
This repository has been archived by the owner on Nov 23, 2017. It is now read-only.

Add asyncio.run() and asyncio.run_forever() functions. #465

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .futures import *
from .locks import *
from .protocols import *
from .runners import *
from .queues import *
from .streams import *
from .subprocess import *
Expand All @@ -36,6 +37,7 @@
futures.__all__ +
locks.__all__ +
protocols.__all__ +
runners.__all__ +
queues.__all__ +
streams.__all__ +
subprocess.__all__ +
Expand Down
148 changes: 148 additions & 0 deletions asyncio/runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""asyncio.run() and asyncio.run_forever() functions."""

__all__ = ['run', 'run_forever']

import inspect
import threading

from . import coroutines
from . import events


def _cleanup(loop):
try:
# `shutdown_asyncgens` was added in Python 3.6; not all
# event loops might support it.
shutdown_asyncgens = loop.shutdown_asyncgens
except AttributeError:
pass
else:
loop.run_until_complete(shutdown_asyncgens())
finally:
events.set_event_loop(None)
loop.close()


def run(main, *, debug=False):
"""Run a coroutine.

This function runs the passed coroutine, taking care of
managing the asyncio event loop and finalizing asynchronous
generators.

This function must be called from the main thread, and it
cannot be called when another asyncio event loop is running.

If debug is True, the event loop will be run in debug mode.

This function should be used as a main entry point for
asyncio programs, and should not be used to call asynchronous
APIs.

Example:

async def main():
await asyncio.sleep(1)
print('hello')

asyncio.run(main())
"""
if events._get_running_loop() is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these two checks (no running loop and main thread) that appear here and in run_forever() could be factored out to a helper function like you did for _cleanup(loop)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to customize the error message for each function, so I guess a little bit of copy/paste is fine.

raise RuntimeError(
"asyncio.run() cannot be called from a running event loop")
if not isinstance(threading.current_thread(), threading._MainThread):
raise RuntimeError(
"asyncio.run() must be called from the main thread")
if not coroutines.iscoroutine(main):
raise ValueError("a coroutine was expected, got {!r}".format(main))

loop = events.new_event_loop()
try:
events.set_event_loop(loop)

if debug:
loop.set_debug(True)

return loop.run_until_complete(main)
finally:
_cleanup(loop)


def run_forever(main, *, debug=False):
"""Run asyncio loop.

main must be an asynchronous generator with one yield, separating
program initialization from cleanup logic.

If debug is True, the event loop will be run in debug mode.

This function should be used as a main entry point for
asyncio programs, and should not be used to call asynchronous
APIs.

Example:

async def main():
server = await asyncio.start_server(...)
try:
yield # <- Let event loop run forever.
except KeyboardInterrupt:
print('^C received; exiting.')
finally:
server.close()
await server.wait_closed()

asyncio.run_forever(main())
"""
if not hasattr(inspect, 'isasyncgen'):
raise NotImplementedError

if events._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run_forever() cannot be called from a running event loop")
if not isinstance(threading.current_thread(), threading._MainThread):
raise RuntimeError(
"asyncio.run_forever() must be called from the main thread")
if not inspect.isasyncgen(main):
raise ValueError(
"an asynchronous generator was expected, got {!r}".format(main))

one_yield_msg = ("asyncio.run_forever() supports only "
"asynchronous generators with one empty yield")
loop = events.new_event_loop()
try:
events.set_event_loop(loop)
if debug:
loop.set_debug(True)

ret = None
try:
ret = loop.run_until_complete(main.asend(None))
except StopAsyncIteration as ex:
return
if ret is not None:
raise RuntimeError(one_yield_msg)

yielded_twice = False
try:
loop.run_forever()
except BaseException as ex:
try:
loop.run_until_complete(main.athrow(ex))
except StopAsyncIteration as ex:
pass
else:
yielded_twice = True
else:
try:
loop.run_until_complete(main.asend(None))
except StopAsyncIteration as ex:
pass
else:
yielded_twice = True

if yielded_twice:
raise RuntimeError(one_yield_msg)

finally:
_cleanup(loop)
4 changes: 4 additions & 0 deletions runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def list_dir(prefix, dir):
print("Skipping '{0}': need at least Python 3.5".format(modname),
file=sys.stderr)
continue
if modname == 'test_runner' and (sys.version_info < (3, 6)):
print("Skipping '{0}': need at least Python 3.6".format(modname),
file=sys.stderr)
continue
try:
loader = importlib.machinery.SourceFileLoader(modname, sourcefile)
mods.append((loader.load_module(), sourcefile))
Expand Down
223 changes: 223 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""Tests asyncio.run() and asyncio.run_forever()."""

import asyncio
import unittest
import sys

from unittest import mock


class TestPolicy(asyncio.AbstractEventLoopPolicy):

def __init__(self, loop_factory):
self.loop_factory = loop_factory
self.loop = None

def get_event_loop(self):
# shouldn't ever be called by asyncio.run()
# or asyncio.run_forever()
raise RuntimeError

def new_event_loop(self):
return self.loop_factory()

def set_event_loop(self, loop):
if loop is not None:
# we want to check if the loop is closed
# in BaseTest.tearDown
self.loop = loop


class BaseTest(unittest.TestCase):

def new_loop(self):
loop = asyncio.BaseEventLoop()
loop._process_events = mock.Mock()
loop._selector = mock.Mock()
loop._selector.select.return_value = ()
loop.shutdown_ag_run = False

async def shutdown_asyncgens():
loop.shutdown_ag_run = True
loop.shutdown_asyncgens = shutdown_asyncgens

return loop

def setUp(self):
super().setUp()

policy = TestPolicy(self.new_loop)
asyncio.set_event_loop_policy(policy)

def tearDown(self):
policy = asyncio.get_event_loop_policy()
if policy.loop is not None:
self.assertTrue(policy.loop.is_closed())
self.assertTrue(policy.loop.shutdown_ag_run)

asyncio.set_event_loop_policy(None)
super().tearDown()


class RunTests(BaseTest):

def test_asyncio_run_return(self):
async def main():
await asyncio.sleep(0)
return 42

self.assertEqual(asyncio.run(main()), 42)

def test_asyncio_run_raises(self):
async def main():
await asyncio.sleep(0)
raise ValueError('spam')

with self.assertRaisesRegex(ValueError, 'spam'):
asyncio.run(main())

def test_asyncio_run_only_coro(self):
for o in {1, lambda: None}:
with self.subTest(obj=o), \
self.assertRaisesRegex(ValueError,
'a coroutine was expected'):
asyncio.run(o)

def test_asyncio_run_debug(self):
async def main(expected):
loop = asyncio.get_event_loop()
self.assertIs(loop.get_debug(), expected)

asyncio.run(main(False))
asyncio.run(main(True), debug=True)

def test_asyncio_run_from_running_loop(self):
async def main():
asyncio.run(main())

with self.assertRaisesRegex(RuntimeError,
'cannot be called from a running'):
asyncio.run(main())


class RunForeverTests(BaseTest):

def stop_soon(self, *, exc=None):
loop = asyncio.get_event_loop()

if exc:
def throw():
raise exc
loop.call_later(0.01, throw)
else:
loop.call_later(0.01, loop.stop)

def test_asyncio_run_forever_return(self):
async def main():
if 0:
yield
return

self.assertIsNone(asyncio.run_forever(main()))

def test_asyncio_run_forever_non_none_yield(self):
async def main():
yield 1

with self.assertRaisesRegex(RuntimeError, 'one empty yield'):
self.assertIsNone(asyncio.run_forever(main()))

def test_asyncio_run_forever_try_finally(self):
DONE = 0

async def main():
nonlocal DONE
self.stop_soon()
try:
yield
finally:
DONE += 1

asyncio.run_forever(main())
self.assertEqual(DONE, 1)

def test_asyncio_run_forever_raises_before_yield(self):
async def main():
await asyncio.sleep(0)
raise ValueError('spam')
yield

with self.assertRaisesRegex(ValueError, 'spam'):
asyncio.run_forever(main())

def test_asyncio_run_forever_raises_after_yield(self):
async def main():
self.stop_soon()
yield
raise ValueError('spam')

with self.assertRaisesRegex(ValueError, 'spam'):
asyncio.run_forever(main())

def test_asyncio_run_forever_two_yields(self):
async def main():
self.stop_soon()
yield
yield
raise ValueError('spam')

with self.assertRaisesRegex(RuntimeError, 'one empty yield'):
asyncio.run_forever(main())

def test_asyncio_run_forever_only_ag(self):
async def coro():
pass

for o in {1, lambda: None, coro()}:
with self.subTest(obj=o), \
self.assertRaisesRegex(ValueError,
'an asynchronous.*was expected'):
asyncio.run_forever(o)

def test_asyncio_run_forever_debug(self):
async def main(expected):
loop = asyncio.get_event_loop()
self.assertIs(loop.get_debug(), expected)
if 0:
yield

asyncio.run_forever(main(False))
asyncio.run_forever(main(True), debug=True)

def test_asyncio_run_forever_from_running_loop(self):
async def main():
asyncio.run_forever(main())
if 0:
yield

with self.assertRaisesRegex(RuntimeError,
'cannot be called from a running'):
asyncio.run_forever(main())

def test_asyncio_run_forever_base_exception(self):
vi = sys.version_info
if vi[:2] != (3, 6) or vi.releaselevel == 'beta' and vi.serial < 4:
# See http://bugs.python.org/issue28721 for details.
raise unittest.SkipTest(
'this test requires Python 3.6b4 or greater')

DONE = 0

class MyExc(BaseException):
pass

async def main():
nonlocal DONE
self.stop_soon(exc=MyExc)
try:
yield
except MyExc:
DONE += 1

asyncio.run_forever(main())
self.assertEqual(DONE, 1)