From ecef296a038772f04ed3c10530b123d68718ad20 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 6 Apr 2024 02:05:17 -0400 Subject: [PATCH] Simplify shallow resolver to just fold ty/consts --- .../src/type_check/relate_tys.rs | 2 +- .../src/infer/canonical/canonicalizer.rs | 2 +- compiler/rustc_infer/src/infer/mod.rs | 166 ++++++++---------- .../rustc_infer/src/infer/relate/combine.rs | 4 +- compiler/rustc_infer/src/infer/resolve.rs | 14 +- .../src/solve/normalize.rs | 2 +- .../src/traits/coherence.rs | 2 +- .../rustc_trait_selection/src/traits/mod.rs | 2 +- .../src/traits/select/candidate_assembly.rs | 6 +- .../src/traits/select/confirmation.rs | 9 +- .../src/traits/select/mod.rs | 8 +- .../rustc_trait_selection/src/traits/util.rs | 2 +- .../rustc_trait_selection/src/traits/wf.rs | 2 +- 13 files changed, 96 insertions(+), 125 deletions(-) diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs index 78609a482ed22..1c6eccbdce754 100644 --- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs +++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs @@ -437,7 +437,7 @@ impl<'bccx, 'tcx> TypeRelation<'tcx> for NllTypeRelating<'_, 'bccx, 'tcx> { a: ty::Const<'tcx>, b: ty::Const<'tcx>, ) -> RelateResult<'tcx, ty::Const<'tcx>> { - let a = self.type_checker.infcx.shallow_resolve(a); + let a = self.type_checker.infcx.shallow_resolve_const(a); assert!(!a.has_non_region_infer(), "unexpected inference var {:?}", a); assert!(!b.has_non_region_infer(), "unexpected inference var {:?}", b); diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 825c3bf82fc89..4d712e9ffd372 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -802,7 +802,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { const_var: ty::Const<'tcx>, ) -> ty::Const<'tcx> { debug_assert!( - !self.infcx.is_some_and(|infcx| const_var != infcx.shallow_resolve(const_var)) + !self.infcx.is_some_and(|infcx| const_var != infcx.shallow_resolve_const(const_var)) ); let var = self.canonical_var(info, const_var.into()); ty::Const::new_bound(self.tcx, self.binder_index, var, self.fold_ty(const_var.ty())) diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index f2fd50a47d51b..e7c68d848b785 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -1253,19 +1253,76 @@ impl<'tcx> InferCtxt<'tcx> { } } - /// Resolve any type variables found in `value` -- but only one - /// level. So, if the variable `?X` is bound to some type - /// `Foo`, then this would return `Foo` (but `?Y` may - /// itself be bound to a type). - /// - /// Useful when you only need to inspect the outermost level of - /// the type and don't care about nested types (or perhaps you - /// will be resolving them as well, e.g. in a loop). - pub fn shallow_resolve(&self, value: T) -> T - where - T: TypeFoldable>, - { - value.fold_with(&mut ShallowResolver { infcx: self }) + pub fn shallow_resolve(&self, ty: Ty<'tcx>) -> Ty<'tcx> { + if let ty::Infer(v) = ty.kind() { self.fold_infer_ty(*v).unwrap_or(ty) } else { ty } + } + + // This is separate from `shallow_resolve` to keep that method small and inlinable. + #[inline(never)] + fn fold_infer_ty(&self, v: InferTy) -> Option> { + match v { + ty::TyVar(v) => { + // Not entirely obvious: if `typ` is a type variable, + // it can be resolved to an int/float variable, which + // can then be recursively resolved, hence the + // recursion. Note though that we prevent type + // variables from unifying to other type variables + // directly (though they may be embedded + // structurally), and we prevent cycles in any case, + // so this recursion should always be of very limited + // depth. + // + // Note: if these two lines are combined into one we get + // dynamic borrow errors on `self.inner`. + let known = self.inner.borrow_mut().type_variables().probe(v).known(); + known.map(|t| self.shallow_resolve(t)) + } + + ty::IntVar(v) => self + .inner + .borrow_mut() + .int_unification_table() + .probe_value(v) + .map(|v| v.to_type(self.tcx)), + + ty::FloatVar(v) => self + .inner + .borrow_mut() + .float_unification_table() + .probe_value(v) + .map(|v| v.to_type(self.tcx)), + + ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => None, + } + } + + pub fn shallow_resolve_const(&self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { + match ct.kind() { + ty::ConstKind::Infer(infer_ct) => match infer_ct { + InferConst::Var(vid) => self + .inner + .borrow_mut() + .const_unification_table() + .probe_value(vid) + .known() + .unwrap_or(ct), + InferConst::EffectVar(vid) => self + .inner + .borrow_mut() + .effect_unification_table() + .probe_value(vid) + .known() + .unwrap_or(ct), + InferConst::Fresh(_) => ct, + }, + ty::ConstKind::Param(_) + | ty::ConstKind::Bound(_, _) + | ty::ConstKind::Placeholder(_) + | ty::ConstKind::Unevaluated(_) + | ty::ConstKind::Value(_) + | ty::ConstKind::Error(_) + | ty::ConstKind::Expr(_) => ct, + } } pub fn root_var(&self, var: ty::TyVid) -> ty::TyVid { @@ -1782,89 +1839,6 @@ impl<'tcx> TypeFolder> for InferenceLiteralEraser<'tcx> { } } -struct ShallowResolver<'a, 'tcx> { - infcx: &'a InferCtxt<'tcx>, -} - -impl<'a, 'tcx> TypeFolder> for ShallowResolver<'a, 'tcx> { - fn interner(&self) -> TyCtxt<'tcx> { - self.infcx.tcx - } - - /// If `ty` is a type variable of some kind, resolve it one level - /// (but do not resolve types found in the result). If `typ` is - /// not a type variable, just return it unmodified. - #[inline] - fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> { - if let ty::Infer(v) = ty.kind() { self.fold_infer_ty(*v).unwrap_or(ty) } else { ty } - } - - fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { - match ct.kind() { - ty::ConstKind::Infer(InferConst::Var(vid)) => self - .infcx - .inner - .borrow_mut() - .const_unification_table() - .probe_value(vid) - .known() - .unwrap_or(ct), - ty::ConstKind::Infer(InferConst::EffectVar(vid)) => self - .infcx - .inner - .borrow_mut() - .effect_unification_table() - .probe_value(vid) - .known() - .unwrap_or(ct), - _ => ct, - } - } -} - -impl<'a, 'tcx> ShallowResolver<'a, 'tcx> { - // This is separate from `fold_ty` to keep that method small and inlinable. - #[inline(never)] - fn fold_infer_ty(&mut self, v: InferTy) -> Option> { - match v { - ty::TyVar(v) => { - // Not entirely obvious: if `typ` is a type variable, - // it can be resolved to an int/float variable, which - // can then be recursively resolved, hence the - // recursion. Note though that we prevent type - // variables from unifying to other type variables - // directly (though they may be embedded - // structurally), and we prevent cycles in any case, - // so this recursion should always be of very limited - // depth. - // - // Note: if these two lines are combined into one we get - // dynamic borrow errors on `self.inner`. - let known = self.infcx.inner.borrow_mut().type_variables().probe(v).known(); - known.map(|t| self.fold_ty(t)) - } - - ty::IntVar(v) => self - .infcx - .inner - .borrow_mut() - .int_unification_table() - .probe_value(v) - .map(|v| v.to_type(self.infcx.tcx)), - - ty::FloatVar(v) => self - .infcx - .inner - .borrow_mut() - .float_unification_table() - .probe_value(v) - .map(|v| v.to_type(self.infcx.tcx)), - - ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => None, - } - } -} - impl<'tcx> TypeTrace<'tcx> { pub fn span(&self) -> Span { self.cause.span diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index 28b7db275a307..8a3125f9dedf8 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -155,8 +155,8 @@ impl<'tcx> InferCtxt<'tcx> { return Ok(a); } - let a = self.shallow_resolve(a); - let b = self.shallow_resolve(b); + let a = self.shallow_resolve_const(a); + let b = self.shallow_resolve_const(b); // We should never have to relate the `ty` field on `Const` as it is checked elsewhere that consts have the // correct type for the generic param they are an argument for. However there have been a number of cases diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index d5999331dfab6..758aac004dcfd 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -12,21 +12,19 @@ use rustc_middle::ty::{self, Const, InferConst, Ty, TyCtxt, TypeFoldable}; /// useful for printing messages etc but also required at various /// points for correctness. pub struct OpportunisticVarResolver<'a, 'tcx> { - // The shallow resolver is used to resolve inference variables at every - // level of the type. - shallow_resolver: crate::infer::ShallowResolver<'a, 'tcx>, + infcx: &'a InferCtxt<'tcx>, } impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> { #[inline] pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self { - OpportunisticVarResolver { shallow_resolver: crate::infer::ShallowResolver { infcx } } + OpportunisticVarResolver { infcx } } } impl<'a, 'tcx> TypeFolder> for OpportunisticVarResolver<'a, 'tcx> { fn interner(&self) -> TyCtxt<'tcx> { - TypeFolder::interner(&self.shallow_resolver) + self.infcx.tcx } #[inline] @@ -34,7 +32,7 @@ impl<'a, 'tcx> TypeFolder> for OpportunisticVarResolver<'a, 'tcx> { if !t.has_non_region_infer() { t // micro-optimize -- if there is nothing in this type that this fold affects... } else { - let t = self.shallow_resolver.fold_ty(t); + let t = self.infcx.shallow_resolve(t); t.super_fold_with(self) } } @@ -43,7 +41,7 @@ impl<'a, 'tcx> TypeFolder> for OpportunisticVarResolver<'a, 'tcx> { if !ct.has_non_region_infer() { ct // micro-optimize -- if there is nothing in this const that this fold affects... } else { - let ct = self.shallow_resolver.fold_const(ct); + let ct = self.infcx.shallow_resolve_const(ct); ct.super_fold_with(self) } } @@ -160,7 +158,7 @@ impl<'a, 'tcx> FallibleTypeFolder> for FullTypeResolver<'a, 'tcx> { if !c.has_infer() { Ok(c) // micro-optimize -- if there is nothing in this const that this fold affects... } else { - let c = self.infcx.shallow_resolve(c); + let c = self.infcx.shallow_resolve_const(c); match c.kind() { ty::ConstKind::Infer(InferConst::Var(vid)) => { return Err(FixupError::UnresolvedConst(vid)); diff --git a/compiler/rustc_trait_selection/src/solve/normalize.rs b/compiler/rustc_trait_selection/src/solve/normalize.rs index 5b45e1a34e485..43f36d18cfc96 100644 --- a/compiler/rustc_trait_selection/src/solve/normalize.rs +++ b/compiler/rustc_trait_selection/src/solve/normalize.rs @@ -203,7 +203,7 @@ impl<'tcx> FallibleTypeFolder> for NormalizationFolder<'_, 'tcx> { #[instrument(level = "debug", skip(self), ret)] fn try_fold_const(&mut self, ct: ty::Const<'tcx>) -> Result, Self::Error> { let infcx = self.at.infcx; - debug_assert_eq!(ct, infcx.shallow_resolve(ct)); + debug_assert_eq!(ct, infcx.shallow_resolve_const(ct)); if !ct.has_aliases() { return Ok(ct); } diff --git a/compiler/rustc_trait_selection/src/traits/coherence.rs b/compiler/rustc_trait_selection/src/traits/coherence.rs index 77eaa4fd03eb8..36028b516591b 100644 --- a/compiler/rustc_trait_selection/src/traits/coherence.rs +++ b/compiler/rustc_trait_selection/src/traits/coherence.rs @@ -501,7 +501,7 @@ fn plug_infer_with_placeholders<'tcx>( } fn visit_const(&mut self, ct: ty::Const<'tcx>) { - let ct = self.infcx.shallow_resolve(ct); + let ct = self.infcx.shallow_resolve_const(ct); if ct.is_ct_infer() { let Ok(InferOk { value: (), obligations }) = self.infcx.at(&ObligationCause::dummy(), ty::ParamEnv::empty()).eq( diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index 98d5b466cd00d..8f5a30c436d33 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -162,7 +162,7 @@ fn pred_known_to_hold_modulo_regions<'tcx>( let errors = ocx.select_all_or_error(); match errors.as_slice() { // Only known to hold if we did no inference. - [] => infcx.shallow_resolve(goal) == goal, + [] => infcx.resolve_vars_if_possible(goal) == goal, errors => { debug!(?errors); diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 974e5ef0e166a..c544a0b790fa1 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -1165,8 +1165,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { return; } - let self_ty = self.infcx.shallow_resolve(obligation.self_ty()); - match self_ty.skip_binder().kind() { + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + match self_ty.kind() { ty::Alias(..) | ty::Dynamic(..) | ty::Error(_) @@ -1317,7 +1317,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { obligation: &PolyTraitObligation<'tcx>, candidates: &mut SelectionCandidateSet<'tcx>, ) { - let self_ty = self.infcx.shallow_resolve(obligation.self_ty()); + let self_ty = self.infcx.resolve_vars_if_possible(obligation.self_ty()); match self_ty.skip_binder().kind() { ty::FnPtr(_) => candidates.vec.push(BuiltinCandidate { has_nested: false }), diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 716b9a49ab543..4fa2455c42de1 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -157,10 +157,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ) -> Result>, SelectionError<'tcx>> { let tcx = self.tcx(); - let trait_predicate = self.infcx.shallow_resolve(obligation.predicate); let placeholder_trait_predicate = - self.infcx.enter_forall_and_leak_universe(trait_predicate).trait_ref; - let placeholder_self_ty = placeholder_trait_predicate.self_ty(); + self.infcx.enter_forall_and_leak_universe(obligation.predicate).trait_ref; + let placeholder_self_ty = self.infcx.shallow_resolve(placeholder_trait_predicate.self_ty()); let candidate_predicate = self .for_each_item_bound( placeholder_self_ty, @@ -422,7 +421,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ) -> Result>, SelectionError<'tcx>> { debug!(?obligation, "confirm_auto_impl_candidate"); - let self_ty = self.infcx.shallow_resolve(obligation.predicate.self_ty()); + let self_ty = obligation.predicate.self_ty().map_bound(|ty| self.infcx.shallow_resolve(ty)); let types = self.constituent_types_for_ty(self_ty)?; Ok(self.vtable_auto_impl(obligation, obligation.predicate.def_id(), types)) } @@ -1378,7 +1377,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { let drop_trait = self.tcx().require_lang_item(LangItem::Drop, None); let tcx = self.tcx(); - let self_ty = self.infcx.shallow_resolve(obligation.self_ty()); + let self_ty = obligation.self_ty().map_bound(|ty| self.infcx.shallow_resolve(ty)); let mut nested = vec![]; let cause = obligation.derived_cause(BuiltinDerivedObligation); diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index e363119393ad1..10370c7898b0d 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -571,7 +571,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { )?; // If the predicate has done any inference, then downgrade the // result to ambiguous. - if this.infcx.shallow_resolve(goal) != goal { + if this.infcx.resolve_vars_if_possible(goal) != goal { result = result.max(EvaluatedToAmbig); } Ok(result) @@ -1774,9 +1774,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { // that means that we must have newly inferred something about the GAT. // We should give up in that case. if !generics.params.is_empty() - && obligation.predicate.args[generics.parent_count..] - .iter() - .any(|&p| p.has_non_region_infer() && self.infcx.shallow_resolve(p) != p) + && obligation.predicate.args[generics.parent_count..].iter().any(|&p| { + p.has_non_region_infer() && self.infcx.resolve_vars_if_possible(p) != p + }) { ProjectionMatchesProjection::Ambiguous } else { diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index d29fc7921bcbd..a8ca7d164a092 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -649,7 +649,7 @@ impl<'tcx> TypeFolder> for PlaceholderReplacer<'_, 'tcx> { } fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { - let ct = self.infcx.shallow_resolve(ct); + let ct = self.infcx.shallow_resolve_const(ct); if let ty::ConstKind::Placeholder(p) = ct.kind() { let replace_var = self.mapped_consts.get(&p); match replace_var { diff --git a/compiler/rustc_trait_selection/src/traits/wf.rs b/compiler/rustc_trait_selection/src/traits/wf.rs index 5553490542b75..f1c24b6adc1ac 100644 --- a/compiler/rustc_trait_selection/src/traits/wf.rs +++ b/compiler/rustc_trait_selection/src/traits/wf.rs @@ -44,7 +44,7 @@ pub fn obligations<'tcx>( GenericArgKind::Const(ct) => { match ct.kind() { ty::ConstKind::Infer(_) => { - let resolved = infcx.shallow_resolve(ct); + let resolved = infcx.shallow_resolve_const(ct); if resolved == ct { // No progress. return None;