Skip to content

Commit

Permalink
Auto merge of #69756 - wesleywiser:simplify_try, r=<try>
Browse files Browse the repository at this point in the history
[WIP] Modify SimplifyArmIdentity so it can trigger on mir-opt-level=1

r? @ghost
  • Loading branch information
bors committed Mar 9, 2020
2 parents 2cb0b85 + f0d5e44 commit be3cc06
Show file tree
Hide file tree
Showing 5 changed files with 719 additions and 80 deletions.
4 changes: 2 additions & 2 deletions src/librustc_mir/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ fn run_optimization_passes<'tcx>(
&const_prop::ConstProp,
&simplify_branches::SimplifyBranches::new("after-const-prop"),
&deaggregator::Deaggregator,
&simplify_try::SimplifyArmIdentity,
&simplify_try::SimplifyBranchSame,
&copy_prop::CopyPropagation,
&simplify_branches::SimplifyBranches::new("after-copy-prop"),
&remove_noop_landing_pads::RemoveNoopLandingPads,
&simplify::SimplifyCfg::new("after-remove-noop-landing-pads"),
&simplify_try::SimplifyArmIdentity,
&simplify_try::SimplifyBranchSame,
&simplify::SimplifyCfg::new("final"),
&simplify::SimplifyLocals,
&add_call_guards::CriticalCallEdges,
Expand Down
283 changes: 244 additions & 39 deletions src/librustc_mir/transform/simplify_try.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ use crate::transform::{simplify, MirPass, MirSource};
use itertools::Itertools as _;
use rustc::mir::*;
use rustc::ty::{Ty, TyCtxt};
use rustc_index::vec::IndexVec;
use rustc_target::abi::VariantIdx;
use std::iter::{Enumerate, Peekable};
use std::slice::Iter;

/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
///
/// This is done by transforming basic blocks where the statements match:
///
/// ```rust
/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
/// ((_LOCAL_0 as Variant).FIELD: TY) = move _LOCAL_TMP;
/// _TMP_2 = _LOCAL_TMP;
/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
/// discriminant(_LOCAL_0) = VAR_IDX;
/// ```
///
Expand All @@ -32,50 +36,251 @@ use rustc_target::abi::VariantIdx;
/// ```
pub struct SimplifyArmIdentity;

#[derive(Debug)]
struct ArmIdentityInfo<'tcx> {
/// Storage location for the variant's field
local_temp_0: Local,
/// Storage location holding the varient being read from
local_1: Local,
/// The varient field being read from
vf_s0: VarField<'tcx>,

/// Tracks each assignment to a temporary of the varient's field
field_tmp_assignments: Vec<(Local, Local)>,

/// Storage location holding the variant's field that was read from
local_tmp_s1: Local,
/// Storage location holding the enum that we are writing to
local_0: Local,
/// The varient field being written to
vf_s1: VarField<'tcx>,

/// Storage location that the discrimentant is being set to
set_discr_local: Local,
/// The variant being written
set_discr_var_idx: VariantIdx,

/// Index of the statement that should be overwritten as a move
stmt_to_overwrite: usize,
/// SourceInfo for the new move
source_info: SourceInfo,

/// Indexes of matching Storage{Live,Dead} statements encountered.
/// (StorageLive index,, StorageDead index, Local)
storage_stmts: Vec<(usize, usize, Local)>,

/// The statements that should be removed (turned into nops)
stmts_to_remove: Vec<usize>,
}

fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmIdentityInfo<'tcx>> {
let mut tmp_assigns = Vec::new();
let mut nop_stmts = Vec::new();
let mut storage_stmts = Vec::new();
let mut storage_live_stmts = Vec::new();
let mut storage_dead_stmts = Vec::new();

type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;

fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
}

