Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Aug 22, 2019
1 parent 427d8c7 commit 6aa0ba1
Showing 1 changed file with 98 additions and 105 deletions.
203 changes: 98 additions & 105 deletions hail/python/hail/ir/renderer.py
Expand Up @@ -136,26 +136,11 @@ def __call__(self, x: 'Renderable'):
return ''.join(builder)


class BindingSite:
def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'):
self.depth = depth
self.lifted_lets = lifted_lets
self.node = node


class PrintStackFrame:
def __init__(self, binding_site: BindingSite):
self.depth = binding_site.depth
self.lifted_lets = binding_site.lifted_lets
self.visited = {}
self.let_bodies = []


Vars = Dict[str, int]
Context = (Vars, Vars, Vars)


class StackFrame:
class AnalysisStackFrame:
def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR',
new_bindings=(None, None, None)):
# immutable
Expand All @@ -182,7 +167,7 @@ def bind_depth(self) -> int:
bind_depth = self.min_binding_depth
return bind_depth

def make_child_frame(self, depth: int) -> 'StackFrame':
def make_child_frame(self, depth: int) -> 'AnalysisStackFrame':
x = self.node
i = self.child_idx - 1
child = x.children[i]
Expand All @@ -198,8 +183,8 @@ def make_child_frame(self, depth: int) -> 'StackFrame':
new_bindings = (eval_bindings, agg_bindings, scan_bindings)
child_context = x.child_context(i, self.context, depth)

return StackFrame(child_outermost_scope, child_context, child,
new_bindings)
return AnalysisStackFrame(child_outermost_scope, child_context, child,
new_bindings)

def update_parent_free_vars(self, parent_free_vars: Set[str]):
# subtract vars bound by parent from free_vars
Expand All @@ -214,6 +199,45 @@ def update_parent_free_vars(self, parent_free_vars: Set[str]):
parent_free_vars.update(self.free_vars)


class BindingSite:
def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'):
self.depth = depth
self.lifted_lets = lifted_lets
self.node = node


class BindingsStackFrame:
def __init__(self, binding_site: BindingSite):
self.depth = binding_site.depth
self.lifted_lets = binding_site.lifted_lets
self.visited = {}
self.let_bodies = []


class PrintStackFrame:
def __init__(self, x, children, builder, outermost_scope, depth):
self.x = x
self.children = children
self.local_builder = builder
self.outermost_scope = outermost_scope
self.ir_child_num = 0
self.depth = depth
self.i = 0

class PostChildrenLifted(PrintStackFrame):
def __init__(self, x, children, local_builder, outermost_scope, depth, builder, lift_to, name):
super().__init__(x, children, local_builder, outermost_scope, depth)
self.builder = builder
self.lift_to = lift_to
self.name = name

class PostChildren(PrintStackFrame):
def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets):
super().__init__(x, children, local_builder, outermost_scope, depth)
self.builder = builder
self.insert_lets = insert_lets


class CSERenderer(Renderer):
def __init__(self, stop_at_jir=False):
self.stop_at_jir = stop_at_jir
Expand All @@ -234,14 +258,14 @@ def add_jir(self, jir):
return jir_id

@staticmethod
def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int:
def find_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame], outermost_scope: int) -> int:
for i in reversed(range(len(context))):
if id(x) in context[i].visited:
return i
return -1

@staticmethod
def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int:
def lifted_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame]) -> int:
for i in range(len(context)):
if id(x) in context[i].lifted_lets:
return i
Expand All @@ -262,110 +286,71 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int:
# * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing
# any lets to be inserted above y.

class State:
def __init__(self, x, children, builder, outermost_scope, depth):
self.x = x
self.children = children
self.builder = builder
self.outermost_scope = outermost_scope
self.ir_child_num = 0
self.depth = depth
self.i = 0

def pre_add_lets(self, state, binding_sites, context):
x = state.x
insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0
if insert_lets:
state.builder = []
context.append(PrintStackFrame(binding_sites[id(x)]))
head = x.render_head(self)
if head != '':
state.builder.append(head)
return insert_lets

class Kont:
pass

class PostChildrenLifted(Kont):
def __init__(self, state, builder, lift_to, name):
self.state = state
self.builder = builder
self.local_builder = state.builder
self.lift_to = lift_to
self.name = name

class PostChildren(Kont):
def __init__(self, state, builder, insert_lets):
self.state = state
self.builder = builder
self.local_builder = state.builder
self.insert_lets = insert_lets

def loop(self, root):
def build_string(self, root):
root_builder = []
context = []
root_state = self.State(root, root.render_children(self), root_builder, 0, 1)
root_state = PrintStackFrame(root, root.render_children(self), root_builder, 0, 1)
if id(root) in self.memo:
root_state.builder.append(self.memo[id(root)])
root_state.local_builder.append(self.memo[id(root)])
return ''.join(root_builder)
insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0
if insert_lets:
root_state.builder = []
context.append(PrintStackFrame(self.binding_sites[id(root)]))
root_state.local_builder = []
context.append(BindingsStackFrame(self.binding_sites[id(root)]))
head = root.render_head(self)
if head != '':
root_state.builder.append(head)
root_state.local_builder.append(head)

