Skip to content

Commit

Permalink
simplify base_struct function
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Nov 11, 2022
1 parent b63aea5 commit a04088b
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 146 deletions.
8 changes: 7 additions & 1 deletion compiler/rustc_hir_typeck/src/demand.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::FnCtxt;
use rustc_infer::infer::base_struct::base_struct;
use rustc_infer::infer::InferOk;
use rustc_middle::middle::stability::EvalResult;
use rustc_trait_selection::infer::InferCtxtExt as _;
use rustc_trait_selection::traits::ObligationCause;
use rustc_ast::util::parser::PREC_POSTFIX;
use rustc_errors::{Applicability, Diagnostic, DiagnosticBuilder, ErrorGuaranteed};
use rustc_hir as hir;
Expand Down Expand Up @@ -181,7 +186,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
expected: Ty<'tcx>,
actual: Ty<'tcx>,
) -> Option<DiagnosticBuilder<'tcx, ErrorGuaranteed>> {
match self.at(cause, self.param_env).base_struct(expected, actual) {
let at = self.at(cause, self.param_env);
match base_struct(at, expected, actual) {
Ok(InferOk { obligations, value: () }) => {
self.register_predicates(obligations);
None
Expand Down
26 changes: 0 additions & 26 deletions compiler/rustc_infer/src/infer/at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,6 @@ impl<'a, 'tcx> At<'a, 'tcx> {
self.trace(expected, actual).glb(expected, actual)
}

pub fn base_struct<T>(
self,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
self.trace(expected, actual).base_struct(expected, actual)
}

/// Sets the "trace" values that will be used for
/// error-reporting, but doesn't actually perform any operation
/// yet (this is useful when you want to set the trace using
Expand Down Expand Up @@ -270,21 +259,6 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
})
}

#[instrument(skip(self), level = "debug")]
pub fn base_struct<T>(self, a: T, b: T) -> InferResult<'tcx, ()>
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
at.infcx.commit_if_ok(|_| {
let mut fields = at.infcx.combine_fields(trace, at.param_env, at.define_opaque_types);
fields
.base_struct(a_is_expected)
.relate(a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
})
}

/// Makes `a == b`; the expectation is set by the call to
/// `trace()`.
#[instrument(skip(self), level = "debug")]
Expand Down
151 changes: 38 additions & 113 deletions compiler/rustc_infer/src/infer/base_struct.rs
Original file line number Diff line number Diff line change
@@ -1,129 +1,54 @@
use std::{iter, mem};
use crate::infer::at::{At, ToTrace};
use crate::infer::sub::Sub;
use crate::infer::{InferOk, InferResult};
use rustc_middle::ty;
use rustc_middle::ty::relate::{Cause, Relate, relate_generic_arg, RelateResult, TypeRelation};
use rustc_middle::ty::{Subst, SubstsRef, Ty, TyCtxt};
use rustc_middle::ty::relate::{relate_generic_arg, RelateResult, TypeRelation};
use rustc_middle::ty::{GenericArg, SubstsRef, Ty};
use rustc_span::def_id::DefId;
use std::iter;

pub struct BaseStruct<'combine, 'infcx, 'tcx> {
sub: Sub<'combine, 'infcx, 'tcx>,
}

impl<'combine, 'infcx, 'tcx> BaseStruct<'combine, 'infcx, 'tcx> {
pub fn new(
sub: Sub<'combine, 'infcx, 'tcx>,
) -> Self {
BaseStruct { sub }
}
pub fn base_struct<'a, 'tcx>(at: At<'a, 'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> InferResult<'tcx, ()> {
let trace = ToTrace::to_trace(at.infcx.tcx, at.cause, true, a, b);
at.infcx.commit_if_ok(|_| {
let mut fields = at.infcx.combine_fields(trace, at.param_env, at.define_opaque_types);
let mut sub = Sub::new(&mut fields, true);
base_struct_tys(&mut sub, a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
})
}

impl<'tcx> TypeRelation<'tcx> for BaseStruct<'_, '_, 'tcx> {
fn tag(&self) -> &'static str {
"BaseStruct"
}

#[inline(always)]
fn tcx(&self) -> TyCtxt<'tcx> {
self.sub.tcx()
}

#[inline(always)]
fn param_env(&self) -> ty::ParamEnv<'tcx> {
self.sub.param_env()
}

#[inline(always)]
fn a_is_expected(&self) -> bool {
self.sub.a_is_expected()
}

