-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize ssa decorator #324
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import inspect | ||
import textwrap | ||
import ast | ||
import os | ||
import types | ||
import typing | ||
import astor | ||
import traceback | ||
|
||
|
||
def get_ast(obj): | ||
indented_program_txt = inspect.getsource(obj) | ||
program_txt = textwrap.dedent(indented_program_txt) | ||
return ast.parse(program_txt) | ||
|
||
def get_func_ast(obj : types.FunctionType): | ||
""" Implicitly strip ast.Module() surrounding the function """ | ||
return get_ast(obj).body[0] | ||
|
||
|
||
def compile_function_to_file(tree : typing.Union[ast.Module, ast.FunctionDef], | ||
func_name : str = None, defn_env : dict = {}): | ||
if isinstance(tree, ast.FunctionDef): | ||
if func_name is None: | ||
func_name = tree.name | ||
else: | ||
if func_name != tree.name: | ||
raise Exception("Passed in func_name that does not match the " | ||
"function being compiled. Got" | ||
f" func_name={func_name} expected" | ||
f" tree.name={tree.name}") | ||
elif isinstance(tree, ast.Module): | ||
if func_name is None: | ||
raise Exception("func_name required when passing in an ast.Module") | ||
os.makedirs(".magma", exist_ok=True) | ||
file_name = os.path.join(".magma", func_name + ".py") | ||
with open(file_name, "w") as fp: | ||
fp.write(astor.to_source(tree)) | ||
# exec(compile(tree, filename=file_name, mode="exec"), defn_env) | ||
try: | ||
exec(compile(astor.to_source(tree), filename=file_name, mode="exec"), defn_env) | ||
except: | ||
import sys | ||
tb = traceback.format_exc() | ||
print(tb) | ||
raise Exception(f"Error occured when compiling and executing m.circuit.combinational function {func_name}, see above") from None | ||
return defn_env[func_name] | ||
|
||
|
||
def inspect_enclosing_env(fn): | ||
""" | ||
Traverses the current call stack to get the current locals and globals in | ||
the environment. | ||
|
||
Possible Improvements: | ||
* Return a scope object that preserves the distinction between globals | ||
and locals. This isn't currently required by the code using it, but | ||
could be useful for other use cases. | ||
* Maintain the stack hierarchy. Again, not currently used, but could be | ||
useful. | ||
""" | ||
def wrapped(*args, **kwargs): | ||
stack = inspect.stack() | ||
enclosing_env = {} | ||
for i in range(1, len(stack)): | ||
enclosing_env.update(stack[i].frame.f_locals) | ||
enclosing_env.update(stack[i].frame.f_globals) | ||
return fn(enclosing_env, *args, **kwargs) | ||
return wrapped | ||
|
||
|
||
def filter_decorator(decorator : typing.Callable, decorator_list : typing.List[ast.AST], env : dict): | ||
def _filter(node): | ||
code = compile(ast.Expression(node), filename="<string>", mode="eval") | ||
return eval(code, env) != decorator | ||
return list(filter(_filter, decorator_list)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from magma.ssa.ssa import ssa |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import magma.ast_utils as ast_utils | ||
import ast | ||
import types | ||
from collections import defaultdict | ||
import inspect | ||
|
||
|
||
def flatten(l : list): | ||
""" | ||
Non-recursive flatten that ignores non-list children | ||
""" | ||
flat = [] | ||
for item in l: | ||
if not isinstance(item, list): | ||
item = [item] | ||
flat += item | ||
return flat | ||
|
||
|
||
class SSAVisitor(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.var_counter = defaultdict(lambda : -1) | ||
|
||
def visit_FunctionDef(self, node): | ||
node.body = flatten([self.visit(s) for s in node.body]) | ||
return node | ||
|
||
def visit_Name(self, node): | ||
if isinstance(node.ctx, ast.Store): | ||
self.var_counter[node.id] += 1 | ||
if isinstance(node.ctx, ast.Store) or node.id in self.var_counter: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Redundant check on if it's a store, no? |
||
node.id += f"_{self.var_counter[node.id]}" | ||
return node | ||
|
||
def visit_If(self, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this (the This way you would preserve the store dict across both branches in the if, and reset the load mapping dict for each branch. This would also eliminate the need to go through everything in the false_var_counter and remap variable store names. The phi logic would still need to be roughly the same though, but you would have the two load mappings and then you would insert a phi for every entry where the last load name is different between the two dicts for a variable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's another bug case I think:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this phi be the other way around?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind on that last point. |
||
false_var_counter = dict(self.var_counter) | ||
test = self.visit(node.test) | ||
result = flatten([self.visit(s) for s in node.body]) | ||
if node.orelse: | ||
false_var_counter = dict(self.var_counter) | ||
result += flatten([self.visit(s) for s in node.orelse]) | ||
for var, count in self.var_counter.items(): | ||
if var in false_var_counter and count != false_var_counter[var]: | ||
phi_args = [ | ||
ast.Name(f"{var}_{count}", ast.Load()), | ||
ast.Name(f"{var}_{false_var_counter[var]}", ast.Load()) | ||
] | ||
if not node.orelse: | ||
phi_args = [phi_args[1], phi_args[0]] | ||
result.append(ast.Assign( | ||
[ast.Name(f"{var}_{count + 1}", ast.Store())], | ||
ast.Call(ast.Name("phi", ast.Load()), [ | ||
ast.List(phi_args, ast.Load()), | ||
test | ||
], []))) | ||
self.var_counter[var] += 1 | ||
return result | ||
|
||
|
||
@ast_utils.inspect_enclosing_env | ||
def ssa(defn_env : dict, fn : types.FunctionType): | ||
stack = inspect.stack() | ||
defn_env = {} | ||
for i in range(1, len(stack)): | ||
defn_env.update(stack[i].frame.f_locals) | ||
defn_env.update(stack[i].frame.f_globals) | ||
tree = ast_utils.get_func_ast(fn) | ||
tree.decorator_list = ast_utils.filter_decorator(ssa, tree.decorator_list, defn_env) | ||
tree = SSAVisitor().visit(tree) | ||
return ast_utils.compile_function_to_file(tree, defn_env=defn_env) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import magma as m | ||
from magma.ssa import ssa | ||
import inspect | ||
|
||
|
||
def test_basic(): | ||
@ssa | ||
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
if S: | ||
x = I[0] | ||
else: | ||
x = I[1] | ||
return x | ||
|
||
assert inspect.getsource(basic_if) == """\ | ||
def basic_if(I: m.Bits(2), S: m.Bit) ->m.Bit: | ||
x_0 = I[0] | ||
x_1 = I[1] | ||
x_2 = phi([x_1, x_0], S) | ||
return x_2 | ||
""" | ||
|
||
|
||
def test_default(): | ||
@ssa | ||
def default(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
x = I[1] | ||
if S: | ||
x = I[0] | ||
return x | ||
|
||
assert inspect.getsource(default) == """\ | ||
def default(I: m.Bits(2), S: m.Bit) ->m.Bit: | ||
x_0 = I[1] | ||
x_1 = I[0] | ||
x_2 = phi([x_0, x_1], S) | ||
return x_2 | ||
""" | ||
|
||
|
||
def test_nested(): | ||
@ssa | ||
def nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit: | ||
if S[0]: | ||
if S[1]: | ||
x = I[0] | ||
else: | ||
x = I[1] | ||
else: | ||
if S[1]: | ||
x = I[2] | ||
else: | ||
x = I[3] | ||
return x | ||
|
||
assert inspect.getsource(nested) == """\ | ||
def nested(I: m.Bits(4), S: m.Bits(2)) ->m.Bit: | ||
x_0 = I[0] | ||
x_1 = I[1] | ||
x_2 = phi([x_1, x_0], S[1]) | ||
x_3 = I[2] | ||
x_4 = I[3] | ||
x_5 = phi([x_4, x_3], S[1]) | ||
x_6 = phi([x_5, x_2], S[0]) | ||
return x_6 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to rename inputs.
For example, this is handled incorrectly:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to preserve interface names, I would suggest inserting input renamings as the first thing in the SSA-ified code, so the output might reasonably look like this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first example I posted generates this right now:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a test for the edge case I mentioned in the second comment:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just going to piggyback on this and say it might be a bigger issue than you think, if this is intended to be used in conjunction with other passes, I wouldn't be surprised if there would be renaming of variables happening in other passes.
For example, creating a wrapper module - I might rename the wires on the internal one by appending _0 to the inputs if I need to modify the wires before feeding them into the internal module. Rigel did something like this. Combined with the ssa pass things would break in ways that would be pretty odd.