# Types and Pattern Matching

## Decomposition

- Let's motivate this with an example; suppose we want to write an interpreter for arithmetic expressions

- For simplicity, let's only deal with 2 types of Expressions; either I get a `Number`, or I get a `Sum`
  - That is, in my evaluator, either I receieve a number, or an expression of a `Sum` of 2 `Numbers`

  ```scala
    trait Expr:
      def isNumber: Boolean
      def isSum: Boolean
      def numValue: Int
      def leftOp: Expr
      def rightOp: Expr

    class Number(n: Int) extends Expr:
      def isNumber = true
      def isSum = false
      def numValue = n
      def leftOp = throw new Error("Number.leftOp")
      def rightOp = throw new Error("Number.rightOp")

    class Sum(e1: Expr, e2: Expr) extends Expr:
      def isNumber = false
      def isSum = true
      def numValue = throw new Error("Sum.numValue")
      def leftOp = e1
      def rightOp = e2
  ```


- Let's try to implement the evaluator `eval`
  ```scala
    def eval(e: Expr): Int = 
      if e.isNumber then e.numValue
      else if e.isSum then eval(e.leftOp) + eval(e.rightOp)
      else throw new Error("Unknown expression " + e)
  ```

- This works, but what if we want to add in operators like multiply? Subtract? Divide?
  - It will become super messy
  - And it will be hell on earth to run the evaluator, because before parsing, you need to check the `Expr` type properly, or you may access members that are not available (e.g. trying to access `numValue` in a Sum)

  ```scala
  class Prod(e1: Expr, e2: Expr) extends Expr

  class Var(x: String) extends Expr

  ```

- Exercise: To integrate Prod and Var into the hierarchy, how many new method
definitions do you need?
  - isProd
  - isVar
  - [Optional] varValue 
  - So depending on implementation, you are adding more than 10 new methods for a basic change!

### (Bad) Solution 1: Type Tests and Type Casts

- In `eval` You can deliberately check the type of the `Expr` before doing the evaluation

- This is both ugly and not scalable

  ```scala
    def eval(e: Expr): Int = {
      if e.isInstanceOf[Number] then
        e.asInstanceOf[Number].numValue
      else if e.isInstanceOf[Sum] then
        eval(e.leftOp) + eval(e.rightOp)
      else
        throw Error("Unknown Expression " + e)
    }
  ```

### Solution 2: Object-Oriented Decomposition

- Instead of a single eval function, you could add `eval` as an abstract method to the `Expr` trait, then initialise it in each of the subclasses

  ```scala
    trait Expr:
      def eval: Int

    class Number(n: Int) extends Expr:
      def eval: Int = n

    class Sum(e1: Expr, e2: Expr) extends Expr:
      def eval: Int = e1.eval + e2.eval
  ```

- This is known as object-oriented decomposition, where you mix the data object with the relevant operations

- Pros: To add a new class of data, you can just add a single class

  ```scala
    class Product(e1: Expr, e2: Expr) extends Expr:
      def eval: Int = e1.eval * e2.eval
  ```

- Cons: Suppose you wish to add an operation that doesn't just work on a single object
  - e.g. I want to simplify `a*b + a*c` into `a * (b+c)`
  - This operation is non-local (i.e. involves more than 1 object)
  - Therefore, encapsulating the `eval` method within individual classes will still require some sort of eval type check, which we desperately want to avoid

## Pattern Matching

- Recall from the previous section, we are trying to find a way to access objects in a class hierachy

- In the last section, we tried
  - Adding common methods to all classes ==> led to quadratic explosion
  - Adding type tests ==> Non maintainable code + potentially unsafe
  - Object-oriented decomposition ==> Couples data and operations, all classes affected when adding new method

- Let's try to generalise what we want to do; basically we need the same method to do different things according to the class it is presented
  - In this case, the type checks/decomposition etc are all trying to **reverse** the construction process (i.e. figure out which subclass was used and what the arguments were)

