From 0bf1d73d229fdd0c22ef87b1c764c88cf35dd616 Mon Sep 17 00:00:00 2001 From: Matthew Jasper Date: Fri, 12 Feb 2021 22:07:46 +0000 Subject: [PATCH] Don't go through TraitRef to relate projections --- compiler/rustc_infer/src/infer/at.rs | 25 ++++++++++++++++++- .../src/traits/project.rs | 15 ++++++----- .../src/traits/select/mod.rs | 20 ++++++++------- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index a7749d33b7c13..11ee8fb17ad1b 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -55,6 +55,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { pub trait ToTrace<'tcx>: Relate<'tcx> + Copy { fn to_trace( + tcx: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -178,7 +179,7 @@ impl<'a, 'tcx> At<'a, 'tcx> { where T: ToTrace<'tcx>, { - let trace = ToTrace::to_trace(self.cause, a_is_expected, a, b); + let trace = ToTrace::to_trace(self.infcx.tcx, self.cause, a_is_expected, a, b); Trace { at: self, trace, a_is_expected } } } @@ -251,6 +252,7 @@ impl<'a, 'tcx> Trace<'a, 'tcx> { impl<'tcx> ToTrace<'tcx> for Ty<'tcx> { fn to_trace( + _: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -262,6 +264,7 @@ impl<'tcx> ToTrace<'tcx> for Ty<'tcx> { impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> { fn to_trace( + _: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -273,6 +276,7 @@ impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> { impl<'tcx> ToTrace<'tcx> for &'tcx Const<'tcx> { fn to_trace( + _: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -284,6 +288,7 @@ impl<'tcx> ToTrace<'tcx> for &'tcx Const<'tcx> { impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> { fn to_trace( + _: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -298,6 +303,7 @@ impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> { impl<'tcx> ToTrace<'tcx> for ty::PolyTraitRef<'tcx> { fn to_trace( + _: TyCtxt<'tcx>, cause: &ObligationCause<'tcx>, a_is_expected: bool, a: Self, @@ -309,3 +315,20 @@ impl<'tcx> ToTrace<'tcx> for ty::PolyTraitRef<'tcx> { } } } + +impl<'tcx> ToTrace<'tcx> for ty::ProjectionTy<'tcx> { + fn to_trace( + tcx: TyCtxt<'tcx>, + cause: &ObligationCause<'tcx>, + a_is_expected: bool, + a: Self, + b: Self, + ) -> TypeTrace<'tcx> { + let a_ty = tcx.mk_projection(a.item_def_id, a.substs); + let b_ty = tcx.mk_projection(b.item_def_id, b.substs); + TypeTrace { + cause: cause.clone(), + values: Types(ExpectedFound::new(a_is_expected, a_ty, b_ty)), + } + } +} diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 6908480f431e6..a38e3817a95e5 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -921,8 +921,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( && infcx.probe(|_| { selcx.match_projection_projections( obligation, - obligation_trait_ref, - &data, + data, potentially_unnormalized_candidates, ) }); @@ -1344,25 +1343,25 @@ fn confirm_param_env_candidate<'cx, 'tcx>( poly_cache_entry, ); - let cache_trait_ref = cache_entry.projection_ty.trait_ref(infcx.tcx); - let obligation_trait_ref = obligation.predicate.trait_ref(infcx.tcx); + let cache_projection = cache_entry.projection_ty; + let obligation_projection = obligation.predicate; let mut nested_obligations = Vec::new(); - let cache_trait_ref = if potentially_unnormalized_candidate { + let cache_projection = if potentially_unnormalized_candidate { ensure_sufficient_stack(|| { normalize_with_depth_to( selcx, obligation.param_env, obligation.cause.clone(), obligation.recursion_depth + 1, - cache_trait_ref, + cache_projection, &mut nested_obligations, ) }) } else { - cache_trait_ref + cache_projection }; - match infcx.at(cause, param_env).eq(cache_trait_ref, obligation_trait_ref) { + match infcx.at(cause, param_env).eq(cache_projection, obligation_projection) { Ok(InferOk { value: _, obligations }) => { nested_obligations.extend(obligations); assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations); diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 87c8099dc3a51..f0970a48808ad 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -32,6 +32,7 @@ use rustc_errors::ErrorReported; use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::Constness; +use rustc_infer::infer::LateBoundRegionConversionTime; use rustc_middle::dep_graph::{DepKind, DepNodeIndex}; use rustc_middle::mir::interpret::ErrorHandled; use rustc_middle::ty::fast_reject; @@ -1254,32 +1255,33 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { pub(super) fn match_projection_projections( &mut self, obligation: &ProjectionTyObligation<'tcx>, - obligation_trait_ref: &ty::TraitRef<'tcx>, - data: &PolyProjectionPredicate<'tcx>, + env_predicate: PolyProjectionPredicate<'tcx>, potentially_unnormalized_candidates: bool, ) -> bool { let mut nested_obligations = Vec::new(); - let projection_ty = if potentially_unnormalized_candidates { + let (infer_predicate, _) = self.infcx.replace_bound_vars_with_fresh_vars( + obligation.cause.span, + LateBoundRegionConversionTime::HigherRankedType, + env_predicate, + ); + let infer_projection = if potentially_unnormalized_candidates { ensure_sufficient_stack(|| { project::normalize_with_depth_to( self, obligation.param_env, obligation.cause.clone(), obligation.recursion_depth + 1, - data.map_bound(|data| data.projection_ty), + infer_predicate.projection_ty, &mut nested_obligations, ) }) } else { - data.map_bound(|data| data.projection_ty) + infer_predicate.projection_ty }; - // FIXME(generic_associated_types): Compare the whole projections - let data_poly_trait_ref = projection_ty.map_bound(|proj| proj.trait_ref(self.tcx())); - let obligation_poly_trait_ref = ty::Binder::dummy(*obligation_trait_ref); self.infcx .at(&obligation.cause, obligation.param_env) - .sup(obligation_poly_trait_ref, data_poly_trait_ref) + .sup(obligation.predicate, infer_projection) .map_or(false, |InferOk { obligations, value: () }| { self.evaluate_predicates_recursively( TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),