In [1]:
from magma import *
from functools import reduce

def one_hot_mux(conds, inputs):
    outputs = []
    for cond, inp in zip(conds, inputs):
        outputs.append(inp & [cond for _ in range(len(inp))])
    return list(reduce(lambda x, y: x | y, outputs))


def uint(*ts, num_bits=None):
    bits = int2seq(ts[0], num_bits)
    return UInt(num_bits, BitOut)(*bits)


class SimpleALU(Circuit):
    name = "SimpleALU"
    IO = ["a", In(UInt(4)), "b", In(UInt(4)), "opcode", In(UInt(2)), "out", Out(UInt(4))]
    
    @classmethod
    def definition(io):
        is_op0 = io.opcode == uint(0, num_bits=2)
        is_op1 = io.opcode == uint(1, num_bits=2)
        is_op2 = io.opcode == uint(2, num_bits=2)
        is_op3 = io.opcode == uint(3, num_bits=2)
        op0_out = io.a + io.b
        op1_out = io.a - io.b
        op2_out = io.a
        op3_out = io.b
        wire(io.out, one_hot_mux([is_op0, is_op1, is_op2, is_op3], [op0_out, op1_out, op2_out, op3_out]))

In [2]:
compile("build/SimpleALU", SimpleALU, output="coreir")
with open("build/SimpleALU.json", "r") as output:
    print(output.read())

{
  "namespaces": {
    "_G": {
      "modules": {
        "SimpleALU": {
          "connections": [
            [
              "inst9.in1.3",
              "inst3.out"
            ],
            [
              "inst9.in1.2",
              "inst3.out"
            ],
            [
              "inst9.in1.1",
              "inst3.out"
            ],
            [
              "inst9.in1.0",
              "inst3.out"
            ],
            [
              "inst8.in1.3",
              "inst2.out"
            ],
            [
              "inst8.in1.2",
              "inst2.out"
            ],
            [
              "inst8.in1.1",
              "inst2.out"
            ],
            [
              "inst8.in1.0",
              "inst2.out"
            ],
            [
              "inst7.in1.3",
              "inst1.out"
            ],
            [
              "inst7.in1.2",
              "inst1.out"
            ],
            [
              "inst7.in1.1",
              "i

In [3]:
from magma.python_simulator import PythonSimulator
from magma.scope import Scope
from magma.bit_vector import BitVector

simulator = PythonSimulator(SimpleALU)
scope = Scope()
simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.opcode, scope, BitVector(0, num_bits=2).as_bool_list())
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == int2seq(3 + 2, 4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.opcode, scope, BitVector(1, num_bits=2).as_bool_list())
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == int2seq(3 - 2, 4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.opcode, scope, BitVector(2, num_bits=2).as_bool_list())
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == int2seq(3, 4)

simulator.set_value(SimpleALU.a, scope, BitVector(3, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.b, scope, BitVector(2, num_bits=4).as_bool_list())
simulator.set_value(SimpleALU.opcode, scope, BitVector(3, num_bits=2).as_bool_list())
simulator.evaluate()
assert simulator.get_value(SimpleALU.out, scope) == int2seq(2, 4)
print("Success!")

Success!
