Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] FP8 exploration #283

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
2 changes: 1 addition & 1 deletion CHIPYARD.hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
bcbe3b7f1f40d1c388aca68df498fd7dd4d16e89
f86707bc95d7e95828e63d70ff28fbdaa76a884e
23 changes: 4 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,26 @@ Dependencies

Before beginning, install the [Chipyard dependencies](https://chipyard.readthedocs.io/en/latest/Chipyard-Basics/Initial-Repo-Setup.html#default-requirements-installation).

Installing Chipyard and Spike
Installing Gemmini
-----------------------------

Run these steps to install Chipyard and Spike (make sure to checkout the correct Chipyard and Spike commits as shown below):
Run these steps:

```shell
git clone https://github.com/ucb-bar/chipyard.git
cd chipyard
git checkout 1.8.1
git checkout gemmini-fp8-exploration-2
./build-setup.sh riscv-tools

source env.sh

cd generators/gemmini
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*"
git checkout dev && git pull origin dev
git checkout fp8-recoding && git pull origin fp8-recoding
git submodule update --init --recursive

make -C software/libgemmini install

# The final step is only necessary if you want to run MIDAS simulations with
# realistic DRAM models
cd -
cd sims/firesim
source sourceme-f1-manager.sh --skip-ssh-setup # Ignore error messages from this command
./build-setup.sh --library --skip-validate
```

Setting Up Gemmini
------------------

Run the steps below to set up Gemmini configuration files, symlinks, and subdirectories:

```shell
cd chipyard/generators/gemmini
./scripts/setup-paths.sh
```

Expand Down
72 changes: 38 additions & 34 deletions src/main/scala/gemmini/Arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import chisel3.util._
import hardfloat._

// Bundles that represent the raw bits of custom datatypes
case class Float(expWidth: Int, sigWidth: Int) extends Bundle {
val bits = UInt((expWidth + sigWidth).W)
case class Float(expWidth: Int, sigWidth: Int, isRecoded: Boolean = false) extends Bundle {
val bits = UInt((expWidth + sigWidth + (if (isRecoded) 1 else 0)).W)

val bias: Int = (1 << (expWidth-1)) - 1
}
Expand Down Expand Up @@ -245,7 +245,7 @@ object Arithmetic {
}

override def reciprocal[U <: Data](u: U): Option[(DecoupledIO[UInt], DecoupledIO[U])] = u match {
case Float(expWidth, sigWidth) =>
case Float(expWidth, sigWidth, false) =>
val input = Wire(Decoupled(UInt(0.W)))
val output = Wire(Decoupled(u.cloneType))

Expand Down Expand Up @@ -287,7 +287,7 @@ object Arithmetic {
}

override def mult_with_reciprocal[U <: Data](reciprocal: U): SInt = reciprocal match {
case recip @ Float(expWidth, sigWidth) =>
case recip @ Float(expWidth, sigWidth, false) =>
def in_to_float(x: SInt) = {
val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth))
in_to_rec_fn.io.signedIn := true.B
Expand Down Expand Up @@ -328,12 +328,10 @@ object Arithmetic {
}

implicit object FloatArithmetic extends Arithmetic[Float] {
// TODO Floating point arithmetic currently switches between recoded and standard formats for every operation. However, it should stay in the recoded format as it travels through the systolic array

override implicit def cast(self: Float): ArithmeticOps[Float] = new ArithmeticOps(self) {
override def *(t: Float): Float = {
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth))
t_resizer.io.in := t_rec
Expand All @@ -351,16 +349,16 @@ object Arithmetic {
muladder.io.b := t_rec_resized
muladder.io.c := 0.U

val out = Wire(Float(self.expWidth, self.sigWidth))
out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
out
}

override def mac(m1: Float, m2: Float): Float = {
// Recode all operands
val m1_rec = recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits)
val m2_rec = recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val m1_rec = if (m1.isRecoded) m1.bits else recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits)
val m2_rec = if (m2.isRecoded) m2.bits else recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Resize m1 to self's width
val m1_resizer = Module(new RecFNToRecFN(m1.expWidth, m1.sigWidth, self.expWidth, self.sigWidth))
Expand Down Expand Up @@ -388,17 +386,17 @@ object Arithmetic {
muladder.io.c := self_rec

// Convert result to standard format // TODO remove these intermediate recodings
val out = Wire(Float(self.expWidth, self.sigWidth))
out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
out
}

override def +(t: Float): Float = {
require(self.getWidth >= t.getWidth) // This just makes it easier to write the resizing code

// Recode all operands
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Generate 1 as a float
val in_to_rec_fn = Module(new INToRecFN(1, self.expWidth, self.sigWidth))
Expand Down Expand Up @@ -427,8 +425,8 @@ object Arithmetic {
muladder.io.b := one_rec
muladder.io.c := self_rec

val result = Wire(Float(self.expWidth, self.sigWidth))
result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
result
}

Expand All @@ -440,7 +438,7 @@ object Arithmetic {

override def >>(u: UInt): Float = {
// Recode self
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Get 2^(-u) as a recoded float
val shift_exp = Wire(UInt(self.expWidth.W))
Expand All @@ -461,15 +459,15 @@ object Arithmetic {
muladder.io.b := shift_rec
muladder.io.c := 0.U

val result = Wire(Float(self.expWidth, self.sigWidth))
result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
result
}

override def >(t: Float): Bool = {
// Recode all operands
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Resize t to self's width
val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth))
Expand All @@ -487,43 +485,49 @@ object Arithmetic {
}

override def withWidthOf(t: Float): Float = {
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth))
resizer.io.in := self_rec
resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag
resizer.io.detectTininess := consts.tininess_afterRounding

val result = Wire(Float(t.expWidth, t.sigWidth))
result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)
val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded))
result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out))
result
}

