Skip to content

Commit

Permalink
Merge pull request #224 from phanrahan/if-statements
Browse files Browse the repository at this point in the history
Preliminary support for if statements and ternary expressions
  • Loading branch information
leonardt committed Aug 2, 2018
2 parents 02db34c + f55820b commit 24083ef
Show file tree
Hide file tree
Showing 26 changed files with 893 additions and 481 deletions.
107 changes: 107 additions & 0 deletions doc/circuit_definitions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Combinational Circuit Definitions
Circuit defintions can be marked with the `@m.circuit.combinational` decorator.
This introduces a set of syntax level features for defining combinational magma
circuits, including the use of `if` statements to generate `Mux`es.

This feature is currently experimental, and therefor expect bugs to occur.
Please file any issues on the magma GitHub repository.

## If and Ternary
The condition must be an expression that evaluates to a `magma` value.

Basic example:
```python
@m.circuit.combinational
def basic_if(I: m.Bits(2), S: m.Bit) -> m.Bit:
if S:
return I[0]
else:
return I[1]

```

Basic nesting:
```python
class IfStatementNested(m.Circuit):
@m.circuit.combinational
def if_statement_nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit:
if S[0]:
if S[1]:
return I[0]
else:
return I[1]
else:
if S[1]:
return I[2]
else:
return I[3]
```

Terneray expressions
```python
def ternary(I: m.Bits(2), S: m.Bit) -> m.Bit:
return I[0] if S else I[1]
```

Nesting terneray expressions
```python
@m.circuit.combinational
def ternary_nested(I: m.Bits(4), S: m.Bits(2)) -> m.Bit:
return I[0] if S[0] else I[1] if S[1] else I[2]
```

Things that aren't supported:
* Using anything other than an assignment statement in the if/else body
* Assigning to a variable only once in the if or else body (not both). We could
support this if the variable is already defined in the enclosing scope, for
example using a default value
```
x = 3
if S:
x = 4
```
* This brings up another issue, which is that it doesn't support a default
value. (So the above code would break even if x was assigned in the else
block.
* If without an else (for the same reason as the above)

## Function composition:
```
@m.circuit.combinational
def basic_if_function_call(I: m.Bits(2), S: m.Bit) -> m.Bit:
return basic_if(I, S)
```
Function calls must refer to another `m.circuit.combinational` element, or a
function that accepts magma values, define instances and wires values, and
returns a magma. Calling any other type of function has undefined behavior.

## Returning multiple values (tuples)

There are two ways to return multiple values, first is to use a Python tuple.
This is specified in the type signature as `(m.Type, m.Type, ...)`. In the
body of the definition, the values can be returned using the standard Python
tuple syntax. The circuit defined with a Python tuple as an output type will
default to the naming convetion `O0, O1, ...` for the output ports.

```python
@m.circuit.combinational
def return_py_tuple(I: m.Bits(2)) -> (m.Bit, m.Bit):
return I[0], I[1]
```

The other method is to use an `m.Tuple` (magma's tuple type). Again, this is
specified in the type signature, using `m.Tuple(m.Type, m.Type, ...)`. You can
also use the namedtuple pattern to give your multiple outputs explicit names
with `m.Tuple(O0=m.Bit, O1=m.Bit)`.


```python
@m.circuit.combinational
def return_magma_tuple(I: m.Bits(2)) -> m.Tuple(m.Bit, m.Bit):
return m.tuple_([I[0], I[1]])
```

```
def return_magma_named_tuple(I: m.Bits(2)) -> m.Tuple(x=m.Bit, y=m.Bit):
return m.namedtuple(x=I[0], y=I[1])
```
29 changes: 20 additions & 9 deletions magma/backend/coreir_.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def magma_port_to_coreir(port):
select = repr(port)

name = port.name
if isinstance(name, TupleRef):
# Prefix integer indexes for unnamed tuples (e.g. 0, 1, 2) with "_"
if name.index.isdigit():
select = select.split(".")
select[-1] = "_" + select[-1]
select = ".".join(select)
name = get_top_name(name)
if isinstance(name, DefnRef):
if name.defn.name != "":
Expand Down Expand Up @@ -84,8 +90,20 @@ def get_type(self, port, is_input):
if isinstance(port, (ArrayType, ArrayKind)):
_type = self.context.Array(port.N, self.get_type(port.T, is_input))
elif isinstance(port, (TupleType, TupleKind)):
_type = self.context.Record({k:self.get_type(t, is_input)
for (k,t) in zip(port.Ks, port.Ts)})
def to_string(k):
"""
Unnamed tuples have integer keys (e.g. 0, 1, 2),
we prefix them with "_" so they can be consumed by coreir's
Record type (key names are constrained such that they can't be
integers)
"""
if isinstance(k, int):
return f"_{k}"
return k
_type = self.context.Record({
to_string(k): self.get_type(t, is_input) for (k, t) in
zip(port.Ks, port.Ts)
})
elif is_input:
if isinstance(port, ClockType):
_type = self.context.named_types[("coreir", "clk")]
Expand Down Expand Up @@ -211,13 +229,6 @@ def compile_definition_to_module_definition(self, definition, module_definition)
if port.isoutput():
self.add_output_port(output_ports, port)


def get_select(value):
if value in [VCC, GND]:
return self.get_constant_instance(value, None, module_definition)
else:
return module_definition.select(output_ports[value])

for instance in definition.instances:
for name, port in instance.interface.ports.items():
if port.isinput():
Expand Down
2 changes: 2 additions & 0 deletions magma/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,5 @@ def wrapped(*args, **kwargs):
result._generator_arguments = GeneratorArguments(args, kwargs)
return result
return wrapped

from magma.circuit_def import combinational
201 changes: 201 additions & 0 deletions magma/circuit_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import functools
import ast
import inspect
import textwrap
from collections import OrderedDict
from magma.logging import warning, debug
import astor
# import astunparse


class CircuitDefinitionSyntaxError(Exception):
pass


def m_dot(attr):
return ast.Attribute(ast.Name("m", ast.Load()), attr, ast.Load())


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)


class IfTransformer(ast.NodeTransformer):
def flatten(self, _list):
"""1-deep flatten"""
flat_list = []
for item in _list:
if isinstance(item, list):
flat_list.extend(item)
else:
flat_list.append(item)
return flat_list

def visit_If(self, node):
# Flatten in case there's a nest If statement that returns a list
node.body = self.flatten(map(self.visit, node.body))
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.orelse = self.flatten(map(self.visit, node.orelse))
seen = OrderedDict()
for stmt in node.body:
if not isinstance(stmt, ast.Assign):
# TODO: Print info from original source file/line
raise CircuitDefinitionSyntaxError(
f"Expected only assignment statements in if statement, got"
f" {type(stmt)}")
if len(stmt.targets) > 1:
raise NotImplementedError("Assigning more than one value")
key = ast.dump(stmt.targets[0])
if key in seen:
# TODO: Print the line number
warning("Assigning to value twice inside `if` block,"
" taking the last value (first value is ignored)")
seen[key] = stmt
orelse_seen = set()
for stmt in node.orelse:
key = ast.dump(stmt.targets[0])
if key in seen:
if key in orelse_seen:
warning("Assigning to value twice inside `else` block,"
" taking the last value (first value is ignored)")
orelse_seen.add(key)
seen[key].value = ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([stmt.value, seen[key].value],
ast.Load()), node.test],
[])
else:
raise NotImplementedError("Assigning to a variable once in"
" `else` block (not in then block)")
return [node for node in seen.values()]

