Skip to content

Commit

Permalink
Auto merge of #16630 - ShoyuVanilla:fix-closure-kind-inference, r=Vey…
Browse files Browse the repository at this point in the history
…kril

fix: Wrong closure kind deduction for closures with predicates

Completes #16472, fixes #16421

The changed closure kind deduction is mostly simlar to `rustc_hir_typeck/src/closure.rs`.
Porting closure sig deduction from it seems possible too and I'm considering doing it with another PR
  • Loading branch information
bors committed Feb 27, 2024
2 parents c031246 + a4021f6 commit a3236be
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 36 deletions.
88 changes: 81 additions & 7 deletions crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
use chalk_ir::{
cast::Cast,
fold::{FallibleTypeFolder, TypeFoldable},
AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind,
};
use either::Either;
use hir_def::{
Expand All @@ -22,13 +22,14 @@ use stdx::never;

use crate::{
db::{HirDatabase, InternedClosure},
from_placeholder_idx, make_binders,
from_chalk_trait_id, from_placeholder_idx, make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
static_lifetime, to_chalk_trait_id,
traits::FnTrait,
utils::{self, generics, Generics},
Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer,
FnSig, Interner, Substitution, Ty, TyExt,
utils::{self, elaborate_clause_supertraits, generics, Generics},
Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy,
DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty,
TyExt, WhereClause,
};

use super::{Expectation, InferenceContext};
Expand All @@ -47,6 +48,15 @@ impl InferenceContext<'_> {
None => return,
};

if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) {
if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) {
self.result
.closure_info
.entry(*closure_id)
.or_insert_with(|| (Vec::new(), closure_kind));
}
}

// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);

Expand All @@ -65,6 +75,60 @@ impl InferenceContext<'_> {
}
}

// Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
// Might need to port closure sig deductions too.
fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> {
match expected_ty.kind(Interner) {
TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => {
let clauses = expected_ty
.impl_trait_bounds(self.db)
.into_iter()
.flatten()
.map(|b| b.into_value_and_skipped_binders().0);
self.deduce_closure_kind_from_predicate_clauses(clauses)
}
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
}),
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
let clauses = self.clauses_for_self_ty(*ty);
self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter())
}
TyKind::Function(_) => Some(FnTrait::Fn),
_ => None,
}
}

fn deduce_closure_kind_from_predicate_clauses(
&self,
clauses: impl DoubleEndedIterator<Item = WhereClause>,
) -> Option<FnTrait> {
let mut expected_kind = None;

for clause in elaborate_clause_supertraits(self.db, clauses.rev()) {
let trait_id = match clause {
WhereClause::AliasEq(AliasEq {
alias: AliasTy::Projection(projection), ..
}) => Some(projection.trait_(self.db)),
WhereClause::Implemented(trait_ref) => {
Some(from_chalk_trait_id(trait_ref.trait_id))
}
_ => None,
};
if let Some(closure_kind) =
trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id))
{
// `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
expected_kind = Some(
expected_kind
.map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)),
);
}
}

expected_kind
}

fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
// Search for a predicate like `<$self as FnX<Args>>::Output == Ret`

Expand Down Expand Up @@ -111,6 +175,10 @@ impl InferenceContext<'_> {

None
}

fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
}
}

// The below functions handle capture and closure kind (Fn, FnMut, ..)
Expand Down Expand Up @@ -962,8 +1030,14 @@ impl InferenceContext<'_> {
}
}
self.restrict_precision_for_unsafe();
// closure_kind should be done before adjust_for_move_closure
let closure_kind = self.closure_kind();
// `closure_kind` should be done before adjust_for_move_closure
// If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
// rustc also does diagnostics here if the latter is not a subtype of the former.
let closure_kind = self
.result
.closure_info
.get(&closure)
.map_or_else(|| self.closure_kind(), |info| info.1);
match capture_by {
CaptureBy::Value => self.adjust_for_move_closure(),
CaptureBy::Ref => (),
Expand Down
73 changes: 70 additions & 3 deletions crates/hir-ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt;
use either::Either;
use ena::unify::UnifyKey;
use hir_expand::name;
use smallvec::SmallVec;
use triomphe::Arc;

use super::{InferOk, InferResult, InferenceContext, TypeError};
use crate::{
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment,
InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution,
Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause,
};

impl InferenceContext<'_> {
Expand All @@ -31,6 +32,72 @@ impl InferenceContext<'_> {
{
self.table.canonicalize(t)
}

pub(super) fn clauses_for_self_ty(
&mut self,
self_ty: InferenceVar,
) -> SmallVec<[WhereClause; 4]> {
self.table.resolve_obligations_as_possible();

let root = self.table.var_unification_table.inference_var_root(self_ty);
let pending_obligations = mem::take(&mut self.table.pending_obligations);
let obligations = pending_obligations
.iter()
.filter_map(|obligation| match obligation.value.value.goal.data(Interner) {
GoalData::DomainGoal(DomainGoal::Holds(
clause @ WhereClause::AliasEq(AliasEq {
alias: AliasTy::Projection(projection),
..
}),
)) => {
let projection_self = projection.self_type_parameter(self.db);
let uncanonical = chalk_ir::Substitute::apply(
&obligation.free_vars,
projection_self,
Interner,
);
if matches!(
self.resolve_ty_shallow(&uncanonical).kind(Interner),
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
) {
Some(chalk_ir::Substitute::apply(
&obligation.free_vars,
clause.clone(),
Interner,
))
} else {
None
}
}
GoalData::DomainGoal(DomainGoal::Holds(
clause @ WhereClause::Implemented(trait_ref),
)) => {
let trait_ref_self = trait_ref.self_type_parameter(Interner);
let uncanonical = chalk_ir::Substitute::apply(
&obligation.free_vars,
trait_ref_self,
Interner,
);
if matches!(
self.resolve_ty_shallow(&uncanonical).kind(Interner),
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
) {
Some(chalk_ir::Substitute::apply(
&obligation.free_vars,
clause.clone(),
Interner,
))
} else {
None
}
}
_ => None,
})
.collect();
self.table.pending_obligations = pending_obligations;

obligations
}
}

