Skip to content

Commit

Permalink
Merge 91714fb into 02db34c
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Jul 28, 2018
2 parents 02db34c + 91714fb commit 4672372
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 0 deletions.
3 changes: 3 additions & 0 deletions magma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ def set_mantle_target(t):
if mantle_target is not None and mantle_target != t:
warning('changing mantle target', mantle_target, t )
mantle_target = t


from .circuit_def import circuit_def
103 changes: 103 additions & 0 deletions magma/circuit_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import functools
import ast
import inspect
import textwrap
from collections import OrderedDict
from magma.logging import warning
import astor
import inspect


class CircuitDefinitionSyntaxError(Exception):
pass


def get_ast(obj):
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)


class IfTransformer(ast.NodeTransformer):
def flatten(self, _list):
"""1-deep flatten"""
flat_list = []
for item in _list:
if isinstance(item, list):
flat_list.extend(item)
else:
flat_list.append(item)
return flat_list

def visit_If(self, node):
# Flatten in case there's a nest If statement that returns a list
node.body = self.flatten(map(self.visit, node.body))
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.orelse = self.flatten(map(self.visit, node.orelse))
seen = OrderedDict()
for stmt in node.body:
if not isinstance(stmt, ast.Assign):
# TODO: Print info from original source file/line
raise CircuitDefinitionSyntaxError(
f"Expected only assignment statements in if statement, got"
f" {type(stmt)}")
if len(stmt.targets) > 1:
raise NotImplementedError("Assigning more than one value")
key = ast.dump(stmt.targets[0])
if key in seen:
# TODO: Print the line number
warning("Assigning to value twice inside `if` block,"
" taking the last value (first value is ignored)")
seen[key] = stmt
orelse_seen = set()
for stmt in node.orelse:
key = ast.dump(stmt.targets[0])
if key in seen:
if key in orelse_seen:
warning("Assigning to value twice inside `else` block,"
" taking the last value (first value is ignored)")
orelse_seen.add(key)
seen[key].value = ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([seen[key].value, stmt.value],
ast.Load()), node.test],
[])
else:
raise NotImplementedError("Assigning to a variable once in"
" `else` block (not in then block)")
return [node for node in seen.values()]

def visit_IfExp(self, node):
if not hasattr(node, "orelse"):
raise NotImplementedError("If without else")
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)
return ast.Call(
ast.Name("mux", ast.Load()),
[ast.List([node.body, node.orelse],
ast.Load()), node.test],
[])


def circuit_def(fn):
stack = inspect.stack()
defn_locals = stack[1].frame.f_locals
defn_globals = stack[1].frame.f_globals
tree = get_ast(fn)
tree = IfTransformer().visit(tree)
tree = ast.fix_missing_locations(tree)
# TODO: Only remove @m.circuit_def, there could be others
tree.body[0].decorator_list = []
# print(astor.to_source(tree))
exec(compile(tree, filename="<ast>", mode="exec"), defn_globals,
defn_locals)

fn = defn_locals[fn.__name__]

@classmethod
@functools.wraps(fn)
def wrapped(io):
return fn(io)

