From c353df09b846fce6d09064504d9ed22314c1c90f Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 8 Feb 2024 17:19:33 +0000 Subject: [PATCH] Normalize opaques eagerly --- compiler/rustc_hir_typeck/src/check.rs | 14 +--- compiler/rustc_hir_typeck/src/closure.rs | 11 --- .../rustc_infer/src/infer/opaque_types.rs | 57 +------------ compiler/rustc_middle/src/traits/mod.rs | 4 +- .../src/solve/normalize.rs | 5 +- .../rustc_trait_selection/src/traits/mod.rs | 2 +- .../src/traits/project.rs | 80 ++++++++----------- .../src/traits/query/normalize.rs | 10 +-- 8 files changed, 49 insertions(+), 134 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index aab78465f8c85..9367e07d463ec 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -38,15 +38,7 @@ pub(super) fn check_fn<'a, 'tcx>( let tcx = fcx.tcx; let hir = tcx.hir(); - let declared_ret_ty = fn_sig.output(); - - let ret_ty = - fcx.register_infer_ok_obligations(fcx.infcx.replace_opaque_types_with_inference_vars( - declared_ret_ty, - fn_def_id, - decl.output.span(), - fcx.param_env, - )); + let ret_ty = fn_sig.output(); fcx.coroutine_types = coroutine_types; fcx.ret_coercion = Some(RefCell::new(CoerceMany::new(ret_ty))); @@ -109,7 +101,7 @@ pub(super) fn check_fn<'a, 'tcx>( hir::FnRetTy::DefaultReturn(_) => body.value.span, hir::FnRetTy::Return(ty) => ty.span, }; - fcx.require_type_is_sized(declared_ret_ty, return_or_body_span, traits::SizedReturnType); + fcx.require_type_is_sized(ret_ty, return_or_body_span, traits::SizedReturnType); fcx.is_whole_body.set(true); fcx.check_return_expr(body.value, false); @@ -120,7 +112,7 @@ pub(super) fn check_fn<'a, 'tcx>( let coercion = fcx.ret_coercion.take().unwrap().into_inner(); let mut actual_return_ty = coercion.complete(fcx); debug!("actual_return_ty = {:?}", actual_return_ty); - if let ty::Dynamic(..) = declared_ret_ty.kind() { + if let ty::Dynamic(..) = ret_ty.kind() { // We have special-cased the case where the function is declared // `-> dyn Foo` and we don't actually relate it to the // `fcx.ret_coercion`, so just substitute a type variable. diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index a985fa201d071..13ec9d2ae0ae6 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -879,17 +879,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let output_ty = self.normalize(closure_span, output_ty); - // async fn that have opaque types in their return type need to redo the conversion to inference variables - // as they fetch the still opaque version from the signature. - let InferOk { value: output_ty, obligations } = self - .replace_opaque_types_with_inference_vars( - output_ty, - body_def_id, - closure_span, - self.param_env, - ); - self.register_predicates(obligations); - Some(output_ty) } diff --git a/compiler/rustc_infer/src/infer/opaque_types.rs b/compiler/rustc_infer/src/infer/opaque_types.rs index 9a9d13ed60825..74a67b0c1d4e4 100644 --- a/compiler/rustc_infer/src/infer/opaque_types.rs +++ b/compiler/rustc_infer/src/infer/opaque_types.rs @@ -1,4 +1,3 @@ -use super::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; use super::{DefineOpaqueTypes, InferResult}; use crate::errors::OpaqueHiddenTypeDiag; use crate::infer::{InferCtxt, InferOk}; @@ -36,60 +35,6 @@ pub struct OpaqueTypeDecl<'tcx> { } impl<'tcx> InferCtxt<'tcx> { - /// This is a backwards compatibility hack to prevent breaking changes from - /// lazy TAIT around RPIT handling. - pub fn replace_opaque_types_with_inference_vars>>( - &self, - value: T, - body_id: LocalDefId, - span: Span, - param_env: ty::ParamEnv<'tcx>, - ) -> InferOk<'tcx, T> { - // We handle opaque types differently in the new solver. - if self.next_trait_solver() { - return InferOk { value, obligations: vec![] }; - } - - if !value.has_opaque_types() { - return InferOk { value, obligations: vec![] }; - } - - let mut obligations = vec![]; - let replace_opaque_type = |def_id: DefId| { - def_id.as_local().is_some_and(|def_id| self.opaque_type_origin(def_id).is_some()) - }; - let value = value.fold_with(&mut BottomUpFolder { - tcx: self.tcx, - lt_op: |lt| lt, - ct_op: |ct| ct, - ty_op: |ty| match *ty.kind() { - ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }) - if replace_opaque_type(def_id) && !ty.has_escaping_bound_vars() => - { - let def_span = self.tcx.def_span(def_id); - let span = if span.contains(def_span) { def_span } else { span }; - let code = traits::ObligationCauseCode::OpaqueReturnType(None); - let cause = ObligationCause::new(span, body_id, code); - // FIXME(compiler-errors): We probably should add a new TypeVariableOriginKind - // for opaque types, and then use that kind to fix the spans for type errors - // that we see later on. - let ty_var = self.next_ty_var(TypeVariableOrigin { - kind: TypeVariableOriginKind::OpaqueTypeInference(def_id), - span, - }); - obligations.extend( - self.handle_opaque_type(ty, ty_var, true, &cause, param_env) - .unwrap() - .obligations, - ); - ty_var - } - _ => ty, - }, - }); - InferOk { value, obligations } - } - pub fn handle_opaque_type( &self, a: Ty<'tcx>, @@ -515,7 +460,7 @@ impl UseKind { impl<'tcx> InferCtxt<'tcx> { #[instrument(skip(self), level = "debug")] - fn register_hidden_type( + pub fn register_hidden_type( &self, opaque_type_key: OpaqueTypeKey<'tcx>, cause: ObligationCause<'tcx>, diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs index 894acf3c2aa5f..b5f754cead0e2 100644 --- a/compiler/rustc_middle/src/traits/mod.rs +++ b/compiler/rustc_middle/src/traits/mod.rs @@ -1004,14 +1004,14 @@ pub enum CodegenObligationError { #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, HashStable, TypeFoldable, TypeVisitable)] pub enum DefiningAnchor { /// Define opaques which are in-scope of the `LocalDefId`. Also, eagerly - /// replace opaque types in `replace_opaque_types_with_inference_vars`. + /// replace opaque types in normalization. Bind(LocalDefId), /// In contexts where we don't currently know what opaques are allowed to be /// defined, such as (old solver) canonical queries, we will simply allow /// opaques to be defined, but "bubble" them up in the canonical response or /// otherwise treat them to be handled later. /// - /// We do not eagerly replace opaque types in `replace_opaque_types_with_inference_vars`, + /// We do not eagerly replace opaque types in normalization, /// which may affect what predicates pass and fail in the old trait solver. Bubble, /// Do not allow any opaques to be defined. This is used to catch type mismatch diff --git a/compiler/rustc_trait_selection/src/solve/normalize.rs b/compiler/rustc_trait_selection/src/solve/normalize.rs index d87cc89954a56..b07702e842105 100644 --- a/compiler/rustc_trait_selection/src/solve/normalize.rs +++ b/compiler/rustc_trait_selection/src/solve/normalize.rs @@ -1,6 +1,6 @@ use crate::traits::error_reporting::TypeErrCtxtExt; use crate::traits::query::evaluate_obligation::InferCtxtExt; -use crate::traits::{needs_normalization, BoundVarReplacer, PlaceholderReplacer}; +use crate::traits::{BoundVarReplacer, PlaceholderReplacer}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_infer::infer::at::At; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; @@ -205,10 +205,9 @@ impl<'tcx> FallibleTypeFolder> for NormalizationFolder<'_, 'tcx> { } fn try_fold_const(&mut self, ct: ty::Const<'tcx>) -> Result, Self::Error> { - let reveal = self.at.param_env.reveal(); let infcx = self.at.infcx; debug_assert_eq!(ct, infcx.shallow_resolve(ct)); - if !needs_normalization(&ct, reveal) { + if !ct.has_projections() { return Ok(ct); } diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index a7f6021d57a96..a50196ea12b44 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -40,7 +40,7 @@ use rustc_span::Span; use std::fmt::Debug; use std::ops::ControlFlow; -pub(crate) use self::project::{needs_normalization, BoundVarReplacer, PlaceholderReplacer}; +pub(crate) use self::project::{BoundVarReplacer, PlaceholderReplacer}; pub use self::coherence::{add_placeholder_note, orphan_check, overlapping_impls}; pub use self::coherence::{OrphanCheckErr, OverlapResult}; diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 955c81eee6be3..5722aabb8b803 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -27,6 +27,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_errors::ErrorGuaranteed; use rustc_hir::def::DefKind; use rustc_hir::lang_items::LangItem; +use rustc_hir::OpaqueTyOrigin; use rustc_infer::infer::at::At; use rustc_infer::infer::resolve::OpportunisticRegionResolver; use rustc_infer::infer::DefineOpaqueTypes; @@ -318,17 +319,6 @@ fn project_and_unify_type<'cx, 'tcx>( }; debug!(?normalized, ?obligations, "project_and_unify_type result"); let actual = obligation.predicate.term; - // For an example where this is necessary see tests/ui/impl-trait/nested-return-type2.rs - // This allows users to omit re-mentioning all bounds on an associated type and just use an - // `impl Trait` for the assoc type to add more bounds. - let InferOk { value: actual, obligations: new } = - selcx.infcx.replace_opaque_types_with_inference_vars( - actual, - obligation.cause.body_id, - obligation.cause.span, - obligation.param_env, - ); - obligations.extend(new); // Need to define opaque types to support nested opaque types like `impl Fn() -> impl Trait` match infcx.at(&obligation.cause, obligation.param_env).eq( @@ -409,25 +399,6 @@ where result } -pub(crate) fn needs_normalization<'tcx, T: TypeVisitable>>( - value: &T, - reveal: Reveal, -) -> bool { - match reveal { - Reveal::UserFacing => value.has_type_flags( - ty::TypeFlags::HAS_TY_PROJECTION - | ty::TypeFlags::HAS_TY_INHERENT - | ty::TypeFlags::HAS_CT_PROJECTION, - ), - Reveal::All => value.has_type_flags( - ty::TypeFlags::HAS_TY_PROJECTION - | ty::TypeFlags::HAS_TY_INHERENT - | ty::TypeFlags::HAS_TY_OPAQUE - | ty::TypeFlags::HAS_CT_PROJECTION, - ), - } -} - struct AssocTypeNormalizer<'a, 'b, 'tcx> { selcx: &'a mut SelectionContext<'b, 'tcx>, param_env: ty::ParamEnv<'tcx>, @@ -488,11 +459,7 @@ impl<'a, 'b, 'tcx> AssocTypeNormalizer<'a, 'b, 'tcx> { "Normalizing {value:?} without wrapping in a `Binder`" ); - if !needs_normalization(&value, self.param_env.reveal()) { - value - } else { - value.fold_with(self) - } + if !value.has_projections() { value } else { value.fold_with(self) } } } @@ -512,7 +479,7 @@ impl<'a, 'b, 'tcx> TypeFolder> for AssocTypeNormalizer<'a, 'b, 'tcx } fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> { - if !needs_normalization(&ty, self.param_env.reveal()) { + if !ty.has_projections() { return ty; } @@ -548,7 +515,36 @@ impl<'a, 'b, 'tcx> TypeFolder> for AssocTypeNormalizer<'a, 'b, 'tcx ty::Opaque => { // Only normalize `impl Trait` outside of type inference, usually in codegen. match self.param_env.reveal() { - Reveal::UserFacing => ty.super_fold_with(self), + Reveal::UserFacing => { + if !data.has_escaping_bound_vars() + && let Some(def_id) = data.def_id.as_local() + && let Some( + OpaqueTyOrigin::TyAlias { in_assoc_ty: true } + | OpaqueTyOrigin::AsyncFn(_) + | OpaqueTyOrigin::FnReturn(_), + ) = self.selcx.infcx.opaque_type_origin(def_id) + { + let infer = self.selcx.infcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::OpaqueTypeInference(data.def_id), + span: self.cause.span, + }); + let InferOk { value: (), obligations } = self + .selcx + .infcx + .register_hidden_type( + ty::OpaqueTypeKey { def_id, args: data.args }, + self.cause.clone(), + self.param_env, + infer, + true, + ) + .expect("uwu"); + self.obligations.extend(obligations); + self.selcx.infcx.resolve_vars_if_possible(infer) + } else { + ty.super_fold_with(self) + } + } Reveal::All => { let recursion_limit = self.interner().recursion_limit(); @@ -752,9 +748,7 @@ impl<'a, 'b, 'tcx> TypeFolder> for AssocTypeNormalizer<'a, 'b, 'tcx #[instrument(skip(self), level = "debug")] fn fold_const(&mut self, constant: ty::Const<'tcx>) -> ty::Const<'tcx> { let tcx = self.selcx.tcx(); - if tcx.features().generic_const_exprs - || !needs_normalization(&constant, self.param_env.reveal()) - { + if tcx.features().generic_const_exprs || !constant.has_projections() { constant } else { let constant = constant.super_fold_with(self); @@ -770,11 +764,7 @@ impl<'a, 'b, 'tcx> TypeFolder> for AssocTypeNormalizer<'a, 'b, 'tcx #[inline] fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { - if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) { - p.super_fold_with(self) - } else { - p - } + if p.allow_normalization() && p.has_projections() { p.super_fold_with(self) } else { p } } } diff --git a/compiler/rustc_trait_selection/src/traits/query/normalize.rs b/compiler/rustc_trait_selection/src/traits/query/normalize.rs index 0b73fefd2da9d..dd01710e83cae 100644 --- a/compiler/rustc_trait_selection/src/traits/query/normalize.rs +++ b/compiler/rustc_trait_selection/src/traits/query/normalize.rs @@ -6,7 +6,7 @@ use crate::infer::at::At; use crate::infer::canonical::OriginalQueryValues; use crate::infer::{InferCtxt, InferOk}; use crate::traits::error_reporting::TypeErrCtxtExt; -use crate::traits::project::{needs_normalization, BoundVarReplacer, PlaceholderReplacer}; +use crate::traits::project::{BoundVarReplacer, PlaceholderReplacer}; use crate::traits::{ObligationCause, PredicateObligation, Reveal}; use rustc_data_structures::sso::SsoHashMap; use rustc_data_structures::stack::ensure_sufficient_stack; @@ -89,7 +89,7 @@ impl<'cx, 'tcx> QueryNormalizeExt<'tcx> for At<'cx, 'tcx> { } } - if !needs_normalization(&value, self.param_env.reveal()) { + if !value.has_projections() { return Ok(Normalized { value, obligations: vec![] }); } @@ -198,7 +198,7 @@ impl<'cx, 'tcx> FallibleTypeFolder> for QueryNormalizer<'cx, 'tcx> #[instrument(level = "debug", skip(self))] fn try_fold_ty(&mut self, ty: Ty<'tcx>) -> Result, Self::Error> { - if !needs_normalization(&ty, self.param_env.reveal()) { + if !ty.has_projections() { return Ok(ty); } @@ -335,7 +335,7 @@ impl<'cx, 'tcx> FallibleTypeFolder> for QueryNormalizer<'cx, 'tcx> &mut self, constant: ty::Const<'tcx>, ) -> Result, Self::Error> { - if !needs_normalization(&constant, self.param_env.reveal()) { + if !constant.has_projections() { return Ok(constant); } @@ -354,7 +354,7 @@ impl<'cx, 'tcx> FallibleTypeFolder> for QueryNormalizer<'cx, 'tcx> &mut self, p: ty::Predicate<'tcx>, ) -> Result, Self::Error> { - if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) { + if p.allow_normalization() && p.has_projections() { p.try_super_fold_with(self) } else { Ok(p)