Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor combinational to use SSA, initial implementation of sequential #354

Merged
merged 9 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion magma/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,5 @@ def wrapped(*args, **kwargs):
return result
return wrapped

from magma.circuit_def import combinational
from magma.syntax.combinational import combinational
from magma.syntax.sequential import sequential
2 changes: 1 addition & 1 deletion magma/ssa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from magma.ssa.ssa import ssa
from magma.ssa.ssa import ssa, convert_tree_to_ssa
93 changes: 81 additions & 12 deletions magma/ssa/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import types
from collections import defaultdict
import inspect
import astor


def flatten(l : list):
Expand All @@ -20,8 +21,11 @@ def flatten(l : list):
class SSAVisitor(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.last_name = defaultdict(lambda : "")
self.var_counter = defaultdict(lambda : -1)
self.last_name = defaultdict(lambda: "")
self.var_counter = defaultdict(lambda: -1)
self.args = set()
self.cond_stack = []
self.return_values = []

def write_name(self, var):
self.var_counter[var] += 1
Expand All @@ -34,13 +38,18 @@ def visit_Assign(self, node):

def visit_FunctionDef(self, node):
for a in node.args.args:
self.args.add(a.arg)
self.last_name[a.arg] = f"{a.arg}_0"
a.arg = f"{a.arg}_0"
node.body = flatten([self.visit(s) for s in node.body])
return node

def visit_Name(self, node):
if node.id not in self.last_name:
self.last_name[node.id] = f"{node.id}_0"
if node.id not in self.args and isinstance(node.ctx, ast.Store):
self.last_name[node.id] = f"{node.id}_0"
else:
return node
if isinstance(node.ctx, ast.Store):
self.write_name(node.id)
node.id = f"{self.last_name[node.id]}"
Expand All @@ -50,17 +59,23 @@ def visit_If(self, node):
false_name = dict(self.last_name)

test = self.visit(node.test)
self.cond_stack.append(test)
result = flatten([self.visit(s) for s in node.body])
true_name = dict(self.last_name)

if node.orelse:
self.last_name = false_name
self.cond_stack[-1] = ast.UnaryOp(ast.Invert(),
self.cond_stack[-1])
result += flatten([self.visit(s) for s in node.orelse])
false_name = dict(self.last_name)

self.cond_stack.pop()

self.last_name = {**true_name, **false_name}
for var in self.last_name.keys():
if var in true_name and var in false_name and true_name[var] != false_name[var]:
if var in true_name and var in false_name and \
true_name[var] != false_name[var]:
phi_args = [
ast.Name(false_name[var], ast.Load()),
ast.Name(true_name[var], ast.Load())
Expand All @@ -75,15 +90,69 @@ def visit_If(self, node):
], [])))
return result

def visit_Return(self, node):
self.return_values.append(self.cond_stack)
node.value = self.visit(node.value)
return node


class TransformReturn(ast.NodeTransformer):
def __init__(self):
self.counter = -1

def visit_Return(self, node):
self.counter += 1
name = f"__magma_ssa_return_value_{self.counter}"
return ast.Assign([ast.Name(name, ast.Store())], node.value)


class MoveReturn(ast.NodeTransformer):
def visit_Return(self, node):
return ast.Assign(
[ast.Name(f"__magma_ssa_return_value", ast.Store())],
node.value
)


