Skip to content

Commit

Permalink
Merge 426179d into d2e3160
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Dec 14, 2018
2 parents d2e3160 + 426179d commit d6ef40c
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 32 deletions.
76 changes: 76 additions & 0 deletions magma/ast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import inspect
import textwrap
import ast
import os
import types
import typing
import astor
import traceback


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)

def get_func_ast(obj : types.FunctionType):
""" Implicitly strip ast.Module() surrounding the function """
return get_ast(obj).body[0]


def compile_function_to_file(tree : typing.Union[ast.Module, ast.FunctionDef],
func_name : str = None, defn_env : dict = {}):
if isinstance(tree, ast.FunctionDef):
if func_name is None:
func_name = tree.name
else:
if func_name != tree.name:
raise Exception("Passed in func_name that does not match the "
"function being compiled. Got"
f" func_name={func_name} expected"
f" tree.name={tree.name}")
elif isinstance(tree, ast.Module):
if func_name is None:
raise Exception("func_name required when passing in an ast.Module")
os.makedirs(".magma", exist_ok=True)
file_name = os.path.join(".magma", func_name + ".py")
with open(file_name, "w") as fp:
fp.write(astor.to_source(tree))
# exec(compile(tree, filename=file_name, mode="exec"), defn_env)
try:
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env)
except:
import sys
tb = traceback.format_exc()
print(tb)
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {func_name}, see above") from None
return defn_env[func_name]


def inspect_enclosing_env(fn):
"""
Traverses the current call stack to get the current locals and globals in
the environment.
Possible Improvements:
* Return a scope object that preserves the distinction between globals
and locals. This isn't currently required by the code using it, but
could be useful for other use cases.
* Maintain the stack hierarchy. Again, not currently used, but could be
useful.
"""
def wrapped(*args, **kwargs):
stack = inspect.stack()
enclosing_env = {}
for i in range(1, len(stack)):
enclosing_env.update(stack[i].frame.f_locals)
enclosing_env.update(stack[i].frame.f_globals)
return fn(enclosing_env, *args, **kwargs)
return wrapped


def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict):
def _filter(node):
code = compile(ast.Expression(node), filename="<string>", mode="eval")
return eval(code, env) != decorator
return list(filter(_filter, decorator_list))
47 changes: 15 additions & 32 deletions magma/circuit_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import os
import traceback
from .debug import debug_info
import magma.ast_utils as ast_utils
import types


class CircuitDefinitionSyntaxError(Exception):
Expand All @@ -19,12 +21,6 @@ def m_dot(attr):
return ast.Attribute(ast.Name("m", ast.Load()), attr, ast.Load())


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)


def report_transformer_warning(message, filename, lineno, line):
warning(f"\033[1m{make_relative(filename)}:{lineno}: {message}")
warning(line)
Expand Down Expand Up @@ -183,41 +179,27 @@ def visit_Return(self, node):
return ast.Assign([ast.Name("O", ast.Store())], node.value)




def combinational(fn):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
tree = get_ast(fn)
@ast_utils.inspect_enclosing_env
def combinational(defn_env : dict, fn : types.FunctionType):
tree = ast_utils.get_func_ast(fn)
tree = FunctionToCircuitDefTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
tree = IfTransformer(inspect.getsourcefile(fn), inspect.getsourcelines(fn)).visit(tree)
tree = ast.fix_missing_locations(tree)
# TODO: Only remove @m.circuit.combinational, there could be others
tree.body[0].decorator_list = []
tree.decorator_list = ast_utils.filter_decorator(
combinational, tree.decorator_list, defn_env)
if "mux" not in defn_env:
tree.body.insert(0, ast.parse("from mantle import mux").body[0])
tree = ast.Module([
ast.parse("import magma as m").body[0],
ast.parse("from mantle import mux").body[0],
tree
])
source = "\n"
for i, line in enumerate(astor.to_source(tree).splitlines()):
source += f" {i}: {line}\n"

