Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify IntVarValue/FloatVarValue #123536

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions compiler/rustc_infer/src/infer/freshen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
use super::InferCtxt;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::bug;
use rustc_middle::infer::unify_key::ToType;
use rustc_middle::ty::fold::TypeFolder;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable, TypeVisitableExt};
use std::collections::hash_map::Entry;
Expand Down Expand Up @@ -204,22 +203,27 @@ impl<'a, 'tcx> TypeFreshener<'a, 'tcx> {

ty::IntVar(v) => {
let mut inner = self.infcx.inner.borrow_mut();
let input = inner
.int_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.infcx.tcx))
.ok_or_else(|| ty::IntVar(inner.int_unification_table().find(v)));
let value = inner.int_unification_table().probe_value(v);
let input = match value {
ty::IntVarValue::IntType(ty) => Ok(Ty::new_int(self.infcx.tcx, ty)),
ty::IntVarValue::UintType(ty) => Ok(Ty::new_uint(self.infcx.tcx, ty)),
ty::IntVarValue::Unknown => {
Err(ty::IntVar(inner.int_unification_table().find(v)))
}
};
drop(inner);
Some(self.freshen_ty(input, |n| Ty::new_fresh_int(self.infcx.tcx, n)))
}

ty::FloatVar(v) => {
let mut inner = self.infcx.inner.borrow_mut();
let input = inner
.float_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.infcx.tcx))
.ok_or_else(|| ty::FloatVar(inner.float_unification_table().find(v)));
let value = inner.float_unification_table().probe_value(v);
let input = match value {
ty::FloatVarValue::Known(ty) => Ok(Ty::new_float(self.infcx.tcx, ty)),
ty::FloatVarValue::Unknown => {
Err(ty::FloatVar(inner.float_unification_table().find(v)))
}
};
drop(inner);
Some(self.freshen_ty(input, |n| Ty::new_fresh_float(self.infcx.tcx, n)))
}
Expand Down
130 changes: 74 additions & 56 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ use rustc_errors::{Diag, DiagCtxt, ErrorGuaranteed};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_macros::extension;
use rustc_middle::infer::canonical::{Canonical, CanonicalVarValues};
use rustc_middle::infer::unify_key::ConstVariableOrigin;
use rustc_middle::infer::unify_key::ConstVariableValue;
use rustc_middle::infer::unify_key::EffectVarValue;
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ToType};
use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
use rustc_middle::mir::interpret::{ErrorHandled, EvalToValTreeResult};
use rustc_middle::mir::ConstraintCategory;
Expand All @@ -41,7 +41,7 @@ use rustc_middle::ty::fold::BoundVarReplacerDelegate;
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::relate::RelateResult;
use rustc_middle::ty::visit::TypeVisitableExt;
use rustc_middle::ty::{self, GenericParamDefKind, InferConst, InferTy, Ty, TyCtxt};
use rustc_middle::ty::{self, GenericParamDefKind, InferConst, Ty, TyCtxt};
use rustc_middle::ty::{ConstVid, EffectVid, FloatVid, IntVid, TyVid};
use rustc_middle::ty::{GenericArg, GenericArgKind, GenericArgs, GenericArgsRef};
use rustc_middle::{bug, span_bug};
Expand Down Expand Up @@ -811,13 +811,13 @@ impl<'tcx> InferCtxt<'tcx> {
vars.extend(
(0..inner.int_unification_table().len())
.map(|i| ty::IntVid::from_u32(i as u32))
.filter(|&vid| inner.int_unification_table().probe_value(vid).is_none())
.filter(|&vid| inner.int_unification_table().probe_value(vid).is_unknown())
.map(|v| Ty::new_int_var(self.tcx, v)),
);
vars.extend(
(0..inner.float_unification_table().len())
.map(|i| ty::FloatVid::from_u32(i as u32))
.filter(|&vid| inner.float_unification_table().probe_value(vid).is_none())
.filter(|&vid| inner.float_unification_table().probe_value(vid).is_unknown())
.map(|v| Ty::new_float_var(self.tcx, v)),
);
vars
Expand Down Expand Up @@ -1025,14 +1025,28 @@ impl<'tcx> InferCtxt<'tcx> {
ty::Const::new_var(self.tcx, vid, ty)
}

pub fn next_const_var_id(&self, origin: ConstVariableOrigin) -> ConstVid {
self.inner
.borrow_mut()
.const_unification_table()
.new_key(ConstVariableValue::Unknown { origin, universe: self.universe() })
.vid
}

fn next_int_var_id(&self) -> IntVid {
self.inner.borrow_mut().int_unification_table().new_key(ty::IntVarValue::Unknown)
}

pub fn next_int_var(&self) -> Ty<'tcx> {
let vid = self.inner.borrow_mut().int_unification_table().new_key(None);
Ty::new_int_var(self.tcx, vid)
Ty::new_int_var(self.tcx, self.next_int_var_id())
}

