Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named phi #426

Merged
merged 4 commits into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 37 additions & 2 deletions magma/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,44 @@ def wrapped(*args, **kwargs):
return fn(enclosing_env, *args, **kwargs)
return wrapped

class NameCollector(ast.NodeVisitor):
def __init__(self):
self.names = set()

def visit_Name(self, node: ast.Name):
self.names.add(node.id)

def visit_FunctionDef(self, node: ast.FunctionDef):
self.names.add(node.name)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.names.add(node.name)

def visit_ClassDef(self, node: ast.ClassDef):
self.names.add(node.name)

def gen_free_name(tree: ast.AST, defn_env: dict, prefix: str = '__auto_name_'):
visitor = NameCollector()
visitor.visit(tree)
used_names = visitor.names | defn_env.keys()
f_str = prefix+'{}'
c = 0
name = f_str.format(c)
while name in used_names:
c += 1
name = f_str.format(c)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a timeout (e.g. if c >= 1000: raise Exception). Unlikely to come up and would incur a performance cost, but you never know?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As python ints don't overflow don't see the need to give up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we can assume that the space of names is finite, so we will eventually find a name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(because the length of the program must be finite)


return name

def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict):
def _filter(node):
code = compile(ast.Expression(node), filename="<string>", mode="eval")
return eval(code, env) != decorator
if isinstance(node, ast.Call):
expr = ast.Expression(node.func)
elif isinstance(node, ast.Name):
expr = ast.Expression(node)
else:
return True
code = compile(expr, filename="<string>", mode="eval")
e = eval(code, env)
return e != decorator
return list(filter(_filter, decorator_list))
41 changes: 29 additions & 12 deletions magma/ssa/ssa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import magma.ast_utils as ast_utils
import ast
import types
import typing
from collections import defaultdict
import astor
import functools


def flatten(l: list):
Expand All @@ -18,13 +19,14 @@ def flatten(l: list):


class SSAVisitor(ast.NodeTransformer):
def __init__(self):
def __init__(self, phi_name):
super().__init__()
self.last_name = defaultdict(lambda: "")
self.var_counter = defaultdict(lambda: -1)
self.args = []
self.cond_stack = []
self.return_values = []
self.phi_name = phi_name

def write_name(self, var):
self.var_counter[var] += 1
Expand Down Expand Up @@ -87,7 +89,7 @@ def visit_If(self, node):
self.write_name(var)
result.append(ast.Assign(
[ast.Name(self.last_name[var], ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(self.phi_name, ast.Load()), [
ast.List(phi_args, ast.Load()),
test
], [])))
Expand Down Expand Up @@ -117,13 +119,11 @@ def visit_Return(self, node):
)


def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list,
defn_env)
def convert_tree_to_ssa(tree: ast.AST, defn_env: dict, phi_name: str = "phi"):
# tree = MoveReturn().visit(tree)
# tree.body.append(
# ast.Return(ast.Name("__magma_ssa_return_value", ast.Load())))
ssa_visitor = SSAVisitor()
ssa_visitor = SSAVisitor(phi_name)
tree = ssa_visitor.visit(tree)
return_transformer = TransformReturn()
tree = return_transformer.visit(tree)
Expand All @@ -149,7 +149,7 @@ def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
for i in range(len(tree.returns.elts)):
tree.body.append(ast.Assign(
[ast.Name(f"O{i}", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(phi_name, ast.Load()), [
ast.List([
ast.Name(f"O{i}", ast.Load()),
ast.Subscript(ast.Name(name, ast.Load()),
Expand All @@ -161,16 +161,33 @@ def convert_tree_to_ssa(tree: ast.AST, defn_env: dict):
else:
tree.body.append(ast.Assign(
[ast.Name("O", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.Call(ast.Name(phi_name, ast.Load()), [
ast.List([ast.Name("O", ast.Load()), ast.Name(name, ast.Load())],
ast.Load()), cond], []))
)
return tree, ssa_visitor.args


@ast_utils.inspect_enclosing_env
def ssa(defn_env: dict, fn: types.FunctionType):
def _ssa(defn_env: dict, phi: typing.Union[str, typing.Callable],
fn: typing.Callable):
tree = ast_utils.get_func_ast(fn)
tree, _ = convert_tree_to_ssa(tree, defn_env)
tree.decorator_list = ast_utils.filter_decorator(ssa,
tree.decorator_list,
defn_env)

if isinstance(phi, str):
phi_name = phi
else:
phi_name = ast_utils.gen_free_name(tree, defn_env)

tree, _ = convert_tree_to_ssa(tree, defn_env, phi_name=phi_name)

if not isinstance(phi, str):
defn_env[phi_name] = phi

tree.body.append(ast.Return(ast.Name("O", ast.Load())))
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)

@ast_utils.inspect_enclosing_env
def ssa(defn_env: dict, phi: typing.Union[str, typing.Callable] = "phi"):
return functools.partial(_ssa, defn_env, phi)
76 changes: 71 additions & 5 deletions tests/test_ssa/test_ssa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import magma as m
from magma.ssa import ssa
import inspect
import re


def test_basic():
@ssa
@ssa()
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
Expand All @@ -24,7 +25,7 @@ def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:


def test_wat():
@ssa
@ssa()
def basic_if(I: m.Bit, S: m.Bit) -> m.Bit:
x = I
if S:
Expand All @@ -46,7 +47,7 @@ def basic_if(I_0: m.Bit, S_0: m.Bit) ->m.Bit:


def test_default():
@ssa
@ssa()
def default(I: m.Bits[2], S: m.Bit) -> m.Bit:
x = I[1]
if S:
Expand All @@ -65,7 +66,7 @@ def default(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:


def test_nested():
@ssa
@ssa()
def nested(I: m.Bits[4], S: m.Bits[2]) -> m.Bit:
if S[0]:
if S[1]:
Expand Down Expand Up @@ -94,7 +95,7 @@ def nested(I_0: m.Bits[4], S_0: m.Bits[2]) ->m.Bit:
"""

def test_weird():
@ssa
@ssa()
def default(I: m.Bits[2], x_0: m.Bit) -> m.Bit:
x = I[1]
if x_0:
Expand All @@ -111,5 +112,70 @@ def default(I_0: m.Bits[2], x_0_0: m.Bit) ->m.Bit:
return O
"""


def test_phi_name():
@ssa(phi='foo')
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = foo([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

def test_phi_custom():
def bar(args, select):
return 'bar'

@ssa(phi=bar)
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert basic_if([0, 1], 0) == 'bar'

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = __auto_name_0([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

def test_phi_lambda():
@ssa(phi=lambda args, s : args[s])
def basic_if(I: m.Bits[2], S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert basic_if([0, 1], 0) == 1

assert inspect.getsource(basic_if) == """\
def basic_if(I_0: m.Bits[2], S_0: m.Bit) ->m.Bit:
x_0 = I_0[0]
x_1 = I_0[1]
x_2 = __auto_name_0([x_1, x_0], S_0)
__magma_ssa_return_value_0 = x_2
O = __magma_ssa_return_value_0
return O
"""

# test_wat()
# test_weird()