Skip to content

Commit

Permalink
Merge c3c9083 into d56ce62
Browse files Browse the repository at this point in the history
  • Loading branch information
cdonovick committed Jul 23, 2019
2 parents d56ce62 + c3c9083 commit e3ead7d
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 19 deletions.
39 changes: 37 additions & 2 deletions magma/ast_utils.py
Expand Up @@ -73,9 +73,44 @@ def wrapped(*args, **kwargs):
return fn(enclosing_env, *args, **kwargs)
return wrapped

class NameCollector(ast.NodeVisitor):
def __init__(self):
self.names = set()

def visit_Name(self, node: ast.Name):
self.names.add(node.id)

def visit_FunctionDef(self, node: ast.FunctionDef):
self.names.add(node.name)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.names.add(node.name)

def visit_ClassDef(self, node: ast.ClassDef):
self.names.add(node.name)

def gen_free_name(tree: ast.AST, defn_env: dict, prefix: str = '__auto_name_'):
visitor = NameCollector()
visitor.visit(tree)
used_names = visitor.names | defn_env.keys()
f_str = prefix+'{}'
c = 0
name = f_str.format(c)
while name in used_names:
c += 1
name = f_str.format(c)

return name

def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict):
def _filter(node):
code = compile(ast.Expression(node), filename="<string>", mode="eval")
return eval(code, env) != decorator
if isinstance(node, ast.Call):
expr = ast.Expression(node.func)
elif isinstance(node, ast.Name):
expr = ast.Expression(node)
else:
return True
code = compile(expr, filename="<string>", mode="eval")
e = eval(code, env)
return e != decorator
return list(filter(_filter, decorator_list))
41 changes: 29 additions & 12 deletions magma/ssa/ssa.py
@@ -1,8 +1,9 @@
import magma.ast_utils as ast_utils
import ast
import types
import typing
from collections import defaultdict
import astor
import functools


def flatten(l: list):
Expand All @@ -18,13 +19,14 @@ def flatten(l: list):


class SSAVisitor(ast.NodeTransformer):
def __init__(self):
def __init__(self, phi_name):
super().__init__()
self.last_name = defaultdict(lambda: "")
self.var_counter = defaultdict(lambda: -1)
self.args = []
self.cond_stack = []
self.return_values = []
self.phi_name = phi_name

def write_name(self, var):
self.var_counter[var] += 1
Expand Down Expand Up @@ -87,7 +89,7 @@ def visit_If(self, node):
self.write_name(var)
result.append(ast.Assign(
[ast.Name(self.last_name[var], ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(self.phi_name, ast.Load()), [
ast.List(phi_args, ast.Load()),
test
], [])))
Expand Down Expand Up @@ -117,13 +119,11 @@ def visit_Return(self, node):
)


def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list,
defn_env)
def convert_tree_to_ssa(tree: ast.AST, defn_env: dict, phi_name: str = "phi"):
# tree = MoveReturn().visit(tree)
# tree.body.append(
# ast.Return(ast.Name("__magma_ssa_return_value", ast.Load())))
ssa_visitor = SSAVisitor()
ssa_visitor = SSAVisitor(phi_name)
tree = ssa_visitor.visit(tree)
return_transformer = TransformReturn()
tree = return_transformer.visit(tree)
Expand All @@ -149,7 +149,7 @@ def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
for i in range(len(tree.returns.elts)):
tree.body.append(ast.Assign(
[ast.Name(f"O{i}", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(phi_name, ast.Load()), [
ast.List([
ast.Name(f"O{i}", ast.Load()),
ast.Subscript(ast.Name(name, ast.Load()),
Expand All @@ -161,16 +161,33 @@ def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
else:
tree.body.append(ast.Assign(
[ast.Name("O", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(phi_name, ast.Load()), [
ast.List([ast.Name("O", ast.Load()), ast.Name(name, ast.Load())],
ast.Load()), cond], []))
)
return tree, ssa_visitor.args


@ast_utils.inspect_enclosing_env
def ssa(defn_env: dict, fn: types.FunctionType):
def _ssa(defn_env: dict, phi: typing.Union[str, typing.Callable],
fn: typing.Callable):
tree = ast_utils.get_func_ast(fn)
tree, _ = convert_tree_to_ssa(tree, defn_env)
tree.decorator_list = ast_utils.filter_decorator(ssa,
tree.decorator_list,
defn_env)

if isinstance(phi, str):
phi_name = phi
else:
phi_name = ast_utils.gen_free_name(tree, defn_env)

tree, _ = convert_tree_to_ssa(tree, defn_env, phi_name=phi_name)

if not isinstance(phi, str):
defn_env[phi_name] = phi

tree.body.append(ast.Return(ast.Name("O", ast.Load())))
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)

@ast_utils.inspect_enclosing_env
def ssa(defn_env: dict, phi: typing.Union[str, typing.Callable] = "phi"):
return functools.partial(_ssa, defn_env, phi)
76 changes: 71 additions & 5 deletions tests/test_ssa/test_ssa.py
@@ -1,10 +1,11 @@
import magma as m
from magma.ssa import ssa
import inspect
import re


def test_basic():
@ssa
@ssa()
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
Expand All @@ -24,7 +25,7 @@ def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:


def test_wat():
@ssa
@ssa()
def basic_if(I: m.Bit, S: m.Bit) -> m.Bit:
x = I
if S:
Expand All @@ -46,7 +47,7 @@ def basic_if(I_0: m.Bit, S_0: m.Bit) ->m.Bit:


def test_default():
@ssa
@ssa()
def default(I: m.Bits[2], S: m.Bit) -> m.Bit:
x = I[1]
if S:
Expand All @@ -65,7 +66,7 @@ def default(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:


def test_nested():
@ssa
@ssa()
def nested(I: m.Bits[4], S: m.Bits[2]) -> m.Bit:
if S[0]:
if S[1]:
Expand Down Expand Up @@ -94,7 +95,7 @@ def nested(I_0: m.Bits[4], S_0: m.Bits[2]) ->m.Bit:
"""

def test_weird():
@ssa
@ssa()
def default(I: m.Bits[2], x_0: m.Bit) -> m.Bit:
x = I[1]
if x_0:
Expand All @@ -111,5 +112,70 @@ def default(I_0: m.Bits[2], x_0_0: m.Bit) ->m.Bit:
return O
"""


def test_phi_name():
@ssa(phi='foo')
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = foo([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

def test_phi_custom():
def bar(args, select):
return 'bar'

@ssa(phi=bar)
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert basic_if([0, 1], 0) == 'bar'

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = __auto_name_0([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

def test_phi_lambda():
@ssa(phi=lambda args, s : args[s])
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert basic_if([0, 1], 0) == 1

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = __auto_name_0([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

# test_wat()
# test_weird()

0 comments on commit e3ead7d

Please sign in to comment.