- Thankfully in Scala, there is an idiomatic way to do this via **case classes**

- This can be applied in 2 steps. First, define the relevant `case class`. Then, use the keyword `match` to check if the input matches the `case class`

  ```scala
    trait Expr
    case class Number(n: Int) extends Expr
    case class Sum(e1: Expr, e2: Expr) extends Expr

    def eval(e: Expr): Int = e match
      case Number(n) => n
      case Sum(e1, e2) => eval(e1) + eval(e2)
      case _ => throw new Error("wtf is this")
  ```

- Patterns must match one of:
  - constructors e.g. `Number`, `Sum`
  - variable e.g. `e1`
  - wildcard pattern e.g. `_`
  - constants e.g. `true`
  - type tests e.g. `n: Number`

- The evaluation resolves in the following order
  ```scala
    //1
    eval(Sum(Number(1), Number(2))) 

    //2
    Sum(Number(1), Number(2)) match
      case Number(n) => n
      case Sum(e1, e2) => eval(e1) + eval(e2)

    //3
    eval(Number(1)) + eval(Number(2))

    //4
    Number(1) match
      case Number(n) => n
      case Sum(e1, e2) => eval(e1) + eval(e2)
    + eval(Number(2))

    //4
    Number(1) match
      case Number(n) => n
      case Sum(e1, e2) => eval(e1) + eval(e2)
    + eval(Number(2))

    //5
    1 + eval(Number(2))

    //6
    3
  ```



- As a convenient alternative, you can place the `eval` function under the `Expr` trait
  ```scala
  trait Expr:
    def eval: Int = this match
      case Number(n) => n
      case Sum(e1, e2) => e1.eval + e2.eval
  ```

### Exercise

- Write a function show that uses pattern matching to return the
representation of a given expressions as a string.

  ```scala
    def show(e: Expr): String = ???
  ```

- See `2-slides.scala`

### Exercise 2

- Add case classes Var for variables x and Prod for products x * y as discussed previously.

- Change your show function so that it also deals with products. Pay attention you get operator precedence right but to use as few parentheses as possible

- Examples
  - Sum(Prod(2, Var("x")), Var("y")) ==> "2 * x + y"
  - Prod(Sum(2, Var("x")), Var("y")) ==> "(2 + x) * y"

- See 2-slides.scala

## Lists

- You should have noticed by now that most of Scala revolves around recursive patterns when writing functions. This is because Scala wishes to maintain the state of any input value, and in a functional programming paradigm, you want to ensure that variables are immutable as far as possible

- Let's talk about a common data structure in Scala, the `List`

- `List` vs `Array`
  - Same
    - Both of homogeneous, and you need to declare type up front (e.g. `List[Int]`)

  - Different
    - Lists are immutable, arrays are not
    - Lists are recursive, arrays are flat
      - i.e. a list can contain another list, an array cannot contain another array

- `List` comes with a new operator, known as `Cons` or **construction**
  - To get a new list, you can do something like `value :: existing_list`
  - Examples:
    - `1 :: 2 :: existing_list` ==> List(1, 2, ...)
    - `x :: Nil` ==> List(x)
    - `List()` ==> Nil
    - `List(2 :: existing_list)` ==> List(2, ...)
 
- `List` has 3 built in operations
  - List.head returns the first item in the list
  - List.tail returns the **list of all remaining items**
  - List.isEmpty returns a boolean checking whether your list is empty

### Exercise 1

- Consider the pattern `x :: y :: List(xs, ys) :: zs`.  What is the condition that describes most accurately the length L of the lists it matches?

- Answer
  - For a list construction to be valid the right most value must be a valid list (could be empty List())
  - Lets suppose that `x` and `y` are values, and `xs`, `ys`, and `zs` are Lists of values
  - Recall that `List` must be homogeneous. So no matter what, this is a List of Lists
  - Since our 3rd element is explicitly a `List`, even if `x` and `y` are Nil, the will be defaulted to an empty list `List()`
  - However, if `zs` is Nil, it will be dropped from the final list (because `xs :: Nil = xs`)

  - Therefore, the minimum number of elements must be 3

