Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Wrong closure kind deduction for closures with predicates #16630

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 81 additions & 7 deletions crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
use chalk_ir::{
cast::Cast,
fold::{FallibleTypeFolder, TypeFoldable},
AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind,
};
use either::Either;
use hir_def::{
Expand All @@ -22,13 +22,14 @@ use stdx::never;

use crate::{
db::{HirDatabase, InternedClosure},
from_placeholder_idx, make_binders,
from_chalk_trait_id, from_placeholder_idx, make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
static_lifetime, to_chalk_trait_id,
traits::FnTrait,
utils::{self, generics, Generics},
Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer,
FnSig, Interner, Substitution, Ty, TyExt,
utils::{self, elaborate_clause_supertraits, generics, Generics},
Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy,
DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty,
TyExt, WhereClause,
};

use super::{Expectation, InferenceContext};
Expand All @@ -47,6 +48,15 @@ impl InferenceContext<'_> {
None => return,
};

if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) {
if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) {
self.result
.closure_info
.entry(*closure_id)
.or_insert_with(|| (Vec::new(), closure_kind));
}
}

// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);

Expand All @@ -65,6 +75,60 @@ impl InferenceContext<'_> {
}
}

// Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
// Might need to port closure sig deductions too.
fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> {
match expected_ty.kind(Interner) {
TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => {
let clauses = expected_ty
.impl_trait_bounds(self.db)
.into_iter()
.flatten()
.map(|b| b.into_value_and_skipped_binders().0);
self.deduce_closure_kind_from_predicate_clauses(clauses)
}
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
}),
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
let clauses = self.clauses_for_self_ty(*ty);
self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter())
}
TyKind::Function(_) => Some(FnTrait::Fn),
_ => None,
}
}

fn deduce_closure_kind_from_predicate_clauses(
&self,
clauses: impl DoubleEndedIterator<Item = WhereClause>,
) -> Option<FnTrait> {
let mut expected_kind = None;

for clause in elaborate_clause_supertraits(self.db, clauses.rev()) {
let trait_id = match clause {
WhereClause::AliasEq(AliasEq {
alias: AliasTy::Projection(projection), ..
}) => Some(projection.trait_(self.db)),
WhereClause::Implemented(trait_ref) => {
Some(from_chalk_trait_id(trait_ref.trait_id))
}
_ => None,
};
if let Some(closure_kind) =
trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id))
{
// `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
expected_kind = Some(
expected_kind
.map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)),
);
}
}

expected_kind
}

fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
// Search for a predicate like `<$self as FnX<Args>>::Output == Ret`

Expand Down Expand Up @@ -111,6 +175,10 @@ impl InferenceContext<'_> {

None
}

fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
}
}

// The below functions handle capture and closure kind (Fn, FnMut, ..)
Expand Down Expand Up @@ -962,8 +1030,14 @@ impl InferenceContext<'_> {
}
}
self.restrict_precision_for_unsafe();
// closure_kind should be done before adjust_for_move_closure
let closure_kind = self.closure_kind();
// `closure_kind` should be done before adjust_for_move_closure
// If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
// rustc also does diagnostics here if the latter is not a subtype of the former.
let closure_kind = self
.result
.closure_info
.get(&closure)
.map_or_else(|| self.closure_kind(), |info| info.1);
match capture_by {
CaptureBy::Value => self.adjust_for_move_closure(),
CaptureBy::Ref => (),
Expand Down
73 changes: 70 additions & 3 deletions crates/hir-ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt;
use either::Either;
use ena::unify::UnifyKey;
use hir_expand::name;
use smallvec::SmallVec;
use triomphe::Arc;

use super::{InferOk, InferResult, InferenceContext, TypeError};
use crate::{
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment,
InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution,
Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause,
};

impl InferenceContext<'_> {
Expand All @@ -31,6 +32,72 @@ impl InferenceContext<'_> {
{
self.table.canonicalize(t)
}

pub(super) fn clauses_for_self_ty(
&mut self,
self_ty: InferenceVar,
) -> SmallVec<[WhereClause; 4]> {
self.table.resolve_obligations_as_possible();

let root = self.table.var_unification_table.inference_var_root(self_ty);
let pending_obligations = mem::take(&mut self.table.pending_obligations);
let obligations = pending_obligations
.iter()
.filter_map(|obligation| match obligation.value.value.goal.data(Interner) {
Comment on lines +44 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not mutating pending_obligations - it should be returned back to its original place self.table.pending_obligations - I think that to use Vec::retain(), we should clone the whole pending_obligations

GoalData::DomainGoal(DomainGoal::Holds(
clause @ WhereClause::AliasEq(AliasEq {
alias: AliasTy::Projection(projection),
..
}),
)) => {
let projection_self = projection.self_type_parameter(self.db);
let uncanonical = chalk_ir::Substitute::apply(
&obligation.free_vars,
projection_self,
Interner,
);
if matches!(
self.resolve_ty_shallow(&uncanonical).kind(Interner),
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
) {
Some(chalk_ir::Substitute::apply(
&obligation.free_vars,
clause.clone(),
Interner,
))
} else {
None
}
}
GoalData::DomainGoal(DomainGoal::Holds(
clause @ WhereClause::Implemented(trait_ref),
)) => {
let trait_ref_self = trait_ref.self_type_parameter(Interner);
let uncanonical = chalk_ir::Substitute::apply(
&obligation.free_vars,
trait_ref_self,
Interner,
);
if matches!(
self.resolve_ty_shallow(&uncanonical).kind(Interner),
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
) {
Some(chalk_ir::Substitute::apply(
&obligation.free_vars,
clause.clone(),
Interner,
))
} else {
None
}
}
_ => None,
})
.collect();
self.table.pending_obligations = pending_obligations;

obligations
}
}

