Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize ssa decorator #324

Merged
merged 5 commits into from
Jan 12, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 76 additions & 0 deletions magma/ast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import inspect
import textwrap
import ast
import os
import types
import typing
import astor
import traceback


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)

def get_func_ast(obj : types.FunctionType):
""" Implicitly strip ast.Module() surrounding the function """
return get_ast(obj).body[0]


def compile_function_to_file(tree : typing.Union[ast.Module, ast.FunctionDef],
func_name : str = None, defn_env : dict = {}):
if isinstance(tree, ast.FunctionDef):
if func_name is None:
func_name = tree.name
else:
if func_name != tree.name:
raise Exception("Passed in func_name that does not match the "
"function being compiled. Got"
f" func_name={func_name} expected"
f" tree.name={tree.name}")
elif isinstance(tree, ast.Module):
if func_name is None:
raise Exception("func_name required when passing in an ast.Module")
os.makedirs(".magma", exist_ok=True)
file_name = os.path.join(".magma", func_name + ".py")
with open(file_name, "w") as fp:
fp.write(astor.to_source(tree))
# exec(compile(tree, filename=file_name, mode="exec"), defn_env)
try:
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env)
except:
import sys
tb = traceback.format_exc()
print(tb)
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {func_name}, see above") from None
return defn_env[func_name]


def inspect_enclosing_env(fn):
"""
Traverses the current call stack to get the current locals and globals in
the environment.

Possible Improvements:
* Return a scope object that preserves the distinction between globals
and locals. This isn't currently required by the code using it, but
could be useful for other use cases.
* Maintain the stack hierarchy. Again, not currently used, but could be
useful.
"""
def wrapped(*args, **kwargs):
stack = inspect.stack()
enclosing_env = {}
for i in range(1, len(stack)):
enclosing_env.update(stack[i].frame.f_locals)
enclosing_env.update(stack[i].frame.f_globals)
return fn(enclosing_env, *args, **kwargs)
return wrapped


def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict):
def _filter(node):
code = compile(ast.Expression(node), filename="<string>", mode="eval")
return eval(code, env) != decorator
return list(filter(_filter, decorator_list))
47 changes: 15 additions & 32 deletions magma/circuit_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import os
import traceback
from .debug import debug_info
import magma.ast_utils as ast_utils
import types


class CircuitDefinitionSyntaxError(Exception):
Expand All @@ -19,12 +21,6 @@ def m_dot(attr):
return ast.Attribute(ast.Name("m", ast.Load()), attr, ast.Load())


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)


def report_transformer_warning(message, filename, lineno, line):
warning(f"\033[1m{make_relative(filename)}:{lineno}: {message}")
warning(line)
Expand Down Expand Up @@ -183,41 +179,27 @@ def visit_Return(self, node):
return ast.Assign([ast.Name("O", ast.Store())], node.value)




def combinational(fn):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
tree = get_ast(fn)
@ast_utils.inspect_enclosing_env
def combinational(defn_env : dict, fn : types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree = FunctionToCircuitDefTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
tree = IfTransformer(inspect.getsourcefile(fn), inspect.getsourcelines(fn)).visit(tree)
tree = ast.fix_missing_locations(tree)
# TODO: Only remove @m.circuit.combinational, there could be others
tree.body[0].decorator_list = []
tree.decorator_list = ast_utils.filter_decorator(
combinational, tree.decorator_list, defn_env)
if "mux" not in defn_env:
tree.body.insert(0, ast.parse("from mantle import mux").body[0])
tree = ast.Module([
ast.parse("import magma as m").body[0],
ast.parse("from mantle import mux").body[0],
tree
])
source = "\n"
for i, line in enumerate(astor.to_source(tree).splitlines()):
source += f" {i}: {line}\n"

debug(source)
os.makedirs(".magma", exist_ok=True)
file_name = os.path.join(".magma", fn.__name__ + ".py")
with open(file_name, "w") as fp:
fp.write(astor.to_source(tree))
# exec(compile(tree, filename=file_name, mode="exec"), defn_env)
try:
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env)
except:
import sys
tb = traceback.format_exc()
print(tb)
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {fn.__name__}, see above") from None
circuit_def = defn_env[fn.__name__]
circuit_def = ast_utils.compile_function_to_file(tree, fn.__name__, defn_env)
circuit_def.debug_info = debug_info(circuit_def.debug_info.filename,
circuit_def.debug_info.lineno,
inspect.getmodule(fn))
Expand All @@ -227,5 +209,6 @@ def func(*args, **kwargs):
return circuit_def()(*args, **kwargs)
func.__name__ = fn.__name__
func.__qualname__ = fn.__name__
func.circuit_definition = circuit_def
# Provide a mechanism for accessing the underlying circuit definition
setattr(func, "circuit_definition", circuit_def)
return func
1 change: 1 addition & 0 deletions magma/ssa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from magma.ssa.ssa import ssa
71 changes: 71 additions & 0 deletions magma/ssa/ssa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import magma.ast_utils as ast_utils
import ast
import types
from collections import defaultdict
import inspect


def flatten(l : list):
"""
Non-recursive flatten that ignores non-list children
"""
flat = []
for item in l:
if not isinstance(item, list):
item = [item]
flat += item
return flat


