Skip to content

Commit

Permalink
Auto merge of #118461 - celinval:smir-switch-targets, r=ouz-a
Browse files Browse the repository at this point in the history
Change `SwitchTarget` representation in StableMIR

The new structure encodes its invariant, which reduces the likelihood of having an inconsistent representation. It is also more intuitive and user friendly.

I encapsulated the structure for now in case we decide to change it back.

### Notes:

1. I had to change the `Successors` type, since there's a conflict on the iterator type. We could potentially implement an iterator here, but I would prefer keeping it simple for now, and add a `successors_iter()` method if needed.
2. I removed `CoroutineDrop` for now since it we never create it. We can add it when we add support to other MIR stages.
  • Loading branch information
bors committed Dec 1, 2023
2 parents 1d726a2 + 9d2c923 commit c263ccf
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 44 deletions.
11 changes: 5 additions & 6 deletions compiler/rustc_smir/src/rustc_smir/convert/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,12 @@ impl<'tcx> Stable<'tcx> for mir::TerminatorKind<'tcx> {
mir::TerminatorKind::SwitchInt { discr, targets } => TerminatorKind::SwitchInt {
discr: discr.stable(tables),
targets: {
let (value_vec, mut target_vec): (Vec<_>, Vec<_>) =
targets.iter().map(|(value, target)| (value, target.as_usize())).unzip();
// We need to push otherwise as last element to ensure it's same as in MIR.
target_vec.push(targets.otherwise().as_usize());
stable_mir::mir::SwitchTargets { value: value_vec, targets: target_vec }
let branches = targets.iter().map(|(val, target)| (val, target.as_usize()));
stable_mir::mir::SwitchTargets::new(
branches.collect(),
targets.otherwise().as_usize(),
)
},
otherwise: targets.otherwise().as_usize(),
},
mir::TerminatorKind::UnwindResume => TerminatorKind::Resume,
mir::TerminatorKind::UnwindTerminate(_) => TerminatorKind::Abort,
Expand Down
82 changes: 55 additions & 27 deletions compiler/stable_mir/src/mir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::ty::{
AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability, Region, RigidTy, Ty, TyKind,
};
use crate::{Error, Opaque, Span, Symbol};
use std::{io, slice};
use std::io;
/// The SMIR representation of a single function.
#[derive(Clone, Debug)]
pub struct Body {
Expand All @@ -23,6 +23,8 @@ pub struct Body {
pub(super) var_debug_info: Vec<VarDebugInfo>,
}

pub type BasicBlockIdx = usize;

impl Body {
/// Constructs a `Body`.
///
Expand Down Expand Up @@ -114,66 +116,64 @@ pub struct Terminator {
}

impl Terminator {
pub fn successors(&self) -> Successors<'_> {
pub fn successors(&self) -> Successors {
self.kind.successors()
}
}

pub type Successors<'a> = impl Iterator<Item = usize> + 'a;
pub type Successors = Vec<BasicBlockIdx>;

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum TerminatorKind {
Goto {
target: usize,
target: BasicBlockIdx,
},
SwitchInt {
discr: Operand,
targets: SwitchTargets,
otherwise: usize,
},
Resume,
Abort,
Return,
Unreachable,
Drop {
place: Place,
target: usize,
target: BasicBlockIdx,
unwind: UnwindAction,
},
Call {
func: Operand,
args: Vec<Operand>,
destination: Place,
target: Option<usize>,
target: Option<BasicBlockIdx>,
unwind: UnwindAction,
},
Assert {
cond: Operand,
expected: bool,
msg: AssertMessage,
target: usize,
target: BasicBlockIdx,
unwind: UnwindAction,
},
CoroutineDrop,
InlineAsm {
template: String,
operands: Vec<InlineAsmOperand>,
options: String,
line_spans: String,
destination: Option<usize>,
destination: Option<BasicBlockIdx>,
unwind: UnwindAction,
},
}

impl TerminatorKind {
pub fn successors(&self) -> Successors<'_> {
pub fn successors(&self) -> Successors {
use self::TerminatorKind::*;
match *self {
Call { target: Some(t), unwind: UnwindAction::Cleanup(ref u), .. }
| Drop { target: t, unwind: UnwindAction::Cleanup(ref u), .. }
| Assert { target: t, unwind: UnwindAction::Cleanup(ref u), .. }
| InlineAsm { destination: Some(t), unwind: UnwindAction::Cleanup(ref u), .. } => {
Some(t).into_iter().chain(slice::from_ref(u).into_iter().copied())
Call { target: Some(t), unwind: UnwindAction::Cleanup(u), .. }
| Drop { target: t, unwind: UnwindAction::Cleanup(u), .. }
| Assert { target: t, unwind: UnwindAction::Cleanup(u), .. }
| InlineAsm { destination: Some(t), unwind: UnwindAction::Cleanup(u), .. } => {
vec![t, u]
}
Goto { target: t }
| Call { target: None, unwind: UnwindAction::Cleanup(t), .. }
Expand All @@ -182,21 +182,18 @@ impl TerminatorKind {
| Assert { target: t, unwind: _, .. }
| InlineAsm { destination: None, unwind: UnwindAction::Cleanup(t), .. }
| InlineAsm { destination: Some(t), unwind: _, .. } => {
Some(t).into_iter().chain((&[]).into_iter().copied())
vec![t]
}

CoroutineDrop
| Return
Return
| Resume
| Abort
| Unreachable
| Call { target: None, unwind: _, .. }
| InlineAsm { destination: None, unwind: _, .. } => {
None.into_iter().chain((&[]).into_iter().copied())
}
SwitchInt { ref targets, .. } => {
None.into_iter().chain(targets.targets.iter().copied())
vec![]
}
SwitchInt { ref targets, .. } => targets.all_targets(),
}
}

Expand All @@ -205,7 +202,6 @@ impl TerminatorKind {
TerminatorKind::Goto { .. }
| TerminatorKind::Return
| TerminatorKind::Unreachable
| TerminatorKind::CoroutineDrop
| TerminatorKind::Resume
| TerminatorKind::Abort
| TerminatorKind::SwitchInt { .. } => None,
Expand All @@ -231,7 +227,7 @@ pub enum UnwindAction {
Continue,
Unreachable,
Terminate,
Cleanup(usize),
Cleanup(BasicBlockIdx),
}

#[derive(Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -662,10 +658,42 @@ pub struct Constant {
pub literal: Const,
}

/// The possible branch sites of a [TerminatorKind::SwitchInt].
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SwitchTargets {
pub value: Vec<u128>,
pub targets: Vec<usize>,
/// The conditional branches where the first element represents the value that guards this
/// branch, and the second element is the branch target.
branches: Vec<(u128, BasicBlockIdx)>,
/// The `otherwise` branch which will be taken in case none of the conditional branches are
/// satisfied.
otherwise: BasicBlockIdx,
}

impl SwitchTargets {
/// All possible targets including the `otherwise` target.
pub fn all_targets(&self) -> Successors {
self.branches.iter().map(|(_, target)| *target).chain(Some(self.otherwise)).collect()
}

/// The `otherwise` branch target.
pub fn otherwise(&self) -> BasicBlockIdx {
self.otherwise
}

/// The conditional targets which are only taken if the pattern matches the given value.
pub fn branches(&self) -> impl Iterator<Item = (u128, BasicBlockIdx)> + '_ {
self.branches.iter().copied()
}

/// The number of targets including `otherwise`.
pub fn len(&self) -> usize {
self.branches.len() + 1
}

/// Create a new SwitchTargets from the given branches and `otherwise` target.
pub fn new(branches: Vec<(u128, BasicBlockIdx)>, otherwise: BasicBlockIdx) -> SwitchTargets {
SwitchTargets { branches, otherwise }
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
Expand Down
15 changes: 7 additions & 8 deletions compiler/stable_mir/src/mir/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ pub fn pretty_statement(statement: &StatementKind) -> String {

pub fn pretty_terminator<W: io::Write>(terminator: &TerminatorKind, w: &mut W) -> io::Result<()> {
write!(w, "{}", pretty_terminator_head(terminator))?;
let successor_count = terminator.successors().count();
let successors = terminator.successors();
let successor_count = successors.len();
let labels = pretty_successor_labels(terminator);

let show_unwind = !matches!(terminator.unwind(), None | Some(UnwindAction::Cleanup(_)));
Expand All @@ -98,12 +99,12 @@ pub fn pretty_terminator<W: io::Write>(terminator: &TerminatorKind, w: &mut W) -
Ok(())
}
(1, false) => {
write!(w, " -> {:?}", terminator.successors().next().unwrap())?;
write!(w, " -> {:?}", successors[0])?;
Ok(())
}
_ => {
write!(w, " -> [")?;
for (i, target) in terminator.successors().enumerate() {
for (i, target) in successors.iter().enumerate() {
if i > 0 {
write!(w, ", ")?;
}
Expand Down Expand Up @@ -157,20 +158,18 @@ pub fn pretty_terminator_head(terminator: &TerminatorKind) -> String {
pretty.push_str(")");
pretty
}
CoroutineDrop => format!(" coroutine_drop"),
InlineAsm { .. } => todo!(),
}
}

pub fn pretty_successor_labels(terminator: &TerminatorKind) -> Vec<String> {
use self::TerminatorKind::*;
match terminator {
Resume | Abort | Return | Unreachable | CoroutineDrop => vec![],
Resume | Abort | Return | Unreachable => vec![],
Goto { .. } => vec!["".to_string()],
SwitchInt { targets, .. } => targets
.value
.iter()
.map(|target| format!("{}", target))
.branches()
.map(|(val, _target)| format!("{val}"))
.chain(iter::once("otherwise".into()))
.collect(),
Drop { unwind: UnwindAction::Cleanup(_), .. } => vec!["return".into(), "unwind".into()],
Expand Down
5 changes: 2 additions & 3 deletions compiler/stable_mir/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ pub trait MirVisitor {
TerminatorKind::Goto { .. }
| TerminatorKind::Resume
| TerminatorKind::Abort
| TerminatorKind::Unreachable
| TerminatorKind::CoroutineDrop => {}
| TerminatorKind::Unreachable => {}
TerminatorKind::Assert { cond, expected: _, msg, target: _, unwind: _ } => {
self.visit_operand(cond, location);
self.visit_assert_msg(msg, location);
Expand Down Expand Up @@ -268,7 +267,7 @@ pub trait MirVisitor {
let local = RETURN_LOCAL;
self.visit_local(&local, PlaceContext::NON_MUTATING, location);
}
TerminatorKind::SwitchInt { discr, targets: _, otherwise: _ } => {
TerminatorKind::SwitchInt { discr, targets: _ } => {
self.visit_operand(discr, location);
}
}
Expand Down

0 comments on commit c263ccf

Please sign in to comment.