<a href="https://colab.research.google.com/github/sriharikrishna/siamcse23/blob/main/rosenbrock_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Rosenbrock function
**Notebook Authors**:
[Jan Hückelheim](https://www.anl.gov/profile/jan-christian-hueckelheim)
[Sri Hari Krishna Narayanan](https://www.mcs.anl.gov/~snarayan) 
[Ludger Paehler](https://ludger.fyi/) 


Also known as Rosenbrock's valley or Rosenbrock's banana function. It is a performance test for optimization. The narrow long valley makes it difficult to find the minimum. 

\begin{equation}
F(x) = \sum_{i=0}^{N-1}100(x_{i+1} - x_i^2)^2 + (1 - x_i)^2.
\end{equation}

<center><img src="https://upload.wikimedia.org/wikipedia/commons/3/32/Rosenbrock_function.svg" width="40%" /></center>


### 1. Primal Function
1. Evaluates the Rosenbrock function for an arbitary input vector.

In [1]:
import torch
import matplotlib.pyplot as plt

def rosenbrock(x):
    """
    Input: x vector of values
    Output: Result of Rosenbrock's banana function
    """
    result = torch.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0,axis=0)
    return result

def plot_vals(vals, grad=None):
    plt.plot(vals,label='primal input')
    if grad != None:
      plt.plot(grad,label='derivatives')
    plt.xlabel('array index')
    plt.legend()
    plt.title('')
    plt.show()

### 2. Primal Driver
1. Calls the Rosenbrock function. 
2. We have arbitarily created an array `[0.5, 0.5..., 0.5]` as its input.

In [2]:
def fun_driver(n):
    """
    Input: n array length
    Output: Result of Rosenbrock's banana function
    """
    val = torch.full(n, 0.5) 

    #compute the result
    result = rosenbrock(val)
    
    print("The input is", val)
    print("The result of Rosenbrock's is ", result)

In [None]:
fun_driver((10,))

### 3. First look at derivatives: `torch.autograd.grad()`
1. https://pytorch.org/docs/stable/autograd.html
2. Together compute $v\cdot J$ for a function that computes a scalar value ($R^n \rightarrow R$).   

In [8]:
#Create a function that computes the derivatives.
def grad_driver(n):
    """
    Input: n array length
    Output: Derivatives of Rosenbrock's banana function
    """
    #create the input array
    val = torch.tensor(torch.full((n,), 0.5), requires_grad=True) 

    #compute the result
    result = rosenbrock(val)

    #compute the derivatives
    grad_vals = torch.autograd.grad(outputs=result, inputs=val)
    
    plot_vals(val.detach().numpy(),grad=grad_vals[0])
    print("The input is", val)
    print("The grad result is ", grad_vals)

In [None]:
grad_driver(10)

### 4. Second look at derivatives: `torch.Tensor.backward()` 
1. https://pytorch.org/docs/stable/autograd.html
2. Together compute $v\cdot J$ for a function that computes a scalar value ($R^n \rightarrow R$).  

In [6]:
#Create a function that computes the derivatives.
def backward_driver(n):
    """
    Input: n array length
    Output: Derivatives of Rosenbrock's banana function
    """
    #create the input array
    val = torch.tensor(torch.full((n,), 0.5), requires_grad=True) 
    
    #compute the result
    result = rosenbrock(val)

    #compute the derivatives
    result.backward()
    
    plot_vals(val.detach().numpy(),grad=val.grad)
    print("The input is", val)
    print("The grad result is ", val.grad)

In [None]:
backward_driver(10)