override def clippedToWidthOf(t: Float): Float = {
// TODO check for overflow. Right now, we just assume that overflow doesn't happen
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth))
resizer.io.in := self_rec
resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag
resizer.io.detectTininess := consts.tininess_afterRounding

val result = Wire(Float(t.expWidth, t.sigWidth))
result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)
val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded))
result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out))
result
}

override def relu: Float = {
val raw = rawFloatFromFN(self.expWidth, self.sigWidth, self.bits)
val raw = if (self.isRecoded) rawFloatFromRecFN(self.expWidth, self.sigWidth, self.bits) else rawFloatFromFN(self.expWidth, self.sigWidth, self.bits)

val result = Wire(Float(self.expWidth, self.sigWidth))
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := Mux(!raw.isZero && raw.sign, 0.U, self.bits)
result
}

override def zero: Float = 0.U.asTypeOf(self)
override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
override def minimum: Float = Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
override def identity: Float = {
require(!self.isRecoded)
Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
}
override def minimum: Float = {
require(!self.isRecoded)
Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/main/scala/gemmini/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import freechips.rocketchip.rocket._
import freechips.rocketchip.tile._
import freechips.rocketchip.system._
import freechips.rocketchip.diplomacy._

import gemmini.Arithmetic.SIntArithmetic
import hardfloat._

Expand All @@ -22,8 +21,11 @@ object GemminiConfigs {
val defaultConfig = GemminiArrayConfig[SInt, Float, Float](
// Datatypes
inputType = SInt(8.W),
weightType = SInt(8.W),
accType = SInt(32.W),

spatialArrayInputType = SInt(8.W),
spatialArrayWeightType = SInt(8.W),
spatialArrayOutputType = SInt(20.W),

// Spatial array size options
Expand Down Expand Up @@ -166,7 +168,10 @@ object GemminiConfigs {

val dummyConfig = GemminiArrayConfig[DummySInt, Float, Float](
inputType = DummySInt(8),
weightType = DummySInt(8),
accType = DummySInt(32),
spatialArrayInputType = DummySInt(8),
spatialArrayWeightType = DummySInt(8),
spatialArrayOutputType = DummySInt(20),
tileRows = defaultConfig.tileRows,
tileColumns = defaultConfig.tileColumns,
Expand Down
Loading