Skip to content

Commit

Permalink
Merge pull request python-trio#397 from njsmith/simplify-traps
Browse files Browse the repository at this point in the history
Simplify implementation of primitive traps like wait_task_rescheduled
  • Loading branch information
touilleMan committed Jan 10, 2018
2 parents a351a2d + 8ca7ca8 commit 07d144e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
1 change: 1 addition & 0 deletions newsfragments/395.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplify implementation of primitive traps like wait_task_rescheduled
31 changes: 14 additions & 17 deletions trio/_core/_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
__all__ = ["cancel_shielded_checkpoint", "Abort", "wait_task_rescheduled"]


# Decorator to turn a generator into a well-behaved async function:
def asyncfunction(fn):
# Set the coroutine flag
fn = types.coroutine(fn)
# Then wrap it in an 'async def', to enable the "coroutine was not
# awaited" warning
@wraps(fn)
async def wrapper(*args, **kwargs):
return await fn(*args, **kwargs)

return wrapper
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _async_yield(obj):
return (yield obj)


# This class object is used as a singleton.
Expand All @@ -28,8 +27,7 @@ class CancelShieldedCheckpoint:
pass


@asyncfunction
def cancel_shielded_checkpoint():
async def cancel_shielded_checkpoint():
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
Expand All @@ -42,7 +40,7 @@ def cancel_shielded_checkpoint():
await trio.hazmat.checkpoint()
"""
return (yield CancelShieldedCheckpoint).unwrap()
return (await _async_yield(CancelShieldedCheckpoint)).unwrap()


# Return values for abort functions
Expand All @@ -65,8 +63,7 @@ class WaitTaskRescheduled:
abort_func = attr.ib()


@asyncfunction
def wait_task_rescheduled(abort_func):
async def wait_task_rescheduled(abort_func):
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in trio. Every time a
Expand Down Expand Up @@ -159,4 +156,4 @@ def abort(inner_raise_cancel):
above about how you should use a higher-level API if at all possible?
"""
return (yield WaitTaskRescheduled(abort_func)).unwrap()
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()

0 comments on commit 07d144e

Please sign in to comment.