Skip to content

Commit

Permalink
Auto merge of #119849 - lcnr:eagerly-instantiate-binders, r=<try>
Browse files Browse the repository at this point in the history
eagerly instantiate binders to avoid relying on `sub`

The old solver sometimes incorrectly used `sub`, change it to explicitly instantiate binders and use `eq` instead. While doing so I also moved the instantiation before the normalize calls. This caused some observable changes, will explain these inline. This PR therefore requires a crater run and an FCP.

r? types
  • Loading branch information
bors committed Jan 12, 2024
2 parents 5431404 + 23e69c0 commit 21bc403
Show file tree
Hide file tree
Showing 46 changed files with 312 additions and 435 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/traits/mod.rs
Expand Up @@ -620,6 +620,7 @@ pub enum SelectionError<'tcx> {
OpaqueTypeAutoTraitLeakageUnknown(DefId),
}

// FIXME(@lcnr): The `Binder` here should be unnecessary. Just use `TraitRef` instead.
#[derive(Clone, Debug, TypeVisitable)]
pub struct SelectionOutputTypeParameterMismatch<'tcx> {
pub found_trait_ref: ty::PolyTraitRef<'tcx>,
Expand Down
Expand Up @@ -3430,6 +3430,8 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
err
}

// FIXME(@lcnr): This function could be changed to trait `TraitRef` directly
// instead of using a `Binder`.
fn report_type_parameter_mismatch_error(
&self,
obligation: &PredicateObligation<'tcx>,
Expand Down
Expand Up @@ -150,9 +150,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
_ => return,
}

let result = self
.infcx
.probe(|_| self.match_projection_obligation_against_definition_bounds(obligation));
let result = self.match_projection_obligation_against_definition_bounds(obligation);

candidates.vec.extend(result.into_iter().map(|idx| ProjectionCandidate(idx)));
}
Expand Down Expand Up @@ -708,14 +706,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let candidate_supertraits = util::supertraits(self.tcx(), principal_trait_ref)
.enumerate()
.filter(|&(_, upcast_trait_ref)| {
self.infcx.probe(|_| {
self.match_normalize_trait_ref(
obligation,
upcast_trait_ref,
placeholder_trait_predicate.trait_ref,
)
.is_ok()
})
self.matches_trait_ref(
obligation,
placeholder_trait_predicate.trait_ref,
upcast_trait_ref,
)
})
.map(|(idx, _)| ObjectCandidate(idx));

