<a href="https://colab.research.google.com/github/sriharikrishna/siamcse21/blob/main/rosenbrock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Rosenbrock function
Also known as Rosenbrock's valley or Rosenbrock's banana function, the Rosenbrock function is a performace test for optimization. The narrow long valley makes it difficult to find the minimum. 

\begin{equation}
F(x) = \sum_{i=0}^{N-1}100(x_{i+1} - x_i^2)^2 + (1 - x_i)^2.
\end{equation}


### Primal Function

In [None]:
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

def rosenbrock(x):
    """
    Computes the Rosenbrock's banana
    x : array of values
    """
    result = jnp.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0,axis=0)
    return result

def plot_vals(vals, grad=None):
    plt.plot(vals,label='values')
    if grad != None:
      plt.plot(grad,label='grad')
    plt.legend()
    plt.title('')
    plt.figure()
    plt.show()

### Primal Driver

In [None]:
def rosenbrockdriver(n):
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val = random.normal(key, (n,), jnp.float64)
    
    #compute the result
    result = rosenbrock(val)
    
    print("The input is", val)
    print("The result of Rosenbrock's is ", result)

In [None]:
rosenbrockdriver(10)

### First look at derivatives: `jax.grad()`
1. Appropriate for a scalar function
2. Uses reverse mode (it is a wrapper for `jax.jacjvp()`)
3. Assumes a seed of `1.0`. 
   (For reverse mode the shape of the seed must match the primal output.)
4. Produces $\frac{\partial{F(x)}}{\partial{x}}$   

In [None]:
#Create a function that computes the derivatives
grad_rosenbrock = jax.grad(rosenbrock)

def rosenbrockgrad(n):
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val = random.normal(key, (n,), jnp.float64)
    
    #compute the result
    result = grad_rosenbrock(val)
    plot_vals(val,grad=result)
    print("The input is", val)
    print("The grad result is ", result)

In [None]:
rosenbrockgrad(10)

### Forward mode using: `jax.jacjvp()`
1. https://jax.readthedocs.io/en/latest/jax.html#jax.jvp
1. Uses forward mode, returns the gradients for a function $R^n \rightarrow R^m$. 
2. You must provide the primal and a seed vector. 
   (For forward mode the shape of the seed must match the primal input.)
4. Produces $\frac{\partial{F(x)}}{\partial{x}}$  
5. The code below obtains the entire Jacobian by calling `jax.jacjvp()` multiple times.

In [None]:
def rosenbrockjvp(n):
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val = random.truncated_normal(key, -2.048, 2.048, (n,), jnp.float64)
    tangents = jnp.empty([0])
    iden_seed = jnp.eye(n)
    for i in range(n):
        primal_output, res = jax.jvp(rosenbrock, (val,), (iden_seed[i],))
        tangents = jnp.append(tangents, res)
    plot_vals(val,grad=tangents)
    print("The input is", val)
    print("The jax.jvp result is ", tangents)

In [None]:
rosenbrockjvp(20)

### Reverse mode using: `jax.jacvjp()`
1. https://jax.readthedocs.io/en/latest/jax.html#jax.vjp
2. Uses reverse mode, returns a function that computes the adjoints for a funtion $R^n \rightarrow R^m$.
3. You can provide a seed. We have chosen `0.5` by default. 
4. Exercise: See how values change as the seed changes.
1. https://jax.readthedocs.io/en/latest/jax.html#jax.jvp
 
2. You must provide the primal and a seed vector. 
   (For forward mode the shape of the seed must match the primal input.)
4. Produces $\frac{\partial{F(x)}}{\partial{x}}$  
5. The code below obtains the entire Jacobian by calling `jax.jacjvp()` multiple times.

In [None]:
def rosenbrockvjp(n):    
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val = random.normal(key, (n,), jnp.float64)
    primals, fun_vjp = jax.vjp(rosenbrock, val)
    
    # The argument of fun_vjp() can be any real number
    seed = 0.5
    adj_val = fun_vjp(seed)
    plot_vals(val,grad=jnp.stack(adj_val[0]))

    print("The input is", val)
    print("The jax.vjp result is ", adj_val[0])

In [None]:
rosenbrockvjp(7)

# Multiple Input Arrays
1. `jax.grad()` `jax.jvp()` `jax.vjp()` all support multiple arrays as *input*

In [None]:
def rosenbrock2(x,y):
    """
    Computes the Rosenbrock's banana
    x : array of values
    """
    result = jnp.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0,axis=0)
    result += jnp.sum(100.0 * (y[1:] - y[:-1]**2.0)**2.0 + (1 - y[:-1])**2.0,axis=0)
    return result

In [None]:
#Create a function that computes the derivatives
grad_rosenbrock2 = jax.grad(rosenbrock2, argnums=(0,1))

def rosenbrock2grad(n1, n2):
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val1 = random.normal(key, (n1,), jnp.float64)
    val2 = random.normal(key, (n2,), jnp.float64)

    #compute the result
    result1, result2 = grad_rosenbrock2(val1, val2,)
    print("The input is", val1, val2)
    print("The grad result is ", result1, result2)

In [None]:
rosenbrock2grad(3,4)

In [None]:
def rosenbrock2jvp(n1, n2):
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val1 = random.normal(key, (n1,), jnp.float64)
    val2 = random.normal(key, (n2,), jnp.float64)
    tangents = jnp.empty([0])
    iden_seed1 = jnp.eye(n1)
    iden_seed2 = jnp.eye(n2)
    zero1 = jnp.zeros(n1)
    zero2 = jnp.zeros(n2)
    for i in range(n1):
        primal_output, res = jax.jvp(rosenbrock2, (val1, val2), (iden_seed1[i],zero2))
        tangents = jnp.append(tangents, res)
    for i in range(n2):
        primal_output, res = jax.jvp(rosenbrock2, (val1, val2), (zero1, iden_seed2[i],))
        tangents = jnp.append(tangents, res)
    print("The input is", val1, val2)
    print("The jax.jvp result is ", tangents)

In [None]:
rosenbrock2jvp(3,4)

In [None]:
def rosenbrock2vjp(n1,n2):    
    """
    n must be a positive integer
    """
    
    #create a random array
    key = random.PRNGKey(0)
    val1 = random.normal(key, (n1,), jnp.float64)
    val2 = random.normal(key, (n2,), jnp.float64)
    primals, fun_vjp = jax.vjp(rosenbrock2, val1,val2)
    
    # The argument of fun_vjp() can be any real number
    seed = 0.5
    adj_val1, adj_val2 = fun_vjp(seed)
    print("The input is", val1, val2)
    print("The jax.vjp result is ", np.array(adj_val1),np.array(adj_val2))

In [None]:
rosenbrock2vjp(3,4)