Skip to content

Commit

Permalink
defunctionalize
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Aug 9, 2019
1 parent 3db71db commit 55c32ba
Showing 1 changed file with 90 additions and 121 deletions.
211 changes: 90 additions & 121 deletions hail/python/hail/ir/renderer.py
Expand Up @@ -308,111 +308,125 @@ def pre_add_lets(self, binding_sites):
self.builder.append(head)
return insert_lets

@staticmethod
def post_lift_visit(state, let_body, lift_to, name):
let_body.append(' ')
# let_bodies is built post-order, which guarantees earlier
# lets can't refer to later lets
state.context[lift_to].let_bodies.append(let_body)
state.builder.append(f'(Ref {name})')
state.i += 1

def loop(self, state, k):
if state.i >= len(state.children):
return k()
return self.apply(k)
state.builder.append(' ')
child = state.children[state.i]
lift_to = self.lifted_in_scope(child, state.context)
child_outermost_scope = state.outermost_scope
if state.x.new_block(state.ir_child_num):
child_outermost_scope = state.depth

new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth)

if isinstance(state.x, ir.BaseIR):
if state.x.new_block(state.ir_child_num):
new_state.outermost_scope = state.depth
if isinstance(child, ir.BaseIR):
new_state.depth += 1
lift_to = self.lifted_in_scope(child, state.context)
else:
lift_to = -1

if (lift_to >= 0 and
state.context[lift_to] and
state.context[lift_to].depth >= state.outermost_scope):
state.ir_child_num += 1
(name, _) = state.context[lift_to].lifted_lets[id(child)]
if id(child) not in state.context[lift_to].visited:
state.context[lift_to].visited[id(child)] = child
let_body = [f'(Let {name} ']

new_state = self.State(child, child.render_children(self), let_body, state.context, child_outermost_scope, state.depth + 1)
x = new_state.x
builder = state.builder
if id(x) in self.memo:
new_state.builder.append(self.memo[id(x)])
self.post_lift_visit(state, let_body, lift_to, name)
return self.loop(state, k)

insert_lets = new_state.pre_add_lets(self.binding_sites)

def post_children():
new_state.builder.append(new_state.x.render_tail(self))
if insert_lets:
self.add_lets(new_state.context, new_state.builder, builder)
self.post_lift_visit(state, let_body, lift_to, name)
return self.loop(state, k)

return self.loop(new_state, post_children)
else:

if id(child) in state.context[lift_to].visited:
state.builder.append(f'(Ref {name})')
state.i += 1
return self.loop(state, k)
else:
if isinstance(child, ir.BaseIR):
new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1)

if id(child) in self.memo:
new_state.builder.append(self.memo[id(child)])
state.i += 1
return self.loop(state, k)
state.context[lift_to].visited[id(child)] = child

insert_lets = new_state.pre_add_lets(self.binding_sites)
new_state.builder = [f'(Let {name} ']

def post_children():
new_state.builder.append(child.render_tail(self))
if insert_lets:
self.add_lets(new_state.context, new_state.builder, state.builder)
state.i += 1
return self.loop(state, k)
if id(child) in self.memo:
pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name, k)
return self.apply(pcl)

return self.loop(new_state, post_children)
else:
head = child.render_head(self)
if head != '':
state.builder.append(head)
insert_lets = new_state.pre_add_lets(self.binding_sites)
assert(not insert_lets)

pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name, k)

return self.loop(new_state, pcl)

if isinstance(child, ir.BaseIR):
if id(child) in self.memo:
new_state.builder.append(*self.memo[id(child)])
state.i += 1
return self.loop(state, k)

insert_lets = new_state.pre_add_lets(self.binding_sites)

pc = self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets, k)

return self.loop(new_state, pc)
else:
head = child.render_head(self)
if head != '':
state.builder.append(head)

new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1)
new_state.ir_child_num = state.ir_child_num
new_state.ir_child_num = state.ir_child_num
insert_lets = False

def post_children_renderable():
new_state.builder.append(new_state.x.render_tail(self))
state.i += 1
return self.loop(state, k)
pc = self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets, k)