debug(source)
os.makedirs(".magma", exist_ok=True)
file_name = os.path.join(".magma", fn.__name__ + ".py")
with open(file_name, "w") as fp:
fp.write(astor.to_source(tree))
# exec(compile(tree, filename=file_name, mode="exec"), defn_env)
try:
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env)
except:
import sys
tb = traceback.format_exc()
print(tb)
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {fn.__name__}, see above") from None
circuit_def = defn_env[fn.__name__]
circuit_def = ast_utils.compile_function_to_file(tree, fn.__name__, defn_env)
circuit_def.debug_info = debug_info(circuit_def.debug_info.filename,
circuit_def.debug_info.lineno,
inspect.getmodule(fn))
Expand All @@ -227,5 +209,6 @@ def func(*args, **kwargs):
return circuit_def()(*args, **kwargs)
func.__name__ = fn.__name__
func.__qualname__ = fn.__name__
func.circuit_definition = circuit_def
# Provide a mechanism for accessing the underlying circuit definition
setattr(func, "circuit_definition", circuit_def)
return func
1 change: 1 addition & 0 deletions magma/ssa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from magma.ssa.ssa import ssa
71 changes: 71 additions & 0 deletions magma/ssa/ssa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import magma.ast_utils as ast_utils
import ast
import types
from collections import defaultdict
import inspect


def flatten(l : list):
"""
Non-recursive flatten that ignores non-list children
"""
flat = []
for item in l:
if not isinstance(item, list):
item = [item]
flat += item
return flat


class SSAVisitor(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.var_counter = defaultdict(lambda : -1)

def visit_FunctionDef(self, node):
node.body = flatten([self.visit(s) for s in node.body])
return node

def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.var_counter[node.id] += 1
if isinstance(node.ctx, ast.Store) or node.id in self.var_counter:
node.id += f"_{self.var_counter[node.id]}"
return node

def visit_If(self, node):
false_var_counter = dict(self.var_counter)
test = self.visit(node.test)
result = flatten([self.visit(s) for s in node.body])
if node.orelse:
false_var_counter = dict(self.var_counter)
result += flatten([self.visit(s) for s in node.orelse])
for var, count in self.var_counter.items():
if var in false_var_counter and count != false_var_counter[var]:
phi_args = [
ast.Name(f"{var}_{count}", ast.Load()),
ast.Name(f"{var}_{false_var_counter[var]}", ast.Load())
]
if not node.orelse:
phi_args = [phi_args[1], phi_args[0]]
result.append(ast.Assign(
[ast.Name(f"{var}_{count + 1}", ast.Store())],
ast.Call(ast.Name("phi", ast.Load()), [
ast.List(phi_args, ast.Load()),
test
], [])))
self.var_counter[var] += 1
return result


@ast_utils.inspect_enclosing_env
def ssa(defn_env : dict, fn : types.FunctionType):
stack = inspect.stack()
defn_env = {}
for i in range(1, len(stack)):
defn_env.update(stack[i].frame.f_locals)
defn_env.update(stack[i].frame.f_globals)
tree = ast_utils.get_func_ast(fn)
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env)
tree = SSAVisitor().visit(tree)
return ast_utils.compile_function_to_file(tree, defn_env=defn_env)
66 changes: 66 additions & 0 deletions tests/test_ssa/test_ssa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import magma as m
from magma.ssa import ssa
import inspect


def test_basic():
@ssa
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

assert inspect.getsource(basic_if) == """\
def basic_if(I: m.Bits(2), S: m.Bit) ->m.Bit:
x_0 = I[0]
x_1 = I[1]
x_2 = phi([x_1, x_0], S)
return x_2
"""


def test_default():
@ssa
def default(I: m.Bits(2), S: m.Bit) -> m.Bit:
x = I[1]
if S:
x = I[0]
return x

assert inspect.getsource(default) == """\
def default(I: m.Bits(2), S: m.Bit) ->m.Bit:
x_0 = I[1]
x_1 = I[0]
x_2 = phi([x_0, x_1], S)
return x_2
"""


def test_nested():
@ssa
def nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit:
if S[0]:
if S[1]:
x = I[0]
else:
x = I[1]
else:
if S[1]:
x = I[2]
else:
x = I[3]
return x

assert inspect.getsource(nested) == """\
def nested(I: m.Bits(4), S: m.Bits(2)) ->m.Bit:
x_0 = I[0]
x_1 = I[1]
x_2 = phi([x_1, x_0], S[1])
x_3 = I[2]
x_4 = I[3]
x_5 = phi([x_4, x_3], S[1])
x_6 = phi([x_5, x_2], S[0])
return x_6
"""

0 comments on commit d6ef40c

Please sign in to comment.