#[derive(Debug, Clone)]
Expand Down
8 changes: 4 additions & 4 deletions crates/hir-ty/src/tests/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,25 +702,25 @@ fn test() {
51..58 'loop {}': !
56..58 '{}': ()
72..171 '{ ... x); }': ()
78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32
78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32
78..105 'foo(&(...y)| x)': i32
82..91 '&(1, "a")': &(i32, &str)
83..91 '(1, "a")': (i32, &str)
84..85 '1': i32
87..90 '"a"': &str
93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32
93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32
94..101 '&(x, y)': &(i32, &str)
95..101 '(x, y)': (i32, &str)
96..97 'x': i32
99..100 'y': &str
103..104 'x': i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32
142..168 'foo(&(...y)| x)': &i32
146..155 '&(1, "a")': &(i32, &str)
147..155 '(1, "a")': (i32, &str)
148..149 '1': i32
151..154 '"a"': &str
157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32
157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32
158..164 '(x, y)': (i32, &str)
159..160 'x': &i32
162..163 'y': &&str
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-ty/src/tests/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ fn main() {
123..126 'S()': S<i32>
132..133 's': S<i32>
132..144 's.g(|_x| {})': ()
136..143 '|_x| {}': impl Fn(&i32)
136..143 '|_x| {}': impl FnOnce(&i32)
137..139 '_x': &i32
141..143 '{}': ()
150..151 's': S<i32>
Expand Down
41 changes: 39 additions & 2 deletions crates/hir-ty/src/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2190,9 +2190,9 @@ fn main() {
149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()>
149..155 'Ok(())': Result<(), ()>
152..154 '()': ()
167..171 'test': fn test<(), (), impl Fn() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl Fn() -> impl Future<Output = Result<(), ()>>)
167..171 'test': fn test<(), (), impl FnMut() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl FnMut() -> impl Future<Output = Result<(), ()>>)
167..228 'test(|... })': ()
172..227 '|| asy... }': impl Fn() -> impl Future<Output = Result<(), ()>>
172..227 '|| asy... }': impl FnMut() -> impl Future<Output = Result<(), ()>>
175..227 'async ... }': impl Future<Output = Result<(), ()>>
191..205 'return Err(())': !
198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()>
Expand Down Expand Up @@ -2886,6 +2886,43 @@ fn f() {
)
}

#[test]
fn closure_kind_with_predicates() {
check_types(
r#"
//- minicore: fn
#![feature(unboxed_closures)]

struct X<T: FnOnce()>(T);

fn f1() -> impl FnOnce() {
|| {}
// ^^^^^ impl FnOnce()
}

fn f2(c: impl FnOnce<(), Output = i32>) {}

fn test {
let x1 = X(|| {});
let c1 = x1.0;
// ^^ impl FnOnce()

let c2 = || {};
// ^^ impl Fn()
let x2 = X(c2);
let c3 = x2.0
// ^^ impl Fn()

let c4 = f1();
// ^^ impl FnOnce() + ?Sized

f2(|| { 0 });
// ^^^^^^^^ impl FnOnce() -> i32
}
"#,
)
}

#[test]
fn derive_macro_should_work_for_associated_type() {
check_types(
Expand Down