From 9f05da73f722b00fe54def1ef353de2b84208144 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 19 Feb 2024 20:36:17 +0000 Subject: [PATCH 1/2] Preserve variance on error in generalizer --- compiler/rustc_infer/src/infer/relate/generalize.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_infer/src/infer/relate/generalize.rs b/compiler/rustc_infer/src/infer/relate/generalize.rs index 2e1ea19078cdd..1dadd399c1c8a 100644 --- a/compiler/rustc_infer/src/infer/relate/generalize.rs +++ b/compiler/rustc_infer/src/infer/relate/generalize.rs @@ -404,9 +404,9 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> { debug!(?self.ambient_variance, "new ambient variance"); // Recursive calls to `relate` can overflow the stack. For example a deeper version of // `ui/associated-consts/issue-93775.rs`. - let r = ensure_sufficient_stack(|| self.relate(a, b))?; + let r = ensure_sufficient_stack(|| self.relate(a, b)); self.ambient_variance = old_ambient_variance; - Ok(r) + r } #[instrument(level = "debug", skip(self, t2), ret)] From 3034b2b0c27d022f461015b541abce4ea1989614 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 22 Feb 2024 17:18:35 +0000 Subject: [PATCH 2/2] Combine sub and eq --- .../src/type_check/relate_tys.rs | 13 +- .../rustc_infer/src/infer/relate/combine.rs | 30 +- .../rustc_infer/src/infer/relate/equate.rs | 198 ----------- compiler/rustc_infer/src/infer/relate/glb.rs | 10 +- compiler/rustc_infer/src/infer/relate/lub.rs | 10 +- compiler/rustc_infer/src/infer/relate/mod.rs | 3 +- compiler/rustc_infer/src/infer/relate/sub.rs | 222 ------------ .../src/infer/relate/type_relating.rs | 317 ++++++++++++++++++ 8 files changed, 348 insertions(+), 455 deletions(-) delete mode 100644 compiler/rustc_infer/src/infer/relate/equate.rs delete mode 100644 compiler/rustc_infer/src/infer/relate/sub.rs create mode 100644 compiler/rustc_infer/src/infer/relate/type_relating.rs diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs index 61b803ea38d4b..06555202c6f85 100644 --- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs +++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs @@ -348,12 +348,15 @@ impl<'bccx, 'tcx> TypeRelation<'tcx> for NllTypeRelating<'_, 'bccx, 'tcx> { debug!(?self.ambient_variance); // In a bivariant context this always succeeds. - let r = - if self.ambient_variance == ty::Variance::Bivariant { a } else { self.relate(a, b)? }; + let r = if self.ambient_variance == ty::Variance::Bivariant { + Ok(a) + } else { + self.relate(a, b) + }; self.ambient_variance = old_ambient_variance; - Ok(r) + r } #[instrument(skip(self), level = "debug")] @@ -576,10 +579,6 @@ impl<'bccx, 'tcx> ObligationEmittingRelation<'tcx> for NllTypeRelating<'_, 'bccx ); } - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - unreachable!("manually overridden to handle ty::Variance::Contravariant ambient variance") - } - fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { self.register_predicates([ty::Binder::dummy(match self.ambient_variance { ty::Variance::Covariant => ty::PredicateKind::AliasRelate( diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index f7690831c2acc..f4528c1d09c5c 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -1,4 +1,6 @@ -//! There are four type combiners: [Equate], [Sub], [Lub], and [Glb]. +//! There are four type combiners: [TypeRelating], [Lub], and [Glb], +//! (and `NllTypeRelating` in rustc_borrowck, which is only used for NLL). +//! //! Each implements the trait [TypeRelation] and contains methods for //! combining two instances of various things and yielding a new instance. //! These combiner methods always yield a `Result`. To relate two @@ -22,10 +24,9 @@ //! [TypeRelation::a_is_expected], so when dealing with contravariance //! this should be correctly updated. -use super::equate::Equate; use super::glb::Glb; use super::lub::Lub; -use super::sub::Sub; +use super::type_relating::TypeRelating; use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace}; use crate::traits::{Obligation, PredicateObligations}; use rustc_middle::infer::canonical::OriginalQueryValues; @@ -303,12 +304,12 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { self.infcx.tcx } - pub fn equate<'a>(&'a mut self, a_is_expected: bool) -> Equate<'a, 'infcx, 'tcx> { - Equate::new(self, a_is_expected) + pub fn equate<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, a_is_expected, ty::Invariant) } - pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> Sub<'a, 'infcx, 'tcx> { - Sub::new(self, a_is_expected) + pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, a_is_expected, ty::Covariant) } pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> { @@ -343,19 +344,8 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> { /// be used if control over the obligation causes is required. fn register_predicates(&mut self, obligations: impl IntoIterator>); - /// Register an obligation that both types must be related to each other according to - /// the [`ty::AliasRelationDirection`] given by [`ObligationEmittingRelation::alias_relate_direction`] - fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { - self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate( - a.into(), - b.into(), - self.alias_relate_direction(), - ))]); - } - - /// Relation direction emitted for `AliasRelate` predicates, corresponding to the direction - /// of the relation. - fn alias_relate_direction(&self) -> ty::AliasRelationDirection; + /// Register `AliasRelate` obligation(s) that both types must be related to each other. + fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>); } fn int_unification_error<'tcx>( diff --git a/compiler/rustc_infer/src/infer/relate/equate.rs b/compiler/rustc_infer/src/infer/relate/equate.rs deleted file mode 100644 index aefa9a5a0d61c..0000000000000 --- a/compiler/rustc_infer/src/infer/relate/equate.rs +++ /dev/null @@ -1,198 +0,0 @@ -use super::combine::{CombineFields, ObligationEmittingRelation}; -use crate::infer::{DefineOpaqueTypes, SubregionOrigin}; -use crate::traits::PredicateObligations; - -use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation}; -use rustc_middle::ty::GenericArgsRef; -use rustc_middle::ty::TyVar; -use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt}; - -use rustc_hir::def_id::DefId; -use rustc_span::Span; - -/// Ensures `a` is made equal to `b`. Returns `a` on success. -pub struct Equate<'combine, 'infcx, 'tcx> { - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, -} - -impl<'combine, 'infcx, 'tcx> Equate<'combine, 'infcx, 'tcx> { - pub fn new( - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, - ) -> Equate<'combine, 'infcx, 'tcx> { - Equate { fields, a_is_expected } - } -} - -impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> { - fn tag(&self) -> &'static str { - "Equate" - } - - fn tcx(&self) -> TyCtxt<'tcx> { - self.fields.tcx() - } - - fn a_is_expected(&self) -> bool { - self.a_is_expected - } - - fn relate_item_args( - &mut self, - _item_def_id: DefId, - a_arg: GenericArgsRef<'tcx>, - b_arg: GenericArgsRef<'tcx>, - ) -> RelateResult<'tcx, GenericArgsRef<'tcx>> { - // N.B., once we are equating types, we don't care about - // variance, so don't try to lookup the variance here. This - // also avoids some cycles (e.g., #41849) since looking up - // variance requires computing types which can require - // performing trait matching (which then performs equality - // unification). - - relate::relate_args_invariantly(self, a_arg, b_arg) - } - - fn relate_with_variance>( - &mut self, - _: ty::Variance, - _info: ty::VarianceDiagInfo<'tcx>, - a: T, - b: T, - ) -> RelateResult<'tcx, T> { - self.relate(a, b) - } - - #[instrument(skip(self), level = "debug")] - fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { - if a == b { - return Ok(a); - } - - trace!(a = ?a.kind(), b = ?b.kind()); - - let infcx = self.fields.infcx; - - let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a); - let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b); - - match (a.kind(), b.kind()) { - (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { - infcx.inner.borrow_mut().type_variables().equate(a_id, b_id); - } - - (&ty::Infer(TyVar(a_vid)), _) => { - infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Invariant, b)?; - } - - (_, &ty::Infer(TyVar(b_vid))) => { - infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Invariant, a)?; - } - - (&ty::Error(e), _) | (_, &ty::Error(e)) => { - infcx.set_tainted_by_errors(e); - return Ok(Ty::new_error(self.tcx(), e)); - } - - ( - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }), - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }), - ) if a_def_id == b_def_id => { - self.fields.infcx.super_combine_tys(self, a, b)?; - } - (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _) - | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })) - if self.fields.define_opaque_types == DefineOpaqueTypes::Yes - && def_id.is_local() - && !self.fields.infcx.next_trait_solver() => - { - self.fields.obligations.extend( - infcx - .handle_opaque_type( - a, - b, - self.a_is_expected(), - &self.fields.trace.cause, - self.param_env(), - )? - .obligations, - ); - } - _ => { - self.fields.infcx.super_combine_tys(self, a, b)?; - } - } - - Ok(a) - } - - fn regions( - &mut self, - a: ty::Region<'tcx>, - b: ty::Region<'tcx>, - ) -> RelateResult<'tcx, ty::Region<'tcx>> { - debug!("{}.regions({:?}, {:?})", self.tag(), a, b); - let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone())); - self.fields - .infcx - .inner - .borrow_mut() - .unwrap_region_constraints() - .make_eqregion(origin, a, b); - Ok(a) - } - - fn consts( - &mut self, - a: ty::Const<'tcx>, - b: ty::Const<'tcx>, - ) -> RelateResult<'tcx, ty::Const<'tcx>> { - self.fields.infcx.super_combine_consts(self, a, b) - } - - fn binders( - &mut self, - a: ty::Binder<'tcx, T>, - b: ty::Binder<'tcx, T>, - ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> - where - T: Relate<'tcx>, - { - // A binder is equal to itself if it's structurally equal to itself - if a == b { - return Ok(a); - } - - if a.skip_binder().has_escaping_bound_vars() || b.skip_binder().has_escaping_bound_vars() { - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; - self.fields.higher_ranked_sub(b, a, self.a_is_expected)?; - } else { - // Fast path for the common case. - self.relate(a.skip_binder(), b.skip_binder())?; - } - Ok(a) - } -} - -impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> { - fn span(&self) -> Span { - self.fields.trace.span() - } - - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.fields.param_env - } - - fn register_predicates(&mut self, obligations: impl IntoIterator>) { - self.fields.register_predicates(obligations); - } - - fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { - self.fields.register_obligations(obligations); - } - - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - ty::AliasRelationDirection::Equate - } -} diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs index 6cf5135459967..d081f86db7f98 100644 --- a/compiler/rustc_infer/src/infer/relate/glb.rs +++ b/compiler/rustc_infer/src/infer/relate/glb.rs @@ -151,8 +151,12 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> { self.fields.register_obligations(obligations); } - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - // FIXME(deferred_projection_equality): This isn't right, I think? - ty::AliasRelationDirection::Equate + fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { + self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + // FIXME(deferred_projection_equality): This isn't right, I think? + ty::AliasRelationDirection::Equate, + ))]); } } diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs index 5b4f80fd73a02..aa31151d10ec1 100644 --- a/compiler/rustc_infer/src/infer/relate/lub.rs +++ b/compiler/rustc_infer/src/infer/relate/lub.rs @@ -151,8 +151,12 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> { self.fields.register_obligations(obligations) } - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - // FIXME(deferred_projection_equality): This isn't right, I think? - ty::AliasRelationDirection::Equate + fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { + self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + // FIXME(deferred_projection_equality): This isn't right, I think? + ty::AliasRelationDirection::Equate, + ))]); } } diff --git a/compiler/rustc_infer/src/infer/relate/mod.rs b/compiler/rustc_infer/src/infer/relate/mod.rs index 1207377e85712..b9e13beedda2f 100644 --- a/compiler/rustc_infer/src/infer/relate/mod.rs +++ b/compiler/rustc_infer/src/infer/relate/mod.rs @@ -2,10 +2,9 @@ //! (except for some relations used for diagnostics and heuristics in the compiler). pub(super) mod combine; -mod equate; mod generalize; mod glb; mod higher_ranked; mod lattice; mod lub; -mod sub; +mod type_relating; diff --git a/compiler/rustc_infer/src/infer/relate/sub.rs b/compiler/rustc_infer/src/infer/relate/sub.rs deleted file mode 100644 index 5bd3a238a67e8..0000000000000 --- a/compiler/rustc_infer/src/infer/relate/sub.rs +++ /dev/null @@ -1,222 +0,0 @@ -use super::combine::CombineFields; -use crate::infer::{DefineOpaqueTypes, ObligationEmittingRelation, SubregionOrigin}; -use crate::traits::{Obligation, PredicateObligations}; - -use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation}; -use rustc_middle::ty::visit::TypeVisitableExt; -use rustc_middle::ty::TyVar; -use rustc_middle::ty::{self, Ty, TyCtxt}; -use rustc_span::Span; -use std::mem; - -/// Ensures `a` is made a subtype of `b`. Returns `a` on success. -pub struct Sub<'combine, 'a, 'tcx> { - fields: &'combine mut CombineFields<'a, 'tcx>, - a_is_expected: bool, -} - -impl<'combine, 'infcx, 'tcx> Sub<'combine, 'infcx, 'tcx> { - pub fn new( - f: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, - ) -> Sub<'combine, 'infcx, 'tcx> { - Sub { fields: f, a_is_expected } - } - - fn with_expected_switched R>(&mut self, f: F) -> R { - self.a_is_expected = !self.a_is_expected; - let result = f(self); - self.a_is_expected = !self.a_is_expected; - result - } -} - -impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> { - fn tag(&self) -> &'static str { - "Sub" - } - - fn tcx(&self) -> TyCtxt<'tcx> { - self.fields.infcx.tcx - } - - fn a_is_expected(&self) -> bool { - self.a_is_expected - } - - fn with_cause(&mut self, cause: Cause, f: F) -> R - where - F: FnOnce(&mut Self) -> R, - { - debug!("sub with_cause={:?}", cause); - let old_cause = mem::replace(&mut self.fields.cause, Some(cause)); - let r = f(self); - debug!("sub old_cause={:?}", old_cause); - self.fields.cause = old_cause; - r - } - - fn relate_with_variance>( - &mut self, - variance: ty::Variance, - _info: ty::VarianceDiagInfo<'tcx>, - a: T, - b: T, - ) -> RelateResult<'tcx, T> { - match variance { - ty::Invariant => self.fields.equate(self.a_is_expected).relate(a, b), - ty::Covariant => self.relate(a, b), - ty::Bivariant => Ok(a), - ty::Contravariant => self.with_expected_switched(|this| this.relate(b, a)), - } - } - - #[instrument(skip(self), level = "debug")] - fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { - if a == b { - return Ok(a); - } - - let infcx = self.fields.infcx; - let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a); - let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b); - - match (a.kind(), b.kind()) { - (&ty::Infer(TyVar(_)), &ty::Infer(TyVar(_))) => { - // Shouldn't have any LBR here, so we can safely put - // this under a binder below without fear of accidental - // capture. - assert!(!a.has_escaping_bound_vars()); - assert!(!b.has_escaping_bound_vars()); - - // can't make progress on `A <: B` if both A and B are - // type variables, so record an obligation. - self.fields.obligations.push(Obligation::new( - self.tcx(), - self.fields.trace.cause.clone(), - self.fields.param_env, - ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { - a_is_expected: self.a_is_expected, - a, - b, - })), - )); - - Ok(a) - } - (&ty::Infer(TyVar(a_vid)), _) => { - infcx.instantiate_ty_var(self, self.a_is_expected, a_vid, ty::Covariant, b)?; - Ok(a) - } - (_, &ty::Infer(TyVar(b_vid))) => { - infcx.instantiate_ty_var(self, !self.a_is_expected, b_vid, ty::Contravariant, a)?; - Ok(a) - } - - (&ty::Error(e), _) | (_, &ty::Error(e)) => { - infcx.set_tainted_by_errors(e); - Ok(Ty::new_error(self.tcx(), e)) - } - - ( - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }), - &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }), - ) if a_def_id == b_def_id => { - self.fields.infcx.super_combine_tys(self, a, b)?; - Ok(a) - } - (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _) - | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })) - if self.fields.define_opaque_types == DefineOpaqueTypes::Yes - && def_id.is_local() - && !self.fields.infcx.next_trait_solver() => - { - self.fields.obligations.extend( - infcx - .handle_opaque_type( - a, - b, - self.a_is_expected, - &self.fields.trace.cause, - self.param_env(), - )? - .obligations, - ); - Ok(a) - } - _ => { - self.fields.infcx.super_combine_tys(self, a, b)?; - Ok(a) - } - } - } - - fn regions( - &mut self, - a: ty::Region<'tcx>, - b: ty::Region<'tcx>, - ) -> RelateResult<'tcx, ty::Region<'tcx>> { - debug!("{}.regions({:?}, {:?}) self.cause={:?}", self.tag(), a, b, self.fields.cause); - - // FIXME -- we have more fine-grained information available - // from the "cause" field, we could perhaps give more tailored - // error messages. - let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone())); - // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a) - self.fields - .infcx - .inner - .borrow_mut() - .unwrap_region_constraints() - .make_subregion(origin, b, a); - - Ok(a) - } - - fn consts( - &mut self, - a: ty::Const<'tcx>, - b: ty::Const<'tcx>, - ) -> RelateResult<'tcx, ty::Const<'tcx>> { - self.fields.infcx.super_combine_consts(self, a, b) - } - - fn binders( - &mut self, - a: ty::Binder<'tcx, T>, - b: ty::Binder<'tcx, T>, - ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> - where - T: Relate<'tcx>, - { - // A binder is always a subtype of itself if it's structurally equal to itself - if a == b { - return Ok(a); - } - - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; - Ok(a) - } -} - -impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> { - fn span(&self) -> Span { - self.fields.trace.span() - } - - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.fields.param_env - } - - fn register_predicates(&mut self, obligations: impl IntoIterator>) { - self.fields.register_predicates(obligations); - } - - fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { - self.fields.register_obligations(obligations); - } - - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - ty::AliasRelationDirection::Subtype - } -} diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs new file mode 100644 index 0000000000000..33a4e310221eb --- /dev/null +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -0,0 +1,317 @@ +use super::combine::CombineFields; +use crate::infer::{DefineOpaqueTypes, ObligationEmittingRelation, SubregionOrigin}; +use crate::traits::{Obligation, PredicateObligations}; + +use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation}; +use rustc_middle::ty::visit::TypeVisitableExt; +use rustc_middle::ty::TyVar; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_span::Span; +use std::mem; + +/// Enforce that `a` is equal to or a subtype of `b`. +pub struct TypeRelating<'combine, 'a, 'tcx> { + fields: &'combine mut CombineFields<'a, 'tcx>, + a_is_expected: bool, + ambient_variance: ty::Variance, +} + +impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> { + pub fn new( + f: &'combine mut CombineFields<'infcx, 'tcx>, + a_is_expected: bool, + ambient_variance: ty::Variance, + ) -> TypeRelating<'combine, 'infcx, 'tcx> { + TypeRelating { fields: f, a_is_expected, ambient_variance } + } +} + +impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { + fn tag(&self) -> &'static str { + "TypeRelating" + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.fields.infcx.tcx + } + + fn a_is_expected(&self) -> bool { + self.a_is_expected + } + + fn with_cause(&mut self, cause: Cause, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + debug!("sub with_cause={:?}", cause); + let old_cause = mem::replace(&mut self.fields.cause, Some(cause)); + let r = f(self); + debug!("sub old_cause={:?}", old_cause); + self.fields.cause = old_cause; + r + } + + fn relate_with_variance>( + &mut self, + variance: ty::Variance, + _info: ty::VarianceDiagInfo<'tcx>, + a: T, + b: T, + ) -> RelateResult<'tcx, T> { + let old_ambient_variance = self.ambient_variance; + self.ambient_variance = self.ambient_variance.xform(variance); + debug!(?self.ambient_variance, "new ambient variance"); + + let r = if self.ambient_variance == ty::Bivariant { Ok(a) } else { self.relate(a, b) }; + + self.ambient_variance = old_ambient_variance; + r + } + + #[instrument(skip(self), level = "debug")] + fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { + if a == b { + return Ok(a); + } + + let infcx = self.fields.infcx; + let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a); + let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b); + + match (a.kind(), b.kind()) { + (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { + // Shouldn't have any LBR here, so we can safely put + // this under a binder below without fear of accidental + // capture. + assert!(!a.has_escaping_bound_vars()); + assert!(!b.has_escaping_bound_vars()); + + match self.ambient_variance { + ty::Covariant => { + // can't make progress on `A <: B` if both A and B are + // type variables, so record an obligation. + self.fields.obligations.push(Obligation::new( + self.tcx(), + self.fields.trace.cause.clone(), + self.fields.param_env, + ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { + a_is_expected: self.a_is_expected, + a, + b, + })), + )); + } + ty::Contravariant => { + // can't make progress on `B <: A` if both A and B are + // type variables, so record an obligation. + self.fields.obligations.push(Obligation::new( + self.tcx(), + self.fields.trace.cause.clone(), + self.fields.param_env, + ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { + a_is_expected: !self.a_is_expected, + a: b, + b: a, + })), + )); + } + ty::Invariant => { + infcx.inner.borrow_mut().type_variables().equate(a_id, b_id); + } + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + } + + (&ty::Infer(TyVar(a_vid)), _) => { + infcx.instantiate_ty_var( + self, + self.a_is_expected, + a_vid, + self.ambient_variance, + b, + )?; + } + (_, &ty::Infer(TyVar(b_vid))) => { + infcx.instantiate_ty_var( + self, + !self.a_is_expected, + b_vid, + self.ambient_variance.xform(ty::Contravariant), + a, + )?; + } + + (&ty::Error(e), _) | (_, &ty::Error(e)) => { + infcx.set_tainted_by_errors(e); + return Ok(Ty::new_error(self.tcx(), e)); + } + + ( + &ty::Alias(ty::Opaque, ty::AliasTy { def_id: a_def_id, .. }), + &ty::Alias(ty::Opaque, ty::AliasTy { def_id: b_def_id, .. }), + ) if a_def_id == b_def_id => { + self.fields.infcx.super_combine_tys(self, a, b)?; + } + + (&ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }), _) + | (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })) + if self.fields.define_opaque_types == DefineOpaqueTypes::Yes + && def_id.is_local() + && !self.fields.infcx.next_trait_solver() => + { + self.fields.obligations.extend( + infcx + .handle_opaque_type( + a, + b, + self.a_is_expected, + &self.fields.trace.cause, + self.param_env(), + )? + .obligations, + ); + } + + _ => { + self.fields.infcx.super_combine_tys(self, a, b)?; + } + } + + Ok(a) + } + + fn regions( + &mut self, + a: ty::Region<'tcx>, + b: ty::Region<'tcx>, + ) -> RelateResult<'tcx, ty::Region<'tcx>> { + debug!("{}.regions({:?}, {:?}) self.cause={:?}", self.tag(), a, b, self.fields.cause); + + // FIXME -- we have more fine-grained information available + // from the "cause" field, we could perhaps give more tailored + // error messages. + let origin = SubregionOrigin::Subtype(Box::new(self.fields.trace.clone())); + + match self.ambient_variance { + // Subtype(&'a u8, &'b u8) => Outlives('a: 'b) => SubRegion('b, 'a) + ty::Covariant => { + self.fields + .infcx + .inner + .borrow_mut() + .unwrap_region_constraints() + .make_subregion(origin, b, a); + } + // Suptype(&'a u8, &'b u8) => Outlives('b: 'a) => SubRegion('a, 'b) + ty::Contravariant => { + self.fields + .infcx + .inner + .borrow_mut() + .unwrap_region_constraints() + .make_subregion(origin, a, b); + } + ty::Invariant => { + // The order of `make_eqregion` apparently matters. + self.fields + .infcx + .inner + .borrow_mut() + .unwrap_region_constraints() + .make_eqregion(origin, a, b); + } + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + + Ok(a) + } + + fn consts( + &mut self, + a: ty::Const<'tcx>, + b: ty::Const<'tcx>, + ) -> RelateResult<'tcx, ty::Const<'tcx>> { + self.fields.infcx.super_combine_consts(self, a, b) + } + + fn binders( + &mut self, + a: ty::Binder<'tcx, T>, + b: ty::Binder<'tcx, T>, + ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> + where + T: Relate<'tcx>, + { + if a == b { + // Do nothing + } else if let Some(a) = a.no_bound_vars() + && let Some(b) = b.no_bound_vars() + { + self.relate(a, b)?; + } else { + match self.ambient_variance { + ty::Covariant => { + self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; + } + ty::Contravariant => { + self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?; + } + ty::Invariant => { + self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; + self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?; + } + ty::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + } + } + + Ok(a) + } +} + +impl<'tcx> ObligationEmittingRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { + fn span(&self) -> Span { + self.fields.trace.span() + } + + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.fields.param_env + } + + fn register_predicates(&mut self, obligations: impl IntoIterator>) { + self.fields.register_predicates(obligations); + } + + fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { + self.fields.register_obligations(obligations); + } + + fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) { + self.register_predicates([ty::Binder::dummy(match self.ambient_variance { + ty::Variance::Covariant => ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + ty::AliasRelationDirection::Subtype, + ), + // a :> b is b <: a + ty::Variance::Contravariant => ty::PredicateKind::AliasRelate( + b.into(), + a.into(), + ty::AliasRelationDirection::Subtype, + ), + ty::Variance::Invariant => ty::PredicateKind::AliasRelate( + a.into(), + b.into(), + ty::AliasRelationDirection::Equate, + ), + ty::Variance::Bivariant => { + unreachable!("Expected bivariance to be handled in relate_with_variance") + } + })]); + } +}