Skip to content

Commit

Permalink
New mir-opt pass to simplify gotos with const values
Browse files Browse the repository at this point in the history
Fixes #77355
  • Loading branch information
simonvandel committed Feb 22, 2021
1 parent 15598a8 commit a6dccfe
Show file tree
Hide file tree
Showing 22 changed files with 652 additions and 374 deletions.
9 changes: 8 additions & 1 deletion compiler/rustc_middle/src/mir/mod.rs
Expand Up @@ -1518,7 +1518,14 @@ pub enum StatementKind<'tcx> {
}

impl<'tcx> StatementKind<'tcx> {
pub fn as_assign_mut(&mut self) -> Option<&mut Box<(Place<'tcx>, Rvalue<'tcx>)>> {
pub fn as_assign_mut(&mut self) -> Option<&mut (Place<'tcx>, Rvalue<'tcx>)> {
match self {
StatementKind::Assign(x) => Some(x),
_ => None,
}
}

pub fn as_assign(&self) -> Option<&(Place<'tcx>, Rvalue<'tcx>)> {
match self {
StatementKind::Assign(x) => Some(x),
_ => None,
Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Expand Up @@ -407,6 +407,22 @@ impl<'tcx> TerminatorKind<'tcx> {
| TerminatorKind::FalseUnwind { ref mut unwind, .. } => Some(unwind),
}
}

pub fn as_switch(&self) -> Option<(&Operand<'tcx>, Ty<'tcx>, &SwitchTargets)> {
match self {
TerminatorKind::SwitchInt { discr, switch_ty, targets } => {
Some((discr, switch_ty, targets))
}
_ => None,
}
}

pub fn as_goto(&self) -> Option<BasicBlock> {
match self {
TerminatorKind::Goto { target } => Some(*target),
_ => None,
}
}
}

impl<'tcx> Debug for TerminatorKind<'tcx> {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir/src/interpret/operand.rs
Expand Up @@ -514,6 +514,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
/// Evaluate the operand, returning a place where you can then find the data.
/// If you already know the layout, you can save two table lookups
/// by passing it in here.
#[inline]
pub fn eval_operand(
&self,
mir_op: &mir::Operand<'tcx>,
Expand Down
122 changes: 122 additions & 0 deletions compiler/rustc_mir/src/transform/const_goto.rs
@@ -0,0 +1,122 @@
//! This pass optimizes the following sequence
//! ```rust,ignore (example)
//! bb2: {
//! _2 = const true;
//! goto -> bb3;
//! }
//!
//! bb3: {
//! switchInt(_2) -> [false: bb4, otherwise: bb5];
//! }
//! ```
//! into
//! ```rust,ignore (example)
//! bb2: {
//! _2 = const true;
//! goto -> bb5;
//! }
//! ```

use crate::transform::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use rustc_middle::{mir::visit::Visitor, ty::ParamEnv};

use super::simplify::{simplify_cfg, simplify_locals};

pub struct ConstGoto;

impl<'tcx> MirPass<'tcx> for ConstGoto {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
if tcx.sess.opts.debugging_opts.mir_opt_level < 3 {
return;
}
trace!("Running ConstGoto on {:?}", body.source);
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
let mut opt_finder =
ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env };
opt_finder.visit_body(body);
let should_simplify = !opt_finder.optimizations.is_empty();
for opt in opt_finder.optimizations {
let terminator = body.basic_blocks_mut()[opt.bb_with_goto].terminator_mut();
let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto };
debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto);
terminator.kind = new_goto;
}

// if we applied optimizations, we potentially have some cfg to cleanup to
// make it easier for further passes
if should_simplify {
simplify_cfg(body);
simplify_locals(body, tcx);
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'a, 'tcx> {
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
let _: Option<_> = try {
let target = terminator.kind.as_goto()?;
// We only apply this optimization if the last statement is a const assignment
let last_statement = self.body.basic_blocks()[location.block].statements.last()?;

if let (place, Rvalue::Use(Operand::Constant(_const))) =
last_statement.kind.as_assign()?
{
// We found a constant being assigned to `place`.
// Now check that the target of this Goto switches on this place.
let target_bb = &self.body.basic_blocks()[target];

// FIXME(simonvandel): We are conservative here when we don't allow
// any statements in the target basic block.
// This could probably be relaxed to allow `StorageDead`s which could be
// copied to the predecessor of this block.
if !target_bb.statements.is_empty() {
None?
}

let target_bb_terminator = target_bb.terminator();
let (discr, switch_ty, targets) = target_bb_terminator.kind.as_switch()?;
if discr.place() == Some(*place) {
// We now know that the Switch matches on the const place, and it is statementless
// Now find which value in the Switch matches the const value.
let const_value =
_const.literal.try_eval_bits(self.tcx, self.param_env, switch_ty)?;
let found_value_idx_option = targets
.iter()
.enumerate()
.find(|(_, (value, _))| const_value == *value)
.map(|(idx, _)| idx);

let target_to_use_in_goto =
if let Some(found_value_idx) = found_value_idx_option {
targets.iter().nth(found_value_idx).unwrap().1
} else {
// If we did not find the const value in values, it must be the otherwise case
targets.otherwise()
};

self.optimizations.push(OptimizationToApply {
bb_with_goto: location.block,
target_to_use_in_goto,
});
}
}
Some(())
};

self.super_terminator(terminator, location);
}
}

struct OptimizationToApply {
bb_with_goto: BasicBlock,
target_to_use_in_goto: BasicBlock,
}

pub struct ConstGotoOptimizationFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
body: &'a Body<'tcx>,
param_env: ParamEnv<'tcx>,
optimizations: Vec<OptimizationToApply>,
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir/src/transform/mod.rs
Expand Up @@ -22,6 +22,7 @@ pub mod check_packed_ref;
pub mod check_unsafety;
pub mod cleanup_post_borrowck;
pub mod const_debuginfo;
pub mod const_goto;
pub mod const_prop;
pub mod coverage;
pub mod deaggregator;
Expand Down Expand Up @@ -492,6 +493,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {

// The main optimizations that we do on MIR.
let optimizations: &[&dyn MirPass<'tcx>] = &[
&const_goto::ConstGoto,
&remove_unneeded_drops::RemoveUnneededDrops,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
Expand Down
37 changes: 20 additions & 17 deletions compiler/rustc_mir/src/transform/simplify.rs
Expand Up @@ -320,28 +320,31 @@ pub struct SimplifyLocals;
impl<'tcx> MirPass<'tcx> for SimplifyLocals {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("running SimplifyLocals on {:?}", body.source);
simplify_locals(body, tcx);
}
}

// First, we're going to get a count of *actual* uses for every `Local`.
let mut used_locals = UsedLocals::new(body);
pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
// First, we're going to get a count of *actual* uses for every `Local`.
let mut used_locals = UsedLocals::new(body);

// Next, we're going to remove any `Local` with zero actual uses. When we remove those
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
// fixedpoint where there are no more unused locals.
remove_unused_definitions(&mut used_locals, body);
// Next, we're going to remove any `Local` with zero actual uses. When we remove those
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
// fixedpoint where there are no more unused locals.
remove_unused_definitions(&mut used_locals, body);

// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
let map = make_local_map(&mut body.local_decls, &used_locals);
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
let map = make_local_map(&mut body.local_decls, &used_locals);

// Only bother running the `LocalUpdater` if we actually found locals to remove.
if map.iter().any(Option::is_none) {
// Update references to all vars and tmps now
let mut updater = LocalUpdater { map, tcx };
updater.visit_body(body);
// Only bother running the `LocalUpdater` if we actually found locals to remove.
if map.iter().any(Option::is_none) {
// Update references to all vars and tmps now
let mut updater = LocalUpdater { map, tcx };
updater.visit_body(body);

body.local_decls.shrink_to_fit();
}
body.local_decls.shrink_to_fit();
}
}

Expand Down
Expand Up @@ -80,7 +80,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral {
// we convert the move in the comparison statement to a copy.

// unwrap is safe as we know this statement is an assign
let box (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();

use Operand::*;
match rhs {
Expand Down
52 changes: 52 additions & 0 deletions src/test/mir-opt/const_goto.issue_77355_opt.ConstGoto.diff
@@ -0,0 +1,52 @@
- // MIR for `issue_77355_opt` before ConstGoto
+ // MIR for `issue_77355_opt` after ConstGoto

fn issue_77355_opt(_1: Foo) -> u64 {
debug num => _1; // in scope 0 at $DIR/const_goto.rs:11:20: 11:23
let mut _0: u64; // return place in scope 0 at $DIR/const_goto.rs:11:33: 11:36
- let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- let mut _3: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ let mut _2: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28

bb0: {
- StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- _3 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
- switchInt(move _3) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ _2 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ switchInt(move _2) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
}

bb1: {
- _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+ _0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
}

bb2: {
- _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- }
-
- bb3: {
- switchInt(move _2) -> [false: bb5, otherwise: bb4]; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
- }
-
- bb4: {
_0 = const 23_u64; // scope 0 at $DIR/const_goto.rs:12:41: 12:43
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
}

- bb5: {
- _0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
- }
-
- bb6: {
- StorageDead(_2); // scope 0 at $DIR/const_goto.rs:12:56: 12:57
+ bb3: {
return; // scope 0 at $DIR/const_goto.rs:13:2: 13:2
}
}

16 changes: 16 additions & 0 deletions src/test/mir-opt/const_goto.rs
@@ -0,0 +1,16 @@
pub enum Foo {
A,
B,
C,
D,
E,
F,
}

// EMIT_MIR const_goto.issue_77355_opt.ConstGoto.diff
fn issue_77355_opt(num: Foo) -> u64 {
if matches!(num, Foo::B | Foo::C) { 23 } else { 42 }
}
fn main() {
issue_77355_opt(Foo::A);
}
51 changes: 51 additions & 0 deletions src/test/mir-opt/const_goto_const_eval_fail.f.ConstGoto.diff
@@ -0,0 +1,51 @@
- // MIR for `f` before ConstGoto
+ // MIR for `f` after ConstGoto

fn f() -> u64 {
let mut _0: u64; // return place in scope 0 at $DIR/const_goto_const_eval_fail.rs:6:44: 6:47
let mut _1: bool; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
let mut _2: i32; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16

bb0: {
StorageLive(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
StorageLive(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
_2 = const A; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
switchInt(_2) -> [1_i32: bb2, 2_i32: bb2, 3_i32: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:13: 9:14
}

bb1: {
_1 = const true; // scope 0 at $DIR/const_goto_const_eval_fail.rs:10:18: 10:22
goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
}

bb2: {
_1 = const B; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:26: 9:27
- goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
+ switchInt(_1) -> [false: bb4, otherwise: bb3]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
}

bb3: {
- switchInt(_1) -> [false: bb5, otherwise: bb4]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
- }
-
- bb4: {
_0 = const 2_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:14:17: 14:18
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
}

- bb5: {
+ bb4: {
_0 = const 1_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:18: 13:19
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
}

- bb6: {
+ bb5: {
StorageDead(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
StorageDead(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
return; // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:2: 16:2
}
}

16 changes: 16 additions & 0 deletions src/test/mir-opt/const_goto_const_eval_fail.rs
@@ -0,0 +1,16 @@
#![feature(min_const_generics)]
#![crate_type = "lib"]

// If const eval fails, then don't crash
// EMIT_MIR const_goto_const_eval_fail.f.ConstGoto.diff
pub fn f<const A: i32, const B: bool>() -> u64 {
match {
match A {
1 | 2 | 3 => B,
_ => true,
}
} {
false => 1,
true => 2,
}
}

0 comments on commit a6dccfe

Please sign in to comment.