# nPE: A Configurable Processing Engine
#### Verification
___

## Setup

In [4]:
val path = System.getProperty("user.dir") + "/source/load-ivy.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))

Compiling Main.sc


[36mpath[39m: [32mString[39m = [32m"""
C:\Users\RyanL\OneDrive\Research\SEAL\nPE/source/load-ivy.sc
"""[39m

In [5]:
import chisel3._
import chisel3.util._
import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

import scala.math.pow

[32mimport [39m[36mchisel3._
[39m
[32mimport [39m[36mchisel3.util._
[39m
[32mimport [39m[36mchisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}

[39m
[32mimport [39m[36mscala.math.pow[39m

## Parallel Register File

### Single Register File

In [3]:
class RF (datawidth: Int, addrwidth: Int) extends Module {
  
    val io = IO(new Bundle {
        val write_en  = Input (Bool())
        val read_en   = Input (Bool())
        val waddr     = Input (UInt(addrwidth.W))
        val wdata     = Input (SInt(datawidth.W))
        val raddr_int = Input (UInt(addrwidth.W))
        val raddr_ext = Input (UInt(addrwidth.W))
        val rdata_int = Output(SInt(datawidth.W))
        val rdata_ext = Output(SInt(datawidth.W))
    })
    
    val registers  = RegInit(Vec(Seq.fill(pow(2, addrwidth).toInt) { 0.S(addrwidth.W) }))
    
    when(io.write_en) {
        registers(io.waddr) := io.wdata
    }
    
    when(io.read_en) {
        io.rdata_int := registers(io.raddr_int)
        io.rdata_ext := registers(io.raddr_ext)
    } .otherwise {
        io.rdata_int := 0.S
        io.rdata_ext := 0.S
    }
}

defined [32mclass[39m [36mRF[39m

### Putting them Together

In [None]:
class pRF(ports: Int, bypass: String, datawidth: Int, addrwidth: Int) extends Module {
    
    require(List("None", "Soft", "Hard").contains(bypass))
    
    val io = IO(new Bundle {
        val write_en  = Input (Vec(ports, Bool()))
        val read_en   = Input (Vec(ports, Bool()))
        val waddr     = Input (Vec(ports, UInt(addrwidth.W)))
        val wdata     = Input (Vec(ports, SInt(datawidth.W)))
        val raddr_int = Input (Vec(ports, UInt(addrwidth.W)))
        val raddr_ext = Input (Vec(ports, UInt(addrwidth.W)))
        val rdata_int = Output(Vec(ports, SInt(datawidth.W)))
        val rdata_ext = Output(Vec(ports, SInt(datawidth.W)))
        val bp_slct   = if (bypass == "Soft") Some(Input(Vec(ports, Bool()))) else None
    })
    
    if(bypass == "None" || bypass == "Soft") {
        
        val rf = Seq.fill(ports){ new RF(datawidth, addrwidth) }
        
        rf.zipWithIndex.map{ case (x: RF, i: Int) => {
            
            x.io.write_en  := io.write_en(i)
            x.io.read_en   := io.read_en(i)
            x.io.waddr     := io.waddr(i)
            x.io.wdata     := io.wdata(i)
            x.io.raddr_int := io.raddr_int(i)
            x.io.raddr_ext := io.raddr_ext(i)
            
            when (io.bp_slct.getOrElse(Seq.fill(ports){ false.B })(i)) {
                io.rdata_int(i) := x.io.rdata_int
                io.rdata_ext(i) := x.io.rdata_ext
            } .otherwise {
                io.rdata_int(i) := x.io.wdata
                io.rdata_ext(i) := x.io.wdata
            }
        }}
        
    } else if(bypass == "Hard") {
        io.rdata_int := io.wdata
        io.rdata_ext := io.wdata
    }
}

## Inner Product Unit

### Parallel Multiplier

In [14]:
class pMultiplier(width: Int, bitwidth: Int) extends Module {
    
    require(width >= 1, "Width must be at least one.")
    require(bitwidth >= 1, "Bitwidth must be at least one.")
    
    val io = IO(new Bundle {
        val in1 = Input (Vec(width, SInt(bitwidth.W)))
        val in2 = Input (Vec(width, SInt(bitwidth.W)))
        val out = Output(Vec(width, SInt(bitwidth.W)))
    })
    
    io.out := (io.in1 zip io.in2).map { case(a, b) => a * b }
}

defined [32mclass[39m [36mpMultiplier[39m

In [20]:
Driver(() => new pMultiplier(4, 8)) {
    uut => new PeekPokeTester(uut) {
        poke(uut.io.in1(0), 1) 
        poke(uut.io.in2(0), 2)
        expect(uut.io.out(0), 2)
        
        poke(uut.io.in1(1), 3) 
        poke(uut.io.in2(1), 4)
        expect(uut.io.out(1), 12)
        
        poke(uut.io.in1(2), 5)
        poke(uut.io.in2(2), 6)
        expect(uut.io.out(2), 30)
        
        poke(uut.io.in1(3), 7)
        poke(uut.io.in2(3), 8)
        expect(uut.io.out(3), 56)
  }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.001] Done elaborating.
Total FIRRTL Compile Time: 11.5 ms
Total FIRRTL Compile Time: 9.6 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1531516643331
test cmd13WrapperHelperpMultiplier Success: 4 tests passed in 5 cycles taking 0.004315 seconds
[[35minfo[0m] [0.003] RAN 0 CYCLES PASSED


[36mres19[39m: [32mBoolean[39m = [32mtrue[39m

### Additive Reduction Tree

In [6]:
// Recursively creates a balanced syntax tree
def nonassocPairwiseReduce[A](xs: List[A], op: (A, A) => A): A = {
  xs match {
    case Nil => throw new IllegalArgumentException
    case List(singleElem) => singleElem
    case sthElse => {
      val grouped = sthElse.grouped(2).toList
      val pairwiseOpd = for (g <- grouped) yield {
        g match {
          case List(a, b) => op(a, b)
          case List(x) => x
        }
      }
      nonassocPairwiseReduce(pairwiseOpd, op)
    }
  }
}


class AdditiveRT(width: Int, bitwidth: Int) extends Module {

    require(width >= 1, "Width must be at least one.")
    require(bitwidth >= 1, "Bitwidth must be at least one.")
    
    val io = IO(new Bundle {
        val in  = Input (Vec(width, SInt(bitwidth.W)))
        val out = Output(SInt(bitwidth.W))
    })
    
    io.out := nonassocPairwiseReduce(io.in toList, (x: SInt, y: SInt) => x + y)
}

defined [32mfunction[39m [36mnonassocPairwiseReduce[39m
defined [32mclass[39m [36mAdditiveRT[39m

In [13]:
Driver(() => new AdditiveRT(4, 8)) {
  uut => new PeekPokeTester(uut) {
    poke(uut.io.in(0), 1) 
    poke(uut.io.in(1), 2)
    poke(uut.io.in(2), 8) 
    poke(uut.io.in(3), 9) 
    expect(uut.io.out, 20)
  }
}

[[35minfo[0m] [0.000] Elaborating design...
[[35minfo[0m] [0.002] Done elaborating.
Total FIRRTL Compile Time: 11.5 ms
Total FIRRTL Compile Time: 9.6 ms
End of dependency graph
Circuit state created
[[35minfo[0m] [0.000] SEED 1531516373702
test cmd5WrapperHelperAdditiveRT Success: 1 tests passed in 5 cycles taking 0.003067 seconds
[[35minfo[0m] [0.001] RAN 0 CYCLES PASSED


[36mres12[39m: [32mBoolean[39m = [32mtrue[39m

### Putting them Together

In [None]:
def checkparamsIPU(width: Int, bypass: String, bitwidth: Int) {
    require(width >= 1, "Width must be at least one.")
    require(List("None", "Firm").contains(bypass), "Bypass must be \"None\" or \"Firm\"")
    require(bitwidth >= 0, "Data bitwidth must be non-negative")
}


class IPU(width: Int, bypass: String, bitwidth: Int) extends Module {
    
    checkparamsIPU(width, bypass, bitwidth)
    
    val io = IO(new Bundle {
        val in1 = Input(Vec(width, SInt(bitwidth.W)))
        val in2 = Input(Vec(width, SInt(bitwidth.W)))
        val out = Output(UInt(bitwidth.W))
        val sel = if(bypass == "Firm") Some(Input(Vec(width, Bool()))) else None
        val bp1 = if(bypass == "Firm") Some(Output(UInt(bitwidth.W)))  else None
        val bp2 = if(bypass == "Firm") Some(Output(UInt(bitwidth.W)))  else None
    })
    
    val pM = new pMultiplier(width, bitwidth)
    pM.io.in1 := io.in1
    pM.io.in2 := io.in2
    
    val aRT = new AdditiveRT(width, bitwidth)
    aRT.io.in := pM.io.out
    
    io.out := aRT.io.out
    
    if (bypass == "Firm") {
        io.bp1.get := PriorityMux(io.sel.get, io.in1)
        io.bp2.get := PriorityMux(io.sel.get, io.in2)
    }
}

## ALU

In [None]:
def checkparamsALU(funcs: List[String], datawidth: Int) {
    require(funcs.contains("Identity"), "ALU functions must explicitly include Identity.")
    val supportedFuncs = List("Identity", "Add", "Max", "Accumulate")
    for(x <- funcs)(require(supportedFuncs.contains(x), "Unsupported Function"))
}

class ALU(funcs: List[String], datawidth: Int) extends Module {
    
    checkparamsALU(funcs, datawidth)
 
    val io = IO(new Bundle {
        val innr_prod = Input(SInt(datawidth.W))
        val func_slct = Input(Vec(funcs.length, Bool()))
        val output    = Output(SInt(datawidth.W))
        val weight_bp = if(List("Add", "Max").contains(funcs)) Some(Input(SInt(datawidth.W))) else None
        val actvtn_bp = if(List("Add", "Max").contains(funcs)) Some(Input(SInt(datawidth.W))) else None
        val rf_feedbk = if(funcs.contains("Accumulate"))       Some(Input(SInt(datawidth.W))) else None
    })
    
    val idnOut = Some(Wire(SInt(datawidth.W)))
    val addOut = if(funcs.contains("Add"))        Some(Wire(SInt(datawidth.W))) else None
    val maxOut = if(funcs.contains("Max"))        Some(Wire(SInt(datawidth.W))) else None
    val accOut = if(funcs.contains("Accumulate")) Some(Wire(SInt(datawidth.W))) else None
    
    idnOut.get := io.innr_prod
    
    if (funcs.contains("Add")       ) { addOut.get := io.weight_bp.get + io.actvtn_bp.get }
    if (funcs.contains("Accumulate")) { accOut.get := io.innr_prod + io.rf_feedbk.get }
    if (funcs.contains("Max")       ) {
        when (io.weight_bp.get > io.weight_bp.get) {
            maxOut.get := io.weight_bp.get
        } .otherwise {
            maxOut.get := io.actvtn_bp.get
        }
    }
    
    val inters = (idnOut:: addOut :: maxOut :: accOut :: Nil) filter ( _.isDefined ) map ( _.get )
    io.output := PriorityMux(io.func_slct, inters)
}

## Nonlinear Unit

In [None]:
def checkparamsNLU(funcs: List[String], datawidth: Int) {
    require(funcs.contains("Identity"), "NLU functions must explicitly include Identity.")
    val supportedFuncs = List("Identity", "ReLu")
    for(x <- funcs)(require(supportedFuncs.contains(x), "Unsupported Function"))
}

class NonlinearUnit(funcs: List[String], datawidth: Int) extends Module {
    
    checkparamsNLU(funcs, datawidth)
    
    val io = IO(new Bundle {
        val input = Input(SInt(datawidth.W))
        val fslct = Input(Vec(funcs.length, Bool()))
        val outpt = Output(SInt(datawidth.W))
    })
    
    val idntOut = Some(Wire(SInt(datawidth.W)))
    val reluOut = if(funcs.contains("ReLu")) Some(Wire(SInt(datawidth.W))) else None
    
    idntOut.get := io.input
    if (funcs.contains("ReLu")) {
        when (io.input > 0.S) {
            reluOut.get := io.input
        } .otherwise {
            reluOut.get := 0.S
        }
    }
    
    val inters = (idntOut :: reluOut :: Nil) filter ( _.isDefined ) map ( _.get )
    io.outpt := PriorityMux(io.fslct, inters)
}

## PE