Skip to content

Commit

Permalink
[VTA] Enable streamlined GEMM execution (apache#4392)
Browse files Browse the repository at this point in the history
* disable pipelined adder and enable streamlined gemm execution

* pipeline first layer of adder

* explain difference between pipeadder and adder

* add comment for explaining the hard-coded latency
  • Loading branch information
liangfu authored and tmoreau89 committed Dec 3, 2019
1 parent d56ae16 commit 5dadf76
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
io.y := add
}

/** Pipelined adder */
/** PipeAdder
*
* This unit loads input bits into register and performs addition in the next cycle
*/
class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle {
Expand All @@ -61,6 +64,27 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
io.y := add
}

/** Adder
*
* This unit wires input bits to an adder directly.
* The output comes out of combinational logic without waiting for another cycle.
*/
class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle {
val a = Input(SInt(aBits.W))
val b = Input(SInt(bBits.W))
val y = Output(SInt(outBits.W))
})
val add = Wire(SInt(outBits.W))
val rA = Wire(SInt(aBits.W))
val rB = Wire(SInt(bBits.W))
rA := io.a
rB := io.b
add := rA +& rB
io.y := add
}

/** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
extends Module {
Expand All @@ -80,9 +104,11 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
val a = Seq.tabulate(p)(
i =>
Seq.fill(s(i + 1))(Module(new PipeAdder(
aBits = (b + i + 1),
bBits = (b + i + 1))))) // # adders within each layer
Seq.fill(s(i + 1))(
if (i == 0)
Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))
else
Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer

// Vector MACs
for (i <- 0 until s(0)) {
Expand Down Expand Up @@ -126,8 +152,9 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
})
val dot = Seq.fill(size)(
Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
val acc = Seq.fill(size)(
Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
// Latency is defined as two in the following, because there is one cycle in the MAC module,
// and another cycle in the pipelined adders as the first layer of the accumulator
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))

Expand Down Expand Up @@ -188,7 +215,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
val wgt_i = Reg(chiselTypeOf(dec.uop_end))
val pBits = log2Ceil(p(CoreKey).blockOut) + 1
val inflight = Reg(UInt(pBits.W))
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
// Latency is defined as two in the following, because there is one cycle in the MAC module,
// and another cycle in the pipelined adders as the first layer of the accumulator
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
val done = inflight === 0.U &
((state === sExe &
cnt_o === dec.lp_0 - 1.U &
Expand Down Expand Up @@ -236,11 +265,14 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
when(state === sIdle) {
inflight := 0.U
}.elsewhen(!dec.reset) {
when(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U
}.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
inflight := inflight - 1.U
}
when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
inflight := inflight
}.elsewhen(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U
}
.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
inflight := inflight - 1.U
}
}

when(
Expand Down Expand Up @@ -278,8 +310,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
inp_i := inp_o
wgt_i := wgt_o
}
.elsewhen(state === sExe &&
uop_idx === uop_end - 1.U) {
.elsewhen(state === sExe && uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1
inp_i := inp_i + dec.inp_1
Expand Down

0 comments on commit 5dadf76

Please sign in to comment.