From eaf10dcb70718b810aaa74eb1b13d87a89612117 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 11 Apr 2023 19:37:31 +0000 Subject: [PATCH 1/2] Normalize types in writeback results with new solver --- compiler/rustc_hir_typeck/src/writeback.rs | 36 ++++++++++++---------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/writeback.rs b/compiler/rustc_hir_typeck/src/writeback.rs index a4c6dd4332a60..2daec205cfc28 100644 --- a/compiler/rustc_hir_typeck/src/writeback.rs +++ b/compiler/rustc_hir_typeck/src/writeback.rs @@ -9,7 +9,6 @@ use rustc_errors::ErrorGuaranteed; use rustc_hir as hir; use rustc_hir::intravisit::{self, Visitor}; use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282; -use rustc_infer::infer::InferCtxt; use rustc_middle::hir::place::Place as HirPlace; use rustc_middle::mir::FakeReadCause; use rustc_middle::ty::adjustment::{Adjust, Adjustment, PointerCast}; @@ -737,8 +736,7 @@ impl Locatable for hir::HirId { /// The Resolver. This is the type folding engine that detects /// unresolved types and so forth. struct Resolver<'cx, 'tcx> { - tcx: TyCtxt<'tcx>, - infcx: &'cx InferCtxt<'tcx>, + fcx: &'cx FnCtxt<'cx, 'tcx>, span: &'cx dyn Locatable, body: &'tcx hir::Body<'tcx>, @@ -752,18 +750,18 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> { span: &'cx dyn Locatable, body: &'tcx hir::Body<'tcx>, ) -> Resolver<'cx, 'tcx> { - Resolver { tcx: fcx.tcx, infcx: fcx, span, body, replaced_with_error: None } + Resolver { fcx, span, body, replaced_with_error: None } } fn report_error(&self, p: impl Into>) -> ErrorGuaranteed { - match self.tcx.sess.has_errors() { + match self.fcx.tcx.sess.has_errors() { Some(e) => e, None => self - .infcx + .fcx .err_ctxt() .emit_inference_failure_err( - self.tcx.hir().body_owner_def_id(self.body.id()), - self.span.to_span(self.tcx), + self.fcx.tcx.hir().body_owner_def_id(self.body.id()), + self.span.to_span(self.fcx.tcx), p.into(), E0282, false, @@ -795,40 +793,46 @@ impl<'tcx> TypeFolder> for EraseEarlyRegions<'tcx> { impl<'cx, 'tcx> TypeFolder> for Resolver<'cx, 'tcx> { fn interner(&self) -> TyCtxt<'tcx> { - self.tcx + self.fcx.tcx } fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { - match self.infcx.fully_resolve(t) { + match self.fcx.fully_resolve(t) { + Ok(t) if self.fcx.tcx.trait_solver_next() => { + // We must normalize erasing regions here, since later lints + // expect that types that show up in the typeck are fully + // normalized. + self.fcx.tcx.try_normalize_erasing_regions(self.fcx.param_env, t).unwrap_or(t) + } Ok(t) => { // Do not anonymize late-bound regions // (e.g. keep `for<'a>` named `for<'a>`). // This allows NLL to generate error messages that // refer to the higher-ranked lifetime names written by the user. - EraseEarlyRegions { tcx: self.tcx }.fold_ty(t) + EraseEarlyRegions { tcx: self.fcx.tcx }.fold_ty(t) } Err(_) => { debug!("Resolver::fold_ty: input type `{:?}` not fully resolvable", t); let e = self.report_error(t); self.replaced_with_error = Some(e); - self.interner().ty_error(e) + self.fcx.tcx.ty_error(e) } } } fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> { debug_assert!(!r.is_late_bound(), "Should not be resolving bound region."); - self.tcx.lifetimes.re_erased + self.fcx.tcx.lifetimes.re_erased } fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { - match self.infcx.fully_resolve(ct) { - Ok(ct) => self.tcx.erase_regions(ct), + match self.fcx.fully_resolve(ct) { + Ok(ct) => self.fcx.tcx.erase_regions(ct), Err(_) => { debug!("Resolver::fold_const: input const `{:?}` not fully resolvable", ct); let e = self.report_error(ct); self.replaced_with_error = Some(e); - self.interner().const_error(ct.ty(), e) + self.fcx.tcx.const_error(ct.ty(), e) } } } From 4cfafb275e8d9049c26ab58831d58254a09b9f61 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sun, 23 Apr 2023 19:58:24 +0000 Subject: [PATCH 2/2] Structurally normalize in the new solver --- compiler/rustc_hir_analysis/src/autoderef.rs | 76 +++++++++++++------ .../rustc_hir_typeck/src/fn_ctxt/_impl.rs | 29 ++++++- .../rustc_trait_selection/src/traits/mod.rs | 2 + .../src/traits/structural_normalize.rs | 55 ++++++++++++++ .../new-solver/normalize-rcvr-for-inherent.rs | 25 ++++++ .../new-solver/structural-resolve-field.rs | 13 ++++ 6 files changed, 176 insertions(+), 24 deletions(-) create mode 100644 compiler/rustc_trait_selection/src/traits/structural_normalize.rs create mode 100644 tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs create mode 100644 tests/ui/traits/new-solver/structural-resolve-field.rs diff --git a/compiler/rustc_hir_analysis/src/autoderef.rs b/compiler/rustc_hir_analysis/src/autoderef.rs index 1cf93c86f4f8a..d6d1498d708ed 100644 --- a/compiler/rustc_hir_analysis/src/autoderef.rs +++ b/compiler/rustc_hir_analysis/src/autoderef.rs @@ -1,6 +1,5 @@ use crate::errors::AutoDerefReachedRecursionLimit; use crate::traits::query::evaluate_obligation::InferCtxtExt; -use crate::traits::NormalizeExt; use crate::traits::{self, TraitEngine, TraitEngineExt}; use rustc_infer::infer::InferCtxt; use rustc_middle::ty::TypeVisitableExt; @@ -9,6 +8,7 @@ use rustc_session::Limit; use rustc_span::def_id::LocalDefId; use rustc_span::def_id::LOCAL_CRATE; use rustc_span::Span; +use rustc_trait_selection::traits::StructurallyNormalizeExt; #[derive(Copy, Clone, Debug)] pub enum AutoderefKind { @@ -66,14 +66,27 @@ impl<'a, 'tcx> Iterator for Autoderef<'a, 'tcx> { } // Otherwise, deref if type is derefable: - let (kind, new_ty) = - if let Some(mt) = self.state.cur_ty.builtin_deref(self.include_raw_pointers) { - (AutoderefKind::Builtin, mt.ty) - } else if let Some(ty) = self.overloaded_deref_ty(self.state.cur_ty) { - (AutoderefKind::Overloaded, ty) + let (kind, new_ty) = if let Some(ty::TypeAndMut { ty, .. }) = + self.state.cur_ty.builtin_deref(self.include_raw_pointers) + { + debug_assert_eq!(ty, self.infcx.resolve_vars_if_possible(ty)); + // NOTE: we may still need to normalize the built-in deref in case + // we have some type like `&::Assoc`, since users of + // autoderef expect this type to have been structurally normalized. + if self.infcx.tcx.trait_solver_next() + && let ty::Alias(ty::Projection, _) = ty.kind() + { + let (normalized_ty, obligations) = self.structurally_normalize(ty)?; + self.state.obligations.extend(obligations); + (AutoderefKind::Builtin, normalized_ty) } else { - return None; - }; + (AutoderefKind::Builtin, ty) + } + } else if let Some(ty) = self.overloaded_deref_ty(self.state.cur_ty) { + (AutoderefKind::Overloaded, ty) + } else { + return None; + }; if new_ty.references_error() { return None; @@ -119,14 +132,11 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> { fn overloaded_deref_ty(&mut self, ty: Ty<'tcx>) -> Option> { debug!("overloaded_deref_ty({:?})", ty); - let tcx = self.infcx.tcx; // let trait_ref = ty::TraitRef::new(tcx, tcx.lang_items().deref_trait()?, [ty]); - let cause = traits::ObligationCause::misc(self.span, self.body_id); - let obligation = traits::Obligation::new( tcx, cause.clone(), @@ -138,26 +148,48 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> { return None; } - let normalized_ty = self + let (normalized_ty, obligations) = + self.structurally_normalize(tcx.mk_projection(tcx.lang_items().deref_target()?, [ty]))?; + debug!("overloaded_deref_ty({:?}) = ({:?}, {:?})", ty, normalized_ty, obligations); + self.state.obligations.extend(obligations); + + Some(self.infcx.resolve_vars_if_possible(normalized_ty)) + } + + #[instrument(level = "debug", skip(self), ret)] + pub fn structurally_normalize( + &self, + ty: Ty<'tcx>, + ) -> Option<(Ty<'tcx>, Vec>)> { + let tcx = self.infcx.tcx; + let mut fulfill_cx = >::new_in_snapshot(tcx); + + let cause = traits::ObligationCause::misc(self.span, self.body_id); + let normalized_ty = match self .infcx .at(&cause, self.param_env) - .normalize(tcx.mk_projection(tcx.lang_items().deref_target()?, trait_ref.substs)); - let mut fulfillcx = >::new_in_snapshot(tcx); - let normalized_ty = - normalized_ty.into_value_registering_obligations(self.infcx, &mut *fulfillcx); - let errors = fulfillcx.select_where_possible(&self.infcx); + .structurally_normalize(ty, &mut *fulfill_cx) + { + Ok(normalized_ty) => normalized_ty, + Err(errors) => { + // This shouldn't happen, except for evaluate/fulfill mismatches, + // but that's not a reason for an ICE (`predicate_may_hold` is conservative + // by design). + debug!(?errors, "encountered errors while fulfilling"); + return None; + } + }; + + let errors = fulfill_cx.select_where_possible(&self.infcx); if !errors.is_empty() { // This shouldn't happen, except for evaluate/fulfill mismatches, // but that's not a reason for an ICE (`predicate_may_hold` is conservative // by design). - debug!("overloaded_deref_ty: encountered errors {:?} while fulfilling", errors); + debug!(?errors, "encountered errors while fulfilling"); return None; } - let obligations = fulfillcx.pending_obligations(); - debug!("overloaded_deref_ty({:?}) = ({:?}, {:?})", ty, normalized_ty, obligations); - self.state.obligations.extend(obligations); - Some(self.infcx.resolve_vars_if_possible(normalized_ty)) + Some((normalized_ty, fulfill_cx.pending_obligations())) } /// Returns the final type we ended up with, which may be an inference diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs index 039316c74dd4c..9721e3b427d2b 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs @@ -35,7 +35,9 @@ use rustc_span::symbol::{kw, sym, Ident}; use rustc_span::Span; use rustc_target::abi::FieldIdx; use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _; -use rustc_trait_selection::traits::{self, NormalizeExt, ObligationCauseCode, ObligationCtxt}; +use rustc_trait_selection::traits::{ + self, NormalizeExt, ObligationCauseCode, ObligationCtxt, StructurallyNormalizeExt, +}; use std::collections::hash_map::Entry; use std::slice; @@ -1460,10 +1462,33 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } /// Resolves `typ` by a single level if `typ` is a type variable. + /// + /// When the new solver is enabled, this will also attempt to normalize + /// the type if it's a projection (note that it will not deeply normalize + /// projections within the type, just the outermost layer of the type). + /// /// If no resolution is possible, then an error is reported. /// Numeric inference variables may be left unresolved. pub fn structurally_resolved_type(&self, sp: Span, ty: Ty<'tcx>) -> Ty<'tcx> { - let ty = self.resolve_vars_with_obligations(ty); + let mut ty = self.resolve_vars_with_obligations(ty); + + if self.tcx.trait_solver_next() + && let ty::Alias(ty::Projection, _) = ty.kind() + { + match self + .at(&self.misc(sp), self.param_env) + .structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut()) + { + Ok(normalized_ty) => { + ty = normalized_ty; + }, + Err(errors) => { + let guar = self.err_ctxt().report_fulfillment_errors(&errors); + return self.tcx.ty_error(guar); + } + } + } + if !ty.is_ty_var() { ty } else { diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index 28dad8592a855..f265230ff772d 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -17,6 +17,7 @@ pub mod query; mod select; mod specialize; mod structural_match; +mod structural_normalize; mod util; mod vtable; pub mod wf; @@ -62,6 +63,7 @@ pub use self::specialize::{ pub use self::structural_match::{ search_for_adt_const_param_violation, search_for_structural_match_violation, }; +pub use self::structural_normalize::StructurallyNormalizeExt; pub use self::util::elaborate; pub use self::util::{expand_trait_aliases, TraitAliasExpander}; pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices}; diff --git a/compiler/rustc_trait_selection/src/traits/structural_normalize.rs b/compiler/rustc_trait_selection/src/traits/structural_normalize.rs new file mode 100644 index 0000000000000..af8dd0da5792a --- /dev/null +++ b/compiler/rustc_trait_selection/src/traits/structural_normalize.rs @@ -0,0 +1,55 @@ +use rustc_infer::infer::at::At; +use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; +use rustc_infer::traits::{FulfillmentError, TraitEngine}; +use rustc_middle::ty::{self, Ty}; + +use crate::traits::{query::evaluate_obligation::InferCtxtExt, NormalizeExt, Obligation}; + +pub trait StructurallyNormalizeExt<'tcx> { + fn structurally_normalize( + &self, + ty: Ty<'tcx>, + fulfill_cx: &mut dyn TraitEngine<'tcx>, + ) -> Result, Vec>>; +} + +impl<'tcx> StructurallyNormalizeExt<'tcx> for At<'_, 'tcx> { + fn structurally_normalize( + &self, + mut ty: Ty<'tcx>, + fulfill_cx: &mut dyn TraitEngine<'tcx>, + ) -> Result, Vec>> { + assert!(!ty.is_ty_var(), "should have resolved vars before calling"); + + if self.infcx.tcx.trait_solver_next() { + while let ty::Alias(ty::Projection, projection_ty) = *ty.kind() { + let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::NormalizeProjectionType, + span: self.cause.span, + }); + let obligation = Obligation::new( + self.infcx.tcx, + self.cause.clone(), + self.param_env, + ty::Binder::dummy(ty::ProjectionPredicate { + projection_ty, + term: new_infer_ty.into(), + }), + ); + if self.infcx.predicate_may_hold(&obligation) { + fulfill_cx.register_predicate_obligation(self.infcx, obligation); + let errors = fulfill_cx.select_where_possible(self.infcx); + if !errors.is_empty() { + return Err(errors); + } + ty = self.infcx.resolve_vars_if_possible(new_infer_ty); + } else { + break; + } + } + Ok(ty) + } else { + Ok(self.normalize(ty).into_value_registering_obligations(self.infcx, fulfill_cx)) + } + } +} diff --git a/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs b/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs new file mode 100644 index 0000000000000..d70534feb072f --- /dev/null +++ b/tests/ui/traits/new-solver/normalize-rcvr-for-inherent.rs @@ -0,0 +1,25 @@ +// compile-flags: -Ztrait-solver=next +// check-pass + +// Verify that we can assemble inherent impl candidates on a possibly +// unnormalized self type. + +trait Foo { + type Assoc; +} +impl Foo for i32 { + type Assoc = Bar; +} + +struct Bar; +impl Bar { + fn method(&self) {} +} + +fn build(_: T) -> T::Assoc { + todo!() +} + +fn main() { + build(1i32).method(); +} diff --git a/tests/ui/traits/new-solver/structural-resolve-field.rs b/tests/ui/traits/new-solver/structural-resolve-field.rs new file mode 100644 index 0000000000000..01899c9ad645f --- /dev/null +++ b/tests/ui/traits/new-solver/structural-resolve-field.rs @@ -0,0 +1,13 @@ +// compile-flags: -Ztrait-solver=next +// check-pass + +#[derive(Default)] +struct Foo { + x: i32, +} + +fn main() { + let mut xs = <[Foo; 1]>::default(); + xs[0].x = 1; + (&mut xs[0]).x = 2; +}