class SSAVisitor(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.var_counter = defaultdict(lambda : -1)

def visit_FunctionDef(self, node):
node.body = flatten([self.visit(s) for s in node.body])
return node

def visit_Name(self, node):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to rename inputs.

For example, this is handled incorrectly:

@ssa
def default(I: m.Bits(2), x_0: m.Bit) -> m.Bit:
    x = I[1]
    if x_0:
        x = I[0]
    return x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to preserve interface names, I would suggest inserting input renamings as the first thing in the SSA-ified code, so the output might reasonably look like this:

def default(I: m.Bits(2), x_0: m.Bit) -> m.Bit:
    I_0, x_0_0 = I, x_0 # this is done in a single tuple assign to avoid things like def default(x_0: m.Bits(2), x_0_0: m.Bit), which would make x_0_0 = x_0, x_0_0_0 = x_0_0, which is also wrong
    # rest of ssa translated code here using new input names...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first example I posted generates this right now:

def default(I: m.Bits(2), x_0: m.Bit) ->m.Bit:
    x_0 = I[1]
    x_1 = I[0]
    x_2 = phi([x_0, x_1], x_0)
    return x_2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a test for the edge case I mentioned in the second comment:

@ssa
def default(x_0: m.Bits(2), x_0_0: m.Bit) -> m.Bit:
    x = x_0[1]
    if x_0_0:
        x = x_0[0]
    return x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just going to piggyback on this and say it might be a bigger issue than you think, if this is intended to be used in conjunction with other passes, I wouldn't be surprised if there would be renaming of variables happening in other passes.

For example, creating a wrapper module - I might rename the wires on the internal one by appending _0 to the inputs if I need to modify the wires before feeding them into the internal module. Rigel did something like this. Combined with the ssa pass things would break in ways that would be pretty odd.

if isinstance(node.ctx, ast.Store):
self.var_counter[node.id] += 1
if isinstance(node.ctx, ast.Store) or node.id in self.var_counter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(node.ctx, ast.Store) or node.id in self.var_counter:

Redundant check on if it's a store, no?

node.id += f"_{self.var_counter[node.id]}"
return node

def visit_If(self, node):
Copy link
Collaborator

@hofstee hofstee Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this (the visit_If function) might be made simpler if you kept track of two things, a dict with the counters for stores, and a dict mapping the variable to its last name for loads.

This way you would preserve the store dict across both branches in the if, and reset the load mapping dict for each branch. This would also eliminate the need to go through everything in the false_var_counter and remap variable store names.

The phi logic would still need to be roughly the same though, but you would have the two load mappings and then you would insert a phi for every entry where the last load name is different between the two dicts for a variable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's another bug case I think:

def test_wat():
    @ssa
    def basic_if(I: m.Bit, S: m.Bit) -> m.Bit:
        x = I
        if S:
            x = x
        else:
            x = x
        return x

    print(inspect.getsource(basic_if))
    assert inspect.getsource(basic_if) == """\
def basic_if(I: m.Bits(2), S: m.Bit) ->m.Bit:
    x_0 = I
    x_1 = x_0
    x_2 = x_0
    x_3 = phi([x_2, x_1], S)
    return x_3
"""

Copy link
Collaborator

@hofstee hofstee Jan 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this phi be the other way around?

def test_default():
    @ssa
    def default(I: m.Bits(2), S: m.Bit) -> m.Bit:
        x = I[1]
        if S:
            x = I[0]
        return x

    assert inspect.getsource(default) == """\
def default(I: m.Bits(2), S: m.Bit) ->m.Bit:
    x_0 = I[1]
    x_1 = I[0]
    x_2 = phi([x_0, x_1], S)
    return x_2
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind on that last point.

false_var_counter = dict(self.var_counter)
test = self.visit(node.test)
result = flatten([self.visit(s) for s in node.body])
if node.orelse:
false_var_counter = dict(self.var_counter)
result += flatten([self.visit(s) for s in node.orelse])
for var, count in self.var_counter.items():
if var in false_var_counter and count != false_var_counter[var]:
phi_args = [
ast.Name(f"{var}_{count}", ast.Load()),
ast.Name(f"{var}_{false_var_counter[var]}", ast.Load())
]
if not node.orelse:
phi_args = [phi_args[1], phi_args[0]]
result.append(ast.Assign(
[ast.Name(f"{var}_{count + 1}", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.List(phi_args, ast.Load()),
test
], [])))
self.var_counter[var] += 1
return result


@ast_utils.inspect_enclosing_env
def ssa(defn_env : dict, fn : types.FunctionType):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
tree = ast_utils.get_func_ast(fn)
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env)
tree = SSAVisitor().visit(tree)
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)
66 changes: 66 additions & 0 deletions tests/test_ssa/test_ssa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import magma as m
from magma.ssa import ssa
import inspect


def test_basic():
@ssa
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert inspect.getsource(basic_if) == """\
def basic_if(I: m.Bits(2), S: m.Bit) ->m.Bit:
x_0 = I[0]
x_1 = I[1]
x_2 = phi([x_1, x_0], S)
return x_2
"""


def test_default():
@ssa
def default(I: m.Bits(2), S: m.Bit) -> m.Bit:
x = I[1]
if S:
x = I[0]
return x

assert inspect.getsource(default) == """\
def default(I: m.Bits(2), S: m.Bit) ->m.Bit:
x_0 = I[1]
x_1 = I[0]
x_2 = phi([x_0, x_1], S)
return x_2
"""


def test_nested():
@ssa
def nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit:
if S[0]:
if S[1]:
x = I[0]
else:
x = I[1]
else:
if S[1]:
x = I[2]
else:
x = I[3]
return x

assert inspect.getsource(nested) == """\
def nested(I: m.Bits(4), S: m.Bits(2)) ->m.Bit:
x_0 = I[0]
x_1 = I[1]
x_2 = phi([x_1, x_0], S[1])
x_3 = I[2]
x_4 = I[3]
x_5 = phi([x_4, x_3], S[1])
x_6 = phi([x_5, x_2], S[0])
return x_6
"""