def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list,
defn_env)
# tree = MoveReturn().visit(tree)
# tree.body.append(
# ast.Return(ast.Name("__magma_ssa_return_value", ast.Load())))
ssa_visitor = SSAVisitor()
tree = ssa_visitor.visit(tree)
return_transformer = TransformReturn()
tree = return_transformer.visit(tree)
num_return_values = len(ssa_visitor.return_values)
for i in reversed(range(num_return_values)):
conds = ssa_visitor.return_values[i]
name = f"__magma_ssa_return_value_{num_return_values - i - 1}"
if i == num_return_values or not conds:
if isinstance(tree.returns, ast.Tuple):
tree.body.append(ast.Assign(
[ast.Tuple([ast.Name(f"O{i}", ast.Store())
for i in range(len(tree.returns.elts))], ast.Store())],
ast.Name(name, ast.Load())
))
else:
tree.body.append(ast.Assign([ast.Name("O", ast.Load)],
ast.Name(name, ast.Load())))
else:
prev_name = ssa_visitor.return_values[i + 1]
cond = conds[-1]
for c in conds[:-1]:
c = ast.BinOp(cond, ast.And(), c)
tree.body.append(ast.Call(ast.Name("phi", ast.Load()), [
ast.List([name, prev_name], ast.Load()),
cond
], []))
return tree


@ast_utils.inspect_enclosing_env
def ssa(defn_env : dict, fn : types.FunctionType):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
def ssa(defn_env: dict, fn: types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env)
tree = SSAVisitor().visit(tree)
tree = convert_tree_to_ssa(tree, defn_env)
tree.body.append(ast.Return(ast.Name("O", ast.Load())))
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)
1 change: 1 addition & 0 deletions magma/syntax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from magma.syntax.combinational import combinational
32 changes: 17 additions & 15 deletions magma/circuit_def.py → magma/syntax/combinational.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import textwrap
from collections import OrderedDict
from magma.logging import debug, warning, error
from .backend.util import make_relative
from magma.backend.util import make_relative
import astor
import os
import traceback
from .debug import debug_info
import magma.ast_utils as ast_utils
import types
from magma.debug import debug_info
from magma.ssa import convert_tree_to_ssa


class CircuitDefinitionSyntaxError(Exception):
Expand Down Expand Up @@ -72,7 +73,7 @@ def visit_If(self, node):
" taking the last value (first value is ignored)", self.filename, node.lineno + self.starting_line, self.lines[node.lineno])
orelse_seen.add(key)
seen[key].value = ast.Call(
ast.Name("mux", ast.Load()),
ast.Name("phi", ast.Load()),
[ast.List([stmt.value, seen[key].value],
ast.Load()), node.test],
[])
Expand All @@ -88,7 +89,7 @@ def visit_IfExp(self, node):
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)
return ast.Call(
ast.Name("mux", ast.Load()),
ast.Name("phi", ast.Load()),
[ast.List([node.orelse, node.body],
ast.Load()), node.test],
[])
Expand Down Expand Up @@ -168,30 +169,31 @@ def visit_Name(self, node):
ast.Load())
return node

def visit_Return(self, node):
node.value = self.visit(node.value)
if isinstance(node.value, ast.Tuple):
return ast.Assign(
[ast.Tuple([ast.Name(f"O{i}", ast.Store())
for i in range(len(node.value.elts))], ast.Store())],
node.value
)
return ast.Assign([ast.Name("O", ast.Store())], node.value)
# def visit_Return(self, node):
# node.value = self.visit(node.value)
# if isinstance(node.value, ast.Tuple):
# return ast.Assign(
# [ast.Tuple([ast.Name(f"O{i}", ast.Store())
# for i in range(len(node.value.elts))], ast.Store())],
# node.value
# )
# return ast.Assign([ast.Name("O", ast.Store())], node.value)


@ast_utils.inspect_enclosing_env
def combinational(defn_env : dict, fn : types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree = convert_tree_to_ssa(tree, defn_env)
tree = FunctionToCircuitDefTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
tree = IfTransformer(inspect.getsourcefile(fn), inspect.getsourcelines(fn)).visit(tree)
tree = ast.fix_missing_locations(tree)
tree.decorator_list = ast_utils.filter_decorator(
combinational, tree.decorator_list, defn_env)
if "mux" not in defn_env:
if "phi" not in defn_env:
tree = ast.Module([
ast.parse("import magma as m").body[0],
ast.parse("from mantle import mux").body[0],
ast.parse("from mantle import mux as phi").body[0],
tree
])
source = "\n"
Expand Down