Skip to content

Commit

Permalink
Merge 3ae1cb8 into 02db34c
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Jul 28, 2018
2 parents 02db34c + 3ae1cb8 commit 2a6f8d3
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 0 deletions.
77 changes: 77 additions & 0 deletions doc/circuit_definitions.md
@@ -0,0 +1,77 @@
# Combinational Circuit Definitions
Circuit defintions can be marked with the `@m.circuit.combinational` decorator.
This introduces a set of syntax level features for defining combinational magma
circuits, including the use of `if` statements to generate `Mux`es.

This feature is currently experimental, and therefor expect bugs to occur.
Please file any issues on the magma GitHub repository.

## If and Ternary
The condition must be an expression that evaluates to a `magma` value.

Basic example:
```python
class IfStatementBasic(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]
@m.circuit.combinational
def definition(io):
if io.S:
O = io.I[0]
else:
O = io.I[1]
m.wire(O, io.O)

```

Basic nesting:
```python
class IfStatementNested(m.Circuit):
IO = ["I", m.In(m.Bits(4)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]
@m.circuit.combinational
def definition(io):
if io.S[0]:
if io.S[1]:
O = io.I[0]
else:
O = io.I[1]
else:
if io.S[1]:
O = io.I[2]
else:
O = io.I[3]
m.wire(O, io.O)
```

Terneray expressions
```python
class Ternary(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]
@m.circuit.combinational
def definition(io):
m.wire(io.O, io.I[0] if io.S else io.I[1])
```

Nesting terneray expressions
```python
class TernaryNested(m.Circuit):
IO = ["I", m.In(m.Bits(3)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]
@m.circuit.combinational
def definition(io):
m.wire(io.O,
io.I[0] if io.S[0] else io.I[1] if io.S[1] else io.I[2])
```

Things that aren't supported:
* Using anything other than an assignment statement in the if/else body
* Assigning to a variable only once in the if or else body (not both). We could
support this if the variable is already defined in the enclosing scope, for
example using a default value
```
x = 3
if S:
x = 4
```
* This brings up another issue, which is that it doesn't support a default
value. (So the above code would break even if x was assigned in the else
block.
* If without an else (for the same reason as the above)
2 changes: 2 additions & 0 deletions magma/circuit.py
Expand Up @@ -547,3 +547,5 @@ def wrapped(*args, **kwargs):
result._generator_arguments = GeneratorArguments(args, kwargs)
return result
return wrapped

from magma.circuit_def import combinational
102 changes: 102 additions & 0 deletions magma/circuit_def.py
@@ -0,0 +1,102 @@
import functools
import ast
import inspect
import textwrap
from collections import OrderedDict
from magma.logging import warning, debug
import astor


class CircuitDefinitionSyntaxError(Exception):
pass


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)


class IfTransformer(ast.NodeTransformer):
def flatten(self, _list):
"""1-deep flatten"""
flat_list = []
for item in _list:
if isinstance(item, list):
flat_list.extend(item)
else:
flat_list.append(item)
return flat_list

def visit_If(self, node):
# Flatten in case there's a nest If statement that returns a list
node.body = self.flatten(map(self.visit, node.body))
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.orelse = self.flatten(map(self.visit, node.orelse))
seen = OrderedDict()
for stmt in node.body:
if not isinstance(stmt, ast.Assign):
# TODO: Print info from original source file/line
raise CircuitDefinitionSyntaxError(
f"Expected only assignment statements in if statement, got"
f" {type(stmt)}")
if len(stmt.targets) > 1:
raise NotImplementedError("Assigning more than one value")
key = ast.dump(stmt.targets[0])
if key in seen:
# TODO: Print the line number
warning("Assigning to value twice inside `if` block,"
" taking the last value (first value is ignored)")
seen[key] = stmt
orelse_seen = set()
for stmt in node.orelse:
key = ast.dump(stmt.targets[0])
if key in seen:
if key in orelse_seen:
warning("Assigning to value twice inside `else` block,"
" taking the last value (first value is ignored)")
orelse_seen.add(key)
seen[key].value = ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([seen[key].value, stmt.value],
ast.Load()), node.test],
[])
else:
raise NotImplementedError("Assigning to a variable once in"
" `else` block (not in then block)")
return [node for node in seen.values()]