return wrapped
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"magma.testing"
],
install_requires=[
"astor",
"six",
"mako",
"pyverilog",
Expand Down
6 changes: 6 additions & 0 deletions tests/gold/test_if_statement_basic.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module TestIfStatementBasic (input [1:0] I, input S, output O);
wire inst0_O;
Mux2xNone inst0 (.I0(I[0]), .I1(I[1]), .S(S), .O(inst0_O));
assign O = inst0_O;
endmodule

10 changes: 10 additions & 0 deletions tests/gold/test_if_statement_nested.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module TestIfStatementNested (input [3:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
wire inst2_O;
Mux2xNone inst0 (.I0(I[0]), .I1(I[1]), .S(S[1]), .O(inst0_O));
Mux2xNone inst1 (.I0(I[2]), .I1(I[3]), .S(S[1]), .O(inst1_O));
Mux2xNone inst2 (.I0(inst0_O), .I1(inst1_O), .S(S[0]), .O(inst2_O));
assign O = inst2_O;
endmodule

6 changes: 6 additions & 0 deletions tests/gold/test_ternary.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module TestTernary (input [1:0] I, input S, output O);
wire inst0_O;
Mux2xNone inst0 (.I0(I[0]), .I1(I[1]), .S(S), .O(inst0_O));
assign O = inst0_O;
endmodule

8 changes: 8 additions & 0 deletions tests/gold/test_ternary_nested.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module TestTernaryNested (input [2:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
Mux2xNone inst0 (.I0(I[1]), .I1(I[2]), .S(S[1]), .O(inst0_O));
Mux2xNone inst1 (.I0(I[0]), .I1(inst0_O), .S(S[0]), .O(inst1_O));
assign O = inst1_O;
endmodule

8 changes: 8 additions & 0 deletions tests/gold/test_ternary_nested2.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module TestTernaryNested2 (input [2:0] I, input [1:0] S, output O);
wire inst0_O;
wire inst1_O;
Mux2xNone inst0 (.I0(I[0]), .I1(I[1]), .S(S[0]), .O(inst0_O));
Mux2xNone inst1 (.I0(inst0_O), .I1(I[2]), .S(S[1]), .O(inst1_O));
assign O = inst1_O;
endmodule

143 changes: 143 additions & 0 deletions tests/test_circuit_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import magma as m
from magma.testing import check_files_equal


@m.cache_definition
def DefineMux(height=2, width=None):
if width is None:
T = m.Bit
else:
T = m.Bits(width)

io = []
for i in range(height):
io += ["I{}".format(i), m.In(T)]
if height == 2:
select_type = m.Bit
else:
select_type = m.Bits(m.bitutils.clog2(height))
io += ['S', m.In(select_type)]
io += ['O', m.Out(T)]

class _Mux(m.Circuit):
name = "Mux{}x{}".format(height, width)
IO = io
return _Mux


def Mux(height=2, width=None, **kwargs):
return DefineMux(height, width)(**kwargs)


def get_length(value):
if isinstance(value, m._BitType):
return None
elif isinstance(value, m.ArrayType):
return len(value)
else:
raise NotImplementedError(f"Cannot get_length of {type(value)}")


def mux(I, S):
if isinstance(S, int):
return I[S]
elif S.const():
return I[m.bitutils.seq2int(S.bits())]
return Mux(len(I), get_length(I[0]))(*I, S)


def test_if_statement_basic():
class TestIfStatementBasic(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]

@m.circuit_def
def definition(io):
if io.S:
O = io.I[0]
# m.wire(io.O, io.I[0])
# io.O = io.I[0]
# TODO: Alternative syntax
# io.O <= io.I[0]
# TODO: Or we could use wire syntax
# wire(io.O, io.I[0])
else:
O = io.I[1]
m.wire(O, io.O)
m.compile("build/test_if_statement_basic", TestIfStatementBasic)
assert check_files_equal(__file__, f"build/test_if_statement_basic.v",
f"gold/test_if_statement_basic.v")


def test_if_statement_nested():
class TestIfStatementNested(m.Circuit):
IO = ["I", m.In(m.Bits(4)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit_def
def definition(io):
if io.S[0]:
if io.S[1]:
O = io.I[0]
else:
O = io.I[1]
# m.wire(io.O, io.I[0])
# io.O = io.I[0]
# TODO: Alternative syntax
# io.O <= io.I[0]
# TODO: Or we could use wire syntax
# wire(io.O, io.I[0])
else:
if io.S[1]:
O = io.I[2]
else:
O = io.I[3]
m.wire(O, io.O)
m.compile("build/test_if_statement_nested", TestIfStatementNested)
assert check_files_equal(__file__, f"build/test_if_statement_nested.v",
f"gold/test_if_statement_nested.v")


def test_ternary():
class TestTernary(m.Circuit):
IO = ["I", m.In(m.Bits(2)), "S", m.In(m.Bit), "O", m.Out(m.Bit)]

@m.circuit_def
def definition(io):
m.wire(io.O, io.I[0] if io.S else io.I[1])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary", TestTernary)
assert check_files_equal(__file__, f"build/test_ternary.v",
f"gold/test_ternary.v")


def test_ternary_nested():
class TestTernaryNested(m.Circuit):
IO = ["I", m.In(m.Bits(3)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit_def
def definition(io):
m.wire(io.O,
io.I[0] if io.S[0] else io.I[1] if io.S[1] else io.I[2])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary_nested", TestTernaryNested)
assert check_files_equal(__file__, f"build/test_ternary_nested.v",
f"gold/test_ternary_nested.v")


def test_ternary_nested2():
class TestTernaryNested2(m.Circuit):
IO = ["I", m.In(m.Bits(3)), "S", m.In(m.Bits(2)), "O", m.Out(m.Bit)]

@m.circuit_def
def definition(io):
m.wire(io.O,
(io.I[0] if io.S[0] else io.I[1]) if io.S[1] else io.I[2])
# io.O = io.I[0] if io.S else io.I[1]
# TODO: Or non block assign?
# io.O <= io.I[0] if io.S else io.[1]
m.compile("build/test_ternary_nested2", TestTernaryNested2)
assert check_files_equal(__file__, f"build/test_ternary_nested2.v",
f"gold/test_ternary_nested2.v")

0 comments on commit 4672372

Please sign in to comment.