# nPE: A Configurable Processing Engine
#### Verification | Version 0.4.1 | Updated 2018.7.25
___

## Setup

In [1]:
val path = System.getProperty("user.dir") + "/source/load-ivy.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))

[36mpath[39m: [32mString[39m = [32m"""
C:\Users\RyanL\OneDrive\Research\SEAL\processing-engine/source/load-ivy.sc
"""[39m

In [2]:
import chisel3._
import chisel3.util._
import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

import scala.math.pow

[32mimport [39m[36mchisel3._
[39m
[32mimport [39m[36mchisel3.util._
[39m
[32mimport [39m[36mchisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

[39m
[32mimport [39m[36mscala.math.pow[39m

## Parallel Register File

### Single Register File

#### Definition

In [350]:
class RFInputs(dataWidth: Int, addrWidth: Int) extends Bundle {
    
    override def cloneType = (new RFInputs(dataWidth, addrWidth)).asInstanceOf[this.type]
    
    val wEnable  = Input(Bool())
    val rEnable  = Input(Bool())
    val wAddr    = Input(UInt(addrWidth.W))
    val wData    = Input(SInt(dataWidth.W))
    val rAddrInt = Input(UInt(addrWidth.W))
    val rAddrExt = Input(UInt(addrWidth.W))
}

class RFOutputs(dataWidth: Int) extends Bundle {
    
    override def cloneType = (new RFOutputs(dataWidth)).asInstanceOf[this.type]
    
    val rDataInt = Output(SInt(dataWidth.W))
    val rDataExt = Output(SInt(dataWidth.W))
}

class RF (dataWidth: Int, addrWidth: Int) extends Module {
  
    val io = IO(new Bundle {
        val in  = new RFInputs(dataWidth, addrWidth)
        val out = new RFOutputs(dataWidth)
    })
    
    val registers  = RegInit(Vec(Seq.fill(pow(2, addrWidth).toInt) { 0.S(addrWidth.W) }))
    
    when(io.in.wEnable) {
        registers(io.in.wAddr) := io.in.wData
    }
    
    when(io.in.rEnable) {
        io.out.rDataInt := registers(io.in.rAddrInt)
        io.out.rDataExt := registers(io.in.rAddrExt)
    } .otherwise {
        io.out.rDataInt := 0.S
        io.out.rDataExt := 0.S
    }
}

defined [32mclass[39m [36mRFInputs[39m
defined [32mclass[39m [36mRFOutputs[39m
defined [32mclass[39m [36mRF[39m

#### Verification

In [351]:
Driver(() => new RF(8, 4)) {
    uut => new PeekPokeTester(uut) {
         
        poke(uut.io.in.wEnable, true)   
        poke(uut.io.in.rEnable, true)

        poke(uut.io.in.wAddr, 1)
        poke(uut.io.in.wData, 1)
        
        step(1)
        
        // Read
        poke(uut.io.in.rAddrInt, 1)
        expect(uut.io.out.rDataInt, 1)
        
        poke(uut.io.in.rAddrExt, 1)
        expect(uut.io.out.rDataExt, 1)
        
        // Write
        poke(uut.io.in.wAddr, 2)
        poke(uut.io.in.wData, 2)
        
        step(1)
        
        // Read
        poke(uut.io.in.rAddrInt, 1)
        expect(uut.io.out.rDataInt, 1)
        
        poke(uut.io.in.rAddrExt, 2)
        expect(uut.io.out.rDataExt, 2)
        
        // Write
        poke(uut.io.in.wAddr, 3)
        poke(uut.io.in.wData, 3)
        
        step(1)
        
        // Read
        poke(uut.io.in.rAddrInt, 1)
        expect(uut.io.out.rDataInt, 1)
        
        poke(uut.io.in.rAddrExt, 2)
        expect(uut.io.out.rDataExt, 2)
        
        poke(uut.io.in.rAddrInt, 3)
        expect(uut.io.out.rDataInt, 3)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.009] Done elaborating.
Total FIRRTL Compile Time: 31.7 ms
Total FIRRTL Compile Time: 35.2 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532562325397
test cmd349WrapperHelperRF Success: 7 tests passed in 8 cycles taking 0.012143 seconds
[[35minfo[0m] [0.009] RAN 3 CYCLES PASSED


[36mres350[39m: [32mBoolean[39m = [32mtrue[39m

### Putting them Together

#### Definition

In [352]:
class PRF(ports: Int, bypass: String, dataWidth: Int, addrWidth: Int) extends Module {
    
    require(List("None", "Soft", "Hard").contains(bypass))
    
    val io = IO(new Bundle {
        val in    = Vec(ports, new RFInputs(dataWidth, addrWidth))
        val out   = Vec(ports, new RFOutputs(dataWidth))
        val bpSel = if (bypass == "Soft") Some(Input(Vec(ports, Bool()))) else None
    })
    
    if (bypass == "None" || bypass == "Soft") {
        
        val rf = Seq.fill(ports){ Module(new RF(dataWidth, addrWidth)) }
        
        rf.zipWithIndex.map{ case (x: RF, i: Int) => {
            
            x.io.in.wEnable  := io.in(i).wEnable
            x.io.in.rEnable  := io.in(i).rEnable
            x.io.in.wAddr    := io.in(i).wAddr
            x.io.in.wData    := io.in(i).wData
            x.io.in.rAddrInt := io.in(i).rAddrInt
            x.io.in.rAddrExt := io.in(i).rAddrExt
            
            when (io.bpSel.getOrElse(Seq.fill(ports){ false.B })(i)) {
                io.out(i).rDataInt := x.io.in.wData
                io.out(i).rDataExt := x.io.in.wData
            } .otherwise {
                io.out(i).rDataInt := x.io.out.rDataInt
                io.out(i).rDataExt := x.io.out.rDataExt
            }
        }}
        
    } else if (bypass == "Hard") {
        for(i <- 0 until ports) {
            io.out(i).rDataInt := io.in(i).wData
            io.out(i).rDataExt := io.in(i).wData
        }
    }
}

defined [32mclass[39m [36mPRF[39m

#### Verification

In [353]:
Driver(() => new PRF(2, "Soft", 8, 4)) {
    uut => new PeekPokeTester(uut) {
         
        poke(uut.io.in(0).wEnable, true)  
        poke(uut.io.in(1).wEnable, true) 
        poke(uut.io.in(0).rEnable, true)
        poke(uut.io.in(1).rEnable, true)
        poke(uut.io.bpSel.get(0), false)
        poke(uut.io.bpSel.get(1), false)

        poke(uut.io.in(0).wAddr, 1)
        poke(uut.io.in(1).wAddr, 1)
        poke(uut.io.in(0).wData, 1)
        poke(uut.io.in(1).wData, 1)
        
        step(1)
        
        // Read
        poke(uut.io.in(0).rAddrInt, 1)
        poke(uut.io.in(1).rAddrInt, 1)
        expect(uut.io.out(0).rDataInt, 1)
        expect(uut.io.out(1).rDataInt, 1)
        
        poke(uut.io.in(0).rAddrExt, 1)
        poke(uut.io.in(1).rAddrExt, 1)
        expect(uut.io.out(0).rDataExt, 1)
        expect(uut.io.out(1).rDataExt, 1)
        
        // Write
        poke(uut.io.in(0).wAddr, 2)
        poke(uut.io.in(1).wAddr, 2)
        poke(uut.io.in(0).wData, 2)
        poke(uut.io.in(1).wData, 2)
        
        step(1)
        
        // Read
        poke(uut.io.in(0)rAddrInt, 1)
        poke(uut.io.in(1).rAddrInt, 1)
        expect(uut.io.out(0).rDataInt, 1)
        expect(uut.io.out(1).rDataInt, 1)
        
        poke(uut.io.in(0).rAddrExt, 2)
        poke(uut.io.in(1).rAddrExt, 2)
        expect(uut.io.out(0).rDataExt, 2)
        expect(uut.io.out(1).rDataExt, 2)
        
        // Write
        poke(uut.io.in(0).wAddr, 3)
        poke(uut.io.in(1).wAddr, 3)
        poke(uut.io.in(0).wData, 3)
        poke(uut.io.in(1).wData, 3)
        
        step(1)
        
        // Read
        poke(uut.io.in(0).rAddrInt, 1)
        poke(uut.io.in(1).rAddrInt, 1)
        expect(uut.io.out(0).rDataInt, 1)
        expect(uut.io.out(1).rDataInt, 1)
        
        poke(uut.io.in(0).rAddrExt, 2)
        poke(uut.io.in(1).rAddrExt, 2)
        expect(uut.io.out(0).rDataExt, 2)
        expect(uut.io.out(1).rDataExt, 2)
        
        poke(uut.io.in(0).rAddrInt, 3)
        poke(uut.io.in(1).rAddrInt, 3)
        expect(uut.io.out(0).rDataInt, 3)
        expect(uut.io.out(1).rDataInt, 3)
        
        // Bypass
        poke(uut.io.bpSel.get(0), true)
        poke(uut.io.bpSel.get(1), false)
        poke(uut.io.in(0).wData, 10)
        poke(uut.io.in(1).wData, 10)
        expect(uut.io.out(0).rDataInt, 10)
        expect(uut.io.out(1).rDataInt, 3)
        expect(uut.io.out(0).rDataExt, 10)
        expect(uut.io.out(1).rDataExt, 2)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.007] Done elaborating.
Total FIRRTL Compile Time: 30.3 ms
Total FIRRTL Compile Time: 25.3 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532562339958
test cmd351WrapperHelperPRF Success: 18 tests passed in 8 cycles taking 0.022411 seconds
[[35minfo[0m] [0.019] RAN 3 CYCLES PASSED


[36mres352[39m: [32mBoolean[39m = [32mtrue[39m

## Inner Product Unit

### Parallel Multiplier

#### Definition

In [233]:
class PMultiplier(numPairs: Int, bitWidth: Int) extends Module {
    
    require(numPairs >= 1, "Must have at least one pair of multiplicands.")
    require(bitWidth >= 1, "Bitwidth must be at least one.")
    
    val io = IO(new Bundle {
        val in1 = Input (Vec(numPairs, SInt(bitWidth.W)))
        val in2 = Input (Vec(numPairs, SInt(bitWidth.W)))
        val out = Output(Vec(numPairs, SInt(bitWidth.W)))
    })
    
    io.out := (io.in1 zip io.in2).map { case(a, b) => a * b }
}

defined [32mclass[39m [36mPMultiplier[39m

#### Verification

In [234]:
Driver(() => new PMultiplier(4, 8)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.in1(0), 1) 
        poke(uut.io.in2(0), 2)
        
        poke(uut.io.in1(1), 3) 
        poke(uut.io.in2(1), 4)
        
        poke(uut.io.in1(2), 5)
        poke(uut.io.in2(2), 6)
        
        poke(uut.io.in1(3), 7)
        poke(uut.io.in2(3), 8)
        
        expect(uut.io.out(0), 2)
        expect(uut.io.out(1), 12)
        expect(uut.io.out(2), 30)
        expect(uut.io.out(3), 56)
  }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.002] Done elaborating.
Total FIRRTL Compile Time: 2.9 ms
Total FIRRTL Compile Time: 2.5 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532545425580
test cmd232WrapperHelperPMultiplier Success: 4 tests passed in 5 cycles taking 0.002662 seconds
[[35minfo[0m] [0.002] RAN 0 CYCLES PASSED


[36mres233[39m: [32mBoolean[39m = [32mtrue[39m

### Additive Reduction Tree

#### Definition

In [235]:
// Recursively creates a balanced syntax tree
def nonassocPairwiseReduce[A](xs: List[A], op: (A, A) => A): A = {
  xs match {
    case Nil => throw new IllegalArgumentException
    case List(singleElem) => singleElem
    case sthElse => {
      val grouped = sthElse.grouped(2).toList
      val pairwiseOpd = for (g <- grouped) yield {
        g match {
          case List(a, b) => op(a, b)
          case List(x) => x
        }
      }
      nonassocPairwiseReduce(pairwiseOpd, op)
    }
  }
}


class AdditiveRT(numAddends: Int, bitWidth: Int) extends Module {

    require(numAddends >= 1, "Number of addends must be at least one.")
    require(bitWidth >= 1, "Bitwidth must be at least one.")
    
    val io = IO(new Bundle {
        val in  = Input (Vec(numAddends, SInt(bitWidth.W)))
        val out = Output(SInt(bitWidth.W))
    })
    
    io.out := nonassocPairwiseReduce(io.in toList, (x: SInt, y: SInt) => x + y)
}

defined [32mfunction[39m [36mnonassocPairwiseReduce[39m
defined [32mclass[39m [36mAdditiveRT[39m

#### Verilog

In [236]:
println(getVerilog(new AdditiveRT(4, 4)))
println(getVerilog(new AdditiveRT(6, 4)))

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.003] Done elaborating.
Total FIRRTL Compile Time: 109.2 ms

module cmd234WrapperHelperAdditiveRT( // @[:@3.2]
  input        clock, // @[:@4.4]
  input        reset, // @[:@5.4]
  input  [3:0] io_in_0, // @[:@6.4]
  input  [3:0] io_in_1, // @[:@6.4]
  input  [3:0] io_in_2, // @[:@6.4]
  input  [3:0] io_in_3, // @[:@6.4]
  output [3:0] io_out // @[:@6.4]
);
  wire [4:0] _T_12; // @[cmd234.sc 29:76:@8.4]
  wire [3:0] _T_13; // @[cmd234.sc 29:76:@9.4]
  wire [3:0] _T_14; // @[cmd234.sc 29:76:@10.4]
  wire [4:0] _T_15; // @[cmd234.sc 29:76:@11.4]
  wire [3:0] _T_16; // @[cmd234.sc 29:76:@12.4]
  wire [3:0] _T_17; // @[cmd234.sc 29:76:@13.4]
  wire [4:0] _T_18; // @[cmd234.sc 29:76:@14.4]
  wire [3:0] _T_19; // @[cmd234.sc 29:76:@15.4]
  wire [3:0] _T_20; // @[cmd234.sc 29:76:@16.4]
  assign _T_12 = $signed(io_in_0) + $signed(io_in_1); // @[cmd234.sc 29:76:@8.4]
  assign _T_13 = _T_12[3:0]; // @[cmd234.sc 29:76:@9.4]
  assign 

#### Verification

In [237]:
Driver(() => new AdditiveRT(4, 8)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.in(0), 1) 
        poke(uut.io.in(1), 2)
        poke(uut.io.in(2), 8) 
        poke(uut.io.in(3), 9) 
        expect(uut.io.out, 20)
        
        poke(uut.io.in(0), 1) 
        poke(uut.io.in(1), 2)
        poke(uut.io.in(2), 8) 
        poke(uut.io.in(3), 9) 
        expect(uut.io.out, 20)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.001] Done elaborating.
Total FIRRTL Compile Time: 2.6 ms
Total FIRRTL Compile Time: 2.1 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532545575458
test cmd234WrapperHelperAdditiveRT Success: 2 tests passed in 5 cycles taking 0.001228 seconds
[[35minfo[0m] [0.001] RAN 0 CYCLES PASSED


[36mres236[39m: [32mBoolean[39m = [32mtrue[39m

### Putting them Together

#### Definition

In [244]:
def checkparamsIPU(width: Int, bypassType: String, bitWidth: Int) {
    require(width >= 1, "Width must be at least one.")
    require(List("None", "Firm").contains(bypassType), "Bypass must be \"None\" or \"Firm\"")
    require(bitWidth >= 0, "Data bitwidth must be non-negative")
}


class IPU(width: Int, bypassType: String, bitWidth: Int) extends Module {
    
    checkparamsIPU(width, bypassType, bitWidth)
    
    val io = IO(new Bundle {
        val in1 = Input(Vec(width, SInt(bitWidth.W)))
        val in2 = Input(Vec(width, SInt(bitWidth.W)))
        val out = Output(SInt(bitWidth.W))
        val sel = if(bypassType == "Firm") Some(Input(Vec(width, Bool()))) else None
        val bp1 = if(bypassType == "Firm") Some(Output(SInt(bitWidth.W)))  else None
        val bp2 = if(bypassType == "Firm") Some(Output(SInt(bitWidth.W)))  else None
    })
    
    val pM = Module(new PMultiplier(width, bitWidth))
    pM.io.in1 := io.in1
    pM.io.in2 := io.in2
    
    val aRT = Module(new AdditiveRT(width, bitWidth))
    aRT.io.in := pM.io.out
    
    io.out := aRT.io.out
    
    if (bypassType == "Firm") {
        io.bp1.get := PriorityMux(io.sel.get, io.in1)
        io.bp2.get := PriorityMux(io.sel.get, io.in2)
    }
}

defined [32mfunction[39m [36mcheckparamsIPU[39m
defined [32mclass[39m [36mIPU[39m

#### Verification

In [259]:
Driver(() => new IPU(width=4, bypassType="Firm", bitWidth=8)) {
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.sel.get(0), 0)
        poke(uut.io.sel.get(1), 0)
        poke(uut.io.sel.get(2), 0)
        poke(uut.io.sel.get(3), 0)
        
        poke(uut.io.in1(0), 1)
        poke(uut.io.in1(1), 2)
        poke(uut.io.in1(2), 3)
        poke(uut.io.in1(3), 4)
        
        poke(uut.io.in2(0), 5)
        poke(uut.io.in2(1), 6)
        poke(uut.io.in2(2), 7)
        poke(uut.io.in2(3), 8)
        
        expect(uut.io.out, 70)
        
        poke(uut.io.sel.get(0), 0)
        poke(uut.io.sel.get(1), 1)
        poke(uut.io.sel.get(2), 0)
        poke(uut.io.sel.get(3), 0)
        
        expect(uut.io.bp1.get, 2)
        expect(uut.io.bp2.get, 6)
        
        poke(uut.io.sel.get(0), 0)
        poke(uut.io.sel.get(1), 0)
        poke(uut.io.sel.get(2), 1)
        poke(uut.io.sel.get(3), 0)
        
        expect(uut.io.bp1.get, 3)
        expect(uut.io.bp2.get, 7)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.001] Done elaborating.
Total FIRRTL Compile Time: 6.7 ms
Total FIRRTL Compile Time: 6.2 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532546560594
test cmd243WrapperHelperIPU Success: 5 tests passed in 5 cycles taking 0.002253 seconds
[[35minfo[0m] [0.002] RAN 0 CYCLES PASSED


[36mres258[39m: [32mBoolean[39m = [32mtrue[39m

## ALU

#### Definition

In [281]:
def checkparamsALU(funcs: List[String], datawidth: Int) {
    require(funcs.contains("Identity"), "ALU functions must explicitly include Identity.")
    val supportedFuncs = List("Identity", "Add", "Max", "Accumulate")
    for(x <- funcs)(require(supportedFuncs.contains(x), "Unsupported Function"))
}

class ALU(funcs: List[String], dataWidth: Int) extends Module {
    
    checkparamsALU(funcs, dataWidth)
    
    val addBypassIn = List("Add", "Max").intersect(funcs).nonEmpty
    val addFeedback = funcs.contains("Accumulate")
 
    val io = IO(new Bundle {
        val innerProduct = Input(SInt(dataWidth.W))
        val funcSel      = Input(Vec(funcs.length, Bool()))
        val output       = Output(SInt(dataWidth.W))
        val weightBp     = if(addBypassIn) Some(Input(SInt(dataWidth.W))) else None
        val actvtnBp     = if(addBypassIn) Some(Input(SInt(dataWidth.W))) else None
        val rfFeedback   = if(addFeedback) Some(Input(SInt(dataWidth.W))) else None
    })
    
    val idnOut = Some(Wire(SInt(dataWidth.W)))
    val addOut = if(funcs.contains("Add"))        Some(Wire(SInt(dataWidth.W))) else None
    val maxOut = if(funcs.contains("Max"))        Some(Wire(SInt(dataWidth.W))) else None
    val accOut = if(funcs.contains("Accumulate")) Some(Wire(SInt(dataWidth.W))) else None
    
    idnOut.get := io.innerProduct
    
    if (funcs.contains("Add")       ) { addOut.get := io.weightBp.get + io.actvtnBp.get }
    if (funcs.contains("Accumulate")) { accOut.get := io.innerProduct + io.rfFeedback.get }
    if (funcs.contains("Max")       ) {
        when (io.weightBp.get > io.actvtnBp.get) {
            maxOut.get := io.weightBp.get
        } .otherwise {
            maxOut.get := io.actvtnBp.get
        }
    }
    
    val inters = (idnOut :: addOut :: maxOut :: accOut :: Nil) filter ( _.isDefined ) map ( _.get )
    io.output := PriorityMux(io.funcSel, inters)
}

defined [32mfunction[39m [36mcheckparamsALU[39m
defined [32mclass[39m [36mALU[39m

#### Verification

In [287]:
val funcs = "Identity" :: "Add" :: "Max" :: "Accumulate" :: Nil

Driver(() => new ALU(funcs, 8)) {
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.innerProduct, 1)
        poke(uut.io.weightBp.get, 2)
        poke(uut.io.actvtnBp.get, 3)
        poke(uut.io.rfFeedback.get, 4)
        
        poke(uut.io.funcSel(0), 1)
        poke(uut.io.funcSel(1), 0)
        poke(uut.io.funcSel(2), 0)
        poke(uut.io.funcSel(3), 0)
        expect(uut.io.output, 1)
        
        poke(uut.io.funcSel(0), 0)
        poke(uut.io.funcSel(1), 1)
        poke(uut.io.funcSel(2), 0)
        poke(uut.io.funcSel(3), 0)
        expect(uut.io.output, 5)
        
        poke(uut.io.funcSel(0), 0)
        poke(uut.io.funcSel(1), 0)
        poke(uut.io.funcSel(2), 1)
        poke(uut.io.funcSel(3), 0)
        expect(uut.io.output, 3)
        
        poke(uut.io.funcSel(0), 0)
        poke(uut.io.funcSel(1), 0)
        poke(uut.io.funcSel(2), 0)
        poke(uut.io.funcSel(3), 1)
        expect(uut.io.output, 5)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.001] Done elaborating.
Total FIRRTL Compile Time: 5.0 ms
Total FIRRTL Compile Time: 4.2 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532556227733
test cmd280WrapperHelperALU Success: 4 tests passed in 5 cycles taking 0.002605 seconds
[[35minfo[0m] [0.002] RAN 0 CYCLES PASSED


[36mfuncs[39m: [32mList[39m[[32mString[39m] = [33mList[39m([32m"Identity"[39m, [32m"Add"[39m, [32m"Max"[39m, [32m"Accumulate"[39m)
[36mres286_1[39m: [32mBoolean[39m = [32mtrue[39m

## Nonlinear Unit

In [9]:
def checkparamsNLU(funcs: List[String], datawidth: Int) {
    require(funcs.contains("Identity"), "NLU functions must explicitly include Identity.")
    val supportedFuncs = List("Identity", "ReLu")
    for(x <- funcs)(require(supportedFuncs.contains(x), "Unsupported Function"))
}

class NonlinearUnit(funcs: List[String], datawidth: Int) extends Module {
    
    checkparamsNLU(funcs, datawidth)
    
    val io = IO(new Bundle {
        val input = Input(SInt(datawidth.W))
        val fslct = Input(Vec(funcs.length, Bool()))
        val outpt = Output(SInt(datawidth.W))
    })
    
    val idntOut = Some(Wire(SInt(datawidth.W)))
    val reluOut = if(funcs.contains("ReLu")) Some(Wire(SInt(datawidth.W))) else None
    
    idntOut.get := io.input
    if (funcs.contains("ReLu")) {
        when (io.input > 0.S) {
            reluOut.get := io.input
        } .otherwise {
            reluOut.get := 0.S
        }
    }
    
    val inters = (idntOut :: reluOut :: Nil) filter ( _.isDefined ) map ( _.get )
    io.outpt := PriorityMux(io.fslct, inters)
}

defined [32mfunction[39m [36mcheckparamsNLU[39m
defined [32mclass[39m [36mNonlinearUnit[39m

## Control

### State Machine

#### Definition

In [340]:
class StateMachine(numStates: Int, nextState: (UInt, UInt, Int) => UInt, ctrlWidth: Int) extends Module {
    
    val stateWidth: Int = log2Up(numStates)
    
    val io = IO(new Bundle {
        val control = Input (UInt(ctrlWidth.W ))
        val out     = Output(UInt(stateWidth.W))
    })
    
    val register = RegInit(0.U(stateWidth.W))
    register := nextState(register, io.control, ctrlWidth)
    io.out := register
}

defined [32mclass[39m [36mStateMachine[39m

#### Example

In [341]:
def stateMap(state: UInt, control: UInt, stateWidth: Int): UInt = {
    
    val nextState = Wire(UInt(stateWidth.W))
    
    when     (state === 0.U & control === 0.U) { nextState := 0.U }
    .elsewhen(state === 0.U & control === 1.U) { nextState := 1.U }
    .elsewhen(state === 1.U & control === 0.U) { nextState := 0.U }
    .elsewhen(state === 1.U & control === 1.U) { nextState := 1.U }
    .otherwise { nextState := 0.U }
    
    nextState
}

defined [32mfunction[39m [36mstateMap[39m

#### Verification

In [343]:
Driver(() => new StateMachine(2, stateMap, 4)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.control, 0)
        expect(uut.io.out, 0)
        
        // 0 -> 1
        poke(uut.io.control, 1)
        step(1)
        expect(uut.io.out, 1)
        
        // 1 -> 1
        poke(uut.io.control, 1)
        step(1)
        expect(uut.io.out, 1)
        
        // 1 -> 0
        poke(uut.io.control, 0)
        step(1)
        expect(uut.io.out, 0)
        
        // 0 -> 0
        poke(uut.io.control, 0)
        step(1)
        expect(uut.io.out, 0)
    }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.002] Done elaborating.
Total FIRRTL Compile Time: 10.1 ms
Total FIRRTL Compile Time: 12.5 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532561171036
test cmd339WrapperHelperStateMachine Success: 5 tests passed in 9 cycles taking 0.004398 seconds
[[35minfo[0m] [0.004] RAN 4 CYCLES PASSED


[36mres342[39m: [32mBoolean[39m = [32mtrue[39m

### Decoder

#### Definition

In [362]:
class RFControl(ports: Int, addrWidth: Int) extends Bundle {
    
    override def cloneType = (new RFControl(ports, addrWidth)).asInstanceOf[this.type]
    
    val wEnable  = Output(Vec(ports, Bool()))
    val rEnable  = Output(Vec(ports, Bool()))
    val wAddr    = Output(Vec(ports, UInt(addrWidth.W)))
    val rAddrInt = Output(Vec(ports, UInt(addrWidth.W)))
    val rAddrExt = Output(Vec(ports, UInt(addrWidth.W)))
    val bpSelGet = Output(Vec(ports, Bool()))
}

class Decoder(decode: (UInt, String) => Data,
              statewidth: Int,
              ports: Int, datawidth: Int, addrwidth: Int,
              aluFuncs: List[String], 
              nluFuncs: List[String]) extends Module {
    
    val io = IO(new Bundle {
        
        val state = Input(UInt(statewidth.W))
        
        val weightRFControl = new RFControl(ports, addrwidth)
        val actvtnRFControl = new RFControl(ports, addrwidth)
        
        val ipuSelGet       = Output(Vec(ports, Bool()))
        val aluFuncSel      = Output(Vec(aluFuncs.length, Bool()))
        
        val intrnlRFControl = new RFControl(ports, addrwidth)
        val intrnlRFDataSel = Output(Bool())
        
        val nluFuncSel      = Output(Vec(nluFuncs.length, Bool()))
        
    })
    
    // Refactor this to use bulk connections
    io.weightRFControl.wEnable  := decode(io.state, "weightRF wEnable")
    io.weightRFControl.rEnable  := decode(io.state, "weightRF rEnable")
    io.weightRFControl.wAddr    := decode(io.state, "weightRF wAddr")
    io.weightRFControl.rAddrInt := decode(io.state, "weightRF rAddrInt")
    io.weightRFControl.rAddrExt := decode(io.state, "weightRF rAddrExt")
    io.weightRFControl.bpSelGet := decode(io.state, "weightRF bpSelGet")
    
    io.actvtnRFControl.wEnable  := decode(io.state, "actvtnRF wEnable")
    io.actvtnRFControl.rEnable  := decode(io.state, "actvtnRF rEnable")
    io.actvtnRFControl.wAddr    := decode(io.state, "actvtnRF wAddr")
    io.actvtnRFControl.rAddrInt := decode(io.state, "actvtnRF rAddrInt")
    io.actvtnRFControl.rAddrExt := decode(io.state, "actvtnRF rAddrExt")
    io.actvtnRFControl.bpSelGet := decode(io.state, "actvtnRF bpSelGet")
    
    io.ipuSelGet  := decode(io.state, "ipuSelGet")
    io.aluFuncSel := decode(io.state, "aluFuncSel")
    
    io.intrnlRFControl.wEnable  := decode(io.state, "intrnlRF wEnable")
    io.intrnlRFControl.rEnable  := decode(io.state, "intrnlRF rEnable")
    io.intrnlRFControl.wAddr    := decode(io.state, "intrnlRF wAddr")
    io.intrnlRFControl.rAddrInt := decode(io.state, "intrnlRF rAddrInt")
    io.intrnlRFControl.rAddrExt := decode(io.state, "intrnlRF rAddrExt")
    io.intrnlRFControl.bpSelGet := decode(io.state, "intrnlRF bpSelGet")
    io.intrnlRFDataSel          := decode(io.state, "intrnlRF dataSel")
    
    io.nluFuncSel := decode(io.state, "nluFuncSel")
}

defined [32mclass[39m [36mRFControl[39m
defined [32mclass[39m [36mDecoder[39m

#### Example

In [363]:
def decode(state: UInt, output: String): Data = {
    
    // Set Types
    val data = output match {
        case "weightRF wEnable"  => Wire(Vec(2, Bool()))
        case "weightRF rEnable"  => Wire(Vec(2, Bool()))
        case "weightRF wAddr"    => Wire(Vec(2, UInt(8.W)))
        case "weightRF rAddrInt" => Wire(Vec(2, UInt(8.W)))
        case "weightRF rAddrExt" => Wire(Vec(2, UInt(8.W)))
        case "weightRF bpSelGet" => Wire(Vec(2, Bool()))
        
        case "actvtnRF wEnable"  => Wire(Vec(2, Bool()))
        case "actvtnRF rEnable"  => Wire(Vec(2, Bool()))
        case "actvtnRF wAddr"    => Wire(Vec(2, UInt(8.W)))
        case "actvtnRF rAddrInt" => Wire(Vec(2, UInt(8.W)))
        case "actvtnRF rAddrExt" => Wire(Vec(2, UInt(8.W)))
        case "actvtnRF bpSelGet" => Wire(Vec(2, Bool()))
        
        case "ipuSelGet"         => Wire(Vec(2, Bool()))
        case "aluFuncSel"        => Wire(Vec(4, Bool()))
         
        case "intrnlRF wEnable"  => Wire(Vec(2, Bool()))
        case "intrnlRF rEnable"  => Wire(Vec(2, Bool()))
        case "intrnlRF wAddr"    => Wire(Vec(2, UInt(8.W)))
        case "intrnlRF rAddrInt" => Wire(Vec(2, UInt(8.W)))
        case "intrnlRF rAddrExt" => Wire(Vec(2, UInt(8.W)))
        case "intrnlRF bpSelGet" => Wire(Vec(2, Bool()))
        case "intrnlRF dataSel"  => Wire(Bool())
        
        case "nluFuncSel"        => Wire(Vec(2, Bool()))
                                    
    }
    
    // Set Values
    when(state === 0.U) {
        data := { output match {
            case "weightRF wEnable"  => Vec.fill(2){true.B}
            case "weightRF rEnable"  => Vec.fill(2){true.B}
            case "weightRF wAddr"    => Vec.fill(2){1.U}
            case "weightRF rAddrInt" => Vec.fill(2){2.U}
            case "weightRF rAddrExt" => Vec.fill(2){3.U}
            case "weightRF bpSelGet" => Vec.fill(2){true.B}
            
            case "actvtnRF wEnable"  => Vec.fill(2){true.B}
            case "actvtnRF rEnable"  => Vec.fill(2){true.B}
            case "actvtnRF wAddr"    => Vec.fill(2){1.U}
            case "actvtnRF rAddrInt" => Vec.fill(2){2.U}
            case "actvtnRF rAddrExt" => Vec.fill(2){3.U}
            case "actvtnRF bpSelGet" => Vec.fill(2){true.B}
            
            case "ipuSelGet"         => Vec(1.U :: 0.U :: Nil)
            case "aluFuncSel"        => Vec(1.U :: 0.U :: 0.U :: 0.U :: Nil)
            
            case "intrnlRF wEnable"  => Vec.fill(2){true.B}
            case "intrnlRF rEnable"  => Vec.fill(2){true.B}
            case "intrnlRF wAddr"    => Vec.fill(2){1.U}
            case "intrnlRF rAddrInt" => Vec.fill(2){2.U}
            case "intrnlRF rAddrExt" => Vec.fill(2){3.U}
            case "intrnlRF bpSelGet" => Vec.fill(2){true.B}
            case "intrnlRF dataSel"  => true.B
            
            case "nluFuncSel"        => Vec(1.U :: 0.U :: Nil)
        }}
    } 

    .otherwise {
        data := { output match {
            case "weightRF wEnable"  => Vec.fill(2){false.B}
            case "weightRF rEnable"  => Vec.fill(2){false.B}
            case "weightRF wAddr"    => Vec.fill(2){4.U}
            case "weightRF rAddrInt" => Vec.fill(2){5.U}
            case "weightRF rAddrExt" => Vec.fill(2){6.U}
            case "weightRF bpSelGet" => Vec.fill(2){false.B}
            
            case "actvtnRF wEnable"  => Vec.fill(2){false.B}
            case "actvtnRF rEnable"  => Vec.fill(2){false.B}
            case "actvtnRF wAddr"    => Vec.fill(2){4.U}
            case "actvtnRF rAddrInt" => Vec.fill(2){5.U}
            case "actvtnRF rAddrExt" => Vec.fill(2){6.U}
            case "actvtnRF bpSelGet" => Vec.fill(2){false.B}
            
            case "ipuSelGet"         => Vec(0.U :: 1.U :: Nil)
            case "aluFuncSel"        => Vec(0.U :: 1.U :: 0.U :: 0.U :: Nil)
            
            case "intrnlRF wEnable"  => Vec.fill(2){false.B}
            case "intrnlRF rEnable"  => Vec.fill(2){false.B}
            case "intrnlRF wAddr"    => Vec.fill(2){4.U}
            case "intrnlRF rAddrInt" => Vec.fill(2){5.U}
            case "intrnlRF rAddrExt" => Vec.fill(2){6.U}
            case "intrnlRF bpSelGet" => Vec.fill(2){false.B}
            case "intrnlRF dataSel"  => false.B
            
            case "nluFuncSel"        => Vec(0.U :: 1.U :: Nil)
        }}
    }
    
    data
}

defined [32mfunction[39m [36mdecode[39m

#### Verification

In [369]:
Driver(() => new Decoder(decode, statewidth=4, ports=2,
                        datawidth=4, addrwidth=4,
                        aluFuncs=List("Identity", "Add", "Max", "Accumulate"),
                        nluFuncs=List("Identity", "ReLu"))) {
    
    uut => new PeekPokeTester(uut) {
        
        
        poke(uut.io.state, 0.U)
        step(1)
        
        expect(uut.io.weightRFControl.wEnable(0), true.B)
        expect(uut.io.weightRFControl.rEnable(0), true.B)
        expect(uut.io.weightRFControl.wAddr(0), 1.U)
        expect(uut.io.weightRFControl.rAddrInt(0), 2.U)
        expect(uut.io.weightRFControl.rAddrExt(0), 3.U)
        expect(uut.io.weightRFControl.bpSelGet(0), true.B)
        
        expect(uut.io.actvtnRFControl.wEnable(0), true.B)
        expect(uut.io.actvtnRFControl.rEnable(0), true.B)
        expect(uut.io.actvtnRFControl.wAddr(0), 1.U)
        expect(uut.io.actvtnRFControl.rAddrInt(0), 2.U)
        expect(uut.io.actvtnRFControl.rAddrExt(0), 3.U)
        expect(uut.io.actvtnRFControl.bpSelGet(0), true.B)
        
        expect(uut.io.ipuSelGet(0), 1)
        expect(uut.io.ipuSelGet(1), 0)
        
        expect(uut.io.aluFuncSel(0), 1)
        expect(uut.io.aluFuncSel(1), 0)
        expect(uut.io.aluFuncSel(2), 0)
        expect(uut.io.aluFuncSel(3), 0)
        
        expect(uut.io.intrnlRFControl.wEnable(0), true.B)
        expect(uut.io.intrnlRFControl.rEnable(0), true.B)
        expect(uut.io.intrnlRFControl.wAddr(0), 1.U)
        expect(uut.io.intrnlRFControl.rAddrInt(0), 2.U)
        expect(uut.io.intrnlRFControl.rAddrExt(0), 3.U)
        expect(uut.io.intrnlRFControl.bpSelGet(0), true.B)
        
        expect(uut.io.nluFuncSel(0), 1)
        expect(uut.io.nluFuncSel(1), 0)
        
        poke(uut.io.state, 1.U) 
        step(1)
        
        expect(uut.io.weightRFControl.wEnable(0), false.B)
        expect(uut.io.weightRFControl.rEnable(0), false.B)
        expect(uut.io.weightRFControl.wAddr(0), 4.U)
        expect(uut.io.weightRFControl.rAddrInt(0), 5.U)
        expect(uut.io.weightRFControl.rAddrExt(0), 6.U)
        expect(uut.io.weightRFControl.bpSelGet(0), false.B)
        
        expect(uut.io.actvtnRFControl.wEnable(0), false.B)
        expect(uut.io.actvtnRFControl.rEnable(0), false.B)
        expect(uut.io.actvtnRFControl.wAddr(0), 4.U)
        expect(uut.io.actvtnRFControl.rAddrInt(0), 5.U)
        expect(uut.io.actvtnRFControl.rAddrExt(0), 6.U)
        expect(uut.io.actvtnRFControl.bpSelGet(0), false.B)
        
        expect(uut.io.ipuSelGet(0), 0)
        expect(uut.io.ipuSelGet(1), 1)
        
        expect(uut.io.aluFuncSel(0), 0)
        expect(uut.io.aluFuncSel(1), 1)
        expect(uut.io.aluFuncSel(2), 0)
        expect(uut.io.aluFuncSel(3), 0)
        
        expect(uut.io.intrnlRFControl.wEnable(0), false.B)
        expect(uut.io.intrnlRFControl.rEnable(0), false.B)
        expect(uut.io.intrnlRFControl.wAddr(0), 4.U)
        expect(uut.io.intrnlRFControl.rAddrInt(0), 5.U)
        expect(uut.io.intrnlRFControl.rAddrExt(0), 6.U)
        expect(uut.io.intrnlRFControl.bpSelGet(0), false.B)
        
        expect(uut.io.nluFuncSel(0), 0)
        expect(uut.io.nluFuncSel(1), 1)
        
    }
}


[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.006] Done elaborating.
Total FIRRTL Compile Time: 45.9 ms
Total FIRRTL Compile Time: 30.4 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1532565630432
test cmd361WrapperHelperDecoder Success: 52 tests passed in 7 cycles taking 0.023516 seconds
[[35minfo[0m] [0.019] RAN 2 CYCLES PASSED


[36mres368[39m: [32mBoolean[39m = [32mtrue[39m

## PE

## Future Plans
* Verify everything using Golden Models