<a href="https://colab.research.google.com/github/sriharikrishna/siamcse21/blob/main/rosenbrock_dot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Rosenbrock Dot Product Test
Here we first rewrite the Rosenbrock example from earlier as calls to two functions:

`result =  rosenbrock_vec_2(rosenbrock_vec_1(x))`

In [60]:
import jax
from jax import random
import jax.numpy as jnp
import numpy as np

def rosenbrock(x):
    """
    Computes the Rosenbrock's banana
    x : array of values
    """
    y = rosenbrock_vec_1(x)
    z = rosenbrock_vec_2(y)
    return z

def rosenbrock_vec_1(x):
    """
    Computes the individual summation terms of Rosenbrock's banana
    x : array of values
    """
    y = (100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)
    return y

def rosenbrock_vec_2(y):
    """
    Sums the individual terms of Rosenbrock's banana
    y : array of summation terms
    """
    z = sum(y)
    return z

We then compute the forward mode partial derivatives.
Create a random vector `x`

In [61]:
#create a random array
n=10
key = random.PRNGKey(0)
x = random.normal(key, (n,), jnp.float64)
print("x", x)

x [-0.372111    0.26423106 -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.5908642   0.73168874  0.5673025 ]


## Forward mode
Create a random seed vector `xd` and compute the partials

In [62]:
 #Forward mode; create a random seed; compute gradients
key = random.PRNGKey(34234)
xd = random.normal(key, (x.shape), jnp.float64)
print("xd", xd)

xd [ 0.25664428  2.0618527   0.17896675 -0.05648184 -1.0964983  -0.7214161
 -1.4962182   0.49886012 -0.35361898 -0.07395974]


In [63]:
y, yd = jax.jvp(rosenbrock_vec_1, (x,),(xd,))
print("yd",yd)

yd [ 55.96148    42.925358   -1.7865543 232.18044   119.90302   239.9787
 319.5841     16.462196    3.0224354]


In [64]:
z, zd = jax.jvp(rosenbrock_vec_2, (y,), (yd,))
print("zd",zd)

zd 1028.2312


## Reverse Mode
Create a random seed vector `xd` and compute the partials

In [65]:
#Reverse mode; create a random seed; compute adjoints
key = random.PRNGKey(134534)
zb = random.normal(key, (1,), jnp.float64)[0]
print("zb",zb)

zb -0.47497585


In [76]:
_, fun_vjp = jax.vjp(rosenbrock_vec_2, y)
yb = np.array(fun_vjp(zb))
#yb = -yb
print("yb",np.array(yb))

yb [[-0.47497585 -0.47497585 -0.47497585 -0.47497585 -0.47497585 -0.47497585
  -0.47497585 -0.47497585 -0.47497585]]


In [77]:
_, fun_vjp = jax.vjp(rosenbrock_vec_1, x)
xb = np.array(fun_vjp(yb[0]))
print("xb", xb)

xb [[ -7.5877934 -23.916172   51.802147  212.44696   123.71325    54.039146
  200.41623    57.509766  -31.647968   -3.0335894]]


## Dot Products
Compute the dot products

In [78]:
#Compute the dot products
print("jnp.dot(xb,xd)", jnp.dot(xb,xd))
print("jnp.dot(yb,yd)", jnp.dot(yb,yd))
print("jnp.dot(zb,zd)",jnp.dot(zb,zd))

jnp.dot(xb,xd) [-488.38498]
jnp.dot(yb,yd) [-488.38498]
jnp.dot(zb,zd) -488.38498


## Exercises
1. Use different seeds for `x`, `xd`, `xb`.
2. Edit the value of `yd` or `yb` after it has been calculated to see the new result of the dot products