return self.loop2(new_state, post_children_renderable)
return self.loop(new_state, pc)

class Kont:
def __apply__(self):
pass
pass

class PostChildrenLifted(Kont):
def __init__(self, state, builder, local_builder, tail, lift_to, name, k):
self.state = state
self.builder = builder
self.k = k
self.local_builder = local_builder
self.tail = tail
self.lift_to = lift_to
self.name = name

class PostChildren(Kont):
def __init__(self, node, state, insert_lets, builder, local_builder, k):
def __init__(self, state, builder, local_builder, tail, insert_lets, k):
self.state = state
self.insert_lets = insert_lets
self.builder = builder
self.k = k
self.local_builder = local_builder
self.node = node
self.context = state.context
self.tail = tail
self.insert_lets = insert_lets

class PostRoot(Kont):
def __init__(self, state, builder, local_builder, tail, insert_lets):
self.state = state
self.builder = builder
self.local_builder = local_builder
self.tail = tail
self.insert_lets = insert_lets

def apply(self, k: Kont):
if isinstance(k, self.PostChildren):
k.local_builder.append(k.node.render_tail(self))
if isinstance(k, self.PostChildrenLifted):
k.local_builder.extend(k.tail)
k.local_builder.append(' ')
# let_bodies is built post-order, which guarantees earlier
# lets can't refer to later lets
k.state.context[k.lift_to].let_bodies.append(k.local_builder)
k.builder.append(f'(Ref {k.name})')
k.state.i += 1
return self.loop(k.state, k.k)
elif isinstance(k, self.PostChildren):
k.local_builder.extend(k.tail)
if k.insert_lets:
CSERenderer.add_lets(k.context, k.local_builder, k.builder)
k.state.ir_child_num += 1
self.add_lets(k.state.context, k.local_builder, k.builder)
if not isinstance(k.state.x, ir.BaseIR):
k.state.ir_child_num += 1
k.state.i += 1
return k.loop2(k.state, k.k)
return self.loop(k.state, k.k)
else:
k.local_builder.append(k.tail)
if k.insert_lets:
self.add_lets(k.state.context, k.local_builder, k.builder)
return ''.join(k.builder)

@staticmethod
def add_lets(context, local_builder, builder):
Expand All @@ -425,47 +439,6 @@ def add_lets(context, local_builder, builder):
for _ in range(num_lets):
builder.append(')')

def loop2(self, state, k):
if state.i >= len(state.children):
return k()
state.builder.append(' ')
child = state.children[state.i]
if isinstance(child, ir.BaseIR):
new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth)

x = new_state.x
builder = new_state.builder
if id(x) in self.memo:
new_state.builder.append(self.memo[id(x)])
state.ir_child_num += 1
state.i += 1
return self.loop2(state, k)

insert_lets = new_state.pre_add_lets(self.binding_sites)

def post_children():
new_state.builder.append(new_state.x.render_tail(self))
if insert_lets:
self.add_lets(new_state, builder)
state.ir_child_num += 1
state.i += 1
return self.loop2(state, k)

return self.loop(new_state, post_children)
else:
head = child.render_head(self)
if head != '':
state.builder.append(head)

new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth)
new_state.ir_child_num = state.ir_child_num

def post_children_renderable():
new_state.builder.append(new_state.x.render_tail(self))
state.i += 1
return self.loop2(state, k)
return self.loop2(new_state, post_children_renderable)

def compute_new_bindings(self, root: 'ir.BaseIR'):
# force computation of all types
root.typ
Expand Down Expand Up @@ -561,10 +534,6 @@ def __call__(self, root: 'ir.BaseIR') -> str:
if head != '':
state.builder.append(head)

def post_children():
state.builder.append(root.render_tail(self))
if insert_lets:
self.add_lets(state.context, state.builder, builder)
return ''.join(builder)
pc = self.PostRoot(state, builder, state.builder, root.render_tail(self), insert_lets)

return self.loop(state, post_children)
return self.loop(state, pc)

0 comments on commit 55c32ba

Please sign in to comment.