From 9c4a00998e44cc4c1b8c4ffd056e65f4103671bb Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 20 Feb 2024 01:15:00 +0000 Subject: [PATCH] Combine Eq and Sub --- .../rustc_infer/src/infer/relate/combine.rs | 12 +- .../rustc_infer/src/infer/relate/equate.rs | 198 ------------------ compiler/rustc_infer/src/infer/relate/mod.rs | 4 +- .../infer/relate/{sub.rs => type_relating.rs} | 147 ++++++++----- 4 files changed, 104 insertions(+), 257 deletions(-) delete mode 100644 compiler/rustc_infer/src/infer/relate/equate.rs rename compiler/rustc_infer/src/infer/relate/{sub.rs => type_relating.rs} (54%) diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index d528def322b6a..ce1ec05a87453 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -22,10 +22,9 @@ //! [TypeRelation::a_is_expected], so when dealing with contravariance //! this should be correctly updated. -use super::equate::Equate; use super::glb::Glb; use super::lub::Lub; -use super::sub::Sub; +use super::type_relating::TypeRelating; use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace}; use crate::traits::{Obligation, PredicateObligations}; use rustc_middle::infer::canonical::OriginalQueryValues; @@ -34,7 +33,6 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::relate::{RelateResult, TypeRelation}; use rustc_middle::ty::{self, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitableExt}; use rustc_middle::ty::{IntType, UintType}; -use rustc_middle::ty::{RelationDirection, TyVar}; use rustc_span::Span; #[derive(Clone)] @@ -304,12 +302,12 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { self.infcx.tcx } - pub fn equate<'a>(&'a mut self, a_is_expected: bool) -> Equate<'a, 'infcx, 'tcx> { - Equate::new(self, a_is_expected) + pub fn equate<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, a_is_expected, ty::RelationDirection::Equate) } - pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> Sub<'a, 'infcx, 'tcx> { - Sub::new(self, a_is_expected) + pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, a_is_expected, ty::RelationDirection::Subtype) } pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/relate/equate.rs b/compiler/rustc_infer/src/infer/relate/equate.rs deleted file mode 100644 index 482e866c2c686..0000000000000 --- a/compiler/rustc_infer/src/infer/relate/equate.rs +++ /dev/null @@ -1,198 +0,0 @@ -use super::combine::{CombineFields, ObligationEmittingRelation}; -use crate::infer::{DefineOpaqueTypes, SubregionOrigin}; -use crate::traits::PredicateObligations; - -use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation}; -use rustc_middle::ty::GenericArgsRef; -use rustc_middle::ty::TyVar; -use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt}; - -use rustc_hir::def_id::DefId; -use rustc_span::Span; - -/// Ensures `a` is made equal to `b`. Returns `a` on success. -pub struct Equate<'combine, 'infcx, 'tcx> { - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, -} - -impl<'combine, 'infcx, 'tcx> Equate<'combine, 'infcx, 'tcx> { - pub fn new( - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, - ) -> Equate<'combine, 'infcx, 'tcx> { - Equate { fields, a_is_expected } - } -} - -impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> { - fn tag(&self) -> &'static str { - "Equate" - } - - fn tcx(&self) -> TyCtxt<'tcx> { - self.fields.tcx() - } - - fn a_is_expected(&self) -> bool { - self.a_is_expected - } - - fn relate_item_args( - &mut self, - _item_def_id: DefId, - a_arg: GenericArgsRef<'tcx>, - b_arg: GenericArgsRef<'tcx>, - ) -> RelateResult<'tcx, GenericArgsRef<'tcx>> { - // N.B., once we are equating types, we don't care about - // variance, so don't try to lookup the variance here. This - // also avoids some cycles (e.g., #41849) since looking up - // variance requires computing types which can require - // performing trait matching (which then performs equality - // unification). - - relate::relate_args_invariantly(self, a_arg, b_arg) - } - - fn relate_with_variance>( - &mut self, - _: ty::Variance, - _info: ty::VarianceDiagInfo<'tcx>, - a: T, - b: T, - ) -> RelateResult<'tcx, T> { - self.relate(a, b) - } - - #[instrument(skip(self), level = "debug")] - fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { - if a == b { - return Ok(a); - } - - trace!(a = ?a.kind(), b = ?b.kind()); - - let infcx = self.fields.infcx; - - let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a); - let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b); - - match (a.kind(), b.kind()) { - (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { - infcx.inner.borrow_mut().type_variables().equate(a_id, b_id); - } - - (&ty::Infer(TyVar(a_vid)), _) => { - infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Invariant, b)?; - } - - (_, &ty::Infer(TyVar(b_vid))) => { - infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Invariant, a)?; - } - - (&ty::Error(e), _) | (_, &ty::Error(e)) => { - infcx.set_tainted_by_errors(e); - return Ok(Ty::new_error(self.tcx(), e)); - } - - ( - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }), - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }), - ) if a_def_id == b_def_id => { - self.fields.infcx.super_combine_tys(self, a, b)?; - } - (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _) - | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })) - if self.fields.define_opaque_types == DefineOpaqueTypes::Yes - && def_id.is_local() - && !self.fields.infcx.next_trait_solver() => - { - self.fields.obligations.extend( - infcx - .handle_opaque_type( - a, - b, - self.a_is_expected(), - &self.fields.trace.cause, - self.param_env(), - )? - .obligations, - ); - } - _ => { - self.fields.infcx.super_combine_tys(self, a, b)?; - } - } - - Ok(a) - } - - fn regions( - &mut self, - a: ty::Region<'tcx>, - b: ty::Region<'tcx>, - ) -> RelateResult<'tcx, ty::Region<'tcx>> { - debug!("{}.regions({:?}, {:?})", self.tag(), a, b); - let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone())); - self.fields - .infcx - .inner - .borrow_mut() - .unwrap_region_constraints() - .make_eqregion(origin, a, b); - Ok(a) - } - - fn consts( - &mut self, - a: ty::Const<'tcx>, - b: ty::Const<'tcx>, - ) -> RelateResult<'tcx, ty::Const<'tcx>> { - self.fields.infcx.super_combine_consts(self, a, b) - } - - fn binders( - &mut self, - a: ty::Binder<'tcx, T>, - b: ty::Binder<'tcx, T>, - ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> - where - T: Relate<'tcx>, - { - // A binder is equal to itself if it's structurally equal to itself - if a == b { - return Ok(a); - } - - if a.skip_binder().has_escaping_bound_vars() || b.skip_binder().has_escaping_bound_vars() { - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; - self.fields.higher_ranked_sub(b, a, self.a_is_expected)?; - } else { - // Fast path for the common case. - self.relate(a.skip_binder(), b.skip_binder())?; - } - Ok(a) - } -} - -impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> { - fn span(&self) -> Span { - self.fields.trace.span() - } - - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.fields.param_env - } - - fn register_predicates(&mut self, obligations: impl IntoIterator>) { - self.fields.register_predicates(obligations); - } - - fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { - self.fields.register_obligations(obligations); - } - - fn alias_relate_direction(&self) -> ty::RelationDirection { - ty::RelationDirection::Equate - } -} diff --git a/compiler/rustc_infer/src/infer/relate/mod.rs b/compiler/rustc_infer/src/infer/relate/mod.rs index 1207377e85712..7fad4f24f62fa 100644 --- a/compiler/rustc_infer/src/infer/relate/mod.rs +++ b/compiler/rustc_infer/src/infer/relate/mod.rs @@ -2,10 +2,10 @@ //! (except for some relations used for diagnostics and heuristics in the compiler). pub(super) mod combine; -mod equate; mod generalize; mod glb; mod higher_ranked; mod lattice; mod lub; -mod sub; +pub mod nll; +mod type_relating; diff --git a/compiler/rustc_infer/src/infer/relate/sub.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs similarity index 54% rename from compiler/rustc_infer/src/infer/relate/sub.rs rename to compiler/rustc_infer/src/infer/relate/type_relating.rs index b30192e94d9d3..e7924cfb2fd39 100644 --- a/compiler/rustc_infer/src/infer/relate/sub.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -9,18 +9,20 @@ use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::Span; use std::mem; -/// Ensures `a` is made a subtype of `b`. Returns `a` on success. -pub struct Sub<'combine, 'a, 'tcx> { +/// Enforce that `a` is equal to or a subtype of `b`. Returns `a` on success. +pub struct TypeRelating<'combine, 'a, 'tcx> { fields: &'combine mut CombineFields<'a, 'tcx>, a_is_expected: bool, + relation_direction: ty::RelationDirection, } -impl<'combine, 'infcx, 'tcx> Sub<'combine, 'infcx, 'tcx> { +impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> { pub fn new( f: &'combine mut CombineFields<'infcx, 'tcx>, a_is_expected: bool, - ) -> Sub<'combine, 'infcx, 'tcx> { - Sub { fields: f, a_is_expected } + relation_direction: ty::RelationDirection, + ) -> TypeRelating<'combine, 'infcx, 'tcx> { + TypeRelating { fields: f, a_is_expected, relation_direction } } fn with_expected_switched R>(&mut self, f: F) -> R { @@ -31,7 +33,7 @@ impl<'combine, 'infcx, 'tcx> Sub<'combine, 'infcx, 'tcx> { } } -impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { +impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { fn tag(&self) -> &'static str { "Sub" } @@ -63,11 +65,14 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { a: T, b: T, ) -> RelateResult<'tcx, T> { - match variance { - ty::Invariant => self.fields.equate(self.a_is_expected).relate(a, b), - ty::Covariant => self.relate(a, b), - ty::Bivariant => Ok(a), - ty::Contravariant => self.with_expected_switched(|this| this.relate(b, a)), + match self.relation_direction { + ty::RelationDirection::Subtype => match variance { + ty::Invariant => self.fields.equate(self.a_is_expected).relate(a, b), + ty::Covariant => self.relate(a, b), + ty::Bivariant => Ok(a), + ty::Contravariant => self.with_expected_switched(|this| this.relate(b, a)), + }, + ty::RelationDirection::Equate => self.relate(a, b), } } @@ -82,40 +87,59 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b); match (a.kind(), b.kind()) { - (&ty::Infer(TyVar(_)), &ty::Infer(TyVar(_))) => { + (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { // Shouldn't have any LBR here, so we can safely put // this under a binder below without fear of accidental // capture. assert!(!a.has_escaping_bound_vars()); assert!(!b.has_escaping_bound_vars()); - // can't make progress on `A <: B` if both A and B are - // type variables, so record an obligation. - self.fields.obligations.push(Obligation::new( - self.tcx(), - self.fields.trace.cause.clone(), - self.fields.param_env, - ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { - a_is_expected: self.a_is_expected, - a, - b, - })), - )); - - Ok(a) - } - (&ty::Infer(TyVar(a_vid)), _) => { - infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Covariant, b)?; - Ok(a) - } - (_, &ty::Infer(TyVar(b_vid))) => { - infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Contravariant, a)?; - Ok(a) + match self.relation_direction { + ty::RelationDirection::Subtype => { + // can't make progress on `A <: B` if both A and B are + // type variables, so record an obligation. + self.fields.obligations.push(Obligation::new( + self.tcx(), + self.fields.trace.cause.clone(), + self.fields.param_env, + ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { + a_is_expected: self.a_is_expected, + a, + b, + })), + )); + } + ty::RelationDirection::Equate => { + infcx.inner.borrow_mut().type_variables().equate(a_id, b_id); + } + } } + (&ty::Infer(TyVar(a_vid)), _) => match self.relation_direction { + ty::RelationDirection::Subtype => { + infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Covariant, b)?; + } + ty::RelationDirection::Equate => { + infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Invariant, b)?; + } + }, + (_, &ty::Infer(TyVar(b_vid))) => match self.relation_direction { + ty::RelationDirection::Subtype => { + infcx.instantiate_ty_var( + self, + !self.a_is_expected, + b_vid, + ty::Contravariant, + a, + )?; + } + ty::RelationDirection::Equate => { + infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Invariant, a)?; + } + }, (&ty::Error(e), _) | (_, &ty::Error(e)) => { infcx.set_tainted_by_errors(e); - Ok(Ty::new_error(self.tcx(), e)) + return Ok(Ty::new_error(self.tcx(), e)); } ( @@ -123,8 +147,8 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }), ) if a_def_id == b_def_id => { self.fields.infcx.super_combine_tys(self, a, b)?; - Ok(a) } + (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _) | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })) if self.fields.define_opaque_types == DefineOpaqueTypes::Yes @@ -142,13 +166,14 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { )? .obligations, ); - Ok(a) } + _ => { self.fields.infcx.super_combine_tys(self, a, b)?; - Ok(a) } } + + Ok(a) } fn regions( @@ -162,13 +187,27 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { // from the "cause" field, we could perhaps give more tailored // error messages. let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone())); - // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a) - self.fields - .infcx - .inner - .borrow_mut() - .unwrap_region_constraints() - .make_subregion(origin, b, a); + + match self.relation_direction { + // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a) + ty::RelationDirection::Subtype => { + self.fields + .infcx + .inner + .borrow_mut() + .unwrap_region_constraints() + .make_subregion(origin, b, a); + } + ty::RelationDirection::Equate => { + // The order of `make_eqregion` apparently matters. + self.fields + .infcx + .inner + .borrow_mut() + .unwrap_region_constraints() + .make_eqregion(origin, a, b); + } + } Ok(a) } @@ -189,17 +228,25 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { where T: Relate<'tcx>, { - // A binder is always a subtype of itself if it's structurally equal to itself if a == b { - return Ok(a); + // Do nothing + } else { + match self.relation_direction { + ty::RelationDirection::Subtype => { + self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; + } + ty::RelationDirection::Equate => { + self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; + self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?; + } + } } - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; Ok(a) } } -impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> { +impl<'tcx> ObligationEmittingRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { fn span(&self) -> Span { self.fields.trace.span() } @@ -217,6 +264,6 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> { } fn alias_relate_direction(&self) -> ty::RelationDirection { - ty::RelationDirection::Subtype + self.relation_direction } }