fn next_float_var_id(&self) -> FloatVid {
self.inner.borrow_mut().float_unification_table().new_key(ty::FloatVarValue::Unknown)
}

pub fn next_float_var(&self) -> Ty<'tcx> {
let vid = self.inner.borrow_mut().float_unification_table().new_key(None);
Ty::new_float_var(self.tcx, vid)
Ty::new_float_var(self.tcx, self.next_float_var_id())
}

/// Creates a fresh region variable with the next available index.
Expand Down Expand Up @@ -1234,45 +1248,44 @@ impl<'tcx> InferCtxt<'tcx> {
}

pub fn shallow_resolve(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
if let ty::Infer(v) = ty.kind() { self.fold_infer_ty(*v).unwrap_or(ty) } else { ty }
}

// This is separate from `shallow_resolve` to keep that method small and inlinable.
#[inline(never)]
fn fold_infer_ty(&self, v: InferTy) -> Option<Ty<'tcx>> {
match v {
ty::TyVar(v) => {
// Not entirely obvious: if `typ` is a type variable,
// it can be resolved to an int/float variable, which
// can then be recursively resolved, hence the
// recursion. Note though that we prevent type
// variables from unifying to other type variables
// directly (though they may be embedded
// structurally), and we prevent cycles in any case,
// so this recursion should always be of very limited
// depth.
//
// Note: if these two lines are combined into one we get
// dynamic borrow errors on `self.inner`.
let known = self.inner.borrow_mut().type_variables().probe(v).known();
known.map(|t| self.shallow_resolve(t))
}
if let ty::Infer(v) = *ty.kind() {
match v {
ty::TyVar(v) => {
// Not entirely obvious: if `typ` is a type variable,
// it can be resolved to an int/float variable, which
// can then be recursively resolved, hence the
// recursion. Note though that we prevent type
// variables from unifying to other type variables
// directly (though they may be embedded
// structurally), and we prevent cycles in any case,
// so this recursion should always be of very limited
// depth.
//
// Note: if these two lines are combined into one we get
// dynamic borrow errors on `self.inner`.
let known = self.inner.borrow_mut().type_variables().probe(v).known();
known.map_or(ty, |t| self.shallow_resolve(t))
}

ty::IntVar(v) => {
match self.inner.borrow_mut().int_unification_table().probe_value(v) {
ty::IntVarValue::IntType(ty) => Ty::new_int(self.tcx, ty),
ty::IntVarValue::UintType(ty) => Ty::new_uint(self.tcx, ty),
ty::IntVarValue::Unknown => ty,
}
}

ty::IntVar(v) => self
.inner
.borrow_mut()
.int_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.tcx)),

ty::FloatVar(v) => self
.inner
.borrow_mut()
.float_unification_table()
.probe_value(v)
.map(|v| v.to_type(self.tcx)),

ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => None,
ty::FloatVar(v) => {
match self.inner.borrow_mut().float_unification_table().probe_value(v) {
ty::FloatVarValue::Known(ty) => Ty::new_float(self.tcx, ty),
ty::FloatVarValue::Unknown => ty,
}
}

ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => ty,
}
} else {
ty
}
}