def visit_IfExp(self, node):
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)
return ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([node.body, node.orelse],
ast.Load()), node.test],
[])


def combinational(fn):
stack = inspect.stack()
defn_locals = stack[1].frame.f_locals
defn_globals = stack[1].frame.f_globals
tree = get_ast(fn)
tree = IfTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
# TODO: Only remove @m.circuit.combinational, there could be others
tree.body[0].decorator_list = []
debug(astor.to_source(tree))
exec(compile(tree, filename="<ast>", mode="exec"), defn_globals,
defn_locals)

fn = defn_locals[fn.__name__]

@classmethod
@functools.wraps(fn)
def wrapped(io):
return fn(io)

return wrapped
5 changes: 5 additions & 0 deletions magma/logging.py
Expand Up @@ -25,6 +25,10 @@ def get_original_wire_call_stack_frame():
return frame.frame


def debug(message, *args, **kwargs):
log.debug(message, *args, **kwargs)


def info(message, *args, **kwargs):
log.info(message, *args, **kwargs)

Expand All @@ -43,3 +47,4 @@ def error(message, include_wire_traceback=False, *args, **kwargs):
print(message, file=sys.stderr, *args, **kwargs)
if include_wire_traceback:
sys.stderr.write("="*80 + "\n")

1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -15,6 +15,7 @@
"magma.testing"
],
install_requires=[
"astor",
"six",
"mako",
"pyverilog",
Expand Down
6 changes: 6 additions & 0 deletions tests/gold/test_if_statement_basic.v
@@ -0,0 +1,6 @@
module TestIfStatementBasic (input [1:0] I, input S, output O);
wire inst0_O;
Mux2 inst0 (.I0(I[0]), .I1(I[1]), .S(S), .O(inst0_O));
assign O = inst0_O;
endmodule

10 changes: 10 additions & 0 deletions tests/gold/test_if_statement_nested.v
@@ -0,0 +1,10 @@
module TestIfStatementNested (input [3:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
wire inst2_O;
Mux2 inst0 (.I0(I[0]), .I1(I[1]), .S(S[1]), .O(inst0_O));
Mux2 inst1 (.I0(I[2]), .I1(I[3]), .S(S[1]), .O(inst1_O));
Mux2 inst2 (.I0(inst0_O), .I1(inst1_O), .S(S[0]), .O(inst2_O));
assign O = inst2_O;
endmodule

6 changes: 6 additions & 0 deletions tests/gold/test_ternary.v
@@ -0,0 +1,6 @@
module TestTernary (input [1:0] I, input S, output O);
wire inst0_O;
Mux2 inst0 (.I0(I[0]), .I1(I[1]), .S(S), .O(inst0_O));
assign O = inst0_O;
endmodule

8 changes: 8 additions & 0 deletions tests/gold/test_ternary_nested.v
@@ -0,0 +1,8 @@
module TestTernaryNested (input [2:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
Mux2 inst0 (.I0(I[1]), .I1(I[2]), .S(S[1]), .O(inst0_O));
Mux2 inst1 (.I0(I[0]), .I1(inst0_O), .S(S[0]), .O(inst1_O));
assign O = inst1_O;
endmodule

8 changes: 8 additions & 0 deletions tests/gold/test_ternary_nested2.v
@@ -0,0 +1,8 @@
module TestTernaryNested2 (input [2:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
Mux2 inst0 (.I0(I[0]), .I1(I[1]), .S(S[0]), .O(inst0_O));
Mux2 inst1 (.I0(inst0_O), .I1(I[2]), .S(S[1]), .O(inst1_O));
assign O = inst1_O;
endmodule

143 changes: 143 additions & 0 deletions tests/test_circuit_def.py
@@ -0,0 +1,143 @@
import magma as m
from magma.testing import check_files_equal


@m.cache_definition
def DefineMux(height=2, width=None):
if width is None:
T = m.Bit
else:
T = m.Bits(width)

io = []
for i in range(height):
io += ["I{}".format(i), m.In(T)]
if height == 2:
select_type = m.Bit
else:
select_type = m.Bits(m.bitutils.clog2(height))
io += ['S', m.In(select_type)]
io += ['O', m.Out(T)]

class _Mux(m.Circuit):
name = f"Mux{height}" + (f"_x{width}" if width else "")
IO = io
return _Mux


def Mux(height=2, width=None, **kwargs):
return DefineMux(height, width)(**kwargs)


def get_length(value):
if isinstance(value, m._BitType):
return None
elif isinstance(value, m.ArrayType):
return len(value)
else:
raise NotImplementedError(f"Cannot get_length of {type(value)}")


def mux(I, S):
if isinstance(S, int):
return I[S]
elif S.const():
return I[m.bitutils.seq2int(S.bits())]
return Mux(len(I), get_length(I[0]))(*I, S)


def test_if_statement_basic():
class TestIfStatementBasic(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]

@m.circuit.combinational
def definition(io):
if io.S:
O = io.I[0]
# m.wire(io.O, io.I[0])
# io.O = io.I[0]
# TODO: Alternative syntax
# io.O <= io.I[0]
# TODO: Or we could use wire syntax
# wire(io.O, io.I[0])
else:
O = io.I[1]
m.wire(O, io.O)
m.compile("build/test_if_statement_basic", TestIfStatementBasic)
assert check_files_equal(__file__, f"build/test_if_statement_basic.v",
f"gold/test_if_statement_basic.v")


def test_if_statement_nested():
class TestIfStatementNested(m.Circuit):
IO = ["I", m.In(m.Bits(4)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit.combinational
def definition(io):
if io.S[0]:
if io.S[1]:
O = io.I[0]
else:
O = io.I[1]
# m.wire(io.O, io.I[0])
# io.O = io.I[0]
# TODO: Alternative syntax
# io.O <= io.I[0]
# TODO: Or we could use wire syntax
# wire(io.O, io.I[0])
else:
if io.S[1]:
O = io.I[2]
else:
O = io.I[3]
m.wire(O, io.O)
m.compile("build/test_if_statement_nested", TestIfStatementNested)
assert check_files_equal(__file__, f"build/test_if_statement_nested.v",
f"gold/test_if_statement_nested.v")


def test_ternary():
class TestTernary(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]

@m.circuit.combinational
def definition(io):
m.wire(io.O, io.I[0] if io.S else io.I[1])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary", TestTernary)
assert check_files_equal(__file__, f"build/test_ternary.v",
f"gold/test_ternary.v")


def test_ternary_nested():
class TestTernaryNested(m.Circuit):
IO = ["I", m.In(m.Bits(3)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit.combinational
def definition(io):
m.wire(io.O,
io.I[0] if io.S[0] else io.I[1] if io.S[1] else io.I[2])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary_nested", TestTernaryNested)
assert check_files_equal(__file__, f"build/test_ternary_nested.v",
f"gold/test_ternary_nested.v")


def test_ternary_nested2():
class TestTernaryNested2(m.Circuit):
IO = ["I", m.In(m.Bits(3)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit.combinational
def definition(io):
m.wire(io.O,
(io.I[0] if io.S[0] else io.I[1]) if io.S[1] else io.I[2])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary_nested2", TestTernaryNested2)
assert check_files_equal(__file__, f"build/test_ternary_nested2.v",
f"gold/test_ternary_nested2.v")

0 comments on commit 2a6f8d3

Please sign in to comment.