root_frame = lambda x: None
root_frame.state = root_state
stack = [root_frame]
stack = [root_state]

while True:
state = stack[-1].state
state = stack[-1]

if state.i >= len(state.children):
if len(stack) <= 1:
root_state.builder.append(root.render_tail(self))
root_state.local_builder.append(root.render_tail(self))
insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0
if insert_lets:
self.add_lets(context, root_state.builder, root_builder)
self.add_lets(context, root_state.local_builder, root_builder)
return ''.join(root_builder)
k = stack[-1]
if isinstance(k, self.PostChildrenLifted):
if id(k.state.x) in self.memo:
k.local_builder.extend(self.memo[id(k.state.x)])
if isinstance(state, PostChildrenLifted):
if id(state.x) in self.memo:
state.local_builder.extend(self.memo[id(state.x)])
else:
k.local_builder.append(k.state.x.render_tail(self))
k.local_builder.append(' ')
state.local_builder.append(state.x.render_tail(self))
state.local_builder.append(' ')
# let_bodies is built post-order, which guarantees earlier
# lets can't refer to later lets
context[k.lift_to].let_bodies.append(k.local_builder)
k.builder.append(f'(Ref {k.name})')
context[state.lift_to].let_bodies.append(state.local_builder)
state.builder.append(f'(Ref {state.name})')
stack.pop()
stack[-1].state.i += 1
stack[-1].i += 1
continue
else:
assert isinstance(k, self.PostChildren)
k.local_builder.append(k.state.x.render_tail(self))
if k.insert_lets:
self.add_lets(context, k.local_builder, k.builder)
assert isinstance(state, PostChildren)
state.local_builder.append(state.x.render_tail(self))
if state.insert_lets:
self.add_lets(context, state.local_builder, state.builder)
stack.pop()
state = stack[-1].state
state = stack[-1]
if not isinstance(state.x, ir.BaseIR):
state.ir_child_num += 1
state.i += 1
continue

state.builder.append(' ')
state.local_builder.append(' ')
child = state.children[state.i]

new_state = self.State(child, child.render_children(self), state.builder, state.outermost_scope, state.depth)
child_children = child.render_children(self)
child_local_builder = state.local_builder
child_outermost_scope = state.outermost_scope
child_depth = state.depth

if isinstance(state.x, ir.BaseIR):
if state.x.new_block(state.ir_child_num):
new_state.outermost_scope = state.depth
child_outermost_scope = state.depth
if isinstance(child, ir.BaseIR):
new_state.depth += 1
child_depth += 1
lift_to = self.lifted_in_scope(child, context)
else:
lift_to = -1
Expand All @@ -377,44 +362,52 @@ def loop(self, root):
(name, _) = context[lift_to].lifted_lets[id(child)]

if id(child) in context[lift_to].visited:
state.builder.append(f'(Ref {name})')
state.local_builder.append(f'(Ref {name})')
state.i += 1
continue

context[lift_to].visited[id(child)] = child

new_state.builder = [f'(Let {name} ']
child_local_builder = [f'(Let {name} ']

if id(child) in self.memo:
new_state.children = []
stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name))
child_children = []
stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name))
continue

insert_lets = self.pre_add_lets(new_state, self.binding_sites, context)
assert(not insert_lets)
assert(not (id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0))
head = child.render_head(self)
if head != '':
child_local_builder.append(head)

stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name))
stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name))
continue

if isinstance(child, ir.BaseIR):
if id(child) in self.memo:
new_state.builder.extend(self.memo[id(child)])
child_local_builder.extend(self.memo[id(child)])
state.i += 1
continue

insert_lets = self.pre_add_lets(new_state, self.binding_sites, context)
insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0
if insert_lets:
child_local_builder = []
context.append(BindingsStackFrame(self.binding_sites[id(child)]))
head = child.render_head(self)
if head != '':
child_local_builder.append(head)

stack.append(self.PostChildren(new_state, state.builder, insert_lets))
stack.append(PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets))
continue
else:
head = child.render_head(self)
if head != '':
state.builder.append(head)
child_local_builder.append(head)

new_state = PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, False)
new_state.ir_child_num = state.ir_child_num
insert_lets = False

stack.append(self.PostChildren(new_state, state.builder, insert_lets))
stack.append(new_state)
continue

@staticmethod
Expand All @@ -429,7 +422,7 @@ def add_lets(context, local_builder, builder):
builder.append(')')

def compute_new_bindings(self, root: 'ir.BaseIR'):
root_frame = StackFrame(0, ({}, {}, {}), root)
root_frame = AnalysisStackFrame(0, ({}, {}, {}), root)
stack = [root_frame]
binding_sites = {}

Expand Down Expand Up @@ -522,4 +515,4 @@ def compute_new_bindings(self, root: 'ir.BaseIR'):
def __call__(self, root: 'ir.BaseIR') -> str:
self.binding_sites = self.compute_new_bindings(root)

return self.loop(root)
return self.build_string(root)

0 comments on commit 6aa0ba1

Please sign in to comment.