### Exercise 2: Insertion Sort

- Suppose we want to sort a list of numbers in ascending order. We implement insertion sort by doing this:
  - Suppose we have List(7,3,9,2)
  - To sort it, tail and sort the list to obtain List(2, 3, 9)
  - Then, traverse list to find the right place for 7

  ```scala
    def isort(xs: List[Int]): List[Int] = xs match
      case List() => List()
      case y :: ys => insert(y, isort(ys))

    def insert(x: Int, xs: List[Int]): List[Int] = xs match
      case List() => ???
      case y :: ys => ???
  ```

  - See `3-slides.scala` for implementation

- What is the worst-case complexity of insertion sort relative to the length of the input list N?
  - For each element, it takes at most $O(N)$ time
  - For the entire list, worst case is $O(N^2)$

## Enums

- We've seen in the last section how to model data with class hierachie. Classes are a useful construct to group functions operating on some common set of values, which are typically declared as fields
  - This is super neat, because data is encapsulated with their corresponding functions

- But what about cases where we want to store just data (**pure data**), without any associated methods?
  - In such cases, the `case class ...` and pattern matching approach works quite well

- Recall from the last section that we were trying to write an evaluator. This is the case class hierachy we established. If you need (for whatever reason) to work with individual case class, you can always do `import Expr.*` 

  ```scala
    trait Expr
    object Expr:
      case class Var(s: String) extends Expr
      case class Number(n: Int) extends Expr
      case class Sum(e1: Expr, e2: Expr) extends Expr
      case class Prod(e1: Expr, e2: Expr) extends Expr
  ```

- In Scala, a pure data definition like this (i.e. object with no associated method) is known as an **algebraic data type**, or ADT

- Since such cases as so common, Scala offers some special syntax to write this easily `enum`

  ```scala
    enum Expr:
      case Var(s: String)
      case Number(n: Int)
      case Sum(e1: Expr, e2: Expr)
      case Prod(e1: Expr, e2: Expr)
  ```

- Using `match` expressions on `enum` is almost the same syntax;

  ```scala
    def show(e: Expr): String = e match
      case Expr.Var(x) => x
      case Expr.Number(n) => n.toString
      case Expr.Sum(e1, e2) => s"${show(e1)} + ${show(e2)}"
      case Expr.Prod(e1, e2) => s"${showP(e1)} * ${showP(e2)}"
    
    def showP(e: Expr): String = e match
      case Expr.Sum => s"(${show(e)})"
      case _ => show(e)
  ```

- In fact, `enum` can also work for classes that don't take parameters!

  ```scala
    enum Color:
      case Red
      case Green
      case Blue

    // same as 
    enum Color:
      case Red, Green Blue
  ```

- Matching works the same way

```scala
  def isBlue(color: Color) = color match 
    case Blue => true
    case _ => false
```

- `enum` definition can even be composed of other enums. This sort of implicit hierachy makes it really easy to organise and define nested type structures
  ```scala 
    enum PaymentMethod:
      case CreditCard(kind: Card, holder: String, number: Long, expires: Date)
      case PayPal(email: String)
      case Cash

    enum Card:
      case Visa, Mastercard, Amex
  ```

- Like `enum` in python, Scala enums implicitly match each case to an ordinal value 
  - Every `enum` class has a `.values` method attached that returns an Array with all the case class names
  - You can view the names using this method 
  - To get the ordinal value of the concrete instantiation, you can use the `ordinal` keyword

  ```scala
    enum Direction(val dx: Int, val dy: Int):
      case Right extends Direction( 1, 0) 
      case Up extends Direction( 0, 1) //ordinal 1
      case Left extends Direction(-1, 0) //ordinal 2
      case Down extends Direction( 0, -1) //ordinal 3

     def leftTurn = Direction.values((ordinal + 1) % 4)
    end Direction

    val r = Direction.Right
    val u = x.leftTurn // u = Up
    val v = (u.dx, u.dy) // v = (1, 0)
  ```