Expand Down
35 changes: 25 additions & 10 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Expand Up @@ -9,7 +9,7 @@
use rustc_ast::Mutability;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir::lang_items::LangItem;
use rustc_infer::infer::BoundRegionConversionTime::HigherRankedType;
use rustc_infer::infer::HigherRankedType;
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::traits::{BuiltinImplSource, SelectionOutputTypeParameterMismatch};
use rustc_middle::ty::{
Expand Down Expand Up @@ -152,7 +152,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let placeholder_trait_predicate =
self.infcx.instantiate_binder_with_placeholders(trait_predicate).trait_ref;
let placeholder_self_ty = placeholder_trait_predicate.self_ty();
let placeholder_trait_predicate = ty::Binder::dummy(placeholder_trait_predicate);
let (def_id, args) = match *placeholder_self_ty.kind() {
// Excluding IATs and type aliases here as they don't have meaningful item bounds.
ty::Alias(ty::Projection | ty::Opaque, ty::AliasTy { def_id, args, .. }) => {
Expand All @@ -167,6 +166,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
.as_trait_clause()
.expect("projection candidate is not a trait predicate")
.map_bound(|t| t.trait_ref);
let candidate = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
candidate,
);
let mut obligations = Vec::new();
let candidate = normalize_with_depth_to(
self,
Expand All @@ -180,7 +184,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
obligations.extend(self.infcx.commit_if_ok(|_| {
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, placeholder_trait_predicate, candidate)
.eq(DefineOpaqueTypes::No, placeholder_trait_predicate, candidate)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| Unimplemented)
})?);
Expand Down Expand Up @@ -486,7 +490,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let trait_predicate = self.infcx.instantiate_binder_with_placeholders(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(trait_predicate.self_ty());
let obligation_trait_ref = ty::Binder::dummy(trait_predicate.trait_ref);
let ty::Dynamic(data, ..) = *self_ty.kind() else {
span_bug!(obligation.cause.span, "object candidate with non-object");
};
Expand All @@ -507,19 +510,24 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let unnormalized_upcast_trait_ref =
supertraits.nth(index).expect("supertraits iterator no longer has as many elements");

let upcast_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
unnormalized_upcast_trait_ref,
);
let upcast_trait_ref = normalize_with_depth_to(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
unnormalized_upcast_trait_ref,
upcast_trait_ref,
&mut nested,
);

nested.extend(self.infcx.commit_if_ok(|_| {
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation_trait_ref, upcast_trait_ref)
.eq(DefineOpaqueTypes::No, trait_predicate.trait_ref, upcast_trait_ref)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| Unimplemented)
})?);
Expand Down Expand Up @@ -900,7 +908,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
obligation: &PolyTraitObligation<'tcx>,
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let obligation_trait_ref = obligation.predicate.to_poly_trait_ref();
let obligation_trait_ref = self
.infcx
.instantiate_binder_with_placeholders(obligation.predicate.to_poly_trait_ref());
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
self_ty_trait_ref,
);
// Normalize the obligation and expected trait refs together, because why not
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
ensure_sufficient_stack(|| {
Expand All @@ -916,15 +931,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.map(|InferOk { mut obligations, .. }| {
obligations.extend(nested);
obligations
})
.map_err(|terr| {
OutputTypeParameterMismatch(Box::new(SelectionOutputTypeParameterMismatch {
expected_trait_ref: obligation_trait_ref,
found_trait_ref: expected_trait_ref,
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
terr,
}))
})
Expand Down
95 changes: 42 additions & 53 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Expand Up @@ -33,6 +33,7 @@ use rustc_errors::Diagnostic;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::BoundRegionConversionTime;
use rustc_infer::infer::BoundRegionConversionTime::HigherRankedType;
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::traits::TraitObligation;
use rustc_middle::dep_graph::dep_kinds;
Expand All @@ -43,7 +44,7 @@ use rustc_middle::ty::abstract_const::NotConstEvaluatable;
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::relate::TypeRelation;
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::{self, EarlyBinder, PolyProjectionPredicate, ToPolyTraitRef, ToPredicate};
use rustc_middle::ty::{self, EarlyBinder, PolyProjectionPredicate, ToPredicate};
use rustc_middle::ty::{Ty, TyCtxt, TypeFoldable, TypeVisitableExt};
use rustc_span::symbol::sym;
use rustc_span::Symbol;
Expand Down Expand Up @@ -1627,33 +1628,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
};
let bounds = tcx.item_bounds(def_id).instantiate(tcx, args);

// The bounds returned by `item_bounds` may contain duplicates after
// normalization, so try to deduplicate when possible to avoid
// unnecessary ambiguity.
let mut distinct_normalized_bounds = FxHashSet::default();

bounds
.iter()
.enumerate()
.filter_map(|(idx, bound)| {
let bound_predicate = bound.kind();
if let ty::ClauseKind::Trait(pred) = bound_predicate.skip_binder() {
let bound = bound_predicate.rebind(pred.trait_ref);
if self.infcx.probe(|_| {
match self.match_normalize_trait_ref(
obligation,
bound,
placeholder_trait_predicate.trait_ref,
) {
Ok(None) => true,
Ok(Some(normalized_trait))
if distinct_normalized_bounds.insert(normalized_trait) =>
{
true
}
_ => false,
}
}) {
if self.matches_trait_ref(
obligation,
placeholder_trait_predicate.trait_ref,
bound,
) {
return Some(idx);
}
}
Expand All @@ -1662,43 +1648,40 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
.collect()
}

/// Equates the trait in `obligation` with trait bound. If the two traits
/// can be equated and the normalized trait bound doesn't contain inference
/// variables or placeholders, the normalized bound is returned.
fn match_normalize_trait_ref(
/// Equates the trait in `obligation` with trait bound and returns
/// true if the two traits can be equated.
fn matches_trait_ref(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
trait_bound: ty::PolyTraitRef<'tcx>,
placeholder_trait_ref: ty::TraitRef<'tcx>,
) -> Result<Option<ty::PolyTraitRef<'tcx>>, ()> {
trait_bound: ty::PolyTraitRef<'tcx>,
) -> bool {
debug_assert!(!placeholder_trait_ref.has_escaping_bound_vars());
if placeholder_trait_ref.def_id != trait_bound.def_id() {
// Avoid unnecessary normalization
return Err(());
return false;
}

let Normalized { value: trait_bound, obligations: _ } = ensure_sufficient_stack(|| {
project::normalize_with_depth(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
self.infcx.probe(|_| {
let trait_bound = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
trait_bound,
)
});
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, ty::Binder::dummy(placeholder_trait_ref), trait_bound)
.map(|InferOk { obligations: _, value: () }| {
// This method is called within a probe, so we can't have
// inference variables and placeholders escape.
if !trait_bound.has_infer() && !trait_bound.has_placeholders() {
Some(trait_bound)
} else {
None
}
})
.map_err(|_| ())
);
let Normalized { value: trait_bound, obligations: _ } = ensure_sufficient_stack(|| {
project::normalize_with_depth(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
trait_bound,
)
});
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq(DefineOpaqueTypes::No, placeholder_trait_ref, trait_bound)
.is_ok()
})
}

fn where_clause_may_apply<'o>(
Expand Down Expand Up @@ -1750,7 +1733,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let is_match = self
.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation.predicate, infer_projection)
.eq(DefineOpaqueTypes::No, obligation.predicate, infer_projection)
.is_ok_and(|InferOk { obligations, value: () }| {
self.evaluate_predicates_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
Expand Down Expand Up @@ -2532,7 +2515,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(
.eq(
DefineOpaqueTypes::No,
upcast_principal.map_bound(|trait_ref| {
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
Expand Down Expand Up @@ -2570,7 +2553,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, source_projection, target_projection)
.eq(DefineOpaqueTypes::No, source_projection, target_projection)
.map_err(|_| SelectionError::Unimplemented)?
.into_obligations(),
);
Expand Down Expand Up @@ -2614,9 +2597,15 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
obligation: &PolyTraitObligation<'tcx>,
poly_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, ()> {
let predicate = self.infcx.instantiate_binder_with_placeholders(obligation.predicate);
let trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
poly_trait_ref,
);
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation.predicate.to_poly_trait_ref(), poly_trait_ref)
.eq(DefineOpaqueTypes::No, predicate.trait_ref, trait_ref)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| ())
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_trait_selection/src/traits/vtable.rs
Expand Up @@ -320,6 +320,7 @@ fn vtable_entries<'tcx>(
}

/// Find slot base for trait methods within vtable entries of another trait
// FIXME(@lcnr): This isn't a query, so why does it take a tuple as its argument.
pub(super) fn vtable_trait_first_method_offset<'tcx>(
tcx: TyCtxt<'tcx>,
key: (
Expand Down
Expand Up @@ -7,7 +7,7 @@ LL | #![feature(return_type_notation)]
= note: see issue #109417 <https://github.com/rust-lang/rust/issues/109417> for more information
= note: `#[warn(incomplete_features)]` on by default

error[E0308]: mismatched types
error: implementation of `Send` is not general enough
--> $DIR/issue-110963-early.rs:14:5
|
LL | / spawn(async move {
Expand All @@ -16,17 +16,12 @@ LL | | if !hc.check().await {
LL | | log_health_check_failure().await;
LL | | }
LL | | });
| |______^ one type is more general than the other
| |______^ implementation of `Send` is not general enough
|
= note: expected trait `Send`
found trait `for<'a> Send`
note: the lifetime requirement is introduced here
--> $DIR/issue-110963-early.rs:34:17
|
LL | F: Future + Send + 'static,
| ^^^^
= note: `Send` would have to be implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'0>() }`, for any two lifetimes `'0` and `'1`...
= note: ...but `Send` is actually implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'2>() }`, for some specific lifetime `'2`

error[E0308]: mismatched types
error: implementation of `Send` is not general enough
--> $DIR/issue-110963-early.rs:14:5
|
LL | / spawn(async move {
Expand All @@ -35,17 +30,11 @@ LL | | if !hc.check().await {
LL | | log_health_check_failure().await;
LL | | }
LL | | });
| |______^ one type is more general than the other
|
= note: expected trait `Send`
found trait `for<'a> Send`
note: the lifetime requirement is introduced here
--> $DIR/issue-110963-early.rs:34:17
| |______^ implementation of `Send` is not general enough
|
LL | F: Future + Send + 'static,
| ^^^^
= note: `Send` would have to be implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'0>() }`, for any two lifetimes `'0` and `'1`...
= note: ...but `Send` is actually implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'2>() }`, for some specific lifetime `'2`
= note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no`

error: aborting due to 2 previous errors; 1 warning emitted

For more information about this error, try `rustc --explain E0308`.
2 changes: 1 addition & 1 deletion tests/ui/closures/multiple-fn-bounds.stderr
Expand Up @@ -7,7 +7,7 @@ LL | foo(move |x| v);
| expected due to this
|
= note: expected closure signature `fn(_) -> _`
found closure signature `for<'a> fn(&'a _) -> _`
found closure signature `fn(&_) -> _`
note: closure inferred to have a different signature due to this bound
--> $DIR/multiple-fn-bounds.rs:1:11
|
Expand Down

0 comments on commit 21bc403

Please sign in to comment.