Skip to content

Commit

Permalink
Merge pull request #354 from phanrahan/sequential
Browse files Browse the repository at this point in the history
Refactor combinational to use SSA, initial implementation of sequential
  • Loading branch information
leonardt committed Feb 21, 2019
2 parents 5067d72 + f5e8fea commit f366ca2
Show file tree
Hide file tree
Showing 51 changed files with 914 additions and 326 deletions.
3 changes: 2 additions & 1 deletion magma/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,5 @@ def wrapped(*args, **kwargs):
return result
return wrapped

from magma.circuit_def import combinational
from magma.syntax.combinational import combinational
from magma.syntax.sequential import sequential
2 changes: 1 addition & 1 deletion magma/ssa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from magma.ssa.ssa import ssa
from magma.ssa.ssa import ssa, convert_tree_to_ssa
93 changes: 81 additions & 12 deletions magma/ssa/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import types
from collections import defaultdict
import inspect
import astor


def flatten(l : list):
Expand All @@ -20,8 +21,11 @@ def flatten(l : list):
class SSAVisitor(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.last_name = defaultdict(lambda : "")
self.var_counter = defaultdict(lambda : -1)
self.last_name = defaultdict(lambda: "")
self.var_counter = defaultdict(lambda: -1)
self.args = set()
self.cond_stack = []
self.return_values = []

def write_name(self, var):
self.var_counter[var] += 1
Expand All @@ -34,13 +38,18 @@ def visit_Assign(self, node):

def visit_FunctionDef(self, node):
for a in node.args.args:
self.args.add(a.arg)
self.last_name[a.arg] = f"{a.arg}_0"
a.arg = f"{a.arg}_0"
node.body = flatten([self.visit(s) for s in node.body])
return node

def visit_Name(self, node):
if node.id not in self.last_name:
self.last_name[node.id] = f"{node.id}_0"
if node.id not in self.args and isinstance(node.ctx, ast.Store):
self.last_name[node.id] = f"{node.id}_0"
else:
return node
if isinstance(node.ctx, ast.Store):
self.write_name(node.id)
node.id = f"{self.last_name[node.id]}"
Expand All @@ -50,17 +59,23 @@ def visit_If(self, node):
false_name = dict(self.last_name)

test = self.visit(node.test)
self.cond_stack.append(test)
result = flatten([self.visit(s) for s in node.body])
true_name = dict(self.last_name)

if node.orelse:
self.last_name = false_name
self.cond_stack[-1] = ast.UnaryOp(ast.Invert(),
self.cond_stack[-1])
result += flatten([self.visit(s) for s in node.orelse])
false_name = dict(self.last_name)

self.cond_stack.pop()

self.last_name = {**true_name, **false_name}
for var in self.last_name.keys():
if var in true_name and var in false_name and true_name[var] != false_name[var]:
if var in true_name and var in false_name and \
true_name[var] != false_name[var]:
phi_args = [
ast.Name(false_name[var], ast.Load()),
ast.Name(true_name[var], ast.Load())
Expand All @@ -75,15 +90,69 @@ def visit_If(self, node):
], [])))
return result

def visit_Return(self, node):
self.return_values.append(self.cond_stack)
node.value = self.visit(node.value)
return node


class TransformReturn(ast.NodeTransformer):
def __init__(self):
self.counter = -1

def visit_Return(self, node):
self.counter += 1
name = f"__magma_ssa_return_value_{self.counter}"
return ast.Assign([ast.Name(name, ast.Store())], node.value)


class MoveReturn(ast.NodeTransformer):
def visit_Return(self, node):
return ast.Assign(
[ast.Name(f"__magma_ssa_return_value", ast.Store())],
node.value
)