fn try_eat_storage_stmts<'a, 'tcx>(
stmt_iter: &mut StmtIter<'a, 'tcx>,
storage_live_stmts: &mut Vec<(usize, Local)>,
storage_dead_stmts: &mut Vec<(usize, Local)>,
) {
while stmt_iter.peek().map(|(_, stmt)| is_storage_stmt(stmt)).unwrap_or(false) {
let (idx, stmt) = stmt_iter.next().unwrap();

if let StatementKind::StorageLive(l) = stmt.kind {
storage_live_stmts.push((idx, l));
} else if let StatementKind::StorageDead(l) = stmt.kind {
storage_dead_stmts.push((idx, l));
}
}
}

fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
if let StatementKind::Assign(box (place, Rvalue::Use(op))) = &stmt.kind {
if let Operand::Copy(p) | Operand::Move(p) = op {
return place.as_local().is_some() && p.as_local().is_some();
}
}

false
}

fn try_eat_assign_tmp_stmts<'a, 'tcx>(
stmt_iter: &mut StmtIter<'a, 'tcx>,
tmp_assigns: &mut Vec<(Local, Local)>,
nop_stmts: &mut Vec<usize>,
) {
while stmt_iter.peek().map(|(_, stmt)| is_tmp_storage_stmt(stmt)).unwrap_or(false) {
let (idx, stmt) = stmt_iter.next().unwrap();

if let StatementKind::Assign(box (place, Rvalue::Use(op))) = &stmt.kind {
if let Operand::Copy(p) | Operand::Move(p) = op {
tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
nop_stmts.push(idx);
}
}
}
}

let mut stmt_iter = stmts.iter().enumerate().peekable();

try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

let (starting_stmt, stmt) = stmt_iter.next()?;
let (local_tmp_s0, local_1, vf_s0) = match_get_variant_field(stmt)?;

try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);

try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);

let (idx, stmt) = stmt_iter.next()?;
let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
nop_stmts.push(idx);

let (idx, stmt) = stmt_iter.next()?;
let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
let discr_stmt_source_info = stmt.source_info;
nop_stmts.push(idx);

try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

for (live_idx, live_local) in storage_live_stmts {
if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
storage_stmts.push((live_idx, dead_idx, live_local));
}
}

Some(ArmIdentityInfo {
local_temp_0: local_tmp_s0,
local_1,
vf_s0,
field_tmp_assignments: tmp_assigns,
local_tmp_s1,
local_0,
vf_s1,
set_discr_local,
set_discr_var_idx,
stmt_to_overwrite: starting_stmt,
source_info: discr_stmt_source_info,
storage_stmts,
stmts_to_remove: nop_stmts,
})
}

fn optimization_applies<'tcx>(
opt_info: &ArmIdentityInfo<'tcx>,
local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
) -> bool {
trace!("testing if optimization applies...");

if opt_info.local_0 == opt_info.local_1 {
trace!("NO: moving into ourselves");
return false;
} else if opt_info.vf_s0 != opt_info.vf_s1 {
trace!("NO: the field-and-variant information do not match");
return false;
} else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
// FIXME(Centril,oli-obk): possibly relax ot same layout?
trace!("NO: source and target locals have different types");
return false;
} else if (opt_info.local_0, opt_info.vf_s0.var_idx)
!= (opt_info.set_discr_local, opt_info.set_discr_var_idx)
{
trace!("NO: the discriminants do not match");
return false;
}

// Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
if opt_info.field_tmp_assignments.len() == 0 {
trace!("NO: no assignments found");
}
let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
let source_local = last_assigned_to;
for (l, r) in &opt_info.field_tmp_assignments {
if *r != last_assigned_to {
trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
return false;
}

last_assigned_to = *l;
}

if source_local != opt_info.local_temp_0 {
trace!(
"NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
source_local,
opt_info.local_temp_0
);
return false;
} else if last_assigned_to != opt_info.local_tmp_s1 {
trace!(
"NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}",
last_assigned_to,
opt_info.local_tmp_s1
);
return false;
}

trace!("SUCCESS: optimization applies!");
return true;
}

impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut BodyAndCache<'tcx>) {
fn run_pass(&self, _: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut BodyAndCache<'tcx>) {
trace!("running SimplifyArmIdentity on {:?}", source);
let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
for bb in basic_blocks {
// Need 3 statements:
let (s0, s1, s2) = match &mut *bb.statements {
[s0, s1, s2] => (s0, s1, s2),
_ => continue,
};
trace!("bb is {:?}", bb.statements);

// Pattern match on the form we want:
let (local_tmp_s0, local_1, vf_s0) = match match_get_variant_field(s0) {
None => continue,
Some(x) => x,
};
let (local_tmp_s1, local_0, vf_s1) = match match_set_variant_field(s1) {
None => continue,
Some(x) => x,
};
if local_tmp_s0 != local_tmp_s1
// Avoid moving into ourselves.
|| local_0 == local_1
// The field-and-variant information match up.
|| vf_s0 != vf_s1
// Source and target locals have the same type.
// FIXME(Centril | oli-obk): possibly relax to same layout?
|| local_decls[local_0].ty != local_decls[local_1].ty
// We're setting the discriminant of `local_0` to this variant.
|| Some((local_0, vf_s0.var_idx)) != match_set_discr(s2)
{
continue;
}
if let Some(mut opt_info) = get_arm_identity_info(&bb.statements) {
trace!("got opt_info = {:#?}", opt_info);
if !optimization_applies(&opt_info, local_decls) {
debug!("optimization skipped for {:?}", source);
continue;
}

// Also remove unused Storage{Live,Dead} statements which correspond
// to temps used previously.
for (left, right) in opt_info.field_tmp_assignments {
for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
if *local == left || *local == right {
opt_info.stmts_to_remove.push(*live_idx);
opt_info.stmts_to_remove.push(*dead_idx);
}
}
}

// Right shape; transform!
s0.source_info = s2.source_info;
match &mut s0.kind {
StatementKind::Assign(box (place, rvalue)) => {
*place = local_0.into();
*rvalue = Rvalue::Use(Operand::Move(local_1.into()));
// Right shape; transform!
let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
stmt.source_info = opt_info.source_info;
match &mut stmt.kind {
StatementKind::Assign(box (place, rvalue)) => {
*place = opt_info.local_0.into();
*rvalue = Rvalue::Use(Operand::Move(opt_info.local_1.into()));
}
_ => unreachable!(),
}
_ => unreachable!(),

for stmt_idx in opt_info.stmts_to_remove {
bb.statements[stmt_idx].make_nop();
}

bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);

trace!("block is now {:?}", bb.statements);
}
s1.make_nop();
s2.make_nop();
}
}
}
Expand Down Expand Up @@ -129,7 +334,7 @@ fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)>
}
}

#[derive(PartialEq)]
#[derive(PartialEq, Debug)]
struct VarField<'tcx> {
field: Field,
field_ty: Ty<'tcx>,
Expand Down
14 changes: 12 additions & 2 deletions src/test/mir-opt/simplify-arm-identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ fn main() {
// }
// ...
// bb3: {
// StorageLive(_4);
// _4 = ((_1 as Foo).0: u8);
// ((_2 as Foo).0: u8) = move _4;
// StorageLive(_5);
// _5 = _4;
// ((_2 as Foo).0: u8) = move _5;
// discriminant(_2) = 0;
// StorageDead(_5);
// StorageDead(_4);
// goto -> bb4;
// }
// ...
Expand All @@ -65,9 +70,14 @@ fn main() {
// }
// ...
// bb3: {
// StorageLive(_4);
// _4 = ((_1 as Foo).0: u8);
// ((_2 as Foo).0: u8) = move _4;
// StorageLive(_5);
// _5 = _4;
// ((_2 as Foo).0: u8) = move _5;
// discriminant(_2) = 0;
// StorageDead(_5);
// StorageDead(_4);
// goto -> bb4;
// }
// ...
Expand Down
Loading

0 comments on commit be3cc06

Please sign in to comment.