Expand Down Expand Up @@ -1321,21 +1334,26 @@ impl<'tcx> InferCtxt<'tcx> {
/// or else the root int var in the unification table.
pub fn opportunistic_resolve_int_var(&self, vid: ty::IntVid) -> Ty<'tcx> {
let mut inner = self.inner.borrow_mut();
if let Some(value) = inner.int_unification_table().probe_value(vid) {
value.to_type(self.tcx)
} else {
Ty::new_int_var(self.tcx, inner.int_unification_table().find(vid))
let value = inner.int_unification_table().probe_value(vid);
match value {
ty::IntVarValue::IntType(ty) => Ty::new_int(self.tcx, ty),
ty::IntVarValue::UintType(ty) => Ty::new_uint(self.tcx, ty),
ty::IntVarValue::Unknown => {
Ty::new_int_var(self.tcx, inner.int_unification_table().find(vid))
}
}
}

/// Resolves a float var to a rigid int type, if it was constrained to one,
/// or else the root float var in the unification table.
pub fn opportunistic_resolve_float_var(&self, vid: ty::FloatVid) -> Ty<'tcx> {
let mut inner = self.inner.borrow_mut();
if let Some(value) = inner.float_unification_table().probe_value(vid) {
value.to_type(self.tcx)
} else {
Ty::new_float_var(self.tcx, inner.float_unification_table().find(vid))
let value = inner.float_unification_table().probe_value(vid);
match value {
ty::FloatVarValue::Known(ty) => Ty::new_float(self.tcx, ty),
ty::FloatVarValue::Unknown => {
Ty::new_float_var(self.tcx, inner.float_unification_table().find(vid))
}
}
}

Expand Down Expand Up @@ -1626,15 +1644,15 @@ impl<'tcx> InferCtxt<'tcx> {
// If `inlined_probe_value` returns a value it's always a
// `ty::Int(_)` or `ty::UInt(_)`, which never matches a
// `ty::Infer(_)`.
self.inner.borrow_mut().int_unification_table().inlined_probe_value(v).is_some()
self.inner.borrow_mut().int_unification_table().inlined_probe_value(v).is_known()
}

TyOrConstInferVar::TyFloat(v) => {
// If `probe_value` returns a value it's always a
// `ty::Float(_)`, which never matches a `ty::Infer(_)`.
//
// Not `inlined_probe_value(v)` because this call site is colder.
self.inner.borrow_mut().float_unification_table().probe_value(v).is_some()
self.inner.borrow_mut().float_unification_table().probe_value(v).is_known()
}

TyOrConstInferVar::Const(v) => {
Expand Down
81 changes: 21 additions & 60 deletions compiler/rustc_infer/src/infer/relate/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace};
use crate::traits::{Obligation, PredicateObligations};
use rustc_middle::bug;
use rustc_middle::infer::unify_key::EffectVarValue;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt, Upcast};
use rustc_middle::ty::{IntType, UintType};
Expand Down Expand Up @@ -68,40 +68,38 @@ impl<'tcx> InferCtxt<'tcx> {
match (a.kind(), b.kind()) {
// Relate integral variables to other types
(&ty::Infer(ty::IntVar(a_id)), &ty::Infer(ty::IntVar(b_id))) => {
self.inner
.borrow_mut()
.int_unification_table()
.unify_var_var(a_id, b_id)
.map_err(|e| int_unification_error(true, e))?;
self.inner.borrow_mut().int_unification_table().union(a_id, b_id);
Ok(a)
}
(&ty::Infer(ty::IntVar(v_id)), &ty::Int(v)) => {
self.unify_integral_variable(true, v_id, IntType(v))
self.unify_integral_variable(v_id, IntType(v));
Ok(b)
}
(&ty::Int(v), &ty::Infer(ty::IntVar(v_id))) => {
self.unify_integral_variable(false, v_id, IntType(v))
self.unify_integral_variable(v_id, IntType(v));
Ok(a)
}
(&ty::Infer(ty::IntVar(v_id)), &ty::Uint(v)) => {
self.unify_integral_variable(true, v_id, UintType(v))
self.unify_integral_variable(v_id, UintType(v));
Ok(b)
}
(&ty::Uint(v), &ty::Infer(ty::IntVar(v_id))) => {
self.unify_integral_variable(false, v_id, UintType(v))
self.unify_integral_variable(v_id, UintType(v));
Ok(a)
}

// Relate floating-point variables to other types
(&ty::Infer(ty::FloatVar(a_id)), &ty::Infer(ty::FloatVar(b_id))) => {
self.inner
.borrow_mut()
.float_unification_table()
.unify_var_var(a_id, b_id)
.map_err(|e| float_unification_error(true, e))?;
self.inner.borrow_mut().float_unification_table().union(a_id, b_id);
Ok(a)
}
(&ty::Infer(ty::FloatVar(v_id)), &ty::Float(v)) => {
self.unify_float_variable(true, v_id, v)
self.unify_float_variable(v_id, ty::FloatVarValue::Known(v));
Ok(b)
}
(&ty::Float(v), &ty::Infer(ty::FloatVar(v_id))) => {
self.unify_float_variable(false, v_id, v)
self.unify_float_variable(v_id, ty::FloatVarValue::Known(v));
Ok(a)
}

// We don't expect `TyVar` or `Fresh*` vars at this point with lazy norm.
Expand Down Expand Up @@ -244,35 +242,14 @@ impl<'tcx> InferCtxt<'tcx> {
}
}

fn unify_integral_variable(
&self,
vid_is_expected: bool,
vid: ty::IntVid,
val: ty::IntVarValue,
) -> RelateResult<'tcx, Ty<'tcx>> {
self.inner
.borrow_mut()
.int_unification_table()
.unify_var_value(vid, Some(val))
.map_err(|e| int_unification_error(vid_is_expected, e))?;
match val {
IntType(v) => Ok(Ty::new_int(self.tcx, v)),
UintType(v) => Ok(Ty::new_uint(self.tcx, v)),
}
#[inline(always)]
fn unify_integral_variable(&self, vid: ty::IntVid, val: ty::IntVarValue) {
self.inner.borrow_mut().int_unification_table().union_value(vid, val);
}

fn unify_float_variable(
&self,
vid_is_expected: bool,
vid: ty::FloatVid,
val: ty::FloatTy,
) -> RelateResult<'tcx, Ty<'tcx>> {
self.inner
.borrow_mut()
.float_unification_table()
.unify_var_value(vid, Some(ty::FloatVarValue(val)))
.map_err(|e| float_unification_error(vid_is_expected, e))?;
Ok(Ty::new_float(self.tcx, val))
#[inline(always)]
fn unify_float_variable(&self, vid: ty::FloatVid, val: ty::FloatVarValue) {
self.inner.borrow_mut().float_unification_table().union_value(vid, val);
}

fn unify_effect_variable(&self, vid: ty::EffectVid, val: ty::Const<'tcx>) -> ty::Const<'tcx> {
Expand Down Expand Up @@ -350,19 +327,3 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
/// Register `AliasRelate` obligation(s) that both types must be related to each other.
fn register_type_relate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>);
}

fn int_unification_error<'tcx>(
a_is_expected: bool,
v: (ty::IntVarValue, ty::IntVarValue),
) -> TypeError<'tcx> {
let (a, b) = v;
TypeError::IntMismatch(ExpectedFound::new(a_is_expected, a, b))
}

fn float_unification_error<'tcx>(
a_is_expected: bool,
v: (ty::FloatVarValue, ty::FloatVarValue),
) -> TypeError<'tcx> {
let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b))
}
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/relate/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ where

let infcx = this.infcx();

let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

match (a.kind(), b.kind()) {
// If one side is known to be a variable and one is not,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/relate/type_relating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
}

let infcx = self.fields.infcx;
let a = infcx.inner.borrow_mut().type_variables().replace_if_possible(a);
let b = infcx.inner.borrow_mut().type_variables().replace_if_possible(b);
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

match (a.kind(), b.kind()) {
(&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
Expand Down
Loading
Loading