Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wg-async-await] Drop async fn arguments in async block #59135

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/librustc/hir/intravisit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ pub trait Visitor<'v> : Sized {
fn visit_pat(&mut self, p: &'v Pat) {
walk_pat(self, p)
}
fn visit_argument_source(&mut self, s: &'v ArgSource) {
walk_argument_source(self, s)
}
fn visit_anon_const(&mut self, c: &'v AnonConst) {
walk_anon_const(self, c)
}
Expand Down Expand Up @@ -391,10 +394,17 @@ pub fn walk_body<'v, V: Visitor<'v>>(visitor: &mut V, body: &'v Body) {
for argument in &body.arguments {
visitor.visit_id(argument.hir_id);
visitor.visit_pat(&argument.pat);
visitor.visit_argument_source(&argument.source);
}
visitor.visit_expr(&body.value);
}

pub fn walk_argument_source<'v, V: Visitor<'v>>(visitor: &mut V, source: &'v ArgSource) {
if let ArgSource::AsyncFn(pat) = source {
visitor.visit_pat(pat);
}
}

pub fn walk_local<'v, V: Visitor<'v>>(visitor: &mut V, local: &'v Local) {
// Intentionally visiting the expr first - the initialization expr
// dominates the local's definition.
Expand Down
131 changes: 93 additions & 38 deletions src/librustc/hir/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2025,10 +2025,17 @@ impl<'a> LoweringContext<'a> {
init: l.init.as_ref().map(|e| P(self.lower_expr(e))),
span: l.span,
attrs: l.attrs.clone(),
source: hir::LocalSource::Normal,
source: self.lower_local_source(l.source),
}, ids)
}

fn lower_local_source(&mut self, ls: LocalSource) -> hir::LocalSource {
match ls {
LocalSource::Normal => hir::LocalSource::Normal,
LocalSource::AsyncFn => hir::LocalSource::AsyncFn,
}
}

