Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 113 additions & 21 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@
//! that contain `AllocId`s.

use std::borrow::Cow;
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};

use either::Either;
use hashbrown::hash_table::{Entry, HashTable};
use itertools::Itertools as _;
use itertools::{Itertools as _, MinMaxResult};
use rustc_abi::{self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx};
use rustc_arena::DroplessArena;
use rustc_const_eval::const_eval::DummyMachine;
Expand All @@ -107,6 +108,7 @@ use rustc_middle::mir::interpret::GlobalAlloc;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::HasTypingEnv;
use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::DUMMY_SP;
use smallvec::SmallVec;
Expand Down Expand Up @@ -1367,17 +1369,18 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}
}

if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) {
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
if let Some(value) = self.simplify_binary_inner(op, ty, lhs_ty, lhs, rhs) {
return Some(value);
}
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
let value = Value::BinaryOp(op, lhs, rhs);
Some(self.insert(ty, value))
}

fn simplify_binary_inner(
&mut self,
op: BinOp,
ty: Ty<'tcx>,
lhs_ty: Ty<'tcx>,
lhs: VnIndex,
rhs: VnIndex,
Expand All @@ -1403,9 +1406,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
};

// Represent the values as `Left(bits)` or `Right(VnIndex)`.
use Either::{Left, Right};
let a = as_bits(lhs).map_or(Right(lhs), Left);
let b = as_bits(rhs).map_or(Right(rhs), Left);
use BitsOrIndex::*;
let a = as_bits(lhs).map_or(Value(lhs), Bits);
let b = as_bits(rhs).map_or(Value(rhs), Bits);

let result = match (op, a, b) {
// Neutral elements.
Expand All @@ -1415,8 +1418,8 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
| BinOp::AddUnchecked
| BinOp::BitOr
| BinOp::BitXor,
Left(0),
Right(p),
Bits(0),
Value(p),
)
| (
BinOp::Add
Expand All @@ -1430,17 +1433,17 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
| BinOp::Offset
| BinOp::Shl
| BinOp::Shr,
Right(p),
Left(0),
Value(p),
Bits(0),
)
| (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Left(1), Right(p))
| (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Bits(1), Value(p))
| (
BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::Div,
Right(p),
Left(1),
Value(p),
Bits(1),
) => p,
// Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size.
(BinOp::BitAnd, Right(p), Left(ones)) | (BinOp::BitAnd, Left(ones), Right(p))
(BinOp::BitAnd, Value(p), Bits(ones)) | (BinOp::BitAnd, Bits(ones), Value(p))
if ones == layout.size.truncate(u128::MAX)
|| (layout.ty.is_bool() && ones == 1) =>
{
Expand All @@ -1450,9 +1453,9 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
(
BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::BitAnd,
_,
Left(0),
Bits(0),
)
| (BinOp::Rem, _, Left(1))
| (BinOp::Rem, _, Bits(1))
| (
BinOp::Mul
| BinOp::MulWithOverflow
Expand All @@ -1462,11 +1465,11 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
| BinOp::BitAnd
| BinOp::Shl
| BinOp::Shr,
Left(0),
Bits(0),
_,
) => self.insert_scalar(lhs_ty, Scalar::from_uint(0u128, layout.size)),
// Attempt to simplify `x | ALL_ONES` to `ALL_ONES`.
(BinOp::BitOr, _, Left(ones)) | (BinOp::BitOr, Left(ones), _)
(BinOp::BitOr, _, Bits(ones)) | (BinOp::BitOr, Bits(ones), _)
if ones == layout.size.truncate(u128::MAX)
|| (layout.ty.is_bool() && ones == 1) =>
{
Expand All @@ -1482,11 +1485,19 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
// - if both operands can be computed as bits, just compare the bits;
// - if we proved that both operands have the same value, we can insert true/false;
// - otherwise, do nothing, as we do not try to prove inequality.
(BinOp::Eq, Left(a), Left(b)) => self.insert_bool(a == b),
(BinOp::Eq, Bits(a), Bits(b)) => self.insert_bool(a == b),
(BinOp::Eq, a, b) if a == b => self.insert_bool(true),
(BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b),
(BinOp::Ne, Bits(a), Bits(b)) => self.insert_bool(a != b),
(BinOp::Ne, a, b) if a == b => self.insert_bool(false),
_ => return None,
// If we know the range of the value, we can compare them.
(BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge | BinOp::Cmp, lhs, rhs)
if let Some(result) = self.simplify_binary_range(op, lhs_ty, lhs, rhs) =>
{
self.insert_scalar(ty, result)
}
_ => {
return None;
}
};

if op.is_overflowing() {
Expand All @@ -1498,6 +1509,81 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}
}

fn simplify_binary_range(
&self,
op: BinOp,
lhs_ty: Ty<'tcx>,
lhs: BitsOrIndex,
rhs: BitsOrIndex,
) -> Option<Scalar> {
if !lhs_ty.is_integral() {
return None;
}
let layout = self.ecx.layout_of(lhs_ty).ok()?;
let range = |val: BitsOrIndex| match val {
BitsOrIndex::Bits(bits) => {
let value = ImmTy::from_uint(bits, layout);
Some(Either::Left(value))
}
BitsOrIndex::Value(value) => {
let value = if let Value::Cast { kind: CastKind::IntToInt, value } = self.get(value)
{
value
} else {
value
};
let Value::Discriminant(discr) = self.get(value) else {
return None;
};
let ty::Adt(adt, _) = self.ty(discr).kind() else {
return None;
};
if !adt.is_enum() {
return None;
}
let discr_ty = adt.repr().discr_type().to_ty(self.tcx);
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let MinMaxResult::MinMax(min, max) = adt
.discriminants(self.tcx)
.map(|(_, discr)| {
let val = ImmTy::from_uint(discr.val, discr_layout);
let val = self.ecx.int_to_int_or_float(&val, layout).discard_err().unwrap();
val
})
.minmax_by(|x, y| {
let cmp = self.ecx.binary_op(BinOp::Cmp, x, y).unwrap();
let cmp = cmp.to_scalar_int().unwrap().to_i8();
match cmp {
-1 => Ordering::Less,
0 => Ordering::Equal,
1 => Ordering::Greater,
_ => unreachable!(),
}
})
else {
return None;
};
Some(Either::Right((min, max)))
}
};
let lhs = range(lhs)?;
let rhs = range(rhs)?;
match (lhs, rhs) {
(Either::Left(lhs), Either::Right((rhs_min, rhs_max))) => {
let cmp_min = self.ecx.binary_op(op, &lhs, &rhs_min).discard_err()?.to_scalar();
let cmp_max = self.ecx.binary_op(op, &lhs, &rhs_max).discard_err()?.to_scalar();
if cmp_min == cmp_max { Some(cmp_min) } else { None }
}
(Either::Right((lhs_min, lhs_max)), Either::Left(rhs)) => {
let cmp_min = self.ecx.binary_op(op, &lhs_min, &rhs).discard_err()?.to_scalar();
let cmp_max = self.ecx.binary_op(op, &lhs_max, &rhs).discard_err()?.to_scalar();
if cmp_min == cmp_max { Some(cmp_min) } else { None }
}
(Either::Left(_), Either::Left(_)) => None,
(Either::Right(_), Either::Right(_)) => None,
}
}

fn simplify_cast(
&mut self,
initial_kind: &mut CastKind,
Expand Down Expand Up @@ -1960,3 +2046,9 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> {
}
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
enum BitsOrIndex {
Bits(u128),
Value(VnIndex),
}
Loading