# Jax control flow

1. If else
2. Loops

In [None]:
import jax
import jax.numpy as jnp
import math
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import math

Does this work?

In [None]:
def test_if(x):
    if x < 0:
        return -x
    else:
        return x

test_if(jnp.array(-1.))

Yes it works. 

How about the for loop?

In [None]:
def test_for(x):
    for i in range(5):
        x = jnp.eye(2) @ x
    return x

test_for(jnp.ones((2, )))

Yep, it works tooo!

It **seems** a fuss to worry about using normal Python control flows.

But, if we want to use `jit` or some autodiff features, you have to be careful with control flows.

# Example

Will this work?

In [None]:
@jax.jit
def test_if(x):
    if x < 0:
        return -x
    else:
        return x

test_if(jnp.array(-1.))

No. The error message explained it well already. 

Basically, it fails because we are building a computational graph that changes based on the **concrete** value of a variable. 

We can force this example to work by using `static_argnums` argument for `jit`. But please use it only if this is intentional, that is, `x` is really static.

Recall that `jax.jit` needs to trace all the (numerical) operations to compile to the XLA code.

`jax.jit` cannot trace the Python control flows, such as `if else` and `for`. 

If the function to be jitted has a for loop, then the operations in the for loop are hardcoded to the XLA programme.

Why? 

Imagine we ask you to implement a function

```python
def my_jit(f):
    return ...
```

such that `my_jit` takes a function `f` as input and detect whether the function has a `for loop`. This is super difficult, and we might need to print the function `f` as string then semantically search for the `for`.

A more illustrative example:

Suppose that JAX can compile a Python code to a C code. How would the compiled C code of

```python
for i in range(100):
    x = f(x)
```

look like?

We expect to get a C code like this:

```c
for (int i = 0; i < 100; i++) {
    x = f(x);
}
```

But actually... we got 

```c
x = f(x);
x = f(x);
x = f(x);
... // hardcore-repeat 100 times
```

Hence, if we desire `for/if` in the **runtime**. we need to write something that `jax` could understand/parse, that are, JAX primitives.

# If else

Consider a Python `if else`

```python
if condition:
    result = true_func(x)
else:
    result = false_func(x)
```

In JAX we write as

```python
result = jax.lax.cond(condition,  
                      true_func, 
                      false_func, 
                      operand=x)
```

Let us implement `test_if` in jax as an example. 

In [None]:
@jax.jit
def test_if(x):
    return jax.lax.cond(x < 0.,       # condition
                        lambda _: -x, # what to execution if the condition is true
                        lambda _: x,  # what to execution if the condition is false
                        x)            # the operand here can be anything because we used x from outer scope

test_if(jnp.array(-1.))

In [None]:
@jax.jit
def test_if(x):
    return jax.lax.cond(x < 0., 
                        lambda u: -u, 
                        lambda u: u,
                        x)

test_if(jnp.array(-1.))

# Exercise

Write ELU activation function in jax and jit it.

$$
\mathrm{elu}(x) = \begin{cases}e^x, & x < 0,\\
                               1, & x\geq 0\end{cases}
$$

A numpy implementation would be

```python
@jax.jit
def elu(x):
    return jax.lax.cond(?,
                        ?, 
                        ?, 
                        ?)

# test
elu(1.)
```

## Solution

In [None]:
@jax.jit
def elu(x):
    return jax.lax.cond(x < 0.,
                        lambda _: jnp.exp(x), 
                        lambda _: 1., 
                        x)
elu(1.)

# ~

1. What if we have multiple if conditions, i.e., `if elif elif ... else`? Use `jax.switch`.

2. What if we have vector input? Use `jnp.where` or `jax.vmap`.

# Loops

$$
\begin{split}
\begin{array} {r|rr} 
\hline \
\textrm{construct} 
& \textrm{jit} 
& \textrm{grad} \\
\hline \
\textrm{if} & ❌ & ✔ \\
\textrm{for} & ✔* & ✔\\
\textrm{while} & ✔* & ✔\\
\textrm{lax.cond} & ✔ & ✔\\
\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.scan} & ✔ & ✔\\
\hline
\end{array}
\end{split}
$$

* = argument-value-independent loop condition - unrolls the loop. See, https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html.

Similar to the `jax.lax.cond` we have seen above, for loops, we have `jax.lax.while_loop, fori_loop, scan`.

Consider a naive numpy implementation of summation:

```python
def my_sum(x):
    summation = 0.
    for i in range(x.shape[0]):
        summation = summation + x[i]
    return summation
```

The jax implementation of it is

In [None]:
def my_sum(x):

    def body_func(i, val):
        return x[i] + val

    return jax.lax.fori_loop(lower=0,            # The starting index
                             upper=x.shape[0],   # The number of loops
                             body_fun=body_func, # The loop body (index, previous_val) -> val
                             init_val=0.)        # Initial value of the loop val

my_sum(jnp.ones((10, )))

The compiled function looks like this:

In [None]:
jax.make_jaxpr(my_sum)(jnp.ones((10, )))

