diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 48663c4a52f61..f3b7e24f885ea 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -85,11 +85,12 @@ //! that contain `AllocId`s. use std::borrow::Cow; +use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use either::Either; use hashbrown::hash_table::{Entry, HashTable}; -use itertools::Itertools as _; +use itertools::{Itertools as _, MinMaxResult}; use rustc_abi::{self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx}; use rustc_arena::DroplessArena; use rustc_const_eval::const_eval::DummyMachine; @@ -107,6 +108,7 @@ use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::layout::HasTypingEnv; +use rustc_middle::ty::util::IntTypeExt; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::DUMMY_SP; use smallvec::SmallVec; @@ -1367,10 +1369,10 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { } } - if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) { + let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs)); + if let Some(value) = self.simplify_binary_inner(op, ty, lhs_ty, lhs, rhs) { return Some(value); } - let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs)); let value = Value::BinaryOp(op, lhs, rhs); Some(self.insert(ty, value)) } @@ -1378,6 +1380,7 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { fn simplify_binary_inner( &mut self, op: BinOp, + ty: Ty<'tcx>, lhs_ty: Ty<'tcx>, lhs: VnIndex, rhs: VnIndex, @@ -1403,9 +1406,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { }; // Represent the values as `Left(bits)` or `Right(VnIndex)`. - use Either::{Left, Right}; - let a = as_bits(lhs).map_or(Right(lhs), Left); - let b = as_bits(rhs).map_or(Right(rhs), Left); + use BitsOrIndex::*; + let a = as_bits(lhs).map_or(Value(lhs), Bits); + let b = as_bits(rhs).map_or(Value(rhs), Bits); let result = match (op, a, b) { // Neutral elements. @@ -1415,8 +1418,8 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { | BinOp::AddUnchecked | BinOp::BitOr | BinOp::BitXor, - Left(0), - Right(p), + Bits(0), + Value(p), ) | ( BinOp::Add @@ -1430,17 +1433,17 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { | BinOp::Offset | BinOp::Shl | BinOp::Shr, - Right(p), - Left(0), + Value(p), + Bits(0), ) - | (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Left(1), Right(p)) + | (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Bits(1), Value(p)) | ( BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::Div, - Right(p), - Left(1), + Value(p), + Bits(1), ) => p, // Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size. - (BinOp::BitAnd, Right(p), Left(ones)) | (BinOp::BitAnd, Left(ones), Right(p)) + (BinOp::BitAnd, Value(p), Bits(ones)) | (BinOp::BitAnd, Bits(ones), Value(p)) if ones == layout.size.truncate(u128::MAX) || (layout.ty.is_bool() && ones == 1) => { @@ -1450,9 +1453,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { ( BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::BitAnd, _, - Left(0), + Bits(0), ) - | (BinOp::Rem, _, Left(1)) + | (BinOp::Rem, _, Bits(1)) | ( BinOp::Mul | BinOp::MulWithOverflow @@ -1462,11 +1465,11 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { | BinOp::BitAnd | BinOp::Shl | BinOp::Shr, - Left(0), + Bits(0), _, ) => self.insert_scalar(lhs_ty, Scalar::from_uint(0u128, layout.size)), // Attempt to simplify `x | ALL_ONES` to `ALL_ONES`. - (BinOp::BitOr, _, Left(ones)) | (BinOp::BitOr, Left(ones), _) + (BinOp::BitOr, _, Bits(ones)) | (BinOp::BitOr, Bits(ones), _) if ones == layout.size.truncate(u128::MAX) || (layout.ty.is_bool() && ones == 1) => { @@ -1482,11 +1485,19 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { // - if both operands can be computed as bits, just compare the bits; // - if we proved that both operands have the same value, we can insert true/false; // - otherwise, do nothing, as we do not try to prove inequality. - (BinOp::Eq, Left(a), Left(b)) => self.insert_bool(a == b), + (BinOp::Eq, Bits(a), Bits(b)) => self.insert_bool(a == b), (BinOp::Eq, a, b) if a == b => self.insert_bool(true), - (BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b), + (BinOp::Ne, Bits(a), Bits(b)) => self.insert_bool(a != b), (BinOp::Ne, a, b) if a == b => self.insert_bool(false), - _ => return None, + // If we know the range of the value, we can compare them. + (BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge | BinOp::Cmp, lhs, rhs) + if let Some(result) = self.simplify_binary_range(op, lhs_ty, lhs, rhs) => + { + self.insert_scalar(ty, result) + } + _ => { + return None; + } }; if op.is_overflowing() { @@ -1498,6 +1509,81 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> { } } + fn simplify_binary_range( + &self, + op: BinOp, + lhs_ty: Ty<'tcx>, + lhs: BitsOrIndex, + rhs: BitsOrIndex, + ) -> Option { + if !lhs_ty.is_integral() { + return None; + } + let layout = self.ecx.layout_of(lhs_ty).ok()?; + let range = |val: BitsOrIndex| match val { + BitsOrIndex::Bits(bits) => { + let value = ImmTy::from_uint(bits, layout); + Some(Either::Left(value)) + } + BitsOrIndex::Value(value) => { + let value = if let Value::Cast { kind: CastKind::IntToInt, value } = self.get(value) + { + value + } else { + value + }; + let Value::Discriminant(discr) = self.get(value) else { + return None; + }; + let ty::Adt(adt, _) = self.ty(discr).kind() else { + return None; + }; + if !adt.is_enum() { + return None; + } + let discr_ty = adt.repr().discr_type().to_ty(self.tcx); + let discr_layout = self.ecx.layout_of(discr_ty).ok()?; + let MinMaxResult::MinMax(min, max) = adt + .discriminants(self.tcx) + .map(|(_, discr)| { + let val = ImmTy::from_uint(discr.val, discr_layout); + let val = self.ecx.int_to_int_or_float(&val, layout).discard_err().unwrap(); + val + }) + .minmax_by(|x, y| { + let cmp = self.ecx.binary_op(BinOp::Cmp, x, y).unwrap(); + let cmp = cmp.to_scalar_int().unwrap().to_i8(); + match cmp { + -1 => Ordering::Less, + 0 => Ordering::Equal, + 1 => Ordering::Greater, + _ => unreachable!(), + } + }) + else { + return None; + }; + Some(Either::Right((min, max))) + } + }; + let lhs = range(lhs)?; + let rhs = range(rhs)?; + match (lhs, rhs) { + (Either::Left(lhs), Either::Right((rhs_min, rhs_max))) => { + let cmp_min = self.ecx.binary_op(op, &lhs, &rhs_min).discard_err()?.to_scalar(); + let cmp_max = self.ecx.binary_op(op, &lhs, &rhs_max).discard_err()?.to_scalar(); + if cmp_min == cmp_max { Some(cmp_min) } else { None } + } + (Either::Right((lhs_min, lhs_max)), Either::Left(rhs)) => { + let cmp_min = self.ecx.binary_op(op, &lhs_min, &rhs).discard_err()?.to_scalar(); + let cmp_max = self.ecx.binary_op(op, &lhs_max, &rhs).discard_err()?.to_scalar(); + if cmp_min == cmp_max { Some(cmp_min) } else { None } + } + (Either::Left(_), Either::Left(_)) => None, + (Either::Right(_), Either::Right(_)) => None, + } + } + fn simplify_cast( &mut self, initial_kind: &mut CastKind, @@ -1960,3 +2046,9 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> { } } } + +#[derive(Debug, PartialEq, Clone, Copy)] +enum BitsOrIndex { + Bits(u128), + Value(VnIndex), +}