- Note that any enum `case` that requires parameters to be passed needs the `extend` keyword
  - i.e. `case Right extends Direction( 1, 0) ` --> needs `extends`, but `case Var(s: String)` does not

- Note that any enum `case` that takes in parameters **DOES NOT** have an associated ordinal number
  - i.e. `case Right extends Direction( 1, 0) ` --> has an ordinal number, but `case Var(s: String)` does not

## Subtyping and Generics

- In this section, we're introducing quite a new concept that Python doesn't support. 

- We'll look into `bounds` and `variance`

### Type Bounds

- In Scala, type checks are enforced strictly. So in general, if you declare that a function will accept some input of type A, you should be able to pass subtypes of A, but not supertypes of A
  
- However, this is still somewhat ambiguous. It is idiomatic (though not compulsory) in Scala that you declare the boundary of the input you are providing

- For example, let's suppose we have a `Shape` superclass, with `Circle` and `Triangle` as the subclasses
  - We wish to create a function `increment_area` that can work on any subclass of `Shape`, so you don't need to create a different function for each type, and add to it everytime you have a new subclass of Shape
  
  - Let's further assume we want the `increment_area` to stop working past the `Shape` class

  - Then we can ues the same notation, to enforce a mixed boundary of types, where `<:` means less than or equal to type, and `>:` means greater than or equal to type

  ```scala    
    class Shape 

    class Circle extends Shape 
    class Triangle extends Shape

    class SubCircle extends Circle

    def increment_area[U <: Shape >: SubCircle](x: U): U = { ??? }
  ```


### Covariance

- So far, we deal with direct type boundaries; that is, we wish to add types to objects that are "simple"

- But there are cases where objects can be made up of other objects. Remember, in Scala, everything is an object! So, for example, a function can also have subtypes, a List can also have subtypes etc

- With complex objects, it becomes more complicated to discuss typing. Do we always want to have `List[A] <: List[B]` if `A <: B`? Are there cases where we want to have `List[A] <: List[B]` if `A >: B`?

- For perspective, we'll consider the case of Arrays in Java vs Scala
  - In Java, arrays are covariant. So `Array[A] <: Array[B]` if A <: B
  - In Scala, arrays are not covariant.

- This causes some problems in Java. Let's see how using the Java block below
  ```java
    NonEmpty [] a = new NonEmpty []{
      new NonEmpty(1, new Empty(), new Empty ())};
    IntSet [] b = a;
    b[0] = new Empty ();
    NonEmpty s = a[0];
  ```

- Notice how 

- Let's consider this case in Scala

```scala
  val a: Array[NonEmpty] = Array(NonEmpty(1, Empty(), Empty()))
  val b: Array[IntSet] = a
  b(0) = Empty()
  val s: NonEmpty = a(0)
```

- When you try out this example, what do you observe?
  - You will see an error in line 2, because 

- Definition of Liskov Substitution: A subtype must always be able to replace a supertype. If A <: B, then everything that B can do, A must be able to do 

- Imagine functions F(A1): B1 and G(A2): B2. What must be true if F() <: G()?
  - A1 >: A2
    - Because if I want to substitute G() with F(), everything that I pass to G, I will also pass to F
    - Therefore, I must be able to pass A2 to F()
    - And for me to be sure that A2 will work in F() as it does in G(), A2 must be a subclass of A1 
  - B1 <: B2
    - Because if I want to substitute G() with F(), anything that F() returns must be seen as the same thing that G() returns
    - That is, anything that I can do with B2 downstream, I must also be able to do with B1 downstream
    - Therefore, B1 <: B2