Skip to content

Commit

Permalink
Rollup merge of #119461 - cjgillot:jump-threading-interp, r=tmiasko
Browse files Browse the repository at this point in the history
Use an interpreter in MIR jump threading

This allows to understand assignments of aggregate constants. This case appears more frequently with GVN promoting aggregates to constants.
  • Loading branch information
Nadrieril committed Jan 21, 2024
2 parents 3cd378c + e72b2b1 commit 203cc69
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 107 deletions.
262 changes: 155 additions & 107 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@
//! cost by `MAX_COST`.

use rustc_arena::DroplessArena;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants};

use crate::cost_checker::CostChecker;
use crate::dataflow_const_prop::DummyMachine;

pub struct JumpThreading;

Expand All @@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena: &arena,
map: &map,
Expand All @@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
debug!(?discr, ?bb);

let discr_ty = discr.ty(body, tcx).ty;
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };

let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
debug!(?discr);
Expand Down Expand Up @@ -142,6 +148,7 @@ struct ThreadingOpportunity {
struct TOFinder<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: &'a Map,
loop_headers: &'a BitSet<BasicBlock>,
Expand Down Expand Up @@ -329,25 +336,82 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};

let conditions = state.try_get_idx(lhs, self.map)?;
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
conditions.iter_matches(int).for_each(register_opportunity);
}

None
}

/// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
#[instrument(level = "trace", skip(self))]
fn process_constant(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
constant: OpTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities
.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
})
}
},
);
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let constant =
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
conditions.iter_matches(constant).for_each(register_opportunity);
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
Expand All @@ -359,6 +423,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
None
}

#[instrument(level = "trace", skip(self))]
fn process_assign(
&mut self,
bb: BasicBlock,
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let lhs = self.map.find(lhs_place.as_ref())?;
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Ok(discr_value) =
self.ecx.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_immediate(bb, discr_target, discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
});
state.insert_value_idx(place, conds, self.map);
}

_ => {}
}

None
}

#[instrument(level = "trace", skip(self))]
fn process_statement(
&mut self,
Expand All @@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.

let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
Some(Operand::const_from_scalar(
self.tcx,
discr.ty,
scalar.into(),
rustc_span::DUMMY_SP,
))
};

match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
Expand All @@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
assert_eq!(index, *variant_index);
Expand All @@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
self.process_operand(bb, discr_target, &discr, state)?;
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
self.process_immediate(bb, discr_target, discr, state)?;
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
Expand All @@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
}
StatementKind::Assign(box (lhs_place, rhs)) => {
if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) =
self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) =
discriminant_for_variant(agg_ty, *variant_index)
{
self.process_operand(bb, discr_target, &discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) =
self.map.apply(lhs, TrackElem::Field(field_index))
{
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (
Operand::Move(place) | Operand::Copy(place),
Operand::Constant(value),
)
| box (
Operand::Constant(value),
Operand::Move(place) | Operand::Copy(place),
),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value
.const_
.normalize(self.tcx, self.param_env)
.try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) {
Polarity::Eq
} else {
Polarity::Ne
},
..c
});
state.insert_value_idx(place, conds, self.map);
}

_ => {}
}
}
self.process_assign(bb, lhs_place, rhs, state)?;
}
_ => {}
}
Expand Down Expand Up @@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {

let discr = discr.place()?;
let discr_ty = discr.ty(self.body, self.tcx).ty;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let conditions = state.try_get(discr.as_ref(), self.map)?;

if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
Expand Down

0 comments on commit 203cc69

Please sign in to comment.