From 6aa0ba14c3a7c35794f10d900e62d176992a63a6 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 22 Aug 2019 10:10:54 -0400 Subject: [PATCH] refactoring --- hail/python/hail/ir/renderer.py | 203 +++++++++++++++----------------- 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 83e87f3416e..3677c5dbfa8 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -136,26 +136,11 @@ def __call__(self, x: 'Renderable'): return ''.join(builder) -class BindingSite: - def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): - self.depth = depth - self.lifted_lets = lifted_lets - self.node = node - - -class PrintStackFrame: - def __init__(self, binding_site: BindingSite): - self.depth = binding_site.depth - self.lifted_lets = binding_site.lifted_lets - self.visited = {} - self.let_bodies = [] - - Vars = Dict[str, int] Context = (Vars, Vars, Vars) -class StackFrame: +class AnalysisStackFrame: def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', new_bindings=(None, None, None)): # immutable @@ -182,7 +167,7 @@ def bind_depth(self) -> int: bind_depth = self.min_binding_depth return bind_depth - def make_child_frame(self, depth: int) -> 'StackFrame': + def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': x = self.node i = self.child_idx - 1 child = x.children[i] @@ -198,8 +183,8 @@ def make_child_frame(self, depth: int) -> 'StackFrame': new_bindings = (eval_bindings, agg_bindings, scan_bindings) child_context = x.child_context(i, self.context, depth) - return StackFrame(child_outermost_scope, child_context, child, - new_bindings) + return AnalysisStackFrame(child_outermost_scope, child_context, child, + new_bindings) def update_parent_free_vars(self, parent_free_vars: Set[str]): # subtract vars bound by parent from free_vars @@ -214,6 +199,45 @@ def update_parent_free_vars(self, parent_free_vars: Set[str]): parent_free_vars.update(self.free_vars) +class BindingSite: + def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): + self.depth = depth + self.lifted_lets = lifted_lets + self.node = node + + +class BindingsStackFrame: + def __init__(self, binding_site: BindingSite): + self.depth = binding_site.depth + self.lifted_lets = binding_site.lifted_lets + self.visited = {} + self.let_bodies = [] + + +class PrintStackFrame: + def __init__(self, x, children, builder, outermost_scope, depth): + self.x = x + self.children = children + self.local_builder = builder + self.outermost_scope = outermost_scope + self.ir_child_num = 0 + self.depth = depth + self.i = 0 + +class PostChildrenLifted(PrintStackFrame): + def __init__(self, x, children, local_builder, outermost_scope, depth, builder, lift_to, name): + super().__init__(x, children, local_builder, outermost_scope, depth) + self.builder = builder + self.lift_to = lift_to + self.name = name + +class PostChildren(PrintStackFrame): + def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets): + super().__init__(x, children, local_builder, outermost_scope, depth) + self.builder = builder + self.insert_lets = insert_lets + + class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir @@ -234,14 +258,14 @@ def add_jir(self, jir): return jir_id @staticmethod - def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int: + def find_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame], outermost_scope: int) -> int: for i in reversed(range(len(context))): if id(x) in context[i].visited: return i return -1 @staticmethod - def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: + def lifted_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame]) -> int: for i in range(len(context)): if id(x) in context[i].lifted_lets: return i @@ -262,110 +286,71 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - class State: - def __init__(self, x, children, builder, outermost_scope, depth): - self.x = x - self.children = children - self.builder = builder - self.outermost_scope = outermost_scope - self.ir_child_num = 0 - self.depth = depth - self.i = 0 - - def pre_add_lets(self, state, binding_sites, context): - x = state.x - insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 - if insert_lets: - state.builder = [] - context.append(PrintStackFrame(binding_sites[id(x)])) - head = x.render_head(self) - if head != '': - state.builder.append(head) - return insert_lets - - class Kont: - pass - - class PostChildrenLifted(Kont): - def __init__(self, state, builder, lift_to, name): - self.state = state - self.builder = builder - self.local_builder = state.builder - self.lift_to = lift_to - self.name = name - - class PostChildren(Kont): - def __init__(self, state, builder, insert_lets): - self.state = state - self.builder = builder - self.local_builder = state.builder - self.insert_lets = insert_lets - - def loop(self, root): + def build_string(self, root): root_builder = [] context = [] - root_state = self.State(root, root.render_children(self), root_builder, 0, 1) + root_state = PrintStackFrame(root, root.render_children(self), root_builder, 0, 1) if id(root) in self.memo: - root_state.builder.append(self.memo[id(root)]) + root_state.local_builder.append(self.memo[id(root)]) return ''.join(root_builder) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: - root_state.builder = [] - context.append(PrintStackFrame(self.binding_sites[id(root)])) + root_state.local_builder = [] + context.append(BindingsStackFrame(self.binding_sites[id(root)])) head = root.render_head(self) if head != '': - root_state.builder.append(head) + root_state.local_builder.append(head) - root_frame = lambda x: None - root_frame.state = root_state - stack = [root_frame] + stack = [root_state] while True: - state = stack[-1].state + state = stack[-1] if state.i >= len(state.children): if len(stack) <= 1: - root_state.builder.append(root.render_tail(self)) + root_state.local_builder.append(root.render_tail(self)) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: - self.add_lets(context, root_state.builder, root_builder) + self.add_lets(context, root_state.local_builder, root_builder) return ''.join(root_builder) - k = stack[-1] - if isinstance(k, self.PostChildrenLifted): - if id(k.state.x) in self.memo: - k.local_builder.extend(self.memo[id(k.state.x)]) + if isinstance(state, PostChildrenLifted): + if id(state.x) in self.memo: + state.local_builder.extend(self.memo[id(state.x)]) else: - k.local_builder.append(k.state.x.render_tail(self)) - k.local_builder.append(' ') + state.local_builder.append(state.x.render_tail(self)) + state.local_builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets - context[k.lift_to].let_bodies.append(k.local_builder) - k.builder.append(f'(Ref {k.name})') + context[state.lift_to].let_bodies.append(state.local_builder) + state.builder.append(f'(Ref {state.name})') stack.pop() - stack[-1].state.i += 1 + stack[-1].i += 1 continue else: - assert isinstance(k, self.PostChildren) - k.local_builder.append(k.state.x.render_tail(self)) - if k.insert_lets: - self.add_lets(context, k.local_builder, k.builder) + assert isinstance(state, PostChildren) + state.local_builder.append(state.x.render_tail(self)) + if state.insert_lets: + self.add_lets(context, state.local_builder, state.builder) stack.pop() - state = stack[-1].state + state = stack[-1] if not isinstance(state.x, ir.BaseIR): state.ir_child_num += 1 state.i += 1 continue - state.builder.append(' ') + state.local_builder.append(' ') child = state.children[state.i] - new_state = self.State(child, child.render_children(self), state.builder, state.outermost_scope, state.depth) + child_children = child.render_children(self) + child_local_builder = state.local_builder + child_outermost_scope = state.outermost_scope + child_depth = state.depth if isinstance(state.x, ir.BaseIR): if state.x.new_block(state.ir_child_num): - new_state.outermost_scope = state.depth + child_outermost_scope = state.depth if isinstance(child, ir.BaseIR): - new_state.depth += 1 + child_depth += 1 lift_to = self.lifted_in_scope(child, context) else: lift_to = -1 @@ -377,44 +362,52 @@ def loop(self, root): (name, _) = context[lift_to].lifted_lets[id(child)] if id(child) in context[lift_to].visited: - state.builder.append(f'(Ref {name})') + state.local_builder.append(f'(Ref {name})') state.i += 1 continue context[lift_to].visited[id(child)] = child - new_state.builder = [f'(Let {name} '] + child_local_builder = [f'(Let {name} '] if id(child) in self.memo: - new_state.children = [] - stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) + child_children = [] + stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) continue - insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) - assert(not insert_lets) + assert(not (id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0)) + head = child.render_head(self) + if head != '': + child_local_builder.append(head) - stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) + stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) continue if isinstance(child, ir.BaseIR): if id(child) in self.memo: - new_state.builder.extend(self.memo[id(child)]) + child_local_builder.extend(self.memo[id(child)]) state.i += 1 continue - insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) + insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 + if insert_lets: + child_local_builder = [] + context.append(BindingsStackFrame(self.binding_sites[id(child)])) + head = child.render_head(self) + if head != '': + child_local_builder.append(head) - stack.append(self.PostChildren(new_state, state.builder, insert_lets)) + stack.append(PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets)) continue else: head = child.render_head(self) if head != '': - state.builder.append(head) + child_local_builder.append(head) + new_state = PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, False) new_state.ir_child_num = state.ir_child_num - insert_lets = False - stack.append(self.PostChildren(new_state, state.builder, insert_lets)) + stack.append(new_state) continue @staticmethod @@ -429,7 +422,7 @@ def add_lets(context, local_builder, builder): builder.append(')') def compute_new_bindings(self, root: 'ir.BaseIR'): - root_frame = StackFrame(0, ({}, {}, {}), root) + root_frame = AnalysisStackFrame(0, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -522,4 +515,4 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) - return self.loop(root) + return self.build_string(root)