# EE194 Lab 0: Chisel - Part 4: Scala Functional Programming
> AKA: Scala meta-programming magic âœ¨

Make sure you fill in any place that says `YOUR CODE HERE` or "YOUR ANSWER HERE" or `YOUR ACTION NEEDED HERE`.

If you see `???` right below `YOUR CODE HERE`, make sure to remove that after you have implemented your solution (and before you run the code block).

### Import the necessary Chisel dependencies. 
> There will be cells like these in every lab. Make sure you run them before proceeding to bring the Chisel Library into the Jupyter Notebook scope!

In [None]:
interp.configureCompiler(_.settings.processArguments(List("-Wconf:cat=deprecation:s"), true))
interp.load.module(os.Path(s"${System.getProperty("user.dir")}/resource/chisel_deps.sc"))

In [None]:
import chisel3._
import chisel3.util._
import chiseltest._
import chiseltest.RawTester.test

## Scala functions
We can write Chisel logic within a Scala function and use them within our modules.

> Write a Scala function named `sum` that takes two UInts as arguments and returns the sum (with width growth). You might find the [Chisel Cheat Sheet](https://github.com/freechipsproject/chisel-cheatsheet/releases/latest/download/chisel_cheatsheet.pdf) a helpful reference.

In [None]:
// YOUR CODE HERE
???

In [None]:
class ScalaFunctions extends Module {
    val io = IO(new Bundle {
        val in1 = Input(UInt(4.W))
        val in2 = Input(UInt(4.W))
        val out = Output(UInt())
    })
    io.out := sum(io.in1, io.in2)
}

In [None]:
def testSum(): Boolean = {
    test(new ScalaFunctions) { dut => 
        for (in1 <- 0 until 16) {
            for (in2 <- 0 until 16) {
                dut.io.in1.poke(in1.U)
                dut.io.in2.poke(in2.U)
                dut.io.out.expect((in1 + in2).U)
            }
        }
    }
    true
}

assert(testSum)

## Chisel without Modules - Companion Objects
A _companion object_ can provide a _factory method_ which can also be used to **return a properly instantiated module**. These can be useful for tidying up IO connections.

An object can be used for other things:
* Shared state (constant or mutable)
* Stateless functions
* Factory methods (as companion object)
* ChiselEnums

Below is an example of `objects` in Scala, just to get you familiar with the syntax.

In [None]:
class MyPair(a: Int, b: Int) {
    def sum() = a + b
}

val mpc = new MyPair(3,4)
mpc.sum()
mpc.sum

object MyPair {
    var numPairs = 0
    def apply(a: Int, b: Int) = {
        numPairs += 1
        new MyPair(a,b)
    }
    def apply(a: Int): MyPair = apply(a, 0)
}

MyPair(2,3)
MyPair(2,3).sum
MyPair.numPairs
val mpo = MyPair(3)
mpo.sum
MyPair.numPairs

> Write an `apply` method that creates an instance of an `ManyConnections` module, connects all the IO, and then returns the module's output as a `Bool`.

In [None]:
class ManyConnections extends Module {
    val io = IO(new Bundle {
        val in0  = Input(Bool())
        val in1  = Input(Bool())
        val in2  = Input(Bool())
        val in3  = Input(Bool())
        val out = Output(Bool())
    })
    io.out := io.in0 & io.in1 & io.in2 & io.in3
    // you can have additional Chisel logic here -- this functions just like the other Chisel modules you've written.
}

object ManyConnections {
// YOUR CODE HERE
???

}

In [None]:
class UseManyConn extends Module {
    val io = IO(new Bundle {
        val in0 = Input(Bool())
        val in1 = Input(Bool())
        val in2 = Input(Bool())
        val in3 = Input(Bool())
        val out = Output(Bool())
    })
    
    val one = ManyConnections(true.B, true.B, true.B, true.B) // Doing this calls the apply function, which then will call the respective Chisel Module, then perform some function using the inputted signals. Then in the `apply` function of the `object` you can choose to return some value back as the return of the `ManyConnections()` call
    val and = ManyConnections(io.in0, io.in1, io.in2, io.in3)
    
//     Hopefully above is more appealing than __multiples__ of these (one set for each use of that module):
//     With a companion object, the hope is you only need to make these connections once (in the `apply` function of the `object`), and then be able to reuse that code you wrote for multiple uses of that underlying Chisel (Class) Module
    
//     val m0 = Module(new ManyConnections)
//     m0.io.in0 := io.in0
//     m0.io.in1 := io.in1
//     m0.io.in2 := io.in2
//     m0.io.in3 := io.in3
//     val and = m0.io.out
    io.out := one & and
}

In [None]:
def testManyConnections: Boolean = {
    test(new ManyConnections) { dut =>
        dut.io.in0.poke(true.B)
        dut.io.in1.poke(true.B)
        dut.io.in2.poke(true.B)
        dut.io.in3.poke(true.B)
        dut.io.out.expect(true.B)

        dut.io.in0.poke(false.B)
        dut.io.in1.poke(true.B)
        dut.io.in2.poke(true.B)
        dut.io.in3.poke(true.B)
        dut.io.out.expect(false.B)

    }
    true
}
assert(testManyConnections)

## Chisel without Modules - Accumulator
> Throughout our labs so far, we have typically wrapped our code in a class that extends `Module` anytime we wished to write Chisel code. We can actually use standard Scala classes/objects/defs to return Chisel components as long as they are eventually used in a `Module`. Make a class `Accumulator` without the use of `Module`. This class should have a register of type `UInt` to store the accumulated `data` values. When the `rst` signal is high, the accumulated value should reset to `0.U`. We have provided a companion object to instantiate your class.

In [None]:
object Accumulator {
    def apply(width: Int, data: UInt, rst: Bool) = {
        val m = new Accumulator(width, data, rst)
        m.count
    }
}

// YOUR CODE HERE
???

In [None]:
class AccumulatorInstMod(width: Int) extends Module {
    val io = IO(new Bundle {
        val data  = Input(UInt(width.W))
        val rst   = Input(Bool())
        val count = Output(UInt(width.W))
    })
    io.count := Accumulator(width, io.data, io.rst)
}

def testAccumulator: Boolean = {
    test(new AccumulatorInstMod(4)) { dut =>
        dut.io.data.poke(5.U)
        dut.io.rst.poke(false.B)
        dut.io.count.expect(0.U)
        dut.clock.step()
        
        dut.io.data.poke(6.U)
        dut.io.rst.poke(false.B)
        dut.io.count.expect(5.U)
        dut.clock.step()
        
        dut.io.data.poke(7.U)
        dut.io.rst.poke(false.B)
        dut.io.count.expect(11.U)
        dut.clock.step()
        
        dut.io.data.poke(0.U)
        dut.io.rst.poke(true.B)
        dut.io.count.expect(2.U)
        dut.clock.step()
        
        dut.io.count.expect(0.U)
    }
    true
}
assert(testAccumulator)

## Case Classes
* Special type of class with additional features built-in
    * Companion object (with constructor) (don't need new to instantiate)
    * All parameters are automatically public (don't need to make them val)
    * Automatic implementations of toString, equals, and copy

Example Scala syntax for reference:

In [None]:
case class Movie(name: String, year: Int, genre: String) {
    def decade(): String = (year - year%10) + "s"
}

val m1 = Movie("Gattaca", 1997, "drama")
m1.genre
val m2 = Movie("The Avengers", 1998, "action")
m2.copy(year=2012)
m2.decade

> Scala's _case classes_ are very useful for packaging up parameters to the same location. The case class `ROMParams` will help pass a `Seq` and other parameters to build the `ROM` module. Complete both ROMParams and ROM. Based on the code provided, you should be able to infer the missing field names, appropriately size the various bitwidths, and instantiate a read-only memory.

In [None]:
case class ROMParams(data: Seq[Int]) {
    val numElems = data.size
    val largestElem = data.max
    val dataInChiselT: Seq[UInt] = Seq.tabulate(data.size)(i => data(i).U)
    // YOUR CODE HERE
    ???
}

class ROMIO (p: ROMParams) extends Bundle {
    val sel = Input(UInt(p.addrWidth.W))
    val out = Output(UInt(p.elemWidth.W))
}

class ROM (p: ROMParams) extends Module {
    val io = IO(new ROMIO(p))
    // YOUR CODE HERE
    ???
}

In [None]:
def testROM(l: Seq[Int]): Boolean = {
    val p = ROMParams(l)
    test(new ROM(p)) { dut =>
        for (i <- 0 until l.size) {
            dut.io.sel.poke(i.U)
            dut.io.out.expect(l(i).U)
        }
        assert(p.addrWidth == log2Ceil(l.size + 1))
        assert(p.elemWidth == log2Ceil(l.max + 1))
    }
    true
}
assert(testROM((0 until 5).toSeq))
assert(testROM((20 until 31).toSeq))

## Applying Scala Methods to Chisel
For the following sections we'll focus on how to supercharge Chisel with Scala functionality. We'll be using quite a few Scala functionality. If you are unfamiliar with specific Scala syntax for that function, please search it up. 

### Seq addition
> Use Scala `zip` and `map` to add the contents of two `Seq`s (element by element).

In [None]:
def addSeqs(a: Seq[Int], b: Seq[Int]): Seq[Int] = {
    // YOUR CODE HERE
    ???
}

In [None]:
val a = Seq.tabulate(8)(_.toInt)
val b = Seq.tabulate(8)(_.toInt)
assert(addSeqs(a, b) == Seq(0, 2, 4, 6, 8, 10, 12, 14))


## foreach with Chisel
> The `VecRotate` module below shifts its input `Vec` by a constant `offset` (wraps around). Complete it by using `foreach`. You may find the Scala methods `drop` and `take` helpful.

In [None]:
class VecRotate(numElems: Int, width: Int, offset: Int) extends Module {
    val io = IO(new Bundle {
        val in  = Input(Vec(numElems, UInt(width.W)))
        val out = Output(Vec(numElems, UInt(width.W)))
    })
    val rotated = io.in.drop(offset) ++ io.in.take(offset)
    // YOUR CODE HERE
    ???
}

In [None]:
def testVecRotate(numElems: Int, width: Int): Boolean = {
    for (offset <- 0 until numElems) {
        val input = 0 until numElems
        val expected = input.drop(offset) ++ input.take(offset)
        test(new VecRotate(numElems, width, offset)) { dut =>
            (0 until numElems).foreach{ i => dut.io.in(i).poke(input(i).U) }
            (0 until numElems).foreach{ i => dut.io.out(i).expect(expected(i).U) }
        }
    }
    true
}

assert(testVecRotate(4,8))

## zipWithIndex
> First, use `foldLeft` to implement the `exp` function (computes exponent). Then use `zipWithIndex`, `map`, `exp`, and `reduce`/`foldLeft` to concisely evaluate a polynomial. The index in the sequence is the degree in the polynomial (e.g. _coefs(i) * x^i_)

In [None]:
def exp(base: Int, deg: Int): Int = {
    // YOUR CODE HERE
    ???
}

def polyEval(coefs: Seq[Int], x: Int): Int = {
    // YOUR CODE HERE
    ???
}

In [None]:
assert (exp(5, 0) == 1)
assert (exp(2, 5) == 32)
assert (exp(4, 3) == 64)
// 0*x^0 + 1*x^1 + 2*x^2
assert(polyEval(Seq(0, 1, 2), 5) == 55)
assert(polyEval(Seq(0, 1, 2), 0) == 0)


## Problem 4 (6 pts) - map on matrix
> Given a `n` x `n` matrix (`Seq[Seq[Int]]`), use `map` and `zipWithIndex` to add `x` to the diagonal (other cells unchanged). For this function `incDiag`, the matrix is in row-major order. For example if `x=4`: 
``` 
    List(1, 1, 1, 1, 1) -> List(5, 1, 1, 1, 1)
    List(1, 1, 1, 1, 1) -> List(1, 5, 1, 1, 1)
    List(1, 1, 1, 1, 1) -> List(1, 1, 5, 1, 1)
    List(1, 1, 1, 1, 1) -> List(1, 1, 1, 5, 1)
    List(1, 1, 1, 1, 1) -> List(1, 1, 1, 1, 5)
```

In [None]:
// YOUR CODE HERE
???

In [None]:
val in = List(
  List(1, 1, 1, 1, 1),
  List(1, 1, 1, 1, 1),
  List(1, 1, 1, 1, 1),
  List(1, 1, 1, 1, 1),
  List(1, 1, 1, 1, 1)
)

val out = List(
  List(5, 1, 1, 1, 1),
  List(1, 5, 1, 1, 1),
  List(1, 1, 5, 1, 1),
  List(1, 1, 1, 5, 1),
  List(1, 1, 1, 1, 5)
)
assert(incDiag(in, 4) == out)



## flatMap and reduce with Chisel
> Let's put together what we've covered to make a Chisel module. Complete the `MatrixSearch` module below that looks for the input `searchFor` by comparing all of the elements of the input matrix `mat` (2D `Vec`). If (and only if) `searchFor` matches any of the elements in `mat`, the output `found` should be _true_. Your solution should use `flatMap`, `reduce`, and possibly `map`.

In [None]:
class MatrixSearch(numRows: Int, numCols: Int, width: Int) extends Module {
    require(numRows > 1)
    require(numCols > 1)
    val io = IO(new Bundle {
        val mat = Input(Vec(numRows, Vec(numCols, UInt(width.W))))
        val searchFor = Input(UInt(width.W))
        val found = Output(Bool())
    })
    // YOUR CODE HERE
    ???
}

In [None]:
def testMatrixSearch(numRows: Int, numCols: Int, width: Int): Boolean = {
    require(log2Ceil(numRows) < width)
    test(new MatrixSearch(numRows, numCols, width)) { dut =>
        (0 until numRows) foreach {
             r => (0 until numCols) foreach { 
                c => dut.io.mat(r)(c).poke(r.U)
            }
        }
        for (r <- 0 until numRows) {
            dut.io.searchFor.poke(r.U)
            dut.io.found.expect(true.B)
        }
        dut.io.searchFor.poke(numRows.U)
        dut.io.found.expect(false.B)
    }
    true
}

assert(testMatrixSearch(2,2,8))
