Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generator resume arguments in the async/await lowering #69033

Merged
merged 12 commits into from
Mar 21, 2020
78 changes: 78 additions & 0 deletions src/libcore/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,84 @@

//! Asynchronous values.

#[cfg(not(bootstrap))]
use crate::{
ops::{Generator, GeneratorState},
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};

mod future;
#[stable(feature = "futures_api", since = "1.36.0")]
pub use self::future::Future;

/// This type is needed because:
///
/// a) Generators cannot implement `for<'a, 'b> Generator<&'a mut Context<'b>>`, so we need to pass
jonas-schievink marked this conversation as resolved.
Show resolved Hide resolved
/// a raw pointer (see https://github.com/rust-lang/rust/issues/68923).
/// b) Raw pointers and `NonNull` aren't `Send` or `Sync`, so that would make every single future
/// non-Send/Sync as well, and we don't want that.
///
/// It also simplifies the HIR lowering of `.await`.
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[derive(Debug, Copy, Clone)]
pub struct ResumeTy(NonNull<Context<'static>>);

#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
unsafe impl Send for ResumeTy {}

#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
unsafe impl Sync for ResumeTy {}

/// Wrap a generator in a future.
///
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
// This is `const` to avoid extra errors after we recover from `const async fn`
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[inline]
pub const fn from_generator<T>(gen: T) -> impl Future<Output = T::Return>
where
T: Generator<ResumeTy, Yield = ()>,
{
struct GenFuture<T: Generator<ResumeTy, Yield = ()>>(T);

// We rely on the fact that async/await futures are immovable in order to create
// self-referential borrows in the underlying generator.
impl<T: Generator<ResumeTy, Yield = ()>> !Unpin for GenFuture<T> {}

impl<T: Generator<ResumeTy, Yield = ()>> Future for GenFuture<T> {
type Output = T::Return;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Safety: Safe because we're !Unpin + !Drop, and this is just a field projection.
let gen = unsafe { Pin::map_unchecked_mut(self, |s| &mut s.0) };

// Resume the generator, turning the `&mut Context` into a `NonNull` raw pointer. The
// `.await` lowering will safely cast that back to a `&mut Context`.
match gen.resume(ResumeTy(NonNull::from(cx).cast::<Context<'static>>())) {
tmandry marked this conversation as resolved.
Show resolved Hide resolved
GeneratorState::Yielded(()) => Poll::Pending,
GeneratorState::Complete(x) => Poll::Ready(x),
}
}
}

GenFuture(gen)
}

#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[inline]
pub unsafe fn poll_with_context<F>(f: Pin<&mut F>, mut cx: ResumeTy) -> Poll<F::Output>
where
F: Future,
{
F::poll(f, cx.0.as_mut())
}
tmandry marked this conversation as resolved.
Show resolved Hide resolved
101 changes: 79 additions & 22 deletions src/librustc_ast_lowering/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
}

/// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
///
/// This results in:
///
/// ```text
/// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
/// <body>
/// })
/// ```
pub(super) fn make_async_expr(
&mut self,
capture_clause: CaptureBy,
Expand All @@ -480,17 +489,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = match ret_ty {
Some(ty) => FnRetTy::Ty(ty),
None => FnRetTy::Default(span),
Some(ty) => hir::FnRetTy::Return(self.lower_ty(&ty, ImplTraitContext::disallowed())),
None => hir::FnRetTy::DefaultReturn(span),
};
let ast_decl = FnDecl { inputs: vec![], output };
let decl = self.lower_fn_decl(&ast_decl, None, /* impl trait allowed */ false, None);
let body_id = self.lower_fn_body(&ast_decl, |this| {

// Resume argument type. We let the compiler infer this to simplify the lowering. It is
// fully constrained by `future::from_generator`.
let input_ty = hir::Ty { hir_id: self.next_id(), kind: hir::TyKind::Infer, span };

// The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
let decl = self.arena.alloc(hir::FnDecl {
inputs: arena_vec![self; input_ty],
output,
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
});

// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
let (pat, task_context_hid) = self.pat_ident_binding_mode(
span,
Ident::with_dummy_span(sym::_task_context),
hir::BindingAnnotation::Mutable,
);
let param = hir::Param { attrs: &[], hir_id: self.next_id(), pat, span };
let params = arena_vec![self; param];

let body_id = self.lower_body(move |this| {
this.generator_kind = Some(hir::GeneratorKind::Async(async_gen_kind));
body(this)

let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
let res = body(this);
this.task_context = old_ctx;
(params, res)
});

// `static || -> <ret_ty> { body }`:
// `static |_task_context| -> <ret_ty> { body }`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading just the comments, I am a bit confused by the leading _ here -- shouldn't this be the binder for the task_context below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable in the desugaring has an underscore because it can trigger the unused variable lint otherwise. I didn't mirror that in the lowering code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah so it should be _task_context in both places? Makes sense, thanks.

let generator_kind = hir::ExprKind::Closure(
capture_clause,
decl,
Expand Down Expand Up @@ -523,13 +557,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// ```rust
/// match <expr> {
/// mut pinned => loop {
/// match ::std::future::poll_with_tls_context(unsafe {
/// <::std::pin::Pin>::new_unchecked(&mut pinned)
/// }) {
/// match unsafe { ::std::future::poll_with_context(
/// <::std::pin::Pin>::new_unchecked(&mut pinned),
/// task_context,
/// ) } {
/// ::std::task::Poll::Ready(result) => break result,
/// ::std::task::Poll::Pending => {}
/// }
jonas-schievink marked this conversation as resolved.
Show resolved Hide resolved
/// yield ();
/// task_context = yield ();
/// }
/// }
/// ```
Expand Down Expand Up @@ -561,12 +596,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
let (pinned_pat, pinned_pat_hid) =
self.pat_ident_binding_mode(span, pinned_ident, hir::BindingAnnotation::Mutable);

// ::std::future::poll_with_tls_context(unsafe {
// ::std::pin::Pin::new_unchecked(&mut pinned)
// })`
let task_context_ident = Ident::with_dummy_span(sym::_task_context);

// unsafe {
// ::std::future::poll_with_context(
// ::std::pin::Pin::new_unchecked(&mut pinned),
// task_context,
// )
// }
let poll_expr = {
let pinned = self.expr_ident(span, pinned_ident, pinned_pat_hid);
let ref_mut_pinned = self.expr_mut_addr_of(span, pinned);
let task_context = if let Some(task_context_hid) = self.task_context {
self.expr_ident_mut(span, task_context_ident, task_context_hid)
} else {
// Use of `await` outside of an async context, we cannot use `task_context` here.
self.expr_err(span)
};
let pin_ty_id = self.next_id();
let new_unchecked_expr_kind = self.expr_call_std_assoc_fn(
pin_ty_id,
Expand All @@ -575,14 +621,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
"new_unchecked",
arena_vec![self; ref_mut_pinned],
);
let new_unchecked =
self.arena.alloc(self.expr(span, new_unchecked_expr_kind, ThinVec::new()));
let unsafe_expr = self.expr_unsafe(new_unchecked);
self.expr_call_std_path(
let new_unchecked = self.expr(span, new_unchecked_expr_kind, ThinVec::new());
let call = self.expr_call_std_path(
gen_future_span,
&[sym::future, sym::poll_with_tls_context],
arena_vec![self; unsafe_expr],
)
&[sym::future, sym::poll_with_context],
arena_vec![self; new_unchecked, task_context],
);
self.arena.alloc(self.expr_unsafe(call))
};

// `::std::task::Poll::Ready(result) => break result`
Expand Down Expand Up @@ -622,14 +667,26 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.stmt_expr(span, match_expr)
};

// task_context = yield ();
let yield_stmt = {
let unit = self.expr_unit(span);
let yield_expr = self.expr(
span,
hir::ExprKind::Yield(unit, hir::YieldSource::Await),
ThinVec::new(),
);
self.stmt_expr(span, yield_expr)
let yield_expr = self.arena.alloc(yield_expr);

if let Some(task_context_hid) = self.task_context {
let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
let assign =
self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, span), AttrVec::new());
self.stmt_expr(span, assign)
} else {
// Use of `await` outside of an async context. Return `yield_expr` so that we can
// proceed with type checking.
self.stmt(span, hir::StmtKind::Semi(yield_expr))
}
};

let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None);
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_ast_lowering/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
}

/// Construct `ExprKind::Err` for the given `span`.
fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
crate fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
self.expr(span, hir::ExprKind::Err, AttrVec::new())
}

Expand Down Expand Up @@ -955,7 +955,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
id
}

