Skip to content

Commit

Permalink
allow passing in of enviroment into comb and seq
Browse files Browse the repository at this point in the history
  • Loading branch information
rdaly525 committed Nov 1, 2019
1 parent f56d049 commit 712012d
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 10 deletions.
11 changes: 7 additions & 4 deletions magma/syntax/combinational.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from magma.config import get_debug_mode
import itertools
import typing
from ast_tools.stack import _SKIP_FRAME_DEBUG_STMT
from ast_tools.stack import _SKIP_FRAME_DEBUG_STMT, SymbolTable

class CircuitDefinitionSyntaxError(Exception):
pass
Expand Down Expand Up @@ -255,8 +255,8 @@ def combinational(
fn: typing.Callable = None,
*,
decorators: typing.Optional[typing.Sequence[typing.Callable]] = None,
env: SymbolTable = None
):

exec(_SKIP_FRAME_DEBUG_STMT)
if decorators is not None:
assert fn is None
Expand All @@ -266,13 +266,16 @@ def wrapped(fn):
decorators = list(itertools.chain(decorators, [wrapped]))
wrapped_combinational = ast_utils.inspect_enclosing_env(
_combinational,
decorators=decorators)
decorators=decorators,
st=env
)
return wrapped_combinational(fn)
return wrapped

else:
wrapped_combinational = ast_utils.inspect_enclosing_env(
_combinational,
decorators=[combinational]
decorators=[combinational],
st=env
)
return wrapped_combinational(fn)
15 changes: 9 additions & 6 deletions magma/syntax/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import Counter
import itertools

from ast_tools.stack import _SKIP_FRAME_DEBUG_STMT
from ast_tools.stack import _SKIP_FRAME_DEBUG_STMT, SymbolTable

class RewriteSelfAttributes(ast.NodeTransformer):
def __init__(self, initial_value_map):
Expand Down Expand Up @@ -401,7 +401,7 @@ def sequential(
async_reset=None,
*,
decorators: typing.Optional[typing.Sequence[typing.Callable]] = None,
):
env: SymbolTable = None):

