-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #224 from phanrahan/if-statements
Preliminary support for if statements and ternary expressions
- Loading branch information
Showing
26 changed files
with
893 additions
and
481 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Combinational Circuit Definitions | ||
Circuit defintions can be marked with the `@m.circuit.combinational` decorator. | ||
This introduces a set of syntax level features for defining combinational magma | ||
circuits, including the use of `if` statements to generate `Mux`es. | ||
|
||
This feature is currently experimental, and therefor expect bugs to occur. | ||
Please file any issues on the magma GitHub repository. | ||
|
||
## If and Ternary | ||
The condition must be an expression that evaluates to a `magma` value. | ||
|
||
Basic example: | ||
```python | ||
@m.circuit.combinational | ||
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
if S: | ||
return I[0] | ||
else: | ||
return I[1] | ||
|
||
``` | ||
|
||
Basic nesting: | ||
```python | ||
class IfStatementNested(m.Circuit): | ||
@m.circuit.combinational | ||
def if_statement_nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit: | ||
if S[0]: | ||
if S[1]: | ||
return I[0] | ||
else: | ||
return I[1] | ||
else: | ||
if S[1]: | ||
return I[2] | ||
else: | ||
return I[3] | ||
``` | ||
|
||
Terneray expressions | ||
```python | ||
def ternary(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
return I[0] if S else I[1] | ||
``` | ||
|
||
Nesting terneray expressions | ||
```python | ||
@m.circuit.combinational | ||
def ternary_nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit: | ||
return I[0] if S[0] else I[1] if S[1] else I[2] | ||
``` | ||
|
||
Things that aren't supported: | ||
* Using anything other than an assignment statement in the if/else body | ||
* Assigning to a variable only once in the if or else body (not both). We could | ||
support this if the variable is already defined in the enclosing scope, for | ||
example using a default value | ||
``` | ||
x = 3 | ||
if S: | ||
x = 4 | ||
``` | ||
* This brings up another issue, which is that it doesn't support a default | ||
value. (So the above code would break even if x was assigned in the else | ||
block. | ||
* If without an else (for the same reason as the above) | ||
|
||
## Function composition: | ||
``` | ||
@m.circuit.combinational | ||
def basic_if_function_call(I: m.Bits(2), S: m.Bit) -> m.Bit: | ||
return basic_if(I, S) | ||
``` | ||
Function calls must refer to another `m.circuit.combinational` element, or a | ||
function that accepts magma values, define instances and wires values, and | ||
returns a magma. Calling any other type of function has undefined behavior. | ||
|
||
## Returning multiple values (tuples) | ||
|
||
There are two ways to return multiple values, first is to use a Python tuple. | ||
This is specified in the type signature as `(m.Type, m.Type, ...)`. In the | ||
body of the definition, the values can be returned using the standard Python | ||
tuple syntax. The circuit defined with a Python tuple as an output type will | ||
default to the naming convetion `O0, O1, ...` for the output ports. | ||
|
||
```python | ||
@m.circuit.combinational | ||
def return_py_tuple(I: m.Bits(2)) -> (m.Bit, m.Bit): | ||
return I[0], I[1] | ||
``` | ||
|
||
The other method is to use an `m.Tuple` (magma's tuple type). Again, this is | ||
specified in the type signature, using `m.Tuple(m.Type, m.Type, ...)`. You can | ||
also use the namedtuple pattern to give your multiple outputs explicit names | ||
with `m.Tuple(O0=m.Bit, O1=m.Bit)`. | ||
|
||
|
||
```python | ||
@m.circuit.combinational | ||
def return_magma_tuple(I: m.Bits(2)) -> m.Tuple(m.Bit, m.Bit): | ||
return m.tuple_([I[0], I[1]]) | ||
``` | ||
|
||
``` | ||
def return_magma_named_tuple(I: m.Bits(2)) -> m.Tuple(x=m.Bit, y=m.Bit): | ||
return m.namedtuple(x=I[0], y=I[1]) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import functools | ||
import ast | ||
import inspect | ||
import textwrap | ||
from collections import OrderedDict | ||
from magma.logging import warning, debug | ||
import astor | ||
# import astunparse | ||
|
||
|
||
class CircuitDefinitionSyntaxError(Exception): | ||
pass | ||
|
||
|
||
def m_dot(attr): | ||
return ast.Attribute(ast.Name("m", ast.Load()), attr, ast.Load()) | ||
|
||
|
||
def get_ast(obj): | ||
indented_program_txt = inspect.getsource(obj) | ||
program_txt = textwrap.dedent(indented_program_txt) | ||
return ast.parse(program_txt) | ||
|
||
|
||
class IfTransformer(ast.NodeTransformer): | ||
def flatten(self, _list): | ||
"""1-deep flatten""" | ||
flat_list = [] | ||
for item in _list: | ||
if isinstance(item, list): | ||
flat_list.extend(item) | ||
else: | ||
flat_list.append(item) | ||
return flat_list | ||
|
||
def visit_If(self, node): | ||
# Flatten in case there's a nest If statement that returns a list | ||
node.body = self.flatten(map(self.visit, node.body)) | ||
if not hasattr(node, "orelse"): | ||
raise NotImplementedError("If without else") | ||
node.orelse = self.flatten(map(self.visit, node.orelse)) | ||
seen = OrderedDict() | ||
for stmt in node.body: | ||
if not isinstance(stmt, ast.Assign): | ||
# TODO: Print info from original source file/line | ||
raise CircuitDefinitionSyntaxError( | ||
f"Expected only assignment statements in if statement, got" | ||
f" {type(stmt)}") | ||
if len(stmt.targets) > 1: | ||
raise NotImplementedError("Assigning more than one value") | ||
key = ast.dump(stmt.targets[0]) | ||
if key in seen: | ||
# TODO: Print the line number | ||
warning("Assigning to value twice inside `if` block," | ||
" taking the last value (first value is ignored)") | ||
seen[key] = stmt | ||
orelse_seen = set() | ||
for stmt in node.orelse: | ||
key = ast.dump(stmt.targets[0]) | ||
if key in seen: | ||
if key in orelse_seen: | ||
warning("Assigning to value twice inside `else` block," | ||
" taking the last value (first value is ignored)") | ||
orelse_seen.add(key) | ||
seen[key].value = ast.Call( | ||
ast.Name("mux", ast.Load()), | ||
[ast.List([stmt.value, seen[key].value], | ||
ast.Load()), node.test], | ||
[]) | ||
else: | ||
raise NotImplementedError("Assigning to a variable once in" | ||
" `else` block (not in then block)") | ||
return [node for node in seen.values()] | ||
|
||
def visit_IfExp(self, node): | ||
if not hasattr(node, "orelse"): | ||
raise NotImplementedError("If without else") | ||
node.body = self.visit(node.body) | ||
node.orelse = self.visit(node.orelse) | ||
return ast.Call( | ||
ast.Name("mux", ast.Load()), | ||
[ast.List([node.orelse, node.body], | ||
ast.Load()), node.test], | ||
[]) | ||
|
||
|
||
class FunctionToCircuitDefTransformer(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.IO = set() | ||
|
||
def visit(self, node): | ||
new_node = super().visit(node) | ||
if new_node is not node: | ||
return ast.copy_location(new_node, node) | ||
return node | ||
|
||
def qualify(self, node, direction): | ||
return ast.Call(m_dot(direction), [node], []) | ||
|
||
def visit_FunctionDef(self, node): | ||
names = [arg.arg for arg in node.args.args] | ||
types = [arg.annotation for arg in node.args.args] | ||
IO = [] | ||
for name, type_ in zip(names, types): | ||
self.IO.add(name) | ||
IO.extend([ast.Str(name), | ||
self.qualify(type_, "In")]) | ||
if isinstance(node.returns, ast.Tuple): | ||
for i, elt in enumerate(node.returns.elts): | ||
IO.extend([ast.Str(f"O{i}"), self.qualify(elt, "Out")]) | ||
else: | ||
IO.extend([ast.Str("O"), self.qualify(node.returns, "Out")]) | ||
IO = ast.List(IO, ast.Load()) | ||
node.body = [self.visit(s) for s in node.body] | ||
if isinstance(node.returns, ast.Tuple): | ||
for i, elt in enumerate(node.returns.elts): | ||
node.body.append(ast.Expr(ast.Call( | ||
m_dot("wire"), | ||
[ast.Name(f"O{i}", ast.Load()), | ||
ast.Attribute(ast.Name("io", ast.Load()), f"O{i}", ast.Load())], | ||
[] | ||
))) | ||
else: | ||
node.body.append(ast.Expr(ast.Call( | ||
m_dot("wire"), | ||
[ast.Name("O", ast.Load()), | ||
ast.Attribute(ast.Name("io", ast.Load()), "O", ast.Load())], | ||
[] | ||
))) | ||
# class {node.name}(m.Circuit): | ||
# IO = {IO} | ||
# @classmethod | ||
# def definition(io): | ||
# {body} | ||
class_def = ast.ClassDef( | ||
node.name, | ||
[ast.Attribute(ast.Name("m", ast.Load()), "Circuit", ast.Load())], | ||
[], [ | ||
ast.Assign([ast.Name("IO", ast.Store())], IO), | ||
ast.FunctionDef( | ||
"definition", | ||
ast.arguments([ast.arg("io", None)], | ||
None, [], [], | ||
None, []), | ||
node.body, | ||
[ast.Name("classmethod", ast.Load())], | ||
None | ||
) | ||
|
||
], | ||
[]) | ||
return class_def | ||
|
||
def visit_Name(self, node): | ||
if node.id in self.IO: | ||
return ast.Attribute(ast.Name("io", ast.Load()), node.id, | ||
ast.Load()) | ||
return node | ||
|
||
def visit_Return(self, node): | ||
node.value = self.visit(node.value) | ||
if isinstance(node.value, ast.Tuple): | ||
return ast.Assign( | ||
[ast.Tuple([ast.Name(f"O{i}", ast.Store()) | ||
for i in range(len(node.value.elts))], ast.Store())], | ||
node.value | ||
) | ||
return ast.Assign([ast.Name("O", ast.Store())], node.value) | ||
|
||
|
||
|
||
|
||
def combinational(fn): | ||
stack = inspect.stack() | ||
defn_locals = stack[1].frame.f_locals | ||
defn_globals = stack[1].frame.f_globals | ||
tree = get_ast(fn) | ||
tree = FunctionToCircuitDefTransformer().visit(tree) | ||
tree = ast.fix_missing_locations(tree) | ||
tree = IfTransformer().visit(tree) | ||
tree = ast.fix_missing_locations(tree) | ||
# TODO: Only remove @m.circuit.combinational, there could be others | ||
tree.body[0].decorator_list = [] | ||
if "mux" not in defn_globals and \ | ||
"mux" not in defn_locals: | ||
tree.body.insert(0, ast.parse("from mantle import mux").body[0]) | ||
debug(astor.to_source(tree)) | ||
# debug(astunparse.dump(tree)) | ||
exec(compile(tree, filename="<ast>", mode="exec"), defn_globals, | ||
defn_locals) | ||
|
||
circuit_def = defn_locals[fn.__name__] | ||
|
||
@functools.wraps(fn) | ||
def func(*args, **kwargs): | ||
return circuit_def()(*args, **kwargs) | ||
func.__name__ = fn.__name__ | ||
func.__qualname__ = fn.__name__ | ||
func.circuit_definition = circuit_def | ||
return func |
Oops, something went wrong.