Skip to content

Commit

Permalink
feat[venom]: optimize get_basic_block() (#4002)
Browse files Browse the repository at this point in the history
`get_basic_block()` is a hotspot in venom (up to 35% of total
compilation time!). this optimizes `get_basic_block()`, on a large
contract near the 24kb limit this reduces time spent in venom from 3s to
1s (total time from 6s to 4s).

note on the same contract, time spent in the IRnode optimizer pipeline
is 2s - so time in venom is now smaller than time in legacy optimizer(!)

notes:
- refactor to use dict for basic_blocks
- clean up basic blocks API 
    hide basic blocks behind `get_basic_blocks()` iterator and
    `num_basic_blocks`.
  • Loading branch information
charles-cooper committed May 8, 2024
1 parent 75c75c5 commit 93147be
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 89 deletions.
6 changes: 3 additions & 3 deletions vyper/venom/analysis/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class CFGAnalysis(IRAnalysis):

def analyze(self) -> None:
fn = self.function
for bb in fn.basic_blocks:
for bb in fn.get_basic_blocks():
bb.cfg_in = OrderedSet()
bb.cfg_out = OrderedSet()
bb.out_vars = OrderedSet()

for bb in fn.basic_blocks:
for bb in fn.get_basic_blocks():
assert len(bb.instructions) > 0, "Basic block should not be empty"
last_inst = bb.instructions[-1]
assert (
Expand All @@ -29,7 +29,7 @@ def analyze(self) -> None:
fn.get_basic_block(op.value).add_cfg_in(bb)

# Fill in the "out" set for each basic block
for bb in fn.basic_blocks:
for bb in fn.get_basic_blocks():
for in_bb in bb.cfg_in:
in_bb.add_cfg_out(bb)

Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def analyze(self):
# %16 = iszero %15
# dfg_outputs of %15 is (%15 = add %13 %14)
# dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...]
for bb in self.function.basic_blocks:
for bb in self.function.get_basic_blocks():
for inst in bb.instructions:
operands = inst.get_inputs()
res = inst.get_outputs()
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dominators.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def as_graph(self) -> str:
Generate a graphviz representation of the dominator tree.
"""
lines = ["digraph dominator_tree {"]
for bb in self.fn.basic_blocks:
for bb in self.fn.get_basic_blocks():
if bb == self.entry_block:
continue
idom = self.immediate_dominator(bb)
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/analysis/dup_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class DupRequirementsAnalysis(IRAnalysis):
def analyze(self):
for bb in self.function.basic_blocks:
for bb in self.function.get_basic_blocks():
last_liveness = bb.out_vars
for inst in reversed(bb.instructions):
inst.dup_requirements = OrderedSet()
Expand Down
4 changes: 2 additions & 2 deletions vyper/venom/analysis/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def analyze(self):
self._reset_liveness()
while True:
changed = False
for bb in self.function.basic_blocks:
for bb in self.function.get_basic_blocks():
changed |= self._calculate_out_vars(bb)
changed |= self._calculate_liveness(bb)

if not changed:
break

def _reset_liveness(self) -> None:
for bb in self.function.basic_blocks:
for bb in self.function.get_basic_blocks():
bb.out_vars = OrderedSet()
for inst in bb.instructions:
inst.liveness = OrderedSet()
Expand Down
97 changes: 40 additions & 57 deletions vyper/venom/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,91 +13,74 @@ class IRFunction:
name: IRLabel # symbol name
ctx: "IRContext" # type: ignore # noqa: F821
args: list
basic_blocks: list[IRBasicBlock]
last_label: int
last_variable: int
_basic_block_dict: dict[str, IRBasicBlock]

# Used during code generation
_ast_source_stack: list[IRnode]
_error_msg_stack: list[str]
_bb_index: dict[str, int]

def __init__(self, name: IRLabel, ctx: "IRContext" = None) -> None: # type: ignore # noqa: F821
self.ctx = ctx
self.name = name
self.args = []
self.basic_blocks = []
self._basic_block_dict = {}

self.last_variable = 0

self._ast_source_stack = []
self._error_msg_stack = []
self._bb_index = {}

self.append_basic_block(IRBasicBlock(name, self))

@property
def entry(self) -> IRBasicBlock:
return self.basic_blocks[0]
return next(self.get_basic_blocks())

def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock:
def append_basic_block(self, bb: IRBasicBlock):
"""
Append basic block to function.
"""
assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'"
self.basic_blocks.append(bb)

return self.basic_blocks[-1]

def _get_basicblock_index(self, label: str):
# perf: keep an "index" of labels to block indices to
# perform fast lookup.
# TODO: maybe better just to throw basic blocks in an ordered
# dict of some kind.
ix = self._bb_index.get(label, -1)
if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label:
return ix
# do a reindex
self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks))
# sanity check - no duplicate labels
assert len(self._bb_index) == len(
self.basic_blocks
), f"Duplicate labels in function '{self.name}' {self._bb_index} {self.basic_blocks}"
return self._bb_index[label]
assert isinstance(bb, IRBasicBlock), bb
assert bb.label.name not in self._basic_block_dict
self._basic_block_dict[bb.label.name] = bb

def remove_basic_block(self, bb: IRBasicBlock):
assert isinstance(bb, IRBasicBlock), bb
del self._basic_block_dict[bb.label.name]

def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock:
"""
Get basic block by label.
If label is None, return the last basic block.
"""
if label is None:
return self.basic_blocks[-1]
ix = self._get_basicblock_index(label)
return self.basic_blocks[ix]
return next(reversed(self._basic_block_dict.values()))

return self._basic_block_dict[label]

def clear_basic_blocks(self):
self._basic_block_dict.clear()

def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock:
def get_basic_blocks(self) -> Iterator[IRBasicBlock]:
"""
Get basic block after label.
Get an iterator over this function's basic blocks
"""
ix = self._get_basicblock_index(label.value)
if 0 <= ix < len(self.basic_blocks) - 1:
return self.basic_blocks[ix + 1]
raise AssertionError(f"Basic block after '{label}' not found")
return iter(self._basic_block_dict.values())

@property
def num_basic_blocks(self) -> int:
return len(self._basic_block_dict)

def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]:
"""
Get basic blocks that are terminal.
"""
for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
if bb.is_terminal:
yield bb

def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]:
"""
Get basic blocks that point to the given basic block
"""
return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in]

def get_next_variable(self) -> IRVariable:
self.last_variable += 1
return IRVariable(f"%{self.last_variable}")
Expand All @@ -109,15 +92,14 @@ def remove_unreachable_blocks(self) -> int:
self._compute_reachability()

removed = []
new_basic_blocks = []

# Remove unreachable basic blocks
for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
if not bb.is_reachable:
removed.append(bb)
else:
new_basic_blocks.append(bb)
self.basic_blocks = new_basic_blocks

for bb in removed:
self.remove_basic_block(bb)

# Remove phi instructions that reference removed basic blocks
for bb in removed:
Expand All @@ -142,7 +124,7 @@ def _compute_reachability(self) -> None:
"""
Compute reachability of basic blocks.
"""
for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
bb.reachable = OrderedSet()
bb.is_reachable = False

Expand Down Expand Up @@ -172,7 +154,7 @@ def normalized(self) -> bool:
Having a normalized CFG makes calculation of stack layout easier when
emitting assembly.
"""
for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
# Ignore if there are no multiple predecessors
if len(bb.cfg_in) <= 1:
continue
Expand Down Expand Up @@ -211,22 +193,23 @@ def chain_basic_blocks(self) -> None:
Otherwise, append a stop instruction. This is necessary for the IR to be valid, and is
done after the IR is generated.
"""
for i, bb in enumerate(self.basic_blocks):
bbs = list(self.get_basic_blocks())
for i, bb in enumerate(bbs):
if not bb.is_terminated:
if len(self.basic_blocks) - 1 > i:
if i < len(bbs) - 1:
# TODO: revisit this. When contructor calls internal functions they
# are linked to the last ctor block. Should separate them before this
# so we don't have to handle this here
if self.basic_blocks[i + 1].label.value.startswith("internal"):
if bbs[i + 1].label.value.startswith("internal"):
bb.append_instruction("stop")
else:
bb.append_instruction("jmp", self.basic_blocks[i + 1].label)
bb.append_instruction("jmp", bbs[i + 1].label)
else:
bb.append_instruction("exit")

def copy(self):
new = IRFunction(self.name)
new.basic_blocks = self.basic_blocks.copy()
new._basic_block_dict = self._basic_block_dict.copy()
new.last_label = self.last_label
new.last_variable = self.last_variable
return new
Expand All @@ -246,11 +229,11 @@ def _make_label(bb):

ret = "digraph G {\n"

for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
for out_bb in bb.cfg_out:
ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n'

for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
ret += f' "{bb.label.value}" [shape=plaintext, '
ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n'

Expand All @@ -259,6 +242,6 @@ def _make_label(bb):

def __repr__(self) -> str:
str = f"IRFunction: {self.name}\n"
for bb in self.basic_blocks:
for bb in self.get_basic_blocks():
str += f"{bb}\n"
return str.strip()
7 changes: 3 additions & 4 deletions vyper/venom/ir_node_to_venom.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ def _append_jmp(fn: IRFunction, label: IRLabel) -> None:
bb.append_instruction("jmp", label)


def _new_block(fn: IRFunction) -> IRBasicBlock:
def _new_block(fn: IRFunction) -> None:
bb = IRBasicBlock(fn.ctx.get_next_label(), fn)
bb = fn.append_basic_block(bb)
return bb
fn.append_basic_block(bb)


def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0):
Expand Down Expand Up @@ -328,7 +327,7 @@ def _convert_ir_bb(fn, ir, symbols):

# exit bb
exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn)
exit_bb = fn.append_basic_block(exit_bb)
fn.append_basic_block(exit_bb)

if_ret = fn.get_next_variable()
if then_ret_val is not None and else_ret_val is not None:
Expand Down
4 changes: 2 additions & 2 deletions vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def run_pass(self) -> None:
self.fence_id = 0
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()

basic_blocks = self.function.basic_blocks
basic_blocks = list(self.function.get_basic_blocks())

self.function.basic_blocks = []
self.function.clear_basic_blocks()
for bb in basic_blocks:
self._process_basic_block(bb)
4 changes: 2 additions & 2 deletions vyper/venom/passes/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _run_pass(self) -> int:
self.analyses_cache.request_analysis(CFGAnalysis)

# Split blocks that need splitting
for bb in fn.basic_blocks:
for bb in list(fn.get_basic_blocks()):
if len(bb.cfg_in) > 1:
self._split_basic_block(bb)

Expand All @@ -71,7 +71,7 @@ def _run_pass(self) -> int:

def run_pass(self):
fn = self.function
for _ in range(len(fn.basic_blocks) * 2):
for _ in range(fn.num_basic_blocks * 2):
if self._run_pass() == 0:
break
else:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/remove_unused_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def run_pass(self):

self.analyses_cache.request_analysis(LivenessAnalysis)

for bb in self.function.basic_blocks:
for bb in self.function.get_basic_blocks():
for i, inst in enumerate(bb.instructions[:-1]):
if inst.volatile:
continue
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _calculate_sccp(self, entry: IRBasicBlock):
and the work list. The `_propagate_constants()` method is responsible
for updating the IR with the constant values.
"""
self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.basic_blocks}
self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.get_basic_blocks()}

dummy = IRBasicBlock(IRLabel("__dummy_start"), self.fn)
self.work_list.append(FlowWorkItem(dummy, entry))
Expand Down

0 comments on commit 93147be

Please sign in to comment.