Skip to content

Commit

Permalink
Split blocks after match statements (#2224)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvalsw committed Feb 22, 2023
1 parent 34796ef commit eaa1153
Show file tree
Hide file tree
Showing 40 changed files with 2,391 additions and 1,465 deletions.
49 changes: 47 additions & 2 deletions crates/cairo-lang-lowering/src/borrow_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ impl<'a> BorrowChecker<'a> {
demand
}

/// Gets the demand of the variables prior to the given remapping.
/// Also returns a new remapping of only the entries that are used according to the demand. The
/// caller can use it for optimizing the remapping.
fn get_remapping_demand(
&mut self,
target_block_id: &BlockId,
remapping: &VarRemapping,
callsite_info: Option<CallsiteInfo<'_>>,
) -> (VarRemapping, LoweredDemand) {
let mut demand = self.get_demand(callsite_info, RealBlock(*target_block_id, 0));
let mut new_remapping = VarRemapping::default();
for (dst, src) in remapping.iter() {
if demand.vars.swap_remove(dst) {
demand.vars.insert(*src);
new_remapping.insert(*dst, *src);
}
}
(new_remapping, demand)
}

/// Computes the variables [LoweredDemand] from a [FlatBlockEnd], while outputting borrow
/// checking diagnostics.
fn get_block_end_demand(
Expand All @@ -125,9 +145,33 @@ impl<'a> BorrowChecker<'a> {
callsite_info: Option<CallsiteInfo<'_>>,
) -> LoweredDemand {
let demand = match block_end {
FlatBlockEnd::Fallthrough(_target_block_id, _remapping)
| FlatBlockEnd::Goto(_target_block_id, _remapping) => todo!(),
FlatBlockEnd::Fallthrough(target_block_id, remapping) => {
let (new_remapping, demand) =
self.get_remapping_demand(target_block_id, remapping, callsite_info);
assert!(
self.new_ends
.insert(
block_id,
FlatBlockEnd::Fallthrough(*target_block_id, new_remapping)
)
.is_none(),
"Borrow checker cannot visit a block more than once."
);
demand
}
FlatBlockEnd::Goto(target_block_id, remapping) => {
let (new_remapping, demand) =
self.get_remapping_demand(target_block_id, remapping, callsite_info);
assert!(
self.new_ends
.insert(block_id, FlatBlockEnd::Goto(*target_block_id, new_remapping))
.is_none(),
"Borrow checker cannot visit a block more than once."
);
demand
}
FlatBlockEnd::Callsite(remapping) => {
// TODO(yuval): remove in the future, or export to function.
let callsite_info = callsite_info.unwrap();
let mut demand =
self.get_demand(callsite_info.parent.cloned(), callsite_info.return_site);
Expand All @@ -146,6 +190,7 @@ impl<'a> BorrowChecker<'a> {
}
FlatBlockEnd::Return(vars) => LoweredDemand { vars: vars.iter().copied().collect() },
FlatBlockEnd::Unreachable => LoweredDemand::default(),
FlatBlockEnd::NotSet => unreachable!(),
};
demand
}
Expand Down
4 changes: 1 addition & 3 deletions crates/cairo-lang-lowering/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ impl DiagnosticEntry for LoweringDiagnostic {
"`inline` without arguments is not supported.".into()
}
LoweringDiagnosticKind::InliningFunctionWithUnreachableEndNotSupported => {
"Inlining of functions where the end of the root block is unreachable is not \
supported."
.into()
"Inlining of functions where the end is unreachable is not supported.".into()
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions crates/cairo-lang-lowering/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ impl DebugWithDb<LoweredFormatter<'_>> for StructuredLowered {

impl DebugWithDb<LoweredFormatter<'_>> for StructuredBlock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, ctx: &LoweredFormatter<'_>) -> std::fmt::Result {
if !self.is_set() {
return write!(f, "BLOCK_NOT_SET");
}
write!(f, "Inputs:")?;
let mut inputs = self.inputs.iter().peekable();
while let Some(var) = inputs.next() {
Expand Down Expand Up @@ -95,6 +98,9 @@ impl DebugWithDb<LoweredFormatter<'_>> for StructuredBlockEnd {
StructuredBlockEnd::Callsite(remapping) => {
return write!(f, " Callsite({:?})", remapping.debug(ctx));
}
StructuredBlockEnd::Fallthrough { target, remapping } => {
return write!(f, " Fallthrough({}, {:?})", target.0, remapping.debug(ctx));
}
StructuredBlockEnd::Return { refs, returns } => {
write!(f, " Return(")?;
chain!(refs, returns).copied().collect()
Expand All @@ -106,6 +112,7 @@ impl DebugWithDb<LoweredFormatter<'_>> for StructuredBlockEnd {
StructuredBlockEnd::Unreachable => {
return write!(f, " Unreachable");
}
StructuredBlockEnd::NotSet => unreachable!(),
};
let mut outputs = outputs.iter().peekable();
while let Some(var) = outputs.next() {
Expand Down Expand Up @@ -183,6 +190,7 @@ impl DebugWithDb<LoweredFormatter<'_>> for FlatBlockEnd {
FlatBlockEnd::Unreachable => {
return write!(f, " Unreachable");
}
FlatBlockEnd::NotSet => unreachable!(),
};
let mut outputs = outputs.iter().peekable();
while let Some(var) = outputs.next() {
Expand Down
52 changes: 37 additions & 15 deletions crates/cairo-lang-lowering/src/inline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use cairo_lang_syntax::node::ast;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{izip, Itertools};

use crate::blocks::FlatBlocks;
use crate::blocks::{Blocks, FlatBlocks};
use crate::db::LoweringGroup;
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics};
use crate::lower::context::{LoweringContext, LoweringContextBuilder, VarRequest};
Expand Down Expand Up @@ -91,13 +91,15 @@ fn gather_inlining_info(

let lowered = db.priv_function_with_body_lowered_flat(function_id)?;
let root_block_id = lowered.root_block?;
// TODO(yuval): Consider caching the result and use it in the next phase.
let last_block_id = find_last_block(root_block_id, &lowered.blocks);

for (block_id, block) in lowered.blocks.iter() {
match &block.end {
FlatBlockEnd::Return(_) => {}
FlatBlockEnd::Unreachable => {
FlatBlockEnd::Unreachable | FlatBlockEnd::Goto(..) | FlatBlockEnd::Fallthrough(..) => {
// TODO(ilya): Remove the following limitation.
if block_id == root_block_id {
if block_id == last_block_id {
if report_diagnostics {
diagnostics.report(
function_id.untyped_stable_ptr(defs_db),
Expand All @@ -107,8 +109,9 @@ fn gather_inlining_info(
return Ok(info);
}
}
FlatBlockEnd::Callsite(_) | FlatBlockEnd::Fallthrough(..) | FlatBlockEnd::Goto(..) => {
if block_id == root_block_id {
FlatBlockEnd::NotSet => unreachable!(),
FlatBlockEnd::Callsite(_) => {
if block_id == root_block_id || block_id == last_block_id {
panic!("Unexpected block end.");
}
}
Expand All @@ -127,8 +130,11 @@ fn should_inline(_db: &dyn LoweringGroup, lowered: &FlatLowered) -> Maybe<bool>
let root_block = &lowered.blocks[root_block_id];

match &root_block.end {
FlatBlockEnd::Return(_) | FlatBlockEnd::Unreachable => {}
FlatBlockEnd::Callsite(_) | FlatBlockEnd::Fallthrough(..) | FlatBlockEnd::Goto(..) => {
FlatBlockEnd::Return(_)
| FlatBlockEnd::Unreachable
| FlatBlockEnd::Goto(..)
| FlatBlockEnd::Fallthrough(..) => {}
FlatBlockEnd::Callsite(_) | FlatBlockEnd::NotSet => {
panic!("Unexpected block end.");
}
};
Expand Down Expand Up @@ -240,7 +246,7 @@ impl BlockQueue {
self.block_queue.push_back(block);
BlockId(self.flat_blocks.len() + self.block_queue.len())
}
//. Pops a block from the queue.
// Pops a block from the queue.
fn dequeue(&mut self) -> Option<FlatBlock> {
self.block_queue.pop_front()
}
Expand All @@ -263,7 +269,8 @@ pub struct Mapper<'a, 'b> {
return_block_id: BlockId,
outputs: &'a [id_arena::Id<crate::Variable>],

/// Offset between blocks_ids in lowered and block_ids in ctx.
/// An offset that is added to all the block IDs in order to translate them into the new
/// lowering representation.
block_id_offset: BlockId,
}

Expand Down Expand Up @@ -301,6 +308,7 @@ impl<'a, 'b> Rebuilder for Mapper<'a, 'b> {
| FlatBlockEnd::Unreachable
| FlatBlockEnd::Fallthrough(_, _)
| FlatBlockEnd::Goto(_, _) => {}
FlatBlockEnd::NotSet => unreachable!(),
}
}
}
Expand Down Expand Up @@ -374,9 +382,9 @@ impl<'db> FunctionInlinerRewriter<'db> {
Ok(())
}

/// Inlines the given function, with the given input and outputs variables.
/// The statements that needs to replace the call statement in the original block
/// are pushed into the statement_rewrite_stack.
/// Inlines the given function, with the given input and output variables.
/// The statements that need to replace the call statement in the original block
/// are pushed into the `statement_rewrite_stack`.
/// May also push additional blocks to the block queue.
/// The function takes an optional return block id to handle early returns.
pub fn inline_function(
Expand All @@ -389,8 +397,7 @@ impl<'db> FunctionInlinerRewriter<'db> {
let root_block_id = lowered.root_block?;
let root_block = &lowered.blocks[root_block_id];

// Create a new block with all the statements that follow
// the call statement.
// Create a new block with all the statements that follow the call statement.
let return_block_id = self.block_queue.enqueue_block(FlatBlock {
inputs: vec![],
statements: self.statement_rewrite_stack.consume(),
Expand Down Expand Up @@ -424,11 +431,17 @@ impl<'db> FunctionInlinerRewriter<'db> {
self.block_end =
FlatBlockEnd::Fallthrough(mapper.map_block_id(root_block_id), VarRemapping::default());

// Find the last block of the function.
let last_block_id = find_last_block(root_block_id, &lowered.blocks);

for (block_id, block) in lowered.blocks.iter() {
// The root block should end with Fallthrough and not Goto.
let mut block = mapper.rebuild_block(block);
// Remove the inputs from the root block.
if block_id == root_block_id {
block.inputs = vec![];
}
// The last block should end with Fallthrough and not Goto.
if block_id == last_block_id {
if let FlatBlockEnd::Goto(target, remapping) = block.end {
block.end = FlatBlockEnd::Fallthrough(target, remapping);
}
Expand All @@ -445,6 +458,15 @@ impl<'db> FunctionInlinerRewriter<'db> {
}
}

/// Finds the last block of a function, given its root block.
fn find_last_block(root_block_id: BlockId, blocks: &Blocks<FlatBlock>) -> BlockId {
let mut cur_block_id = root_block_id;
while let FlatBlockEnd::Fallthrough(target_id, _) = blocks[cur_block_id].end {
cur_block_id = target_id;
}
cur_block_id
}

pub fn apply_inlining(
db: &dyn LoweringGroup,
function_id: FunctionWithBodyId,
Expand Down
Loading

0 comments on commit eaa1153

Please sign in to comment.