fn lower_body(
pub(super) fn lower_body(
&mut self,
f: impl FnOnce(&mut Self) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>),
) -> hir::BodyId {
Expand Down
5 changes: 5 additions & 0 deletions src/librustc_ast_lowering/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ struct LoweringContext<'a, 'hir: 'a> {

generator_kind: Option<hir::GeneratorKind>,

/// When inside an `async` context, this is the `HirId` of the
/// `task_context` local bound to the resume argument of the generator.
task_context: Option<hir::HirId>,

/// Used to get the current `fn`'s def span to point to when using `await`
/// outside of an `async fn`.
current_item: Option<Span>,
Expand Down Expand Up @@ -295,6 +299,7 @@ pub fn lower_crate<'a, 'hir>(
item_local_id_counters: Default::default(),
node_id_to_hir_id: IndexVec::new(),
generator_kind: None,
task_context: None,
current_item: None,
lifetimes_to_define: Vec::new(),
is_collecting_in_band_lifetimes: false,
Expand Down
7 changes: 5 additions & 2 deletions src/librustc_mir/borrow_check/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
};

debug!(
"equate_inputs_and_outputs: normalized_input_tys = {:?}, local_decls = {:?}",
normalized_input_tys, body.local_decls
);

// Equate expected input tys with those in the MIR.
for (&normalized_input_ty, argument_index) in normalized_input_tys.iter().zip(0..) {
// In MIR, argument N is stored in local N+1.
let local = Local::new(argument_index + 1);

debug!("equate_inputs_and_outputs: normalized_input_ty = {:?}", normalized_input_ty);

let mir_input_ty = body.local_decls[local].ty;
let mir_input_span = body.local_decls[local].source_info.span;
self.equate_normalized_input_or_output(
Expand Down
3 changes: 2 additions & 1 deletion src/librustc_span/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ symbols! {
plugin_registrar,
plugins,
Poll,
poll_with_tls_context,
poll_with_context,
powerpc_target_feature,
precise_pointer_size_matching,
pref_align_of,
Expand Down Expand Up @@ -720,6 +720,7 @@ symbols! {
target_has_atomic_load_store,
target_thread_local,
task,
_task_context,
tmandry marked this conversation as resolved.
Show resolved Hide resolved
tbm_target_feature,
termination_trait,
termination_trait_test,
Expand Down
25 changes: 18 additions & 7 deletions src/libstd/future.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Asynchronous values.

use core::cell::Cell;
use core::marker::Unpin;
use core::ops::{Drop, Generator, GeneratorState};
use core::option::Option;
use core::pin::Pin;
use core::ptr::NonNull;
use core::task::{Context, Poll};
#[cfg(bootstrap)]
use core::{
cell::Cell,
marker::Unpin,
ops::{Drop, Generator, GeneratorState},
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};

#[doc(inline)]
#[stable(feature = "futures_api", since = "1.36.0")]
Expand All @@ -17,22 +19,26 @@ pub use core::future::*;
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
// This is `const` to avoid extra errors after we recover from `const async fn`
#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
pub const fn from_generator<T: Generator<Yield = ()>>(x: T) -> impl Future<Output = T::Return> {
GenFuture(x)
}

/// A wrapper around generators used to implement `Future` for `async`/`await` code.
#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct GenFuture<T: Generator<Yield = ()>>(T);

// We rely on the fact that async/await futures are immovable in order to create
// self-referential borrows in the underlying generator.
#[cfg(bootstrap)]
impl<T: Generator<Yield = ()>> !Unpin for GenFuture<T> {}

#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
Expand All @@ -48,12 +54,15 @@ impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
}
}

#[cfg(bootstrap)]
thread_local! {
static TLS_CX: Cell<Option<NonNull<Context<'static>>>> = Cell::new(None);
}

#[cfg(bootstrap)]
struct SetOnDrop(Option<NonNull<Context<'static>>>);

#[cfg(bootstrap)]
impl Drop for SetOnDrop {
fn drop(&mut self) {
TLS_CX.with(|tls_cx| {
Expand All @@ -64,13 +73,15 @@ impl Drop for SetOnDrop {

// Safety: the returned guard must drop before `cx` is dropped and before
// any previous guard is dropped.
#[cfg(bootstrap)]
unsafe fn set_task_context(cx: &mut Context<'_>) -> SetOnDrop {
// transmute the context's lifetime to 'static so we can store it.
let cx = core::mem::transmute::<&mut Context<'_>, &mut Context<'static>>(cx);
let old_cx = TLS_CX.with(|tls_cx| tls_cx.replace(Some(NonNull::from(cx))));
SetOnDrop(old_cx)
}

#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
/// Polls a future in the current thread-local task waker.
Expand Down
Loading