In [1]:
import magma as m
from functools import reduce
m.set_mantle_target("coreir")
import mantle

def one_hot_mux(conds, inputs):
    outputs = []
    for cond, inp in zip(conds, inputs):
        outputs.append(inp & m.bits([cond for _ in range(len(inp))]))
    return reduce(lambda x, y: x | y, outputs)


class SimpleALU(m.Circuit):
    name = "SimpleALU"
    IO = ["a", m.In(m.UInt(4)), "b", m.In(m.UInt(4)), 
          "opcode", m.In(m.UInt(2)), "out", m.Out(m.UInt(4))]
    
    @classmethod
    def definition(io):
        is_op0 = io.opcode == m.uint(0, n=2)
        is_op1 = io.opcode == m.uint(1, n=2)
        is_op2 = io.opcode == m.uint(2, n=2)
        is_op3 = io.opcode == m.uint(3, n=2)
        op0_out = io.a + io.b
        op1_out = io.a - io.b
        op2_out = io.a
        op3_out = io.b
        m.wire(io.out, one_hot_mux([is_op0, is_op1, is_op2, is_op3], [op0_out, op1_out, op2_out, op3_out]))

In [2]:
m.compile("build/SimpleALU", SimpleALU, output="coreir")
%cat build/SimpleALU.json

{"top":"global.SimpleALU",
"namespaces":{
  "global":{
    "modules":{
      "Add4_cin":{
        "type":["Record",[
          ["I0",["Array",4,"BitIn"]],
          ["I1",["Array",4,"BitIn"]],
          ["O",["Array",4,"Bit"]],
          ["CIN","BitIn"]
        ]],
        "instances":{
          "bit_const_0_None":{
            "modref":"corebit.const",
            "modargs":{"value":["Bool",false]}
          },
          "inst0":{
            "genref":"coreir.add",
            "genargs":{"width":["Int",4]}
          },
          "inst1":{
            "genref":"coreir.add",
            "genargs":{"width":["Int",4]}
          }
        },
        "connections":[
          ["inst1.in0.1","bit_const_0_None.out"],
          ["inst1.in0.2","bit_const_0_None.out"],
          ["inst1.in0.3","bit_const_0_None.out"],
          ["inst1.out","inst0.in0"],
          ["self.I1","inst0.in1"],
          ["self.O","inst0.out"],
          ["self.CIN","inst1.in0.0"],
  

In [3]:
from magma.simulator import PythonSimulator
from magma.bit_vector import BitVector

simulator = PythonSimulator(SimpleALU)
simulator.set_value(SimpleALU.a, 3)
simulator.set_value(SimpleALU.b, 2)

simulator.set_value(SimpleALU.opcode, 0)
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == 3 + 2, simulator.get_value(SimpleALU.out)

simulator.set_value(SimpleALU.opcode, 1)
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == 3 - 2

simulator.set_value(SimpleALU.opcode, 2)
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == 3

simulator.set_value(SimpleALU.opcode, 3)
simulator.evaluate()
assert simulator.get_value(SimpleALU.out) == 2
print("Success!")

Success!