exec(_SKIP_FRAME_DEBUG_STMT)
if async_reset is not None or decorators is not None:
Expand All @@ -416,15 +416,18 @@ def wrapped(cls):
decorators = list(itertools.chain(decorators, [wrapped]))
wrapped_sequential = ast_utils.inspect_enclosing_env(
_sequential,
decorators=decorators)
combinational = m.circuit.combinational(decorators=decorators)
decorators=decorators,
st=env
)
combinational = m.circuit.combinational(decorators=decorators, env=env)
return wrapped_sequential(async_reset, cls, combinational)
return wrapped
else:
assert cls is not None
wrapped_sequential = ast_utils.inspect_enclosing_env(
_sequential,
decorators=[sequential]
decorators=[sequential],
st=env
)
combinational = m.circuit.combinational(decorators=[sequential])
combinational = m.circuit.combinational(decorators=[sequential], env=env)
return wrapped_sequential(True, cls, combinational)
58 changes: 58 additions & 0 deletions tests/test_syntax/gold/CustomEnv.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{"top":"global.TestBasic",
"namespaces":{
"global":{
"modules":{
"TestBasic":{
"type":["Record",[
["I",["Array",2,"BitIn"]],
["CLK",["Named","coreir.clkIn"]],
["ASYNCRESET",["Named","coreir.arstIn"]],
["O",["Array",2,"Bit"]]
]],
"instances":{
"TestBasic_comb_inst0":{
"modref":"global.TestBasic_comb"
},
"reg_PR_inst0":{
"genref":"coreir.reg_arst",
"genargs":{"width":["Int",2]},
"modargs":{"arst_posedge":["Bool",true], "clk_posedge":["Bool",true], "init":[["BitVector",2],"2'h2"]}
},
"reg_PR_inst1":{
"genref":"coreir.reg_arst",
"genargs":{"width":["Int",2]},
"modargs":{"arst_posedge":["Bool",true], "clk_posedge":["Bool",true], "init":[["BitVector",2],"2'h0"]}
}
},
"connections":[
["self.I","TestBasic_comb_inst0.I"],
["reg_PR_inst0.in","TestBasic_comb_inst0.O0"],
["reg_PR_inst1.in","TestBasic_comb_inst0.O1"],
["self.O","TestBasic_comb_inst0.O2"],
["reg_PR_inst0.out","TestBasic_comb_inst0.self_x_O"],
["reg_PR_inst1.out","TestBasic_comb_inst0.self_y_O"],
["self.ASYNCRESET","reg_PR_inst0.arst"],
["self.CLK","reg_PR_inst0.clk"],
["self.ASYNCRESET","reg_PR_inst1.arst"],
["self.CLK","reg_PR_inst1.clk"]
]
},
"TestBasic_comb":{
"type":["Record",[
["I",["Array",2,"BitIn"]],
["self_x_O",["Array",2,"BitIn"]],
["self_y_O",["Array",2,"BitIn"]],
["O0",["Array",2,"Bit"]],
["O1",["Array",2,"Bit"]],
["O2",["Array",2,"Bit"]]
]],
"connections":[
["self.O0","self.I"],
["self.self_x_O","self.O1"],
["self.self_y_O","self.O2"]
]
}
}
}
}
}
31 changes: 31 additions & 0 deletions tests/test_syntax/gold/CustomEnv.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module coreir_reg_arst #(parameter width = 1, parameter arst_posedge = 1, parameter clk_posedge = 1, parameter init = 1) (input clk, input arst, input [width-1:0] in, output [width-1:0] out);
reg [width-1:0] outReg;
wire real_rst;
assign real_rst = arst_posedge ? arst : ~arst;
wire real_clk;
assign real_clk = clk_posedge ? clk : ~clk;
always @(posedge real_clk, posedge real_rst) begin
if (real_rst) outReg <= init;
else outReg <= in;
end
assign out = outReg;
endmodule

module TestBasic_comb (input [1:0] I, output [1:0] O0, output [1:0] O1, output [1:0] O2, input [1:0] self_x_O, input [1:0] self_y_O);
assign O0 = I;
assign O1 = self_x_O;
assign O2 = self_y_O;
endmodule

module TestBasic (input ASYNCRESET, input CLK, input [1:0] I, output [1:0] O);
wire [1:0] TestBasic_comb_inst0_O0;
wire [1:0] TestBasic_comb_inst0_O1;
wire [1:0] TestBasic_comb_inst0_O2;
wire [1:0] reg_PR_inst0_out;
wire [1:0] reg_PR_inst1_out;
TestBasic_comb TestBasic_comb_inst0(.I(I), .O0(TestBasic_comb_inst0_O0), .O1(TestBasic_comb_inst0_O1), .O2(TestBasic_comb_inst0_O2), .self_x_O(reg_PR_inst0_out), .self_y_O(reg_PR_inst1_out));
coreir_reg_arst #(.arst_posedge(1), .clk_posedge(1), .init(2'h2), .width(2)) reg_PR_inst0(.arst(ASYNCRESET), .clk(CLK), .in(TestBasic_comb_inst0_O0), .out(reg_PR_inst0_out));
coreir_reg_arst #(.arst_posedge(1), .clk_posedge(1), .init(2'h0), .width(2)) reg_PR_inst1(.arst(ASYNCRESET), .clk(CLK), .in(TestBasic_comb_inst0_O1), .out(reg_PR_inst1_out));
assign O = TestBasic_comb_inst0_O2;
endmodule

