In [1]:
from magma import *
from functools import reduce

def one_hot_mux(conds, inputs):
    outputs = []
    for cond, inp in zip(conds, inputs):
        outputs.append(inp & [cond for _ in range(len(inp))])
    return list(reduce(lambda x, y: x | y, outputs))



class SimpleALU(Circuit):
    name = "SimpleALU"
    IO = ["a", In(UInt(4)), "b", In(UInt(4)), "opcode", In(UInt(2)), "out", Out(UInt(4))]
    
    @classmethod
    def definition(io):
        is_op0 = io.opcode == uint(0, 2)
        is_op1 = io.opcode == uint(1, 2)
        is_op2 = io.opcode == uint(2, 2)
        is_op3 = io.opcode == uint(3, 2)
        op0_out = io.a + io.b
        op1_out = io.a - io.b
        op2_out = io.a
        op3_out = io.b
        wire(io.out, one_hot_mux([is_op0, is_op1, is_op2, is_op3], [op0_out, op1_out, op2_out, op3_out]))

In [2]:
from magma.backend.verilog import compile as compile_verilog

print(compile_verilog(SimpleALU))

compiling SimpleALU
module SimpleALU (input [3:0] a, input [3:0] b, input [1:0] opcode, output [3:0] out);
wire  inst0_out;
wire  inst1_out;
wire  inst2_out;
wire  inst3_out;
wire [3:0] inst4_out;
wire [3:0] inst5_out;
wire [3:0] inst6_out;
wire [3:0] inst7_out;
wire [3:0] inst8_out;
wire [3:0] inst9_out;
wire [3:0] inst10_out;
wire [3:0] inst11_out;
wire [3:0] inst12_out;
coreir_eq #(.width(2)) inst0 (.in0(opcode), .in1({1'b0,1'b0}), .out(inst0_out));
coreir_eq #(.width(2)) inst1 (.in0(opcode), .in1({1'b0,1'b1}), .out(inst1_out));
coreir_eq #(.width(2)) inst2 (.in0(opcode), .in1({1'b1,1'b0}), .out(inst2_out));
coreir_eq #(.width(2)) inst3 (.in0(opcode), .in1({1'b1,1'b1}), .out(inst3_out));
coreir_add #(.width(4)) inst4 (.in0(a), .in1(b), .out(inst4_out));
coreir_sub #(.width(4)) inst5 (.in0(a), .in1(b), .out(inst5_out));
coreir_and #(.width(4)) inst6 (.in0(inst4_out), .in1({inst0_out,inst0_out,inst0_out,inst0_out}), .out(inst6_out));
coreir_and #(.width(4)) inst7 (.in0(inst5_out), .in

In [3]:
from magma.python_simulator import PythonSimulator
from magma.scope import Scope
from magma.bit_vector import BitVector

simulator = PythonSimulator(SimpleALU)
scope = Scope()
simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, scope, BitVector(0, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == BitVector(3 + 2, num_bits=4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, scope, BitVector(1, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == BitVector(3 - 2, num_bits=4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, scope, BitVector(2, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == BitVector(3, num_bits=4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4))
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4))
simulator.set_value(SimpleALU.opcode, scope, BitVector(3, num_bits=2))
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == BitVector(2, num_bits=4)
print("Success!")

Success!