#[derive(Debug, Clone)]
Expand Down
8 changes: 4 additions & 4 deletions crates/hir-ty/src/tests/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,25 +702,25 @@ fn test() {
51..58 'loop {}': !
56..58 '{}': ()
72..171 '{ ... x); }': ()
78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32
78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32
78..105 'foo(&(...y)| x)': i32
82..91 '&(1, "a")': &(i32, &str)
83..91 '(1, "a")': (i32, &str)
84..85 '1': i32
87..90 '"a"': &str
93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32
93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32
94..101 '&(x, y)': &(i32, &str)
95..101 '(x, y)': (i32, &str)
96..97 'x': i32
99..100 'y': &str
103..104 'x': i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32
142..168 'foo(&(...y)| x)': &i32
146..155 '&(1, "a")': &(i32, &str)
147..155 '(1, "a")': (i32, &str)
148..149 '1': i32
151..154 '"a"': &str
157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32
157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32
158..164 '(x, y)': (i32, &str)
159..160 'x': &i32
162..163 'y': &&str
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-ty/src/tests/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ fn main() {
123..126 'S()': S<i32>
132..133 's': S<i32>
132..144 's.g(|_x| {})': ()
136..143 '|_x| {}': impl Fn(&i32)
136..143 '|_x| {}': impl FnOnce(&i32)
137..139 '_x': &i32
141..143 '{}': ()
150..151 's': S<i32>
Expand Down
41 changes: 39 additions & 2 deletions crates/hir-ty/src/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2190,9 +2190,9 @@ fn main() {
149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()>
149..155 'Ok(())': Result<(), ()>
152..154 '()': ()
167..171 'test': fn test<(), (), impl Fn() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl Fn() -> impl Future<Output = Result<(), ()>>)
167..171 'test': fn test<(), (), impl FnMut() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl FnMut() -> impl Future<Output = Result<(), ()>>)
167..228 'test(|... })': ()
172..227 '|| asy... }': impl Fn() -> impl Future<Output = Result<(), ()>>
172..227 '|| asy... }': impl FnMut() -> impl Future<Output = Result<(), ()>>
175..227 'async ... }': impl Future<Output = Result<(), ()>>
191..205 'return Err(())': !
198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()>
Expand Down Expand Up @@ -2886,6 +2886,43 @@ fn f() {
)
}

#[test]
fn closure_kind_with_predicates() {
check_types(
r#"
//- minicore: fn
#![feature(unboxed_closures)]
struct X<T: FnOnce()>(T);
fn f1() -> impl FnOnce() {
|| {}
// ^^^^^ impl FnOnce()
}
fn f2(c: impl FnOnce<(), Output = i32>) {}
fn test {
let x1 = X(|| {});
let c1 = x1.0;
// ^^ impl FnOnce()
let c2 = || {};
// ^^ impl Fn()
let x2 = X(c2);
let c3 = x2.0
// ^^ impl Fn()
let c4 = f1();
// ^^ impl FnOnce() + ?Sized
f2(|| { 0 });
// ^^^^^^^^ impl FnOnce() -> i32
}
"#,
)
}

#[test]
fn derive_macro_should_work_for_associated_type() {
check_types(
Expand Down
Loading

0 comments on commit a3236be

Please sign in to comment.