In [9]:
from magma import *
from functools import reduce
import os
os.environ["MANTLE"] = "lattice"
from mantle import EQ, Add, Sub, Or, And

def one_hot_mux(conds, inputs):
    outputs = []
    for cond, inp in zip(conds, inputs):
        outputs.append(And(2, 4)(inp, bits([cond for _ in range(len(inp))])))
    return reduce(lambda x, y: Or(2, 4)(x, y), outputs)


class SimpleALU(Circuit):
    name = "SimpleALU"
    IO = ["a", In(UInt(4)), "b", In(UInt(4)), "opcode", In(UInt(2)), "out", Out(UInt(4))]
    
    @classmethod
    def definition(io):
        is_op0 = EQ(2)(io.opcode, uint(0, n=2))
        is_op1 = EQ(2)(io.opcode, uint(1, n=2))
        is_op2 = EQ(2)(io.opcode, uint(2, n=2))
        is_op3 = EQ(2)(io.opcode, uint(3, n=2))
        op0_out = Add(4)(io.a, io.b)
        op1_out = Sub(4)(io.a, io.b)
        op2_out = io.a
        op3_out = io.b
        wire(io.out, one_hot_mux([is_op0, is_op1, is_op2, is_op3], [op0_out, op1_out, op2_out, op3_out]))

In [10]:
from magma.backend.verilog import compile as compile_verilog

print(compile_verilog(SimpleALU))

compiling EQ2
compiling FullAdder
compiling Add4
compiling Invert4
compiling Add4Cin
compiling Sub4
compiling And2
compiling And2x4
compiling Or2
compiling Or2x4
compiling SimpleALU
module EQ2 (input [1:0] I0, input [1:0] I1, output  O);
wire  inst0_O;
SB_LUT4 #(.LUT_INIT(16'h9009)) inst0 (.I0(I0[0]), .I1(I1[0]), .I2(I0[1]), .I3(I1[1]), .O(inst0_O));
assign O = inst0_O;
endmodule

module FullAdder (input  I0, input  I1, input  CIN, output  O, output  COUT);
wire  inst0_O;
wire  inst1_CO;
SB_LUT4 #(.LUT_INIT(16'h9696)) inst0 (.I0(I0), .I1(I1), .I2(CIN), .I3(1'b0), .O(inst0_O));
SB_CARRY inst1 (.I0(I0), .I1(I1), .CI(CIN), .CO(inst1_CO));
assign O = inst0_O;
assign COUT = inst1_CO;
endmodule

module Add4 (input [3:0] I0, input [3:0] I1, output [3:0] O);
wire  inst0_O;
wire  inst0_COUT;
wire  inst1_O;
wire  inst1_COUT;
wire  inst2_O;
wire  inst2_COUT;
wire  inst3_O;
wire  inst3_COUT;
FullAdder inst0 (.I0(I0[0]), .I1(I1[0]), .CIN(1'b0), .O(inst0_O), .COUT(inst0_COUT));
FullAdder inst1 (.I0(

In [11]:
from magma.simulator import PythonSimulator
from magma.bit_vector import BitVector

simulator = PythonSimulator(SimpleALU)
simulator.set_value(SimpleALU.a, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, BitVector(0, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == BitVector(3 + 2, num_bits=4)

simulator.set_value(SimpleALU.a, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, BitVector(1, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == BitVector(3 - 2, num_bits=4)

simulator.set_value(SimpleALU.a, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, BitVector(2, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == BitVector(3, num_bits=4)

simulator.set_value(SimpleALU.a, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, BitVector(3, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == BitVector(2, num_bits=4)
print("Success!")

Success!
