### config

In [162]:
states = [
    "S_IDLE", 
    "S_PREP", # do normalization
    "S_CONV", # do convolution,
    "S_FC", # do linear,
    "S_DONE"
]

config = {
    "fn": "model.sv",
    "module_name": "Model",
    "inputs": {
        "i_clk": {},
        "i_rst_n": {},
        "i_start": {},
        "i_data": {
            "bits": 16, "int_bits": 16, "size": 40
        }
    },
    "outputs": {
        "o_logits": {
            "bits": 32, "int_bits": 24, "size": 27
        },
        "o_finished": {}
    }
}

### localparams

In [163]:
from math import log2, ceil

localparams_code = ""

for i, s in enumerate(states):
    localparams_code += f"localparam {s} = {ceil(log2(len(states)))}'d{i};\n"

In [164]:
print(localparams_code)

localparam S_IDLE = 3'd0;
localparam S_PREP = 3'd1;
localparam S_CONV = 3'd2;
localparam S_FC = 3'd3;
localparam S_DONE = 3'd4;



In [165]:
# model param
modelparam_code = ""
with open("./cnn_param.txt", "r") as f:
    modelparam_code += f.read()

modelparam_code += "\n"

with open("./fc_param.txt", "r") as f:
    modelparam_code += f.read()

### logics

In [195]:
logics = {
    "control": [
        "norm_start",
        "cnn_start",
        "fc_start",
        "finish",
    ],
    "signal": {
        "norm_finish": 5,
        "cnn_finish": 10,
        "fc_finish": 27,
    },
    "states": {
        "num": 1,
        "bits": ceil(log2(len(states)))
    },
    "numbers": {
        "data": {
            "num": 40,
            "bits": 16
        },
        "norm_data": {
            "num": 40,
            "bits": 16
        },
        "norm_data_T": {
            "num": 40,
            "bits": 16
        },
        "cnn_output": {
            "num": 30,
            "bits": 24,
        },
        "fc_output": {
            "num": 27,
            "bits": 32
        }
    }
}

In [196]:
logics_code = ""

# control
for port in logics["control"]:
    logics_code += f"logic {port}_r, {port}_w;\n"
    
# signal
logics_code += "\n"
for port, size in logics["signal"].items():
    logics_code += f"logic {port} [0:{size-1}];\n"

# state
logics_code += "\n"
logics_code += f"logic [{logics['states']['bits']-1}:0] state_r, state_w;\n"

# numbers
logics_code += "\n"
for key, data in logics["numbers"].items():
    logics_code += f"logic [{data['bits']-1}:0] {key} [0:{data['num']-1}];\n"

In [197]:
print(logics_code)

logic norm_start_r, norm_start_w;
logic cnn_start_r, cnn_start_w;
logic fc_start_r, fc_start_w;
logic finish_r, finish_w;

logic norm_finish [0:4];
logic cnn_finish [0:9];
logic fc_finish [0:26];

logic [2:0] state_r, state_w;

logic [15:0] data [0:39];
logic [15:0] norm_data [0:39];
logic [15:0] norm_data_T [0:39];
logic [23:0] cnn_output [0:29];
logic [31:0] fc_output [0:26];



### assign

In [206]:
assign_code = ""

# assign output
# for i in range(config["outputs"]["o_logits"]["size"]):
assign_code += f"assign o_logits[{i}:0] = fc_output[{i}:0];\n"

# assign finish
# assign_code += "\n"
assign_code += "assign o_finished = finish_r;\n"

# assign norm_data_T
for i in range(8):
    line = f"assign norm_data_T[{i*5}: {i*5+4}] = " + "{"
    for j in range(5):
        line += f"norm_data[{i+j*8}],"
    line = line[:-1] + "};\n"
    assign_code += line


In [207]:
print(assign_code)

