Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the default branch with an unreachable branch If it is the last variant #120268

Merged
merged 5 commits into from Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 25 additions & 2 deletions compiler/rustc_middle/src/mir/patch.rs
Expand Up @@ -11,6 +11,8 @@ pub struct MirPatch<'tcx> {
resume_block: Option<BasicBlock>,
// Only for unreachable in cleanup path.
unreachable_cleanup_block: Option<BasicBlock>,
// Only for unreachable not in cleanup path.
unreachable_no_cleanup_block: Option<BasicBlock>,
// Cached block for UnwindTerminate (with reason)
terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
body_span: Span,
Expand All @@ -27,6 +29,7 @@ impl<'tcx> MirPatch<'tcx> {
next_local: body.local_decls.len(),
resume_block: None,
unreachable_cleanup_block: None,
unreachable_no_cleanup_block: None,
terminate_block: None,
body_span: body.span,
};
Expand All @@ -43,9 +46,12 @@ impl<'tcx> MirPatch<'tcx> {
// Check if we already have an unreachable block
if matches!(block.terminator().kind, TerminatorKind::Unreachable)
&& block.statements.is_empty()
&& block.is_cleanup
{
result.unreachable_cleanup_block = Some(bb);
if block.is_cleanup {
result.unreachable_cleanup_block = Some(bb);
} else {
result.unreachable_no_cleanup_block = Some(bb);
}
continue;
}

Expand Down Expand Up @@ -95,6 +101,23 @@ impl<'tcx> MirPatch<'tcx> {
bb
}

pub fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
if let Some(bb) = self.unreachable_no_cleanup_block {
return bb;
}

let bb = self.new_block(BasicBlockData {
statements: vec![],
terminator: Some(Terminator {
source_info: SourceInfo::outermost(self.body_span),
kind: TerminatorKind::Unreachable,
}),
is_cleanup: false,
});
self.unreachable_no_cleanup_block = Some(bb);
bb
}

pub fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
if let Some((cached_bb, cached_reason)) = self.terminate_block
&& reason == cached_reason
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Expand Up @@ -74,6 +74,17 @@ impl SwitchTargets {
pub fn target_for_value(&self, value: u128) -> BasicBlock {
self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise())
}

/// Adds a new target to the switch. But You cannot add an already present value.
#[inline]
pub fn add_target(&mut self, value: u128, bb: BasicBlock) {
let value = Pu128(value);
if self.values.contains(&value) {
bug!("target value {:?} already present", value);
}
self.values.push(value);
self.targets.insert(self.targets.len() - 1, bb);
}
}

pub struct SwitchTargetsIter<'a> {
Expand Down
84 changes: 57 additions & 27 deletions compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
Expand Up @@ -2,8 +2,10 @@

use crate::MirPass;
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::{
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind,
TerminatorKind,
};
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Ty, TyCtxt};
Expand Down Expand Up @@ -77,7 +79,8 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("UninhabitedEnumBranching starting for {:?}", body.source);

let mut removable_switchs = Vec::new();
let mut unreachable_targets = Vec::new();
let mut patch = MirPatch::new(body);

for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
trace!("processing block {:?}", bb);
Expand All @@ -92,46 +95,73 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
);

let allowed_variants = if let Ok(layout) = layout {
let mut allowed_variants = if let Ok(layout) = layout {
variant_discriminants(&layout, discriminant_ty, tcx)
} else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
variant_range
.map(|variant| {
discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
})
.collect()
} else {
continue;
};

trace!("allowed_variants = {:?}", allowed_variants);

let terminator = bb_data.terminator();
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
unreachable_targets.clear();
let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
bug!()
};

let mut reachable_count = 0;
for (index, (val, _)) in targets.iter().enumerate() {
if allowed_variants.contains(&val) {
reachable_count += 1;
} else {
removable_switchs.push((bb, index));
if !allowed_variants.remove(&val) {
unreachable_targets.push(index);
}
}
let otherwise_is_empty_unreachable =
body.basic_blocks[targets.otherwise()].is_empty_unreachable();
// After resolving https://github.com/llvm/llvm-project/issues/78578,
// we can remove the limit on the number of successors.
fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool {
let mut successors = basic_blocks[bb].terminator().successors();
let Some(first_successor) = successors.next() else { return true };
if successors.next().is_some() {
return true;
}
if let TerminatorKind::SwitchInt { .. } =
&basic_blocks[first_successor].terminator().kind
{
return false;
};
true
}
let otherwise_is_last_variant = !otherwise_is_empty_unreachable
&& allowed_variants.len() == 1
&& check_successors(&body.basic_blocks, targets.otherwise());
let replace_otherwise_to_unreachable = otherwise_is_last_variant
|| !otherwise_is_empty_unreachable && allowed_variants.is_empty();
Copy link
Member

@RalfJung RalfJung Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could use a few more comments explaining what happens here -- why all these checks are needed and why they are combined in exactly the way they are. Imagine someone reading this code in a year without knowing about this PR -- what would they have to know to make sense of all this? For instance, what is check_successors even checking?

Also, a || b && c could use parentheses, the precedence is currently unclear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, testing has started, r- or I'll add a PR subsequently?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subsequent PR is fine.


if reachable_count == allowed_variants.len() {
removable_switchs.push((bb, targets.iter().count()));
if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
continue;
}
}

