# Unified Processing Engine
#### Verification | Version 0.6.1 | Updated 2018.7.30
___

## Setup

In [1]:
val path = System.getProperty("user.dir") + "/source/load-ivy.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))

[36mpath[39m: [32mString[39m = [32m"""
C:\Users\RyanL\OneDrive\Research\SEAL\processing-engine/source/load-ivy.sc
"""[39m

In [2]:
import chisel3._
import chisel3.util._
import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

import scala.math.pow

[32mimport [39m[36mchisel3._
[39m
[32mimport [39m[36mchisel3.util._
[39m
[32mimport [39m[36mchisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

[39m
[32mimport [39m[36mscala.math.pow[39m

## Register File

##### Definition

In [38]:
class PartialRFConfig(
        val numInputs: Int,
        val numOutputs: Int,
        val numCrossInputs: Int,
        val addrWidth: Int,
        val bpSoft: Boolean,
        val bpFirm: Boolean)

class PartialRFControl(c: PartialRFConfig) extends Bundle {
    val wEnable = Vec(c.numInputs, Bool())
    val rEnable = Vec(c.numOutputs, Bool())
    val wAddr = if (!c.bpFirm) Some(Vec(c.numInputs, UInt(c.addrWidth.W))) else None
    val rAddr = if (!c.bpFirm) Some(Vec(c.numOutputs, UInt(c.addrWidth.W))) else None
    // Each output can select which input of the opposite bus to bypass from
    val bpSel = if (c.bpSoft || c.bpFirm) Some(Vec(c.numOutputs, Vec(c.numCrossInputs, Bool()))) else None
}

class RFConfig(
        val numIntInputs: Int,
        val numExtInputs: Int,
        val numIntOutputs: Int,
        val numExtOutputs: Int,
        val addrWidth: Int,
        val dataWidth: Int,
        val bpType: String) {
    
    val bpNone = (bpType == "None")
    val bpSoft = (bpType == "Soft")
    val bpFirm = (bpType == "Firm")
    
    require(bpNone || bpSoft || bpFirm, "Invalid Bypass type.\n")
    require(numIntInputs > 0 || numExtInputs > 0, "Must have at least one input.\n")
    require(numIntOutputs > 0 || numExtOutputs > 0, "Must have at least one output.\n")
    require(dataWidth > 0, "Data bitwidth must be at least one.\n") 
    if (bpFirm) { require(addrWidth == 0, "Address width must be 0 when Firm Bypassing.\n") }
    
    val intConfig = new PartialRFConfig(
        numIntInputs, numIntOutputs, numExtOutputs, addrWidth, bpSoft, bpFirm)
    
    val extConfig = new PartialRFConfig(
        numExtInputs, numExtOutputs, numIntOutputs, addrWidth, bpSoft, bpFirm)
}

class RFControl(c: RFConfig) extends Bundle {
    
    override def cloneType = (new RFControl(c)).asInstanceOf[this.type]
    
    val internal = if (c.numIntInputs > 0 || c.numIntOutputs > 0)
        Some(new PartialRFControl(c.intConfig)) else None
    val external = if (c.numExtInputs > 0 || c.numExtOutputs > 0)
        Some(new PartialRFControl(c.extConfig)) else None
}

class RF(c: RFConfig) extends Module {
    
    val io = IO(new Bundle {
        val control = Input(new RFControl(c))
        val wInternal = Input(Vec(c.numIntInputs, SInt(c.dataWidth.W))) 
        val wExternal = Input(Vec(c.numExtInputs, SInt(c.dataWidth.W)))
        val rInternal = Output(Vec(c.numIntOutputs, SInt(c.dataWidth.W)))
        val rExternal = Output(Vec(c.numExtOutputs, SInt(c.dataWidth.W)))
    })
    
    val dataRegister = if (!c.bpFirm) 
        Some(RegInit(Vec.fill(pow(2, c.addrWidth).toInt){0.S(c.dataWidth.W)})) else None
    
    // Need to bypass through a register to prevent combinational loops
    val bpAny = c.bpSoft || c.bpFirm
    val bpRegisterInt = if (bpAny && c.numIntInputs > 0)
        Some(RegInit(Vec.fill(c.numIntInputs){0.S(c.dataWidth.W)})) else None
    val bpRegisterExt = if (bpAny && c.numExtInputs > 0)
        Some(RegInit(Vec.fill(c.numExtInputs){0.S(c.dataWidth.W)})) else None
    
    for (i <- 0 until c.numIntInputs) {
        when (io.control.internal.get.wEnable(i)) {
            if (!c.bpFirm) { dataRegister.get(io.control.internal.get.wAddr.get(i)) := io.wInternal(i) }
            if (bpRegisterInt.isDefined) { bpRegisterInt.get(i) := io.wInternal(i) }
        }
    }
    
    for (i <- 0 until c.numExtInputs) {
        when (io.control.external.get.wEnable(i)) {
            if (!c.bpFirm) { dataRegister.get(io.control.external.get.wAddr.get(i)) := io.wExternal(i) }
            if (bpRegisterExt.isDefined) { bpRegisterExt.get(i) := io.wExternal(i) }
        }
    }
    
    for (i <- 0 until c.numIntOutputs) {
        when (io.control.internal.get.rEnable(i)) {
            if (c.bpFirm) {
                io.rInternal(i) := PriorityMux(io.control.internal.get.bpSel.get(i), bpRegisterExt.get)
            } else if (c.bpSoft) {
                when (io.control.internal.get.bpSel.get(i).contains(true.B)) {
                    // External write bypasses to Internal read
                    io.rInternal(i) := PriorityMux(io.control.internal.get.bpSel.get(i), bpRegisterExt.get)
                } .otherwise {
                    io.rInternal(i) := dataRegister.get(io.control.internal.get.rAddr.get(i))
                }
            } else {
                io.rInternal(i) := dataRegister.get(io.control.internal.get.rAddr.get(i))
            }
        } .otherwise {
            io.rInternal(i) := 0.S
        }
    }
    
    for (i <- 0 until c.numExtOutputs) {
        when (io.control.external.get.rEnable(i)) {
            if (c.bpFirm) {
                io.rExternal(i) := PriorityMux(io.control.external.get.bpSel.get(i), bpRegisterInt.get)
            } else if (c.bpSoft) {
                when (io.control.external.get.bpSel.get(i).contains(true.B)) {
                    // Internal write bypasses to External read
                    io.rExternal(i) := PriorityMux(io.control.external.get.bpSel.get(i), bpRegisterInt.get)
                } .otherwise {
                    io.rExternal(i) := dataRegister.get(io.control.external.get.rAddr.get(i))
                }
            } else {
                io.rExternal(i) := dataRegister.get(io.control.external.get.rAddr.get(i))
            }
        } .otherwise {
            io.rExternal(i) := 0.S
        }
    }
}

defined [32mclass[39m [36mPartialRFConfig[39m
defined [32mclass[39m [36mPartialRFControl[39m
defined [32mclass[39m [36mRFConfig[39m
defined [32mclass[39m [36mRFControl[39m
defined [32mclass[39m [36mRF[39m

##### Verification

In [41]:
/*
Basic Test Checklist:
[-] Optional Hardware
    [-] No Internal Read Port
    [-] No External Read Port
    [-] No Internal Write Port
    [-] No External Write Port

[-] No Bypass
    [-] Standard Read/Write
    [-] Port Independence 
    [-] Read Enable
    [-] Write Enable

[-] Soft Bypass
    [-] Standard Read/Write
    [-] Port Independence
    [-] Read Enable
    [-] Write Enable 
    [-] Bypass Enable/Select
    
[-] Hard Bypass
    [-] Bypass Enable/Select

Better would be to check these together.
Even better would be to use Golden Model...
*/

// TODO: Do this.
val exRFConfigNoIntWrite = new RFConfig(0, 2, 2, 2, 4, 8, "None")
val exRFConfigNoExtWrite = new RFConfig(2, 0, 2, 2, 4, 8, "None")
val exRFConfigNoIntRead = new RFConfig(2, 2, 0, 2, 4, 8, "None")
val exRFConfigNoExtRead = new RFConfig(2, 2, 2, 0, 4, 8, "None")

val exRFConfigNoBypass = new RFConfig(2, 2, 2, 2, 4, 8, "None")
val exRFConfigSoftBypass = new RFConfig(2, 2, 2, 2, 4, 8, "Soft")
val exRFConfigHardBypass = new RFConfig(2, 2, 2, 2, 0, 8, "Firm")

val noIntWriteTest = Driver(() => new RF(exRFConfigNoIntWrite)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val noExtWriteTest = Driver(() => new RF(exRFConfigNoExtWrite)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val noIntReadTest = Driver(() => new RF(exRFConfigNoIntRead)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val noExtReadTest = Driver(() => new RF(exRFConfigNoExtRead)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val noBypassTest = Driver(() => new RF(exRFConfigNoBypass)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val softBypassTest = Driver(() => new RF(exRFConfigSoftBypass)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}

val hardBypassTest = Driver(() => new RF(exRFConfigHardBypass)) {
    uut => new PeekPokeTester(uut) {
        // TODO
    }
}
                          

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.008] Done elaborating.
Total FIRRTL Compile Time: 51.8 ms
Total FIRRTL Compile Time: 41.8 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1533006619289
test cmd37WrapperHelperRF Success: 0 tests passed in 5 cycles taking 0.003443 seconds
[[35minfo[0m] [0.000] RAN 0 CYCLES PASSED
[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.004] Done elaborating.
Total FIRRTL Compile Time: 54.9 ms
Total FIRRTL Compile Time: 42.2 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1533006619413
test cmd37WrapperHelperRF Success: 0 tests passed in 5 cycles taking 0.002061 seconds
[[35minfo[0m] [0.000] RAN 0 CYCLES PASSED
[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.003] Done elaborating.
Total FIRRTL Compile Time: 41.5 ms
Total FIRRTL Compile Time: 45.5 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1533006619536

[36mexRFConfigNoIntWrite[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@7b591557
[36mexRFConfigNoExtWrite[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@421b2be0
[36mexRFConfigNoIntRead[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@797afcca
[36mexRFConfigNoExtRead[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@21dbc84f
[36mexRFConfigNoBypass[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@121b928e
[36mexRFConfigSoftBypass[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@944ff06
[36mexRFConfigHardBypass[39m: [32mRFConfig[39m = $sess.cmd37Wrapper$Helper$RFConfig@2776ce48
[36mnoIntWriteTest[39m: [32mBoolean[39m = [32mtrue[39m
[36mnoExtWriteTest[39m: [32mBoolean[39m = [32mtrue[39m
[36mnoIntReadTest[39m: [32mBoolean[39m = [32mtrue[39m
[36mnoExtReadTest[39m: [32mBoolean[39m = [32mtrue[39m
[36mnoBypassTest[39m: [32mBoolean[39m = [32mtrue[39m
[36msoftBypa

## Inner Product Unit

### Parallel Multiplier

##### Definition

In [None]:
class PMultConfig(val numPairs: Int, val bitWidth: Int) {
    require(numPairs >= 1, "Must have at least one pair of multiplicands.")
    require(bitWidth >= 1, "Bitwidth must be at least one.")
}

class PMultInput(c: PMultConfig) extends Bundle {
    
    override def cloneType = (new PMultInput(c)).asInstanceOf[this.type]
    
    val weight = Vec(c.numPairs, SInt(c.bitWidth.W))
    val actvtn = Vec(c.numPairs, SInt(c.bitWidth.W))
}

class PMult(c: PMultConfig) extends Module {
    
    val io = IO(new Bundle {
        val in = Input(new PMultInput(c))
        val prod = Output(Vec(c.numPairs, SInt(c.bitWidth.W)))
    })
    
    io.prod := (io.in.weight zip io.in.actvtn).map { case(a, b) => a * b }
}

##### Verification

In [None]:
val examplePMultCon = new PMultConfig(numPairs = 4, bitWidth = 8)

Driver(() => new PMult(examplePMultCon)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.in.weight(0), 1) 
        poke(uut.io.in.actvtn(0), 2)
        
        poke(uut.io.in.weight(1), 3) 
        poke(uut.io.in.actvtn(1), 4)
        
        poke(uut.io.in.weight(2), 5)
        poke(uut.io.in.actvtn(2), 6)
        
        poke(uut.io.in.weight(3), 7)
        poke(uut.io.in.actvtn(3), 8)
        
        expect(uut.io.prod(0), 2)
        expect(uut.io.prod(1), 12)
        expect(uut.io.prod(2), 30)
        expect(uut.io.prod(3), 56)
  }
}

### Additive Reduction Tree

#### Definition

In [None]:
class ARTreeConfig(val numAddends: Int, val bitWidth: Int) {
    require(numAddends >= 1, "Number of addends must be at least one.")
    require(bitWidth >= 1, "Bitwidth must be at least one.")
}

// Recursively creates a balanced syntax tree
def adjReduce[A](xs: List[A], op: (A, A) => A): A = xs match {
    case Nil => throw new IllegalArgumentException
    case List(single) => single
    case default => {
        val grouped = default.grouped(2).toList
        val result = for (g <- grouped) yield {
            g match {
                case List(a, b) => op(a, b)
                case List(x) => x
            }
        }
        adjReduce(result, op)
    }
}

class ARTree(c: ARTreeConfig) extends Module {
    
    val io = IO(new Bundle {
        val in  = Input(Vec(c.numAddends, SInt(c.bitWidth.W)))
        val sum = Output(SInt(c.bitWidth.W))
    })
    
    io.sum := adjReduce(io.in toList, (x: SInt, y: SInt) => x + y)
}

#### Verilog

In [None]:
val exampleARTreeCon = new ARTreeConfig(4, 8)
println(getVerilog(new ARTree(exampleARTreeCon)))

#### Verification

In [None]:
Driver(() => new ARTree(exampleARTreeCon)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.in(0), 1) 
        poke(uut.io.in(1), 2)
        poke(uut.io.in(2), 8) 
        poke(uut.io.in(3), 9) 
        expect(uut.io.sum, 20)
        
        poke(uut.io.in(0), 1) 
        poke(uut.io.in(1), 2)
        poke(uut.io.in(2), 8) 
        poke(uut.io.in(3), 9) 
        expect(uut.io.sum, 20)
    }
}

### Putting them Together

#### Definition

In [None]:
class IPUConfig(val width: Int, val bitWidth: Int, val bpType: String) {
    
    private val bypssError = "Bypass must be \"None\" or \"Firm\""
    private val widthError = "Width must be at least one"
    private val bitWdError = "Data bitwidth must be non-negative"
    
    val supportedBp = List("None", "Firm")
    
    require(width >= 1, widthError)
    require(supportedBp.contains(bpType), bypssError)
    require(bitWidth >= 0, bitWdError)
    
    val childPMultConfig = new PMultConfig(width, bitWidth)
    val childARTreeConfig = new ARTreeConfig(width, bitWidth)
    
    val bpFirm = (bpType == "Firm")
}

class IPUOutput(c: IPUConfig) extends Bundle {
    
    override def cloneType = (new IPUOutput(c)).asInstanceOf[this.type]
    
    // Disallowed. TODO: Read up on co/contravariance.
    // def this(ca: ALUConfig) = this(new IPUConfig(42, ca.dataWidth, ca.ipuBypassTp))
    
    val innerProd = Output(SInt(c.bitWidth.W))
    val bpWeight = if (c.bpFirm) Some(SInt(c.bitWidth.W)) else None
    val bpActvtn = if (c.bpFirm) Some(SInt(c.bitWidth.W)) else None
}


class IPU(c: IPUConfig) extends Module {
    
    val cPMConfig = c.childPMultConfig
    val cARTConfig = c.childARTreeConfig
    
    val io = IO(new Bundle {
        val dataIn = Input(new PMultInput(cPMConfig))
        val dataOut = Output(new IPUOutput(c))
        val bpSel = if (c.bpFirm) Some(Input(Vec(c.width, Bool()))) else None
    })
    
    val pMult = Module(new PMult(cPMConfig))
    pMult.io.in <> io.dataIn
    
    val aRTree = Module(new ARTree(cARTConfig))
    aRTree.io.in := pMult.io.prod
    
    io.dataOut.innerProd := aRTree.io.sum
    
    if (c.bpFirm) {
        io.dataOut.bpWeight.get := PriorityMux(io.bpSel.get, io.dataIn.weight)
        io.dataOut.bpActvtn.get := PriorityMux(io.bpSel.get, io.dataIn.actvtn)
    }
}

#### Verification

In [None]:
val exampleIPUCon = new IPUConfig(width = 4, bitWidth = 8, bpType = "Firm")

Driver(() => new IPU(exampleIPUCon)) {
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.bpSel.get(0), 0)
        poke(uut.io.bpSel.get(1), 0)
        poke(uut.io.bpSel.get(2), 0)
        poke(uut.io.bpSel.get(3), 0)
        
        poke(uut.io.dataIn.weight(0), 1)
        poke(uut.io.dataIn.weight(1), 2)
        poke(uut.io.dataIn.weight(2), 3)
        poke(uut.io.dataIn.weight(3), 4)
        
        poke(uut.io.dataIn.actvtn(0), 5)
        poke(uut.io.dataIn.actvtn(1), 6)
        poke(uut.io.dataIn.actvtn(2), 7)
        poke(uut.io.dataIn.actvtn(3), 8)
        
        expect(uut.io.dataOut.innerProd, 70)
        
        poke(uut.io.bpSel.get(0), 0)
        poke(uut.io.bpSel.get(1), 1)
        poke(uut.io.bpSel.get(2), 0)
        poke(uut.io.bpSel.get(3), 0)
        
        expect(uut.io.dataOut.bpWeight.get, 2)
        expect(uut.io.dataOut.bpActvtn.get, 6)
        
        poke(uut.io.bpSel.get(0), 0)
        poke(uut.io.bpSel.get(1), 0)
        poke(uut.io.bpSel.get(2), 1)
        poke(uut.io.bpSel.get(3), 0)
        
        expect(uut.io.dataOut.bpWeight.get, 3)
        expect(uut.io.dataOut.bpActvtn.get, 7)
    }
}

## ALU

#### Definition

In [None]:
class ALUConfig(val dataWidth: Int, val funcs: List[String]) {
    val identityError = "ALU functions must explicitly include Identity."
    val functionError = "Unsupported Error"
    val supportedFuncs = List("Identity", "Add", "Max", "Accumulate")
    
    require(funcs.contains("Identity"), identityError)
    for(x <- funcs) { require(supportedFuncs.contains(x), functionError) }
    
    val addSupp = funcs.contains("Add")
    val maxSupp = funcs.contains("Max")
    val accSupp = funcs.contains("Accumulate")
    val addBypassIn = addSupp || maxSupp
    val ipuBypassTp = if (addBypassIn) "Firm" else "None"
    val numFuncs = funcs.length
}

class ALUInput(c: ALUConfig) extends Bundle {
    
    override def cloneType = (new ALUInput(c)).asInstanceOf[this.type]
    
    val funcSel = Vec(c.numFuncs, Bool())
    val ipu = new IPUOutput(new IPUConfig(42, c.dataWidth, c.ipuBypassTp))
    val rfFeedback = if(c.accSupp) Some(SInt(c.dataWidth.W)) else None
}

class ALU(c: ALUConfig) extends Module {
 
    val io = IO(new Bundle {
        val in = Input(new ALUInput(c))
        val out = Output(SInt(c.dataWidth.W))
    })
    
    val idnOut = Some(Wire(SInt(c.dataWidth.W)))
    val addOut = if(c.addSupp) Some(Wire(SInt(c.dataWidth.W))) else None
    val maxOut = if(c.maxSupp) Some(Wire(SInt(c.dataWidth.W))) else None
    val accOut = if(c.accSupp) Some(Wire(SInt(c.dataWidth.W))) else None
    
    idnOut.get := io.in.ipu.innerProd
    
    if (c.addSupp) { addOut.get := io.in.ipu.bpWeight.get + io.in.ipu.bpActvtn.get }
    if (c.accSupp) { accOut.get := io.in.ipu.innerProd + io.in.rfFeedback.get }
    if (c.maxSupp) {
        when (io.in.ipu.bpWeight.get > io.in.ipu.bpActvtn.get) {
            maxOut.get := io.in.ipu.bpWeight.get
        } .otherwise {
            maxOut.get := io.in.ipu.bpActvtn.get
        }
    }
    
    val inters = (idnOut :: addOut :: maxOut :: accOut :: Nil) filter ( _.isDefined ) map ( _.get )
    io.out := PriorityMux(io.in.funcSel, inters)
}

#### Verification

In [None]:
val exampleALUFuncs = "Identity" :: "Add" :: "Max" :: "Accumulate" :: Nil
val exampleALUCon = new ALUConfig(dataWidth = 8, funcs = exampleALUFuncs)

Driver(() => new ALU(exampleALUCon)) {
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.in.ipu.innerProd, 1)
        poke(uut.io.in.ipu.bpWeight.get, 2)
        poke(uut.io.in.ipu.bpActvtn.get, 3)
        poke(uut.io.in.rfFeedback.get, 4)
        
        poke(uut.io.in.funcSel(0), 1)
        poke(uut.io.in.funcSel(1), 0)
        poke(uut.io.in.funcSel(2), 0)
        poke(uut.io.in.funcSel(3), 0)
        expect(uut.io.out, 1)
        
        poke(uut.io.in.funcSel(0), 0)
        poke(uut.io.in.funcSel(1), 1)
        poke(uut.io.in.funcSel(2), 0)
        poke(uut.io.in.funcSel(3), 0)
        expect(uut.io.out, 5)
        
        poke(uut.io.in.funcSel(0), 0)
        poke(uut.io.in.funcSel(1), 0)
        poke(uut.io.in.funcSel(2), 1)
        poke(uut.io.in.funcSel(3), 0)
        expect(uut.io.out, 3)
        
        poke(uut.io.in.funcSel(0), 0)
        poke(uut.io.in.funcSel(1), 0)
        poke(uut.io.in.funcSel(2), 0)
        poke(uut.io.in.funcSel(3), 1)
        expect(uut.io.out, 5)
    }
}

## Nonlinear Unit

In [None]:
class NLUConfig(val dataWidth: Int, val funcs: List[String]) {
    
    val supportedFuncs = List("Identity", "ReLu")
    val identityError = "NLU functions must explicitly include Identity."
    val functionError = "Unsupported Function"
    
    require(funcs.contains("Identity"), identityError)
    for(x <- funcs)(require(supportedFuncs.contains(x), functionError))
    
    val reluSupp = funcs.contains("ReLu")
    val numFuncs = funcs.length
}

class NLUInputs(c: NLUConfig) extends Bundle {
    
    override def cloneType = (new NLUInputs(c)).asInstanceOf[this.type]
    
    val data = SInt(c.dataWidth.W)
    val fSel = Vec(c.numFuncs, Bool())
}

class NLU(c: NLUConfig) extends Module {
    
    val io = IO(new Bundle {
        val in  = Input(new NLUInputs(c))
        val out = Output(SInt(c.dataWidth.W))
    })
    
    val idRes   = Some(Wire(SInt(c.dataWidth.W)))
    val reluRes = if(c.reluSupp) Some(Wire(SInt(c.dataWidth.W))) else None
    
    idRes.get := io.in.data
    
    if (c.reluSupp) {
        when (io.in.data > 0.S) {
            reluRes.get := io.in.data
        } .otherwise {
            reluRes.get := 0.S
        }
    }
    
    val inters = (idRes :: reluRes :: Nil) filter ( _.isDefined ) map ( _.get )
    io.out := PriorityMux(io.in.fSel, inters)
}

In [None]:
val nluFuncs = "Identity" :: "ReLu" :: Nil
val nluCon = new NLUConfig(dataWidth = 8, funcs = nluFuncs)

Driver(() => new NLU(nluCon)) {
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.in.data, 5)
        
        poke(uut.io.in.fSel(0), 1)
        poke(uut.io.in.fSel(1), 0)
        expect(uut.io.out, 5)
        
        poke(uut.io.in.fSel(0), 0)
        poke(uut.io.in.fSel(1), 1)
        expect(uut.io.out, 5)
        
        poke(uut.io.in.data, -4)
        
        poke(uut.io.in.fSel(0), 1)
        poke(uut.io.in.fSel(1), 0)
        expect(uut.io.out, -4)
        
        poke(uut.io.in.fSel(0), 0)
        poke(uut.io.in.fSel(1), 1)
        expect(uut.io.out, 0)
    }
}

## Control

### State Machine

#### Definition

In [None]:
class StateMachineConfig(
        val numStates: Int, 
        val numCtrlSigs: Int, 
        val stateMap: (UInt, UInt, StateMachineConfig) => UInt) {
    
    val stateWidth = log2Up(numStates)
    val ctrlWidth = log2Up(numCtrlSigs)
}

class StateMachine(c: StateMachineConfig) extends Module {
    
    val stateWidth: Int = log2Up(c.numStates)
    
    val io = IO(new Bundle {
        val control = Input (UInt(c.ctrlWidth.W))
        val out = Output(UInt(c.stateWidth.W))
    })
    
    val register = RegInit(0.U(c.stateWidth.W))
    register := c.stateMap(register, io.control, c)
    io.out := register
}

#### Example

In [None]:
def exampleStateMap(state: UInt, control: UInt, c: StateMachineConfig): UInt = {
    
    val nextState = Wire(UInt(c.stateWidth.W))
    
    when      (state === 0.U & control === 0.U) { nextState := 0.U }
    .elsewhen (state === 0.U & control === 1.U) { nextState := 1.U }
    .elsewhen (state === 1.U & control === 0.U) { nextState := 0.U }
    .elsewhen (state === 1.U & control === 1.U) { nextState := 1.U }
    .otherwise { nextState := 0.U }
    
    nextState
}

#### Verification

In [None]:
val exampleStateMachineConfig = new StateMachineConfig(2, 2, exampleStateMap)

Driver(() => new StateMachine(exampleStateMachineConfig)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.control, 0)
        expect(uut.io.out, 0)
        
        // 0 -> 1
        poke(uut.io.control, 1)
        step(1)
        expect(uut.io.out, 1)
        
        // 1 -> 1
        poke(uut.io.control, 1)
        step(1)
        expect(uut.io.out, 1)
        
        // 1 -> 0
        poke(uut.io.control, 0)
        step(1)
        expect(uut.io.out, 0)
        
        // 0 -> 0
        poke(uut.io.control, 0)
        step(1)
        expect(uut.io.out, 0)
    }
}

### Decoder

#### Definition

In [None]:
class PEConfig(
        val weightPRFConfig: PRFConfig,
        val actvtnPRFConfig: PRFConfig,
        val scratchRFConfig: RFConfig,
        val ipuConfig: IPUConfig,
        val aluConfig: ALUConfig,
        val nluConfig: NLUConfig,
        val smConfig: StateMachineConfig,
        val decodeWeightPRF: (UInt, PRFConfig) => Data,
        val decodeActvtnPRF: (UInt, PRFConfig) => Data,
        val decodeScratchRF: (UInt, RFConfig) => Data,
        val decodeIPU: (UInt, IPUConfig) => Data,
        val decodeALU: (UInt, ALUConfig) => Data,
        val decodeNLU: (UInt, NLUConfig) => Data)

class MemoryControl(c: PEConfig) extends Bundle {
    
    override def cloneType = (new MemoryControl(c)).asInstanceOf[this.type]
    
    val weightPRF = new PRFControl(c.weightPRFConfig)
    val actvtnPRF = new PRFControl(c.actvtnPRFConfig)
    val scratchRF = new RFControl(c.scratchRFConfig)
}

class ProcessControl(c: PEConfig) extends Bundle {
    
    override def cloneType = (new ProcessControl(c)).asInstanceOf[this.type]
    
    val aluFSel = Output(Vec(c.aluConfig.numFuncs, Bool()))
    val nluFSel = Output(Vec(c.nluConfig.numFuncs, Bool()))
    
    private val ipuPorts = c.weightPRFConfig.ports
    private val ipuBpFirm = c.ipuConfig.bpFirm
    val ipuBpSel = if (ipuBpFirm) Some(Output(Vec(ipuPorts, Bool()))) else None
}

class Decoder(c: PEConfig) extends Module {
    
    val io = IO(new Bundle {
        val state = Input(UInt(c.smConfig.stateWidth.W))
        val mem = Output(new MemoryControl(c))
        val proc = Output(new ProcessControl(c))
    })
    
    io.mem.weightPRF <> c.decodeWeightPRF(io.state, c.weightPRFConfig)
    io.mem.actvtnPRF <> c.decodeActvtnPRF(io.state, c.actvtnPRFConfig)
    io.mem.scratchRF <> c.decodeScratchRF(io.state, c.scratchRFConfig)
    
    if (c.ipuConfig.bpFirm) { 
        io.proc.ipuBpSel.get := c.decodeIPU(io.state, c.ipuConfig)
    }
    
    io.proc.aluFSel := c.decodeALU(io.state, c.aluConfig)
    io.proc.nluFSel := c.decodeNLU(io.state, c.nluConfig)
}

#### Example

In [None]:
def exampleDecodeWeightPRF(state: UInt, c: PRFConfig) = {
    
    val data = Wire(new PRFControl(c))
    
    when (state === 0.U) {
        data.rf.foreach { k =>
            k.wEnable   := true.B
            k.rEnable   := true.B
            k.wAddr     := 1.U
            k.rAddrInt  := 2.U
            k.rAddrExt  := 3.U
            if (k.bpSel.isDefined) { k.bpSel.get := true.B }
        }
    } .otherwise {
        data.rf.foreach { k =>
            k.wEnable   := false.B
            k.rEnable   := false.B
            k.wAddr     := 4.U
            k.rAddrInt  := 5.U
            k.rAddrExt  := 6.U
            if (k.bpSel.isDefined) { k.bpSel.get := false.B }
        }
    }
    
    data
}

def exampleDecodeActvtnPRF(state: UInt, c: PRFConfig) = {
    
    val data = Wire(new PRFControl(c))
    
    when (state === 0.U) {
        data.rf.foreach { k =>
            k.wEnable   := true.B
            k.rEnable   := true.B
            k.wAddr     := 1.U
            k.rAddrInt  := 2.U
            k.rAddrExt  := 3.U
            if (k.bpSel.isDefined) { k.bpSel.get := true.B }
        }
    } .otherwise {
        data.rf.foreach { k =>
            k.wEnable   := false.B
            k.rEnable   := false.B
            k.wAddr     := 4.U
            k.rAddrInt  := 5.U
            k.rAddrExt  := 6.U
            if (k.bpSel.isDefined) { k.bpSel.get := false.B }
        }
    }
    
    data
}

def exampleDecodeScratchRF(state: UInt, c: RFConfig) = {
    
    val data = Wire(new RFControl(c))
    
    when (state === 0.U) {
        data.wEnable   := true.B
        data.rEnable   := true.B
        data.wAddr     := 1.U
        data.rAddrInt  := 2.U
        data.rAddrExt  := 3.U
        if(data.bpSel.isDefined) { data.bpSel.get := true.B }
        if(data.inSel.isDefined) { data.inSel.get := Vec(List(true.B, false.B)) }
    } .otherwise {
        data.wEnable   := false.B
        data.rEnable   := false.B
        data.wAddr     := 4.U
        data.rAddrInt  := 5.U
        data.rAddrExt  := 6.U
        if(data.bpSel.isDefined) { data.bpSel.get := false.B }
        if(data.inSel.isDefined) { data.inSel.get := Vec(List(false.B, true.B)) }
    }
    
    data
}

def exampleDecodeIPU(state: UInt, c: IPUConfig) = {
    
    val data = Wire(Vec(c.width, Bool()))
    
    when (state === 0.U) {
        data := Vec(1.U :: 0.U :: Nil)
    } .otherwise {
        data := Vec(0.U :: 1.U :: Nil)
    }
    
    data
}

def exampleDecodeALU(state: UInt, c: ALUConfig) = {
    
    val data = Wire(Vec(c.numFuncs, Bool()))
    
    when (state === 0.U) {
        data := Vec(1.U :: 0.U :: 0.U :: 0.U :: Nil)
    } .otherwise {
        data := Vec(0.U :: 1.U :: 0.U :: 0.U :: Nil)
    }
    
    data
}

def exampleDecodeNLU(state: UInt, c: NLUConfig) = {
    
    val data = Wire(Vec(c.numFuncs, Bool()))
    
    when (state === 0.U) {
        data := Vec(1.U :: 0.U :: Nil)
    } .otherwise {
        data := Vec(0.U :: 1.U :: Nil)
    }
    
    data
}


#### Verification

In [None]:
// TODO: require IPU width == weightPRF width == actvtnPRF width
// TODO: require IPUConfig "Firm" if ALUConfig "Add" or "Max"

val examplePEConfig = new PEConfig(
    new PRFConfig(2, 8, 4, 1, "Soft"),
    new PRFConfig(2, 8, 4, 1, "Soft"),
    new RFConfig(8, 4, 2, true),
    new IPUConfig(2, 8, "Firm"),
    new ALUConfig(8, List("Identity", "Add", "Max", "Accumulate")),
    new NLUConfig(8, List("Identity", "ReLu")),
    new StateMachineConfig(4, 4, exampleStateMap),
    exampleDecodeWeightPRF,
    exampleDecodeActvtnPRF,
    exampleDecodeScratchRF,
    exampleDecodeIPU,
    exampleDecodeALU,
    exampleDecodeNLU
)


Driver(() => new Decoder(examplePEConfig)) {
    
    uut => new PeekPokeTester(uut) {
        
        poke(uut.io.state, 0.U)
        step(1)
        
        expect(uut.io.mem.weightPRF.rf(0).wEnable, true.B)
        expect(uut.io.mem.weightPRF.rf(0).rEnable, true.B)
        expect(uut.io.mem.weightPRF.rf(0).wAddr, 1.U)
        expect(uut.io.mem.weightPRF.rf(0).rAddrInt, 2.U)
        expect(uut.io.mem.weightPRF.rf(0).rAddrExt, 3.U)
        expect(uut.io.mem.weightPRF.rf(0).bpSel.get, true.B)
        
        expect(uut.io.mem.actvtnPRF.rf(0).wEnable, true.B)
        expect(uut.io.mem.actvtnPRF.rf(0).rEnable, true.B)
        expect(uut.io.mem.actvtnPRF.rf(0).wAddr, 1.U)
        expect(uut.io.mem.actvtnPRF.rf(0).rAddrInt, 2.U)
        expect(uut.io.mem.actvtnPRF.rf(0).rAddrExt, 3.U)
        expect(uut.io.mem.actvtnPRF.rf(0).bpSel.get, true.B)
        
        expect(uut.io.proc.ipuBpSel.get(0), 1)
        expect(uut.io.proc.ipuBpSel.get(1), 0)
        
        expect(uut.io.proc.aluFSel(0), 1)
        expect(uut.io.proc.aluFSel(1), 0)
        expect(uut.io.proc.aluFSel(2), 0)
        expect(uut.io.proc.aluFSel(3), 0)
        
        expect(uut.io.mem.sratchRF.rf(0).wEnable, true.B)
        expect(uut.io.mem.sratchRF.rf(0).rEnable, true.B)
        expect(uut.io.mem.sratchRF.rf(0).wAddr, 1.U)
        expect(uut.io.mem.sratchRF.rf(0).rAddrInt, 2.U)
        expect(uut.io.mem.sratchRF.rf(0).rAddrExt, 3.U)
        expect(uut.io.mem.sratchRF.rf(0).bpSel.get, true.B)
        
        expect(uut.io.proc.nluFSel(0), 1)
        expect(uut.io.proc.nluFSel(1), 0)
        
        poke(uut.io.state, 1.U) 
        step(1)
        
        expect(uut.io.mem.weightPRF.rf(0).wEnable, false.B)
        expect(uut.io.mem.weightPRF.rf(0).rEnable, false.B)
        expect(uut.io.mem.weightPRF.rf(0).wAddr, 4.U)
        expect(uut.io.mem.weightPRF.rf(0).rAddrInt, 5.U)
        expect(uut.io.mem.weightPRF.rf(0).rAddrExt, 6.U)
        expect(uut.io.mem.weightPRF.rf(0).bpSel.get, false.B)
        
        expect(uut.io.mem.actvtnPRF.rf(0).wEnable, false.B)
        expect(uut.io.mem.actvtnPRF.rf(0).rEnable, false.B)
        expect(uut.io.mem.actvtnPRF.rf(0).wAddr, 4.U)
        expect(uut.io.mem.actvtnPRF.rf(0).rAddrInt, 5.U)
        expect(uut.io.mem.actvtnPRF.rf(0).rAddrExt, 6.U)
        expect(uut.io.mem.actvtnPRF.rf(0).bpSel.get, false.B)
        
        expect(uut.io.proc.ipuBpSel.get(0), 0)
        expect(uut.io.proc.ipuBpSel.get(1), 1)
        
        expect(uut.io.proc.aluFSel(0), 0)
        expect(uut.io.proc.aluFSel(1), 1)
        expect(uut.io.proc.aluFSel(2), 0)
        expect(uut.io.proc.aluFSel(3), 0)
        
        expect(uut.io.mem.scratchRF.rf(0).wEnable, false.B)
        expect(uut.io.mem.scratchRF.rf(0).rEnable, false.B)
        expect(uut.io.mem.scratchRF.rf(0).wAddr, 4.U)
        expect(uut.io.mem.scratchRF.rf(0).rAddrInt, 5.U)
        expect(uut.io.mem.scratchRF.rf(0).rAddrExt, 6.U)
        expect(uut.io.mem.scratchRF.rf(0).bpSel.get, false.B)
        
        expect(uut.io.proc.nluFSel(0), 0)
        expect(uut.io.proc.nluFSel(1), 1)
        
    }
}


## PE

#### Definition

In [None]:
class PE(c: PEConfig) extends Module {
    
    val cw = c.weightPRFConfig
    val ca = c.actvtnPRFConfig
    val cs = c.scratchRFConfig
    
    val io = IO(new Bundle {
        val stateCtrl = Input(UInt(c.smConfig.ctrlWidth.W))
        val toWeightPRF = Input(new PRFInput(cw))
        val toActvtnPRF = Input(new PRFInput(ca))
        val toScratchRF = Input(SInt(cs.dataWidth.W))
        val fromWeightPRF = Output(Vec(cw.ports, SInt(cw.dataWidth.W)))
        val fromActvtnPRF = Output(Vec(ca.ports, SInt(ca.dataWidth.W)))
        val fromScratchRF = Output(SInt(cs.dataWidth.W))
        val totalOutput = Output(SInt(c.nluConfig.dataWidth.W))
    })
    
    val stateMachine = Module(new StateMachine(c.smConfig))
    stateMachine.io.control := io.stateCtrl
    
    val decoder = Module(new Decoder(c))
    decoder.io.state := stateMachine.io.out
    
    val weightPRF = Module(new PRF(cw))
    weightPRF.io.control <> decoder.io.mem.weightPRF
    weightPRF.io.in <> io.toWeightPRF
    weightPRF.io.out.rf.zipWithIndex.map { 
        case (x: RFOutput, i: Int) => io.fromWeightPRF(i) := x.ext
    }
    
    val actvtnPRF = Module(new PRF(ca))
    actvtnPRF.io.control <> decoder.io.mem.actvtnPRF
    actvtnPRF.io.in <> io.toActvtnPRF
    actvtnPRF.io.out.rf.zipWithIndex.map {
        case (x: RFOutput, i: Int) => io.fromActvtnPRF(i) := x.ext
    }
       
    val ipu = Module(new IPU(c.ipuConfig))
    if (ipu.io.bpSel.isDefined) { ipu.io.bpSel.get := decoder.io.proc.ipuBpSel.get }
    weightPRF.io.out.rf.zipWithIndex.map { 
        case (x: RFOutput, i: Int) => ipu.io.dataIn.weight(i) := x.int
    }
    actvtnPRF.io.out.rf.zipWithIndex.map {
        case (x: RFOutput, i: Int) => ipu.io.dataIn.actvtn(i) := x.int
    }

    val alu = Module(new ALU(c.aluConfig))
    alu.io.in.funcSel := decoder.io.proc.aluFSel
    alu.io.in.ipu <> ipu.io.dataOut
    
    val scratchRF = Module(new RF(cs))
    scratchRF.io.control <> decoder.io.mem.scratchRF
    // This next group of statements is the result of poor decisions :(
    scratchRF.io.in.data(0) := io.toScratchRF
    scratchRF.io.in.data(1) := alu.io.out
    io.fromScratchRF := scratchRF.io.out.ext
    if(alu.io.in.rfFeedback.isDefined) alu.io.in.rfFeedback.get := scratchRF.io.out.int
    
    val nlu = Module(new NLU(c.nluConfig))
    nlu.io.in.fSel := decoder.io.proc.nluFSel
    nlu.io.in.data := scratchRF.io.out.int
    io.totalOutput := nlu.io.out
}

#### Verification

In [None]:
val examplePEConfig10 = new PEConfig(
    new PRFConfig(2, 8, 4, 1, "Soft"),
    new PRFConfig(2, 8, 4, 1, "Soft"),
    new RFConfig(8, 4, 2, false),
    new IPUConfig(2, 8, "Firm"),
    new ALUConfig(8, List("Identity", "Add", "Max", "Accumulate")),
    new NLUConfig(8, List("Identity", "ReLu")),
    new StateMachineConfig(4, 4, exampleStateMap),
    exampleDecodeWeightPRF,
    exampleDecodeActvtnPRF,
    exampleDecodeScratchRF,
    exampleDecodeIPU,
    exampleDecodeALU,
    exampleDecodeNLU
)

Driver(() => new PE(examplePEConfig10)) {
    uut => new PeekPokeTester(uut) {
        
    }
}

## Future Plans
* Verify everything using Golden Models