# Continuation Passing Style

Thus far, we have built interpreters for various features in Lettuce. However, all of our interpreters depended
on recursive calls to the eval function. The use of recursion was very convenient for us to translate 
the semantics directly into a scala program. However, this is not ideal since we all know about recursions
and stacks. Thus, large programs can cause the stack to overflow. 

Today, we will revisit the theme of eliminating non-tail recursion. We have already done this using
an accumulator. However, accumulators are limited in their scope. We will now present a general scheme
that works without accumulators.

## Recap: Recursion, Tail Recursion and Eliminating the Non-Tail Recursion

We will take a few minutes to quickly recap recursion, tail recursion and the problem of eliminating
non-tail recursion.

- Recursion causes the activation records to grow on the stack, potentially causing stack overflow.
- Tail recursion is a benign case when the result of any recursive calls are returned without any further processing.
- Tail recursive calls can be implemented such that the activation records need not grow.

## Continuation passing style (CPS)

Continuation passing style (CPS) is a "style" of programming wherein every function will have an extra argument
called the `continuation`. A continuation is a function that is passed in and specifies what the caller
wishes to do with the result that has been computed.

Take for instance, a function `func` that takes in an integer and returns an integer.
~~~
def func(x: Int): Int = {
     // .. do some work to compute result .. 
     return result
}
~~~

In the CPS, this function is now written as

~~~
def func-k(x: Int, k: Int => Int) : Int = {
    //  .. do some work to compute result ..
    k(result) // Pass the result onto the continuation.
}
~~~

Note that `func-k` takes in an extra argument `k` called continuation. It
is the function through which the caller specifies what they want done with
the result of the call. Rather than return the result and make the caller operate
on it, the caller simply bundles up the results and passes it all in.

Let us look at a concrete example. First take a look at these three functions defined below.





In [38]:
// This function takes an integer x and returns x + 1
def addOne(x: Int): Int = {
    val result = x + 1
    result
}