assign o_logits[7:0] = fc_output[7:0];
assign o_finished = finish_r;
assign norm_data_T[0: 4] = {norm_data[0],norm_data[8],norm_data[16],norm_data[24],norm_data[32]};
assign norm_data_T[5: 9] = {norm_data[1],norm_data[9],norm_data[17],norm_data[25],norm_data[33]};
assign norm_data_T[10: 14] = {norm_data[2],norm_data[10],norm_data[18],norm_data[26],norm_data[34]};
assign norm_data_T[15: 19] = {norm_data[3],norm_data[11],norm_data[19],norm_data[27],norm_data[35]};
assign norm_data_T[20: 24] = {norm_data[4],norm_data[12],norm_data[20],norm_data[28],norm_data[36]};
assign norm_data_T[25: 29] = {norm_data[5],norm_data[13],norm_data[21],norm_data[29],norm_data[37]};
assign norm_data_T[30: 34] = {norm_data[6],norm_data[14],norm_data[22],norm_data[30],norm_data[38]};
assign norm_data_T[35: 39] = {norm_data[7],norm_data[15],norm_data[23],norm_data[31],norm_data[39]};



### norm

In [208]:
def get_norm_module(idx: int, offset = 0):
    module_code = ""
    module_code += "\t" * offset + f"Normalizer norm{idx}(\n"
    module_code += "\t" * offset + "\t.i_clk(i_clk),\n"
    module_code += "\t" * offset + "\t.i_rst_n(i_rst_n),\n"
    module_code += "\t" * offset + "\t.i_start(norm_start_r),\n"
    module_code += "\t" * offset + f"\t.i_data(i_data[{idx*8}:{idx*8+7}]),\n"
    module_code += "\t" * offset + f"\t.o_norm(norm_data[{idx*8}:{idx*8+7}]),\n"
    module_code += "\t" * offset + f"\t.o_finished(norm_finish[{idx}])\n"
    module_code += "\t" * offset + f");\n\n"
    
    return module_code

In [209]:
norm_code = ""

for i in range(5):
    norm_code += get_norm_module(i)

In [210]:
print(norm_code)

Normalizer norm0(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(norm_start_r),
	.i_data(i_data[0:7]),
	.o_norm(norm_data[0:7]),
	.o_finished(norm_finish[0])
);

Normalizer norm1(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(norm_start_r),
	.i_data(i_data[8:15]),
	.o_norm(norm_data[8:15]),
	.o_finished(norm_finish[1])
);

Normalizer norm2(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(norm_start_r),
	.i_data(i_data[16:23]),
	.o_norm(norm_data[16:23]),
	.o_finished(norm_finish[2])
);

Normalizer norm3(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(norm_start_r),
	.i_data(i_data[24:31]),
	.o_norm(norm_data[24:31]),
	.o_finished(norm_finish[3])
);

Normalizer norm4(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(norm_start_r),
	.i_data(i_data[32:39]),
	.o_norm(norm_data[32:39]),
	.o_finished(norm_finish[4])
);




### cnn

In [211]:
def get_conv_module(idx: int, offset = 0):
    module_code = ""
    module_code += "\t" * offset + f"Conv conv{idx}(\n"
    module_code += "\t" * offset + "\t.i_clk(i_clk),\n"
    module_code += "\t" * offset + "\t.i_rst_n(i_rst_n),\n"
    module_code += "\t" * offset + "\t.i_start(cnn_start_r),\n"
    module_code += "\t" * offset + f"\t.i_kernel(kernel_weight{idx}),\n"
    module_code += "\t" * offset + f"\t.i_data(norm_data_T),\n"
    module_code += "\t" * offset + f"\t.i_bias(cnn_bias[{idx}]),\n"
    module_code += "\t" * offset + f"\t.o_weights(cnn_output[{idx*3}:{idx*3+2}]),\n"
    module_code += "\t" * offset + f"\t.o_finished(cnn_finish[{idx}])\n"
    module_code += "\t" * offset + f");\n\n"
    
    return module_code

In [212]:
cnn_code = ""
cnn_output_channel = 10