38 changes: 38 additions & 0 deletions tests/test_syntax/gold/custom_env0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{"top":"global.basic_fun",
"namespaces":{
"global":{
"modules":{
"Mux2":{
"type":["Record",[
["I0","BitIn"],
["I1","BitIn"],
["S","BitIn"],
["O","Bit"]
]]
},
"basic_fun":{
"type":["Record",[
["I","BitIn"],
["S","BitIn"],
["O","Bit"]
]],
"instances":{
"Mux2_inst0":{
"modref":"global.Mux2"
},
"bit_const_0_None":{
"modref":"corebit.const",
"modargs":{"value":["Bool",false]}
}
},
"connections":[
["bit_const_0_None.out","Mux2_inst0.I0"],
["self.I","Mux2_inst0.I1"],
["self.O","Mux2_inst0.O"],
["self.S","Mux2_inst0.S"]
]
}
}
}
}
}
6 changes: 6 additions & 0 deletions tests/test_syntax/gold/custom_env0.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module basic_fun (input I, input S, output O);
wire Mux2_inst0_O;
Mux2 Mux2_inst0 (.I0(1'b0), .I1(I), .S(S), .O(Mux2_inst0_O));
assign O = Mux2_inst0_O;
endmodule

38 changes: 38 additions & 0 deletions tests/test_syntax/gold/custom_env1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{"top":"global.basic_fun",
"namespaces":{
"global":{
"modules":{
"Mux2":{
"type":["Record",[
["I0","BitIn"],
["I1","BitIn"],
["S","BitIn"],
["O","Bit"]
]]
},
"basic_fun":{
"type":["Record",[
["I","BitIn"],
["S","BitIn"],
["O","Bit"]
]],
"instances":{
"Mux2_inst0":{
"modref":"global.Mux2"
},
"bit_const_1_None":{
"modref":"corebit.const",
"modargs":{"value":["Bool",true]}
}
},
"connections":[
["bit_const_1_None.out","Mux2_inst0.I0"],
["self.I","Mux2_inst0.I1"],
["self.O","Mux2_inst0.O"],
["self.S","Mux2_inst0.S"]
]
}
}
}
}
}
6 changes: 6 additions & 0 deletions tests/test_syntax/gold/custom_env1.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module basic_fun (input I, input S, output O);
wire Mux2_inst0_O;
Mux2 Mux2_inst0 (.I0(1'b1), .I1(I), .S(S), .O(Mux2_inst0_O));
assign O = Mux2_inst0_O;
endmodule

15 changes: 15 additions & 0 deletions tests/test_syntax/test_combinational.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,18 @@ def definition(io):
io.O <= inv.O

compile_and_check("test_renamed_args_wire", Foo, target)


@pytest.mark.parametrize("val", [0,1])
def test_custom_env(target, val):
def basic_fun(I: m.Bit, S: m.Bit) -> m.Bit:
if S:
return I
else:
return m.Bit(_custom_local_var_)

_globals = globals()
_globals.update({'_custom_local_var_':val})
env = ast_tools.stack.SymbolTable(locals=locals(),globals=_globals)
_basic_fun = m.circuit.combinational(basic_fun,env=env)
compile_and_check(f"custom_env{val}", _basic_fun.circuit_definition, target)
19 changes: 19 additions & 0 deletions tests/test_syntax/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ def __call__(self, I: m.Bits[2]) -> m.Bits[2]:
"""
_run_verilator(TestBasic, directory="tests/test_syntax/build")

def test_custom_env(target):

_globals = globals()
_globals.update({'_custom_local_var_':2})
env = ast_tools.stack.SymbolTable(locals=locals(),globals=_globals)

class TestBasic:
def __init__(self):
self.x: m.Bits[2] = m.bits(_custom_local_var_, 2)
self.y: m.Bits[2] = m.bits(0, 2)

def __call__(self, I: m.Bits[2]) -> m.Bits[2]:
O = self.y
self.y = self.x
self.x = I
return O

_TestBasic = m.circuit.sequential(TestBasic,env=env)
compile_and_check("CustomEnv", _TestBasic, target)

def test_seq_hierarchy(target, async_reset):
@m.cache_definition
Expand Down

0 comments on commit 712012d

Please sign in to comment.