if removable_switchs.is_empty() {
return;
let unreachable_block = patch.unreachable_no_cleanup_block();
let mut targets = targets.clone();
if replace_otherwise_to_unreachable {
if otherwise_is_last_variant {
#[allow(rustc::potential_query_instability)]
let last_variant = *allowed_variants.iter().next().unwrap();
targets.add_target(last_variant, targets.otherwise());
}
unreachable_targets.push(targets.iter().count());
}
for index in unreachable_targets.iter() {
targets.all_targets_mut()[*index] = unreachable_block;
}
patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
}

let new_block = BasicBlockData::new(Some(Terminator {
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
kind: TerminatorKind::Unreachable,
}));
let unreachable_block = body.basic_blocks.as_mut().push(new_block);

for (bb, index) in removable_switchs {
let bb = &mut body.basic_blocks.as_mut()[bb];
let terminator = bb.terminator_mut();
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
targets.all_targets_mut()[index] = unreachable_block;
}
patch.apply(body);
}
}
24 changes: 24 additions & 0 deletions tests/codegen/enum/uninhabited_enum_default_branch.rs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no uninhabited enum anywhere in this test... how does the test filename make sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case came directly from the issue it fixed. It will call partial_cmp, so it's essentially the same as #119520 (comment) . I think it makes sense to add the test code in the issue, maybe I should create two test cases.

Copy link
Member

@RalfJung RalfJung Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But again there's no uninhabited enums anywhere that I can see, so (a) what does the test content have to do with the filename, and (b) what does it have to do with this PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the code a bit more, I think the MIR pass (and associated test) are misnamed. This is no longer just about uninhabited variants, it is now also about exploiting that Discriminant will never return something that isn't a variant index. I am a bit surprised that this is done as a MIR transform rather than during MIR building but the MIR transform is correct according to our current understanding of MIR semantics. Just the name is misleading after this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a) I can change the file name to issue-119520.rs.
(b) https://rust.godbolt.org/z/za5c5hzoY When I reduce the issue's case, I found out that it uses Ordering after inlining. It implies the enum. I also hope that this test case will not lose optimization due to other changes in the future.

Copy link
Member Author

@DianQK DianQK Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm considering updating the name. (It’s just that I didn’t think of a suitable name.)
Maybe I can change it to UnreachableEnumBranching.

I am a bit surprised that this is done as a MIR transform rather than during MIR building but the MIR transform is correct according to our current understanding of MIR semantics. Just the name is misleading after this PR.

It better for me to have MIR building match the structure of the code itself where possible. (This purpose may not matter either?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It better for me to have MIR building match the structure of the code itself where possible. (This purpose may not matter either?)

Ah, I may have misunderstood where this optimization kicks in. I thought even this would just use fallback for the last variant:

    match c {
        Less => -5,
        Equal => 0,
        Greater => 42,
    }

But already on stable that becomes switchInt(move _2) -> [255: bb3, 0: bb4, 1: bb1, otherwise: bb2].

I can change the file name to issue-119520.rs.

Once we have a better name for the pass, it can use that name. (Though it would also be good to mention the issue either in the file name or file contents. It's always good to add more cross-references and those are otherwise much harder to reconstruct in the future.)

Maybe I can change it to UnreachableEnumBranching.

I like it. :) The module-level comment in that file should then explain the two ways that "unreachable" is determined.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I may have misunderstood where this optimization kicks in. I thought even this would just use fallback for the last variant:

    match c {
        Less => -5,
        Equal => 0,
        Greater => 42,
    }

But already on stable that becomes switchInt(move _2) -> [255: bb3, 0: bb4, 1: bb1, otherwise: bb2].

This is something UninhabitedEnumBranching has already done.

This PR transforms following codes

    match c {
        Less => -5,
        Equal => 0,
        _ => 42,
    }

to

    match c {
        Less => -5,
        Equal => 0,
        Greater => 42,
    }

.

@@ -0,0 +1,24 @@
//@ compile-flags: -O

#![crate_type = "lib"]

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Int(u32);

const A: Int = Int(201);
const B: Int = Int(270);
const C: Int = Int(153);

// CHECK-LABEL: @foo(
// CHECK-SAME: [[TMP0:%.*]])
// CHECK-NEXT: start:
// CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0]], -201
// CHECK-NEXT: icmp ult i32 [[TMP1]], 70
// CHECK-NEXT: icmp eq i32 [[TMP0]], 153
// CHECK-NEXT: [[SPEC_SELECT:%.*]] = or i1
// CHECK-NEXT: ret i1 [[SPEC_SELECT]]
#[no_mangle]
pub fn foo(x: Int) -> bool {
(x >= A && x <= B)
|| x == C
}
Expand Up @@ -69,7 +69,7 @@
StorageLive(_6);
_6 = ((*_1).4: std::option::Option<usize>);
_7 = discriminant(_6);
switchInt(move _7) -> [1: bb4, otherwise: bb6];
switchInt(move _7) -> [1: bb4, 0: bb6, otherwise: bb9];
}