for i in range(cnn_output_channel):
    cnn_code += get_conv_module(i)
    
print(cnn_code)

Conv conv0(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(cnn_start_r),
	.i_kernel(kernel_weight0),
	.i_data(norm_data_T),
	.i_bias(cnn_bias[0]),
	.o_weights(cnn_output[0:2]),
	.o_finished(cnn_finish[0])
);

Conv conv1(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(cnn_start_r),
	.i_kernel(kernel_weight1),
	.i_data(norm_data_T),
	.i_bias(cnn_bias[1]),
	.o_weights(cnn_output[3:5]),
	.o_finished(cnn_finish[1])
);

Conv conv2(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(cnn_start_r),
	.i_kernel(kernel_weight2),
	.i_data(norm_data_T),
	.i_bias(cnn_bias[2]),
	.o_weights(cnn_output[6:8]),
	.o_finished(cnn_finish[2])
);

Conv conv3(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(cnn_start_r),
	.i_kernel(kernel_weight3),
	.i_data(norm_data_T),
	.i_bias(cnn_bias[3]),
	.o_weights(cnn_output[9:11]),
	.o_finished(cnn_finish[3])
);

Conv conv4(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(cnn_start_r),
	.i_kernel(kernel_weight4),
	.i_data(norm_data_T),
	.i_bias(cnn_bias[4]),
	.o_weights(cnn_outpu

### fc

In [176]:
def get_fc_module(idx: int, offset:int = 0):
    module_code = ""
    module_code += "\t" * offset + f"FC fc{idx}(\n"
    module_code += "\t" * offset + "\t.i_clk(i_clk),\n"
    module_code += "\t" * offset + "\t.i_rst_n(i_rst_n),\n"
    module_code += "\t" * offset + "\t.i_start(fc_start_r),\n"
    module_code += "\t" * offset + f"\t.i_weight(fc_weight{idx}),\n"
    module_code += "\t" * offset + f"\t.i_data(cnn_output),\n"
    module_code += "\t" * offset + f"\t.i_bias(fc_bias[{idx}]),\n"
    module_code += "\t" * offset + f"\t.o_output(fc_output[{idx}]),\n"
    module_code += "\t" * offset + f"\t.o_finished(fc_finish[{idx}])\n"
    module_code += "\t" * offset + f");\n\n"
    
    return module_code

In [177]:
fc_code = ""
classes = 27

for i in range(classes):
    fc_code += get_fc_module(i)
    
print(fc_code)

FC fc0(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(fc_start_r),
	.i_weight(fc_weight0),
	.i_data(cnn_output),
	.i_bias(fc_bias[0]),
	.o_output(fc_output[0]),
	.o_finished(fc_finish[0])
);

FC fc1(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(fc_start_r),
	.i_weight(fc_weight1),
	.i_data(cnn_output),
	.i_bias(fc_bias[1]),
	.o_output(fc_output[1]),
	.o_finished(fc_finish[1])
);

FC fc2(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(fc_start_r),
	.i_weight(fc_weight2),
	.i_data(cnn_output),
	.i_bias(fc_bias[2]),
	.o_output(fc_output[2]),
	.o_finished(fc_finish[2])
);

FC fc3(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(fc_start_r),
	.i_weight(fc_weight3),
	.i_data(cnn_output),
	.i_bias(fc_bias[3]),
	.o_output(fc_output[3]),
	.o_finished(fc_finish[3])
);

FC fc4(
	.i_clk(i_clk),
	.i_rst_n(i_rst_n),
	.i_start(fc_start_r),
	.i_weight(fc_weight4),
	.i_data(cnn_output),
	.i_bias(fc_bias[4]),
	.o_output(fc_output[4]),
	.o_finished(fc_finish[4])
);

FC fc5(
	.i_clk(i_clk),
	.i_rst_n(i_rst_

### always_comb

In [178]:
indent = 0
always_comb_code = ""

always_comb_code += "always_comb begin\n"

indent += 1
for control in logics["control"]:
    always_comb_code += "\t" * indent + f"{control}_w = {control}_r;\n"
always_comb_code += "\t" * indent + f"state_w = state_r;\n"    
always_comb_code += "\n"

always_comb_code += "\t" * indent + "case (state_r)\n"

indent += 1
for state in states:
    always_comb_code += "\t" * indent + f"{state}: begin\n\n"
    always_comb_code += "\t" * indent + "end\n"
indent -= 1

always_comb_code += "\t" * indent + "endcase\n"

always_comb_code += "end\n"

In [179]:
print(always_comb_code)

always_comb begin
	norm_start_w = norm_start_r;
	cnn_start_w = cnn_start_r;
	fc_start_w = fc_start_r;
	finish_w = finish_r;
	state_w = state_r;

	case (state_r)
		S_IDLE: begin

		end
		S_PREP: begin

		end
		S_CONV: begin

		end
		S_FC: begin

		end
		S_DONE: begin

		end
	endcase
end



### always_ff

In [180]:
indent = 0
always_ff_code = ""

always_ff_code += "\t" * indent + "always_ff @ (posedge i_clk or negedge i_rst_n) begin\n"

indent += 1

always_ff_code += "\t" * indent + "if (!i_rst_n) begin\n"

indent += 1
for control in logics["control"]:
    always_ff_code += "\t" * indent + f"{control}_r <= 0;\n"
always_ff_code += "\t" * indent + f"state_r <= S_IDLE;\n"
indent -= 1

always_ff_code += "\t" * indent + "end\n"

always_ff_code += "\t" * indent + "else begin\n"

indent += 1
for control in logics["control"]:
    always_ff_code += "\t" * indent + f"{control}_r <= {control}_w;\n"
always_ff_code += "\t" * indent + f"state_r <= state_w;\n"
indent -= 1
always_ff_code += "\t" * indent + "end\n"

indent -= 1

always_ff_code += "\t" * indent + "end\n"

In [181]:
print(always_ff_code)

always_ff @ (posedge i_clk or negedge i_rst_n) begin
	if (!i_rst_n) begin
		norm_start_r <= 0;
		cnn_start_r <= 0;
		fc_start_r <= 0;
		finish_r <= 0;
		state_r <= S_IDLE;
	end
	else begin
		norm_start_r <= norm_start_w;
		cnn_start_r <= cnn_start_w;
		fc_start_r <= fc_start_w;
		finish_r <= finish_w;
		state_r <= state_w;
	end
end



In [182]:
with open(config["fn"], "w") as fp:
    fp.write(f"module {config['module_name']}(\n")

    # configure inputs and outputs

    for port, setting in config["inputs"].items():
        fp.write("\tinput ")

        if setting.get("bits"):
            fp.write(f"[{setting['bits']-1}:0] ")

        fp.write(f"{port}")

        if setting.get("size"):
            fp.write(f" [0:{setting['size']-1}],\n")
        else:
            fp.write(",\n")

    for port, setting in config["outputs"].items():
        fp.write("\toutput ")

        if setting.get("bits"):
            fp.write(f"[{setting['bits']-1}:0] ")

        fp.write(f"{port}")

        if setting.get("size"):
            fp.write(f" [0:{setting['size']-1}],\n")
        else:
            fp.write(",\n")

    fp.write(");\n")
    
    fp.write("\n")
    
    fp.write(localparams_code)
    fp.write("\n")
    fp.write(modelparam_code)
    fp.write("\n")
    fp.write(logics_code)
    fp.write("\n")
    fp.write(assign_code)
    fp.write("\n")
    fp.write(norm_code)
    fp.write("\n")
    fp.write(cnn_code)
    fp.write("\n")
    fp.write(fc_code)
    fp.write("\n")
    fp.write(always_comb_code)
    fp.write("\n")
    fp.write(always_ff_code)
    fp.write("\n")
    fp.write("endmodule\n")