def visit_IfExp(self, node):
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)
return ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([node.orelse, node.body],
ast.Load()), node.test],
[])


class FunctionToCircuitDefTransformer(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.IO = set()

def visit(self, node):
new_node = super().visit(node)
if new_node is not node:
return ast.copy_location(new_node, node)
return node

def qualify(self, node, direction):
return ast.Call(m_dot(direction), [node], [])

def visit_FunctionDef(self, node):
names = [arg.arg for arg in node.args.args]
types = [arg.annotation for arg in node.args.args]
IO = []
for name, type_ in zip(names, types):
self.IO.add(name)
IO.extend([ast.Str(name),
self.qualify(type_, "In")])
if isinstance(node.returns, ast.Tuple):
for i, elt in enumerate(node.returns.elts):
IO.extend([ast.Str(f"O{i}"), self.qualify(elt, "Out")])
else:
IO.extend([ast.Str("O"), self.qualify(node.returns, "Out")])
IO = ast.List(IO, ast.Load())
node.body = [self.visit(s) for s in node.body]
if isinstance(node.returns, ast.Tuple):
for i, elt in enumerate(node.returns.elts):
node.body.append(ast.Expr(ast.Call(
m_dot("wire"),
[ast.Name(f"O{i}", ast.Load()),
ast.Attribute(ast.Name("io", ast.Load()), f"O{i}", ast.Load())],
[]
)))
else:
node.body.append(ast.Expr(ast.Call(
m_dot("wire"),
[ast.Name("O", ast.Load()),
ast.Attribute(ast.Name("io", ast.Load()), "O", ast.Load())],
[]
)))
# class {node.name}(m.Circuit):
# IO = {IO}
# @classmethod
# def definition(io):
# {body}
class_def = ast.ClassDef(
node.name,
[ast.Attribute(ast.Name("m", ast.Load()), "Circuit", ast.Load())],
[], [
ast.Assign([ast.Name("IO", ast.Store())], IO),
ast.FunctionDef(
"definition",
ast.arguments([ast.arg("io", None)],
None, [], [],
None, []),
node.body,
[ast.Name("classmethod", ast.Load())],
None
)

],
[])
return class_def

def visit_Name(self, node):
if node.id in self.IO:
return ast.Attribute(ast.Name("io", ast.Load()), node.id,
ast.Load())
return node

def visit_Return(self, node):
node.value = self.visit(node.value)
if isinstance(node.value, ast.Tuple):
return ast.Assign(
[ast.Tuple([ast.Name(f"O{i}", ast.Store())
for i in range(len(node.value.elts))], ast.Store())],
node.value
)
return ast.Assign([ast.Name("O", ast.Store())], node.value)




def combinational(fn):
stack = inspect.stack()
defn_locals = stack[1].frame.f_locals
defn_globals = stack[1].frame.f_globals
tree = get_ast(fn)
tree = FunctionToCircuitDefTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
tree = IfTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
# TODO: Only remove @m.circuit.combinational, there could be others
tree.body[0].decorator_list = []
if "mux" not in defn_globals and \
"mux" not in defn_locals:
tree.body.insert(0, ast.parse("from mantle import mux").body[0])
debug(astor.to_source(tree))
# debug(astunparse.dump(tree))
exec(compile(tree, filename="<ast>", mode="exec"), defn_globals,
defn_locals)

circuit_def = defn_locals[fn.__name__]

@functools.wraps(fn)
def func(*args, **kwargs):
return circuit_def()(*args, **kwargs)
func.__name__ = fn.__name__
func.__qualname__ = fn.__name__
func.circuit_definition = circuit_def
return func

0 comments on commit 24083ef

Please sign in to comment.