<a href="https://colab.research.google.com/github/sriharikrishna/siamcse21/blob/main/norm_nm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Performance testing derivative comparison
This notebook compares different modes of derivative computation. 

In [None]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from scipy.optimize import minimize
from jax import random
from jax import float0
import time
from jax import jacfwd, jacrev
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator

# Primal Function 
 The function takes an input vector of length $n$ and computes $m$ norms. The function is $R^n \rightarrow R^m $.

In [None]:
def fun(x,m):
  """
  Input: x vector of values
  Output: vector of m norm values
  """
  return jnp.array([jnp.linalg.norm(x, ord=i) for i in range(m)])

# Derivative Drivers
We compute the derivatives in 4 ways. 
1. Calling `jax.jvp` `n` times. 
2. Calling `jax.vjp` `m` times. 
3. Calling `jax.jacfwd` once.
4. Calling `jax.jacrev` once.

In [None]:
def jvp_driver(val, n, m):
    """
    Input: n array length
    Output: Derivatives of the function
    """
    tangents = np.zeros((m,n))

    #compute the derivatives. It takes n calls
    iden_seed = jnp.eye(n)
    for i in range(n):
        seed = jnp.zeros(n)
        seed = jax.ops.index_update(seed, jax.ops.index[i], 1)
        # jax.jvp must be called once for each input and seed value
        primal_output, res = jax.jvp(lambda x: fun(x,m), (val,), (seed,))
        tangents[:,i] = np.array(res)
    return tangents


In [None]:
def vjp_driver(val, n, m):
    """
    Input: n array length
    Output: Derivatives of the function
    """
    #jax.vjp must be called once for each input value
    primals, fun_vjp = jax.vjp(lambda x: fun(x,m), val)
    
    adjoints = np.zeros((m,n))
    #compute the derivatives. It takes m calls
    for i in range(m):
        seed = jnp.zeros(m)
        seed = jax.ops.index_update(seed, jax.ops.index[i], 1)
        res = fun_vjp((seed))
        adjoints[i,:] = np.array(res)
    return adjoints

In [None]:
fun_jacrev = jax.jacrev(fun)
fun_jacfwd = jax.jacfwd(fun)

# Compute Runtimes

In [None]:
grid_size = 4
n_vals = {i:4**i for i in range(grid_size)}
m_vals = {i:4**i for i in range(grid_size)}

def compute_times(derfun):
  times = np.zeros((grid_size,grid_size))
  for idxn in n_vals:
    n = n_vals[idxn]
    val = random.normal(random.PRNGKey(0), (n,), jnp.float64)
    for idxm in m_vals:
      m = m_vals[idxm]
      tic = time.perf_counter()
      u = derfun(val,n,m)
      toc = time.perf_counter()
      times[idxn,idxm] = toc-tic
      print("Completed ", "n=", n,"\t m=", m, "\t time ", toc-tic)
  return times

times_fun = compute_times(lambda val,n,m: fun(val,m))
times_vjp = compute_times(vjp_driver)
times_vjp = np.divide(times_vjp,times_fun)
times_jvp = compute_times(jvp_driver)
times_jvp = np.divide(times_jvp,times_fun)
times_jf = compute_times(lambda val,n,m: fun_jacfwd(val,m))
times_jf = np.divide(times_jf,times_fun)
times_jr = compute_times(lambda val,n,m: fun_jacrev(val,m))
times_jr = np.divide(times_jr,times_fun)

# Plot Runtimes

In [None]:
def plot_fun(ax, n_vals, m_vals, times, maxval, title=None):
  # Plot the surface.
  X, Y = np.meshgrid(list(n_vals.keys()), list(m_vals.keys()))
  surf = ax.plot_surface(X, Y, times, cmap=cm.coolwarm,
                        antialiased=True)
  # Customize the z  axis.
  ax.set_zlim(0, maxval*1.1)

  coordsn = list(n_vals.keys())
  ax.set_xticks(coordsn, ["$4^{:d}$".format(i) for i in coordsn])
  
  coordsm = list(m_vals.keys())
  ax.set_yticks(coordsm, ["$4^{:d}$".format(i) for i in coordsm])
  
  ax.invert_yaxis()
  
  ax.title.set_text(title)
  ax.set_xlabel('#inputs')
  ax.set_ylabel('#outputs')

fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, nrows=2, ncols=2)
maxval = max(np.amax(times_vjp), np.amax(times_jvp), np.amax(times_jf), np.amax(times_jr))
plot_fun(ax[0,0], n_vals, m_vals, times_vjp, maxval, title="Time: jax.vjp()")
plot_fun(ax[0,1], n_vals, m_vals, times_jvp, maxval, title="Time : jaxjvp()")
plot_fun(ax[1,0], n_vals, m_vals, times_jr, maxval, title="Time: jax.jacrev()")
plot_fun(ax[1,1], n_vals, m_vals, times_jf, maxval, title="Time: jax.jacfwd()")
plt.show()