Now that if we don't use the jax language `jax.fori_loop` but simply use Python for loop, what does jax see?

In [None]:
# this loop is jittable indeed, but ...
def my_sum_naive(x):
    summation = 0.
    for i in range(x.shape[0]):
        summation = summation + x[i]
    return summation

jax.make_jaxpr(my_sum_naive)(jnp.ones((100, )))

See? Again recall that JAX cannot trace the Python control flows.

# Exercise

Consider a recursion

$$
X_k = 0.1 \, X_{k-1} + U_{k-1}.
$$

Suppose that the initial $X_0$ and inputs $\lbrace U_k \rbrace_{k=0}^{T-1}$ are known. Could you compute $X_{T}$?

This is a simple numpy way to do it:

In [None]:
def recursion(x0, us):
    T = us.shape[0]
    x = x0
    for k in range(T):
        x = 0.1 * x + us[k]
    return x

recursion(np.array(0.1), 0.2 * np.ones((10, )))

$$
X_k = 0.1 \, X_{k-1} + U_{k-1}.
$$

```python
def recursion(x0, us):
    def fori_body(k, x):
        return ??

    return jax.lax.fori_loop(lower=??,
                             upper=??,
                             body_fun=??,
                             init_val=??)

recursion(jnp.array(0.1), 0.2 * jnp.ones((10, )))
```

## Solution

In [None]:
def recursion(x0, us):
    def fori_body(k, x):
        return 0.1 * x + us[k]

    return jax.lax.fori_loop(lower=0,
                             upper=us.shape[0],
                             body_fun=fori_body,
                             init_val=x0)

recursion(jnp.array(0.1), 0.2 * jnp.ones((10, )))

# ~

Wait a sec, but the function only returns the end value at $T$. How do I keep all the history results?

This is very simple in numpy. Just introduce a result accumulator, say, `xs`.

In [None]:
def recursion(x0, us):
    T = us.shape[0]
    xs = np.zeros((T, )) # The accumulator

    x = x0
    for k in range(T):
        x = 0.1 * x + us[k]
        xs[k] = x
    return xs

recursion(np.array(0.1), 0.2 * np.ones((10, )))

Can I do the same in jax?

```python
@jax.jit
def recursion(x0, us):
    xs = jnp.zeros((T, ))

    def fori_body(k, x):
        x = 0.1 * x + us[k]
        xs[k] = x
        return x

    return jax.lax.fori_loop(lower=0,
                             upper=us.shape[0],
                             body_fun=fori_body,
                             init_val=x0)
```

No. We will get error in the line `xs[k] = x` because jax DeviceArray are immutable (i.e., no assignment).

We can, to some extent, force `xs[k] = x` to work by using "jax array update" at the cost of making your programme nasty, slow, and unreadable. 

The authentic way to do it is by using the **scan** operation, because 

$$
X_k = 0.1 \, X_{k-1} + U_{k-1}
$$

is essentially a scan operation. Think about what the essential parts of such scan loop are, then we can abstract them!

In [None]:
def recursion(x0, us):
    def scan_body(carry, elem):
        # Unpack carry and elem
        x = carry
        u = elem

        x = 0.1 * x + u
        return x, x                 # Scan body returns two values. First returns as the next carry, the second goes to the result container.

    return jax.lax.scan(scan_body,  # The scan body function
                        x0,         # Initial value/carry
                        us)         # Inputs

(last_x, xs) = recursion(jnp.array(0.1), 0.2 * jnp.ones((10, )))
xs

# Exercise

Consider an SDE

$$
\mathrm{d} X(t) = \sin(10 \, \pi \, X(t)) \, \mathrm{d} t + \mathrm{d}W(t),
$$

where $X(0) = 0.1$. Use Euler--Maruyama to simulate a trajectory of $X$ at times $0.01, 0.02, \ldots, 1$.

Formula:

$$
X(t_k) \approx X(t_{k-1}) + \sin(10 \, \pi \, X(t_{k-1})) \, (t_k - t_{k-1}) + \Delta W_k, \quad \Delta W_k \sim \mathrm{N}(0, t_k - t_{k-1}).
$$

```python
dt = 0.01
T = 100
ts = jnp.linspace(dt, dt * T, T)

key = jax.random.PRNGKey(666)
ws = ? # Generate a Wiener process at ts


def scan_body(carry, elem):
    ? = carry
    dw = elem

    # Euler equation here?
    return ?, ?

_, xs =  jax.lax.scan(?, 
                      ?, 
                      ws)

plt.plot(ts, xs)
```

## Solution

In [None]:
dt = 0.01
T = 100
ts = jnp.linspace(dt, dt * T, T)

key = jax.random.PRNGKey(666)
ws = jnp.cumsum(math.sqrt(dt) * jax.random.normal(key, (T, ))) # Wiener process at the times

def scan_body(carry, elem):
        x = carry
        dw = elem

        x = x + jnp.sin(10 * math.pi * x) * dt + dw
        return x, x

_, xs =  jax.lax.scan(scan_body, 
                      jnp.array(0.1), 
                      ws)

plt.plot(ts, xs)