#[inline(always)]
fn with_cause<F, R>(&mut self, cause: Cause, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
let old_cause = mem::replace(&mut self.sub.fields.cause, Some(cause));
let r = f(self);
self.sub.fields.cause = old_cause;
r
pub fn base_struct_tys<'tcx>(sub: &mut Sub<'_, '_, 'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
match (a.kind(), b.kind()) {
(&ty::Adt(a_def, a_substs), &ty::Adt(b_def, b_substs)) if a_def == b_def => {
base_struct_substs(sub, a_def.did(), a_substs, b_substs)?;
Ok(())
}
_ => bug!("not adt ty: {:?} and {:?}", a, b),
}
}

fn relate_item_substs(
&mut self,
item_def_id: DefId,
a_subst: SubstsRef<'tcx>,
b_subst: SubstsRef<'tcx>,
) -> RelateResult<'tcx, SubstsRef<'tcx>> {
debug!(
fn base_struct_substs<'tcx>(
sub: &mut Sub<'_, '_, 'tcx>,
item_def_id: DefId,
a_subst: SubstsRef<'tcx>,
b_subst: SubstsRef<'tcx>,
) -> RelateResult<'tcx, ()> {
debug!(
"relate_item_substs(item_def_id={:?}, a_subst={:?}, b_subst={:?})",
item_def_id, a_subst, b_subst
);

let tcx = self.tcx();
let variances = tcx.variances_of(item_def_id);
let tcx = sub.tcx();
let variances = tcx.variances_of(item_def_id);

let mut cached_ty = None;
let params = iter::zip(a_subst, b_subst).enumerate().map(|(i, (a, b))| {
let cached_ty =
*cached_ty.get_or_insert_with(|| tcx.bound_type_of(item_def_id).subst(tcx, a_subst));
relate_generic_arg(&mut self.sub, variances, cached_ty, a, b, i).or_else(|_| {
Ok(b)
})
});
let mut cached_ty = None;
iter::zip(a_subst, b_subst).enumerate().for_each(|(i, (a, b))| {
let cached_ty = *cached_ty
.get_or_insert_with(|| tcx.bound_type_of(item_def_id).subst(tcx, a_subst));
let _arg: RelateResult<'tcx, GenericArg<'tcx>> =
relate_generic_arg(sub, variances, cached_ty, a, b, i).or_else(|_| Ok(b));
});

tcx.mk_substs(params)
}

#[inline(always)]
fn relate_with_variance<T: Relate<'tcx>>(
&mut self,
variance: ty::Variance,
info: ty::VarianceDiagInfo<'tcx>,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
self.sub.relate_with_variance(variance, info, a, b)
}

#[inline(always)]
#[instrument(skip(self), level = "debug")]
fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
match (a.kind(), b.kind()) {
(&ty::Adt(a_def, a_substs), &ty::Adt(b_def, b_substs)) if a_def == b_def => {
let substs = self.relate_item_substs(a_def.did(), a_substs, b_substs)?;
Ok(self.tcx().mk_adt(a_def, substs))
}
_ => bug!("not adt ty: {:?} and {:?}", a, b)
}
}

#[inline(always)]
fn regions(
&mut self,
a: ty::Region<'tcx>,
b: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
self.sub.regions(a, b)
}

#[inline(always)]
fn consts(
&mut self,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
self.sub.consts(a, b)
}

#[inline(always)]
fn binders<T>(
&mut self,
a: ty::Binder<'tcx, T>,
b: ty::Binder<'tcx, T>,
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: Relate<'tcx>,
{
self.sub.binders(a, b)
}
Ok(())
}
5 changes: 0 additions & 5 deletions compiler/rustc_infer/src/infer/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{self, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitable};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::{Span, DUMMY_SP};
use crate::infer::base_struct::BaseStruct;

#[derive(Clone)]
pub struct CombineFields<'infcx, 'tcx> {
Expand Down Expand Up @@ -301,10 +300,6 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
Sub::new(self, a_is_expected)
}

pub fn base_struct<'a>(&'a mut self, a_is_expected: bool) -> BaseStruct<'a, 'infcx, 'tcx> {
BaseStruct::new(Sub::new(self, a_is_expected))
}

pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> {
Lub::new(self, a_is_expected)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub mod resolve;
mod sub;
pub mod type_variable;
mod undo_log;
mod base_struct;
pub mod base_struct;

#[must_use]
#[derive(Debug)]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub fn relate_substs_with_variances<'tcx, R: TypeRelation<'tcx>>(
tcx.mk_substs(params)
}

#[inline]
pub fn relate_generic_arg<'tcx, R: TypeRelation<'tcx>>(
relation: &mut R,
variances: &[ty::Variance],
Expand Down

0 comments on commit a04088b

Please sign in to comment.