diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 66f9c24e8724..98ca70010929 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -466,6 +466,7 @@ impl ExprCollector<'_> { arg_types: arg_types.into(), ret_type, body, + is_async: e.async_token().is_some(), }, syntax_ptr, ) diff --git a/crates/hir-def/src/expr.rs b/crates/hir-def/src/expr.rs index c1b3788acb7d..ef63e2b6ffd3 100644 --- a/crates/hir-def/src/expr.rs +++ b/crates/hir-def/src/expr.rs @@ -196,6 +196,7 @@ pub enum Expr { arg_types: Box<[Option>]>, ret_type: Option>, body: ExprId, + is_async: bool, }, Tuple { exprs: Box<[ExprId]>, diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 2a13106390d9..489bb99a49a5 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -218,7 +218,7 @@ impl<'a> InferenceContext<'a> { self.diverges = Diverges::Maybe; TyBuilder::unit() } - Expr::Closure { body, args, ret_type, arg_types } => { + Expr::Closure { body, args, ret_type, arg_types, is_async } => { assert_eq!(args.len(), arg_types.len()); let mut sig_tys = Vec::new(); @@ -262,18 +262,46 @@ impl<'a> InferenceContext<'a> { ); // Now go through the argument patterns - for (arg_pat, arg_ty) in args.iter().zip(sig_tys) { + for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) { self.infer_pat(*arg_pat, &arg_ty, BindingMode::default()); } let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); - self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty)); + let inner_ty = self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty)); + + + let inner_ty = if *is_async { + // Use the first type parameter as the output type of future. + // existential type AsyncBlockImplTrait: Future + let impl_trait_id = crate::ImplTraitId::AsyncBlockTypeImplTrait(self.owner, *body); + let opaque_ty_id = self.db.intern_impl_trait_id(impl_trait_id).into(); + TyKind::OpaqueType(opaque_ty_id, Substitution::from1(Interner, inner_ty)) + .intern(Interner) + } else { + inner_ty + }; self.diverges = prev_diverges; self.return_ty = prev_ret_ty; + sig_tys.pop(); + sig_tys.push(inner_ty); + + let sig_ty = TyKind::Function(FnPointer { + num_binders: 0, + sig: FnSig { abi: (), safety: chalk_ir::Safety::Safe, variadic: false }, + substitution: FnSubst( + Substitution::from_iter(Interner, sig_tys.clone()).shifted_in(Interner), + ), + }) + .intern(Interner); + + let closure_ty = + TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone())) + .intern(Interner); + closure_ty } Expr::Call { callee, args, .. } => {