def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list,
defn_env)
# tree = MoveReturn().visit(tree)
# tree.body.append(
# ast.Return(ast.Name("__magma_ssa_return_value", ast.Load())))
ssa_visitor = SSAVisitor()
tree = ssa_visitor.visit(tree)
return_transformer = TransformReturn()
tree = return_transformer.visit(tree)
num_return_values = len(ssa_visitor.return_values)
for i in reversed(range(num_return_values)):
conds = ssa_visitor.return_values[i]
name = f"__magma_ssa_return_value_{num_return_values - i - 1}"
if i == num_return_values or not conds:
if isinstance(tree.returns, ast.Tuple):
tree.body.append(ast.Assign(
[ast.Tuple([ast.Name(f"O{i}", ast.Store())
for i in range(len(tree.returns.elts))], ast.Store())],
ast.Name(name, ast.Load())
))
else:
tree.body.append(ast.Assign([ast.Name("O", ast.Load)],
ast.Name(name, ast.Load())))
else:
prev_name = ssa_visitor.return_values[i + 1]
cond = conds[-1]
for c in conds[:-1]:
c = ast.BinOp(cond, ast.And(), c)
tree.body.append(ast.Call(ast.Name("phi", ast.Load()), [
ast.List([name, prev_name], ast.Load()),
cond
], []))
return tree


@ast_utils.inspect_enclosing_env
def ssa(defn_env : dict, fn : types.FunctionType):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
def ssa(defn_env: dict, fn: types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env)
tree = SSAVisitor().visit(tree)
tree = convert_tree_to_ssa(tree, defn_env)
tree.body.append(ast.Return(ast.Name("O", ast.Load())))
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)
1 change: 1 addition & 0 deletions magma/syntax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from magma.syntax.combinational import combinational
32 changes: 17 additions & 15 deletions magma/circuit_def.py → magma/syntax/combinational.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import textwrap
from collections import OrderedDict
from magma.logging import debug, warning, error
from .backend.util import make_relative
from magma.backend.util import make_relative
import astor
import os
import traceback
from .debug import debug_info
import magma.ast_utils as ast_utils
import types
from magma.debug import debug_info
from magma.ssa import convert_tree_to_ssa


class CircuitDefinitionSyntaxError(Exception):
Expand Down Expand Up @@ -72,7 +73,7 @@ def visit_If(self, node):
" taking the last value (first value is ignored)", self.filename, node.lineno + self.starting_line, self.lines[node.lineno])
orelse_seen.add(key)
seen[key].value = ast.Call(
ast.Name("mux", ast.Load()),
ast.Name("phi", ast.Load()),
[ast.List([stmt.value, seen[key].value],
ast.Load()), node.test],
[])
Expand All @@ -88,7 +89,7 @@ def visit_IfExp(self, node):
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)
return ast.Call(
ast.Name("mux", ast.Load()),
ast.Name("phi", ast.Load()),
[ast.List([node.orelse, node.body],
ast.Load()), node.test],
[])
Expand Down Expand Up @@ -168,30 +169,31 @@ def visit_Name(self, node):
ast.Load())
return node

def visit_Return(self, node):
node.value = self.visit(node.value)
if isinstance(node.value, ast.Tuple):
return ast.Assign(
[ast.Tuple([ast.Name(f"O{i}", ast.Store())
for i in range(len(node.value.elts))], ast.Store())],
node.value
)
return ast.Assign([ast.Name("O", ast.Store())], node.value)
# def visit_Return(self, node):
# node.value = self.visit(node.value)
# if isinstance(node.value, ast.Tuple):
# return ast.Assign(
# [ast.Tuple([ast.Name(f"O{i}", ast.Store())
# for i in range(len(node.value.elts))], ast.Store())],
# node.value
# )
# return ast.Assign([ast.Name("O", ast.Store())], node.value)


@ast_utils.inspect_enclosing_env
def combinational(defn_env : dict, fn : types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree = convert_tree_to_ssa(tree, defn_env)
tree = FunctionToCircuitDefTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
tree = IfTransformer(inspect.getsourcefile(fn), inspect.getsourcelines(fn)).visit(tree)
tree = ast.fix_missing_locations(tree)
tree.decorator_list = ast_utils.filter_decorator(
combinational, tree.decorator_list, defn_env)
if "mux" not in defn_env:
if "phi" not in defn_env:
tree = ast.Module([
ast.parse("import magma as m").body[0],
ast.parse("from mantle import mux").body[0],
ast.parse("from mantle import mux as phi").body[0],
tree
])
source = "\n"
Expand Down

0 comments on commit f366ca2

Please sign in to comment.