bb4: {
Expand Down Expand Up @@ -135,5 +135,9 @@
StorageDead(_6);
return;
}

bb9: {
unreachable;
}
}

Expand Up @@ -69,7 +69,7 @@
StorageLive(_6);
_6 = ((*_1).4: std::option::Option<usize>);
_7 = discriminant(_6);
switchInt(move _7) -> [1: bb4, otherwise: bb6];
switchInt(move _7) -> [1: bb4, 0: bb6, otherwise: bb9];
}

bb4: {
Expand Down Expand Up @@ -135,5 +135,9 @@
StorageDead(_6);
return;
}

bb9: {
unreachable;
}
}

Expand Up @@ -33,21 +33,21 @@ fn num_to_digit(_1: char) -> u32 {
_3 = &_2;
StorageLive(_4);
_4 = discriminant(_2);
StorageDead(_3);
StorageDead(_2);
switchInt(move _4) -> [1: bb2, otherwise: bb7];
switchInt(move _4) -> [1: bb2, 0: bb6, otherwise: bb8];
}

bb2: {
StorageDead(_4);
StorageDead(_3);
StorageDead(_2);
StorageLive(_5);
_5 = char::methods::<impl char>::to_digit(move _1, const 8_u32) -> [return: bb3, unwind unreachable];
}

bb3: {
StorageLive(_6);
_6 = discriminant(_5);
switchInt(move _6) -> [0: bb4, 1: bb5, otherwise: bb6];
switchInt(move _6) -> [0: bb4, 1: bb5, otherwise: bb8];
}

bb4: {
Expand All @@ -58,20 +58,22 @@ fn num_to_digit(_1: char) -> u32 {
_0 = move ((_5 as Some).0: u32);
StorageDead(_6);
StorageDead(_5);
goto -> bb8;
goto -> bb7;
}

bb6: {
unreachable;
StorageDead(_4);
StorageDead(_3);
StorageDead(_2);
_0 = const 0_u32;
goto -> bb7;
}

bb7: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb8;
return;
}

bb8: {
return;
unreachable;
}
}
Expand Up @@ -33,21 +33,21 @@ fn num_to_digit(_1: char) -> u32 {
_3 = &_2;
StorageLive(_4);
_4 = discriminant(_2);
StorageDead(_3);
StorageDead(_2);
switchInt(move _4) -> [1: bb2, otherwise: bb7];
switchInt(move _4) -> [1: bb2, 0: bb6, otherwise: bb8];
}

bb2: {
StorageDead(_4);
StorageDead(_3);
StorageDead(_2);
StorageLive(_5);
_5 = char::methods::<impl char>::to_digit(move _1, const 8_u32) -> [return: bb3, unwind continue];
}

bb3: {
StorageLive(_6);
_6 = discriminant(_5);
switchInt(move _6) -> [0: bb4, 1: bb5, otherwise: bb6];
switchInt(move _6) -> [0: bb4, 1: bb5, otherwise: bb8];
}

bb4: {
Expand All @@ -58,20 +58,22 @@ fn num_to_digit(_1: char) -> u32 {
_0 = move ((_5 as Some).0: u32);
StorageDead(_6);
StorageDead(_5);
goto -> bb8;
goto -> bb7;
}

bb6: {
unreachable;
StorageDead(_4);
StorageDead(_3);
StorageDead(_2);
_0 = const 0_u32;
goto -> bb7;
}

bb7: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb8;
return;
}

bb8: {
return;
unreachable;
}
}