Skip to content

Commit

Permalink
Transforms a match containing negative numbers into an assignment sta…
Browse files Browse the repository at this point in the history
…tement as well
  • Loading branch information
DianQK committed Feb 4, 2024
1 parent eccc782 commit 7a47635
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 59 deletions.
55 changes: 44 additions & 11 deletions compiler/rustc_mir_transform/src/match_branches.rs
Expand Up @@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> {
_ => unreachable!(),
};

if !self.can_simplify(tcx, targets, param_env, bbs) {
let discr_ty = discr.ty(local_decls, tcx);
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
return false;
}

// Take ownership of items now that we know we can optimize.
let discr = discr.clone();
let discr_ty = discr.ty(local_decls, tcx);

// Introduce a temporary for the discriminant value.
let source_info = bbs[switch_bb_idx].terminator().source_info;
Expand Down Expand Up @@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool;

fn new_stmts(
Expand Down Expand Up @@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
_discr_ty: Ty<'tcx>,
) -> bool {
if targets.iter().len() != 1 {
return false;
Expand Down Expand Up @@ -265,7 +267,7 @@ struct SimplifyToExp {
enum CompareType<'tcx, 'a> {
Same(&'a StatementKind<'tcx>),
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
Discr(&'a Place<'tcx>, Ty<'tcx>),
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
}

enum TransfromType {
Expand All @@ -279,7 +281,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
match compare_type {
CompareType::Same(_) => TransfromType::Same,
CompareType::Eq(_, _, _) => TransfromType::Eq,
CompareType::Discr(_, _) => TransfromType::Discr,
CompareType::Discr(_, _, _) => TransfromType::Discr,
}
}
}
Expand Down Expand Up @@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool {
if targets.iter().len() < 2 || targets.iter().len() > 64 {
return false;
Expand All @@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return false;
}

let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
let first_stmts = &bbs[first_target].statements;
let (second_val, second_target) = iter.next().unwrap();
let second_stmts = &bbs[second_target].statements;
Expand Down Expand Up @@ -379,12 +383,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
) {
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
(Some(f), Some(s))
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
&& f.try_to_int(f.size()).unwrap()
== ScalarInt::try_from_uint(first_val, discr_size)
.unwrap()
.try_to_int(discr_size)
.unwrap()
&& s.try_to_int(s.size()).unwrap()
== ScalarInt::try_from_uint(second_val, discr_size)
.unwrap()
.try_to_int(discr_size)
.unwrap())
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s)
== ScalarInt::try_from_uint(second_val, s.size())) =>
{
CompareType::Discr(lhs_f, f_c.const_.ty())
CompareType::Discr(
lhs_f,
f_c.const_.ty(),
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
)
}
_ => {
return false;
}
_ => return false,
}
}

Expand All @@ -409,15 +431,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
&& s_c.const_.ty() == f_ty
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
(
CompareType::Discr(lhs_f, f_ty),
CompareType::Discr(lhs_f, f_ty, is_signed),
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
return false;
};
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
return false;
if is_signed
&& s_c.const_.ty().is_signed()
&& f.try_to_int(f.size()).unwrap()
== ScalarInt::try_from_uint(other_val, discr_size)
.unwrap()
.try_to_int(discr_size)
.unwrap()
{
continue;
}
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
continue;
}
return false;
}
_ => return false,
}
Expand Down
Expand Up @@ -5,32 +5,37 @@
debug i => _1;
let mut _0: i8;
let mut _2: i16;
+ let mut _3: i16;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
}

bb1: {
_0 = const -3_i8;
goto -> bb5;
}

bb2: {
unreachable;
}

bb3: {
_0 = const -1_i8;
goto -> bb5;
}

bb4: {
_0 = const 2_i8;
goto -> bb5;
}

bb5: {
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
- }
-
- bb1: {
- _0 = const -3_i8;
- goto -> bb5;
- }
-
- bb2: {
- unreachable;
- }
-
- bb3: {
- _0 = const -1_i8;
- goto -> bb5;
- }
-
- bb4: {
- _0 = const 2_i8;
- goto -> bb5;
- }
-
- bb5: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as i8 (IntToInt);
+ StorageDead(_3);
return;
}
}
Expand Down
Expand Up @@ -5,32 +5,37 @@
debug i => _1;
let mut _0: i16;
let mut _2: i8;
+ let mut _3: i8;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
}

bb1: {
_0 = const -3_i16;
goto -> bb5;
}

bb2: {
unreachable;
}

bb3: {
_0 = const -1_i16;
goto -> bb5;
}

bb4: {
_0 = const 2_i16;
goto -> bb5;
}

bb5: {
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
- }
-
- bb1: {
- _0 = const -3_i16;
- goto -> bb5;
- }
-
- bb2: {
- unreachable;
- }
-
- bb3: {
- _0 = const -1_i16;
- goto -> bb5;
- }
-
- bb4: {
- _0 = const 2_i16;
- goto -> bb5;
- }
-
- bb5: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as i16 (IntToInt);
+ StorageDead(_3);
return;
}
}
Expand Down
8 changes: 6 additions & 2 deletions tests/mir-opt/matches_reduce_branches.rs
Expand Up @@ -144,7 +144,9 @@ enum EnumAi8 {
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
fn match_i8_i16(i: EnumAi8) -> i16 {
// CHECK-LABEL: fn match_i8_i16(
// CHECK: switchInt
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i16 (IntToInt);
// CHECH: return
match i {
EnumAi8::A => -1,
EnumAi8::B => 2,
Expand Down Expand Up @@ -173,7 +175,9 @@ enum EnumAi16 {
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
fn match_i16_i8(i: EnumAi16) -> i8 {
// CHECK-LABEL: fn match_i16_i8(
// CHECK: switchInt
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i8 (IntToInt);
// CHECH: return
match i {
EnumAi16::A => -1,
EnumAi16::B => 2,
Expand Down

0 comments on commit 7a47635

Please sign in to comment.