fn lower_mutability(&mut self, m: Mutability) -> hir::Mutability {
match m {
Mutability::Mutable => hir::MutMutable,
Expand All @@ -2041,6 +2048,14 @@ impl<'a> LoweringContext<'a> {
hir::Arg {
hir_id,
pat: self.lower_pat(&arg.pat),
source: self.lower_arg_source(&arg.source),
}
}

fn lower_arg_source(&mut self, source: &ArgSource) -> hir::ArgSource {
match source {
ArgSource::Normal => hir::ArgSource::Normal,
ArgSource::AsyncFn(pat) => hir::ArgSource::AsyncFn(self.lower_pat(pat)),
}
}

Expand Down Expand Up @@ -2809,15 +2824,21 @@ impl<'a> LoweringContext<'a> {
fn lower_async_body(
&mut self,
decl: &FnDecl,
asyncness: IsAsync,
asyncness: &IsAsync,
body: &Block,
) -> hir::BodyId {
self.lower_body(Some(decl), |this| {
if let IsAsync::Async { closure_id, .. } = asyncness {
self.lower_body(Some(&decl), |this| {
if let IsAsync::Async { closure_id, ref arguments, .. } = asyncness {
let mut body = body.clone();

for a in arguments.iter().rev() {
body.stmts.insert(0, a.stmt.clone());
}

let async_expr = this.make_async_expr(
CaptureBy::Value, closure_id, None,
CaptureBy::Value, *closure_id, None,
|this| {
let body = this.lower_block(body, false);
let body = this.lower_block(&body, false);
this.expr_block(body, ThinVec::new())
});
this.expr(body.span, async_expr, ThinVec::new())
Expand Down Expand Up @@ -2876,26 +2897,42 @@ impl<'a> LoweringContext<'a> {
value
)
}
ItemKind::Fn(ref decl, header, ref generics, ref body) => {
ItemKind::Fn(ref decl, ref header, ref generics, ref body) => {
let fn_def_id = self.resolver.definitions().local_def_id(id);
self.with_new_scopes(|this| {
// Note: we don't need to change the return type from `T` to
// `impl Future<Output = T>` here because lower_body
// only cares about the input argument patterns in the function
// declaration (decl), not the return types.
let body_id = this.lower_async_body(decl, header.asyncness.node, body);
let mut lower_fn = |decl: &FnDecl| {
// Note: we don't need to change the return type from `T` to
// `impl Future<Output = T>` here because lower_body
// only cares about the input argument patterns in the function
// declaration (decl), not the return types.
let body_id = this.lower_async_body(&decl, &header.asyncness.node, body);

let (generics, fn_decl) = this.add_in_band_defs(
generics,
fn_def_id,
AnonymousLifetimeMode::PassThrough,
|this, idty| this.lower_fn_decl(
&decl,
Some((fn_def_id, idty)),
true,
header.asyncness.node.opt_return_id()
),
);

let (generics, fn_decl) = this.add_in_band_defs(
generics,
fn_def_id,
AnonymousLifetimeMode::PassThrough,
|this, idty| this.lower_fn_decl(
decl,
Some((fn_def_id, idty)),
true,
header.asyncness.node.opt_return_id()
),
);
(body_id, generics, fn_decl)
};

let (body_id, generics, fn_decl) = if let IsAsync::Async {
arguments, ..
} = &header.asyncness.node {
let mut decl = decl.clone();
// Replace the arguments of this async function with the generated
// arguments that will be moved into the closure.
decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
lower_fn(&decl)
} else {
lower_fn(decl)
};

hir::ItemKind::Fn(
fn_decl,
Expand Down Expand Up @@ -3384,15 +3421,33 @@ impl<'a> LoweringContext<'a> {
)
}
ImplItemKind::Method(ref sig, ref body) => {
let body_id = self.lower_async_body(&sig.decl, sig.header.asyncness.node, body);
let impl_trait_return_allow = !self.is_in_trait_impl;
let (generics, sig) = self.lower_method_sig(
&i.generics,
sig,
impl_item_def_id,
impl_trait_return_allow,
sig.header.asyncness.node.opt_return_id(),
);
let mut lower_method = |sig: &MethodSig| {
let body_id = self.lower_async_body(
&sig.decl, &sig.header.asyncness.node, body
);
let impl_trait_return_allow = !self.is_in_trait_impl;
let (generics, sig) = self.lower_method_sig(
&i.generics,
sig,
impl_item_def_id,
impl_trait_return_allow,
sig.header.asyncness.node.opt_return_id(),
);
(body_id, generics, sig)
};

let (body_id, generics, sig) = if let IsAsync::Async {
ref arguments, ..
} = sig.header.asyncness.node {
let mut sig = sig.clone();
// Replace the arguments of this async function with the generated
// arguments that will be moved into the closure.
sig.decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
lower_method(&sig)
} else {
lower_method(sig)
};

(generics, hir::ImplItemKind::Method(sig, body_id))
}
ImplItemKind::Type(ref ty) => (
Expand Down Expand Up @@ -3582,7 +3637,7 @@ impl<'a> LoweringContext<'a> {
impl_trait_return_allow: bool,
is_async: Option<NodeId>,
) -> (hir::Generics, hir::MethodSig) {
let header = self.lower_fn_header(sig.header);
let header = self.lower_fn_header(&sig.header);
let (generics, decl) = self.add_in_band_defs(
generics,
fn_def_id,
Expand All @@ -3604,10 +3659,10 @@ impl<'a> LoweringContext<'a> {
}
}

fn lower_fn_header(&mut self, h: FnHeader) -> hir::FnHeader {
fn lower_fn_header(&mut self, h: &FnHeader) -> hir::FnHeader {
hir::FnHeader {
unsafety: self.lower_unsafety(h.unsafety),
asyncness: self.lower_asyncness(h.asyncness.node),
asyncness: self.lower_asyncness(&h.asyncness.node),
constness: self.lower_constness(h.constness),
abi: h.abi,
}
Expand All @@ -3627,7 +3682,7 @@ impl<'a> LoweringContext<'a> {
}
}

fn lower_asyncness(&mut self, a: IsAsync) -> hir::IsAsync {
fn lower_asyncness(&mut self, a: &IsAsync) -> hir::IsAsync {
match a {
IsAsync::Async { .. } => hir::IsAsync::Async,
IsAsync::NotAsync => hir::IsAsync::NotAsync,
Expand Down Expand Up @@ -3940,7 +3995,7 @@ impl<'a> LoweringContext<'a> {
})
}
ExprKind::Closure(
capture_clause, asyncness, movability, ref decl, ref body, fn_decl_span
capture_clause, ref asyncness, movability, ref decl, ref body, fn_decl_span
) => {
if let IsAsync::Async { closure_id, .. } = asyncness {
let outer_decl = FnDecl {
Expand Down Expand Up @@ -3978,7 +4033,7 @@ impl<'a> LoweringContext<'a> {
Some(&**ty)
} else { None };
let async_body = this.make_async_expr(
capture_clause, closure_id, async_ret_ty,
capture_clause, *closure_id, async_ret_ty,
|this| {
this.with_new_scopes(|this| this.lower_expr(body))
});
Expand Down
39 changes: 27 additions & 12 deletions src/librustc/hir/map/def_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,17 @@ impl<'a> DefCollector<'a> {
id: NodeId,
name: Name,
span: Span,
header: &FnHeader,
header: &'a FnHeader,
generics: &'a Generics,
decl: &'a FnDecl,
body: &'a Block,
) {
let (closure_id, return_impl_trait_id) = match header.asyncness.node {
let (closure_id, return_impl_trait_id, arguments) = match &header.asyncness.node {
IsAsync::Async {
closure_id,
return_impl_trait_id,
} => (closure_id, return_impl_trait_id),
arguments,
} => (closure_id, return_impl_trait_id, arguments),
_ => unreachable!(),
};

Expand All @@ -86,17 +87,31 @@ impl<'a> DefCollector<'a> {
let fn_def_data = DefPathData::ValueNs(name.as_interned_str());
let fn_def = self.create_def(id, fn_def_data, ITEM_LIKE_SPACE, span);
return self.with_parent(fn_def, |this| {
this.create_def(return_impl_trait_id, DefPathData::ImplTrait, REGULAR_SPACE, span);
this.create_def(*return_impl_trait_id, DefPathData::ImplTrait, REGULAR_SPACE, span);

visit::walk_generics(this, generics);
visit::walk_fn_decl(this, decl);

let closure_def = this.create_def(closure_id,
DefPathData::ClosureExpr,
REGULAR_SPACE,
span);
// Walk the generated arguments for the `async fn`.
for a in arguments {
use visit::Visitor;
this.visit_ty(&a.arg.ty);
}

// We do not invoke `walk_fn_decl` as this will walk the arguments that are being
// replaced.
visit::walk_fn_ret_ty(this, &decl.output);

let closure_def = this.create_def(
*closure_id, DefPathData::ClosureExpr, REGULAR_SPACE, span,
);
this.with_parent(closure_def, |this| {
visit::walk_block(this, body);
for a in arguments {
use visit::Visitor;
// Walk each of the generated statements before the regular block body.
this.visit_stmt(&a.stmt);
}

visit::walk_block(this, &body);
})
})
}
Expand Down Expand Up @@ -288,7 +303,7 @@ impl<'a> visit::Visitor<'a> for DefCollector<'a> {

match expr.node {
ExprKind::Mac(..) => return self.visit_macro_invoc(expr.id),
ExprKind::Closure(_, asyncness, ..) => {
ExprKind::Closure(_, ref asyncness, ..) => {
let closure_def = self.create_def(expr.id,
DefPathData::ClosureExpr,
REGULAR_SPACE,
Expand All @@ -298,7 +313,7 @@ impl<'a> visit::Visitor<'a> for DefCollector<'a> {
// Async closures desugar to closures inside of closures, so
// we must create two defs.
if let IsAsync::Async { closure_id, .. } = asyncness {
let async_def = self.create_def(closure_id,
let async_def = self.create_def(*closure_id,
DefPathData::ClosureExpr,
REGULAR_SPACE,
expr.span);
Expand Down
31 changes: 31 additions & 0 deletions src/librustc/hir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,17 @@ pub enum LocalSource {
Normal,
/// A desugared `for _ in _ { .. }` loop.
ForLoopDesugar,
/// When lowering async functions, we create locals within the `async move` so that
/// all arguments are dropped after the future is polled.
///
/// ```ignore (pseudo-Rust)
/// async fn foo(<pattern> @ x: Type) {
/// async move {
/// let <pattern> = x;
/// }
/// }
/// ```
AsyncFn,
}

/// Hints at the original code for a `match _ { .. }`.
Expand Down Expand Up @@ -1871,6 +1882,26 @@ pub struct InlineAsm {
pub struct Arg {
pub pat: P<Pat>,
pub hir_id: HirId,
pub source: ArgSource,
}

impl Arg {
/// Returns the pattern representing the original binding for this argument.
pub fn original_pat(&self) -> &P<Pat> {
match &self.source {
ArgSource::Normal => &self.pat,
ArgSource::AsyncFn(pat) => &pat,
}
}
}

/// Represents the source of an argument in a function header.
#[derive(Clone, RustcEncodable, RustcDecodable, Debug, HashStable)]
pub enum ArgSource {
/// Argument as specified by the user.
Normal,
/// Generated argument from `async fn` lowering, contains the original binding pattern.
AsyncFn(P<Pat>),
}

/// Represents the header (not the body) of a function declaration.
Expand Down
1 change: 0 additions & 1 deletion src/librustc/ich/impls_hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,3 @@ impl<'hir> HashStable<StableHashingContext<'hir>> for attr::OptimizeAttr {
mem::discriminant(self).hash_stable(hcx, hasher);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,16 @@ impl<'a, 'gcx, 'tcx> NiceRegionError<'a, 'gcx, 'tcx> {
let sub_is_ret_type =
self.is_return_type_anon(scope_def_id_sub, bregion_sub, ty_fndecl_sub);

let span_label_var1 = if let Some(simple_ident) = anon_arg_sup.pat.simple_ident() {
format!(" from `{}`", simple_ident)
} else {
String::new()
let span_label_var1 = match anon_arg_sup.original_pat().simple_ident() {
Some(simple_ident) => format!(" from `{}`", simple_ident),
None => String::new(),
};

let span_label_var2 = if let Some(simple_ident) = anon_arg_sub.pat.simple_ident() {
format!(" into `{}`", simple_ident)
} else {
String::new()
let span_label_var2 = match anon_arg_sub.original_pat().simple_ident() {
Some(simple_ident) => format!(" into `{}`", simple_ident),
None => String::new(),
};


let (span_1, span_2, main_label, span_label) = match (sup_is_ret_type, sub_is_ret_type) {
(None, None) => {
let (main_label_1, span_label_1) = if ty_sup.hir_id == ty_sub.hir_id {
Expand Down
Loading