defined [32mfunction[39m [36maddOne[39m

In [39]:
addOne(5)

[36mres38[39m: [32mInt[39m = [32m6[39m

In [40]:
// In CPS, we would write...
def addOne_cps(x: Int, k: Int => Int): Int = {
    val result = x + 1
    k(result)
}

defined [32mfunction[39m [36maddOne_cps[39m

In [41]:
val square: Int => Int = y => y * y
addOne_cps(5, square)

[36msquare[39m: [32mInt[39m => [32mInt[39m = ammonite.$sess.cmd40$Helper$$Lambda$2810/0x0000000800c8f040@48531c33
[36mres40_1[39m: [32mInt[39m = [32m36[39m

In [42]:
// Use generics to allow any return type for k
def addTwo_cps[T](x: Int, k: Int => T): T = {
    val result = x + 2
    k(result)
}

defined [32mfunction[39m [36maddTwo_cps[39m

In [43]:
addTwo_cps(5, square)

[36mres42[39m: [32mInt[39m = [32m49[39m

In [44]:
addTwo_cps(5, x => List(x))

[36mres43[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m7[39m)

### Another Example
Here, we will write a function `madd_k` which will-
- call `multiply_k` on x, y and pass a continuation `k1` to `multiply_k`
- The continuation `k1` should-
  1. Call addUp_k
  2. Pass the result on to continuation k

In [45]:
def addUp_k(x: Int, y: Int, z:Int, k: Int => Int): Int = {
    k(x + y + z)
}

def multiply_k(x: Int, y: Int, k: Int => Int): Int = {
    k ( x * y)
}
    // let's create a function do both ie. first multiply and then add
def madd_k(x: Int, y: Int, z: Int, k: Int => Int): Int ={
    // Create a new continuation.
    // This continuation k1 is a closure that will be passed to multiply.
    // It will be called by addUp_k but must do the work that was originally done by madd.
    def k1(v1: Int): Int = addUp_k(v1, y, z, k) // Call addUp on v1, y, z and ask addUp_k to run k on the result.
    multiply_k(x, y, k1) // here k1 = x*y or 1*2=2
}
//first x*y here 1*2=2
//then we add the result given to the continuation k1 with y and z so it becomes x*y + y + z here 2 + 2 + 3
//Finally apply the continuation k here k(7) => 7 * 2 => 14


defined [32mfunction[39m [36maddUp_k[39m
defined [32mfunction[39m [36mmultiply_k[39m
defined [32mfunction[39m [36mmadd_k[39m

In [46]:
madd_k(1, 2, 3, x => x * 2)

[36mres45[39m: [32mInt[39m = [32m14[39m

### Example: Side Effects
Given the following functions, change them to use continuations:

In [47]:
// Function with side effects
def printFive(f: Int => String): Int = {
    val five = 5
    println("My value is: "+ f(five))
    five
}

defined [32mfunction[39m [36mprintFive[39m

In [48]:
def printFive_cps(f: Int => String, k: Int => Int): Int = {
    // Your Code
    val five = 5
    print("My value now is: " + f(five))
    k(five)
}

defined [32mfunction[39m [36mprintFive_cps[39m

In [49]:
printFive(x => x.toString)

My value is: 5


[36mres48[39m: [32mInt[39m = [32m5[39m

In [50]:
printFive_cps(x => x.toString, y => y * 10)

My value now is: 5

[36mres49[39m: [32mInt[39m = [32m50[39m

# Error Handling

So far we have worked with continuation without any error. Now let's see how to handle error cases within continuation which is called the "error continuation". It is called whenever the program encounters an error.

The type of our CPS function will become:

fun_k(arg: ..., k: ResultType=> T, err_k: Unit => T ) : T

Here, if some error arises in the computation that would normally be handled by throwing an exception, we will call the error continuation instead.



In [1]:
def error_continuation_ex[T](x: Int,  k: Int => T, err_k: () => T): T = {
    x match {
        case 1 => k(1)
        case x if x > 1 => error_continuation_ex(x - 1, k, err_k)
        case _ => err_k()
    }
}

defined [32mfunction[39m [36merror_continuation_ex[39m

In [8]:
println(error_continuation_ex(1, x => "I've found the one", () => "Havn't found the one yet!"))
println(error_continuation_ex(10, x => "It took me a while! But I've found the one", () => "Havn't found the one yet!"))
println(error_continuation_ex(0, x => "I've found the one", () => "Havn't found the one yet!"))

I've found the one
It took me a while! But I've found the one
Havn't found the one yet!


### Exercise: Fibonacci

In [51]:
def fibonacci(n: Int): Int = {
    if (n < 2){
        1
    } else {
        fibonacci(n-1) + fibonacci(n-2)
    }    
}

// Fibonacci in Contiuation Passing Style
def fib_cps (n: Int, k: (Int => Int)) : Int = n match{
    case 0 => k(0)
    case 1 => k(1)
    case _ => fib_cps(n-1, (a: Int) => // This is a continuation for adding the first number
        fib_cps(n-2, (b: Int) => // This is another continuation for adding the second number
            k(a+b))) // This is the final continuation for adding the those two numbers
}

defined [32mfunction[39m [36mfibonacci[39m
defined [32mfunction[39m [36mfib_cps[39m

In [None]:
assert(fib_cps(0, (x: Int) => x) == 0)
assert(fib_cps(6, (x: Int) => x) == 8)
assert(fib_cps(8, (x: Int) => x) == 21)

### Exercise: Backtracking
Search a binary tree using CPS. Return true if the tree has a node with the integer `i` as a value.

In [52]:
sealed trait Tree
case object Empty extends Tree
case class Node(left: Tree, value: Int, right: Tree) extends Tree

def search(t: Tree, i: Int): Boolean =
    // BEGIN SOLUTION
    t match {
        case Empty => false
        case Node(l, j, r) =>
            if (i == j) true
            else if (search(l, i)) true
            else search(r, i)
    }
def search_cps(t: Tree, i: Int, continuation: Boolean => Boolean): Boolean =
    // BEGIN SOLUTION
    t match {
        case Empty => continuation(false)
        //in order traversal
        case Node(l, j, r) if i == j =>
            continuation(true)
        case Node(l, j, r) =>
            search_cps(l, i, (found: Boolean) => { if (found) continuation(true) else
                search_cps(r, i, (found: Boolean) => { 
                    continuation(found) // remaining computation on the result. 
                    //Since we have not recieved the search result from right binary tree hence, 
                    //we add the continuation here
                })
            })
    }

defined [32mtrait[39m [36mTree[39m
defined [32mobject[39m [36mEmpty[39m
defined [32mclass[39m [36mNode[39m
defined [32mfunction[39m [36msearch[39m
defined [32mfunction[39m [36msearch_cps[39m

In [54]:
val t = Node(Empty, 10, Node(Node(Empty, 15, Empty), 6, Node(Empty, 12, Empty)))
//    10
//  /   \
// empty 6
//      / \
//     15 12

assert(search(t, 10))
assert(!search(t, 0))
assert(search_cps(t, 10, (f:Boolean) => {println(s"The result is $f") 
                                         f}))
assert(!search_cps(t, 0, (f:Boolean) => {println(s"The result is $f") 
                                         f}))

The result is true
The result is false


[36mt[39m: [32mNode[39m = [33mNode[39m(
  Empty,
  [32m10[39m,
  [33mNode[39m([33mNode[39m([33mNode[39m(Empty, [32m3[39m, Empty), [32m5[39m, Empty), [32m6[39m, [33mNode[39m(Empty, [32m12[39m, Empty))
)

# Key takeaway

- We add an extra continuation argument to every function call in the program.
- We transform the program so that all function calls happen at the tail position.
- Finally, we __hope__ that the compiler/interpreter in all its goodness will optimize the tail call away.