-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
229 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import inspect | ||
import textwrap | ||
import ast | ||
import os | ||
import types | ||
import typing | ||
import astor | ||
import traceback | ||
|
||
|
||
def get_ast(obj): | ||
indented_program_txt = inspect.getsource(obj) | ||
program_txt = textwrap.dedent(indented_program_txt) | ||
return ast.parse(program_txt) | ||
|
||
def get_func_ast(obj : types.FunctionType): | ||
""" Implicitly strip ast.Module() surrounding the function """ | ||
return get_ast(obj).body[0] | ||
|
||
|
||
def compile_function_to_file(tree : typing.Union[ast.Module, ast.FunctionDef], | ||
func_name : str = None, defn_env : dict = {}): | ||
if isinstance(tree, ast.FunctionDef): | ||
if func_name is None: | ||
func_name = tree.name | ||
else: | ||
if func_name != tree.name: | ||
raise Exception("Passed in func_name that does not match the " | ||
"function being compiled. Got" | ||
f" func_name={func_name} expected" | ||
f" tree.name={tree.name}") | ||
elif isinstance(tree, ast.Module): | ||
if func_name is None: | ||
raise Exception("func_name required when passing in an ast.Module") | ||
os.makedirs(".magma", exist_ok=True) | ||
file_name = os.path.join(".magma", func_name + ".py") | ||
with open(file_name, "w") as fp: | ||
fp.write(astor.to_source(tree)) | ||
# exec(compile(tree, filename=file_name, mode="exec"), defn_env) | ||
try: | ||
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env) | ||
except: | ||
import sys | ||
tb = traceback.format_exc() | ||
print(tb) | ||
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {func_name}, see above") from None | ||
return defn_env[func_name] | ||
|
||
|
||
def inspect_enclosing_env(fn): | ||
""" | ||
Traverses the current call stack to get the current locals and globals in | ||
the environment. | ||
Possible Improvements: | ||
* Return a scope object that preserves the distinction between globals | ||
and locals. This isn't currently required by the code using it, but | ||
could be useful for other use cases. | ||
* Maintain the stack hierarchy. Again, not currently used, but could be | ||
useful. | ||
""" | ||
def wrapped(*args, **kwargs): | ||
stack = inspect.stack() | ||
enclosing_env = {} | ||
for i in range(1, len(stack)): | ||
enclosing_env.update(stack[i].frame.f_locals) | ||
enclosing_env.update(stack[i].frame.f_globals) | ||
return fn(enclosing_env, *args, **kwargs) | ||
return wrapped | ||
|
||
|
||
def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict): | ||
def _filter(node): | ||
code = compile(ast.Expression(node), filename="<string>", mode="eval") | ||
return eval(code, env) != decorator | ||
return list(filter(_filter, decorator_list)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from magma.ssa.ssa import ssa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import magma.ast_utils as ast_utils | ||
import ast | ||
import types | ||
from collections import defaultdict | ||
import inspect | ||
|
||
|
||
def flatten(l : list): | ||
""" | ||
Non-recursive flatten that ignores non-list children | ||
""" | ||
flat = [] | ||
for item in l: | ||
if not isinstance(item, list): | ||
item = [item] | ||
flat += item | ||
return flat | ||
|
||
|
||
class SSAVisitor(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.var_counter = defaultdict(lambda : -1) | ||
|
||
def visit_FunctionDef(self, node): | ||
node.body = flatten([self.visit(s) for s in node.body]) | ||
return node | ||
|
||
def visit_Name(self, node): | ||
if isinstance(node.ctx, ast.Store): | ||
self.var_counter[node.id] += 1 | ||
if isinstance(node.ctx, ast.Store) or node.id in self.var_counter: | ||
node.id += f"_{self.var_counter[node.id]}" | ||
return node | ||
|
||
def visit_If(self, node): | ||
false_var_counter = dict(self.var_counter) | ||
test = self.visit(node.test) | ||
result = flatten([self.visit(s) for s in node.body]) | ||
if node.orelse: | ||
false_var_counter = dict(self.var_counter) | ||
result += flatten([self.visit(s) for s in node.orelse]) | ||
for var, count in self.var_counter.items(): | ||
if var in false_var_counter and count != false_var_counter[var]: | ||
phi_args = [ | ||
ast.Name(f"{var}_{count}", ast.Load()), | ||
ast.Name(f"{var}_{false_var_counter[var]}", ast.Load()) | ||
] | ||
if not node.orelse: | ||
phi_args = [phi_args[1], phi_args[0]] | ||
result.append(ast.Assign( | ||
[ast.Name(f"{var}_{count + 1}", ast.Store())], | ||
ast.Call(ast.Name("phi", ast.Load()), [ | ||
ast.List(phi_args, ast.Load()), | ||
test | ||
], []))) | ||
self.var_counter[var] += 1 | ||
return result | ||
|
||
|
||
@ast_utils.inspect_enclosing_env | ||
def ssa(defn_env : dict, fn : types.FunctionType): | ||
stack = inspect.stack() | ||
defn_env = {} | ||
for i in range(1, len(stack)): | ||
defn_env.update(stack[i].frame.f_locals) | ||
defn_env.update(stack[i].frame.f_globals) | ||
tree = ast_utils.get_func_ast(fn) | ||
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env) | ||
tree = SSAVisitor().visit(tree) | ||
return ast_utils.compile_function_to_file(tree, defn_env=defn_env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import magma as m | ||
from magma.ssa import ssa | ||
import inspect | ||
|
||
|
||
def test_basic(): | ||
@ssa | ||
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
if S: | ||
x = I[0] | ||
else: | ||
x = I[1] | ||
return x | ||
|
||
assert inspect.getsource(basic_if) == """\ | ||
def basic_if(I: m.Bits(2), S: m.Bit) ->m.Bit: | ||
x_0 = I[0] | ||
x_1 = I[1] | ||
x_2 = phi([x_1, x_0], S) | ||
return x_2 | ||
""" | ||
|
||
|
||
def test_default(): | ||
@ssa | ||
def default(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
x = I[1] | ||
if S: | ||
x = I[0] | ||
return x | ||
|
||
assert inspect.getsource(default) == """\ | ||
def default(I: m.Bits(2), S: m.Bit) ->m.Bit: | ||
x_0 = I[1] | ||
x_1 = I[0] | ||
x_2 = phi([x_0, x_1], S) | ||
return x_2 | ||
""" | ||
|
||
|
||
def test_nested(): | ||
@ssa | ||
def nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit: | ||
if S[0]: | ||
if S[1]: | ||
x = I[0] | ||
else: | ||
x = I[1] | ||
else: | ||
if S[1]: | ||
x = I[2] | ||
else: | ||
x = I[3] | ||
return x | ||
|
||
assert inspect.getsource(nested) == """\ | ||
def nested(I: m.Bits(4), S: m.Bits(2)) ->m.Bit: | ||
x_0 = I[0] | ||
x_1 = I[1] | ||
x_2 = phi([x_1, x_0], S[1]) | ||
x_3 = I[2] | ||
x_4 = I[3] | ||
x_5 = phi([x_4, x_3], S[1]) | ||
x_6 = phi([x_5, x_2], S[0]) | ||
return x_6 | ||
""" |