<a href="https://colab.research.google.com/github/parvbhargava/Django-Blog-Webstite/blob/main/Queston1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Importing the necessary modules
import numpy as np
import jax.numpy as jnp
from jax import random
from jax.scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
key = random.PRNGKey(29)

# **What is this JAX?**
JAX is an automatic differentiation (AD) toolbox developed by a group of people at Google Brain and the open source community. It aims to bring differentiable programming in NumPy-style onto TPUs. On the highest level JAX combines the previous projects XLA & Autograd to accelorate your favorite linear algebra-based projects.

Python as an interpreted programming language is slow by nature. It translates one program statement to machine code at a time and computations may get stuck in the global interpreter lock (GIL). So in order to train networks at scale we need fast compilation and parallel computing! Complied CUDA kernels for example provide a set of primitive instructions which can be executed massively parallel on a NVIDIA GPU. The computation graph generated by PyTorch or TensorFlow can then be compiled into a sequence of executions (basic operations, e.g. add/multiply) with precompiled kernels. Ideally, we want to launch as few kernels as possible because this reduces communication times and memory load. And this is where XLA comes in. It optimizes memory bandwith by “fusing” operations and reduces the amount of returned intermediate computations. In practice this can help to significantly spead up things.

Autograd, on the other hand, provides automatic differentiation support for large parts of standard Python features. AD resembles the backbone of optimization in Deep Learning. It simplifies the derivative expression of a compositional function at every possible point in time. For a vast set of basic math operations we already know the functional form of their derivative. By the beauty of the chain rule, we can combine these elementary derivative and reduce the complexity of the expression at the cost of memory storage. This allows us to compute gradients which we can then use to optimize the parameters of our models using our favorite gradient-based optimization algorithm. Broadly speaking there are two types of automatic differentiation: Forward and backward mode (aka backpropagation). JAX supports AD for standard NumPy functions as well as loops which transform numerical variables.

In principle these ingredients make JAX’s applicability a lot broader than Deep Learning and provide another step into the era of “Code 2.0” and differentiable programming. Many recent projects focus on DL applications (such as rlax, or haiku - two of DeepMind’s recent open source releases) but there are also other examples which benefit from buth Numba-like speed-ups with some gradient-sauce on top (e.g. Hamiltonian Monte-Carlo). 

# Multivariate Gaussian distribution
The multivariate Gaussian distribution of an n-dimensional vector $x = (x_1 , x_2,x_3,...,x_n)$ may be written as:

\begin{align} p(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\Sigma}) = \frac{1}{\sqrt{(2\pi)^n|\boldsymbol{\Sigma}|}} \exp\left( -\frac{1}{2}(\boldsymbol{x}-\boldsymbol{\mu})^\mathrm{T}\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\right),\end{align}

where μ is the n-dimensional mean vector and Σ is the n×n covariance matrix.

To visualize the magnitude of $p(x \ ; \ μ ,\  Σ)$ as a function of all the n dimensions requires a plot in n+1 dimensions, so visualizing this distribution for n > 2 is tricky. The code below calculates and visualizes the case of n = 2, the bivariate Gaussian distribution.

The plot uses the colormap viridis, which was introduced in

In [None]:
# Our 2-dimensional distribution will be over variables X and Y
# Since device arrays are immutable hence we used ndarray.
N = 60
X = np.linspace(-3, 3, N)
Y = np.linspace(-3, 4, N)
X, Y = np.meshgrid(X, Y)

# Mean vector and covariance matrix
mu = jnp.array([0., 1.])
Sigma = jnp.array([[ 1. , 0.6], [0.6,  2]])

# Pack X and Y into a single 3-dimensional array
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

def multivariate_gaussian(pos, mu, Sigma):
    """Return the multivariate Gaussian distribution on array pos.

    pos is an array constructed by packing the meshed arrays of variables
    x_1, x_2, x_3, ..., x_k into its _last_ dimension.

    """

    n = mu.shape[0]
    Sigma_det = jnp.linalg.det(Sigma)
    Sigma_inv = jnp.linalg.inv(Sigma)
    N = jnp.sqrt((2*np.pi)**n * Sigma_det)
    # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
    # way across all the input variables.
    fac = jnp.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu)

    return jnp.exp(-fac / 2) / N

# The distribution on the variables X, Y packed into pos.
Z = multivariate_gaussian(pos, mu, Sigma)

# Create a surface plot and projected filled contour plot under it.
fig = plt.figure()
ax = fig.gca(projection='3d')
# ax.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True,
#                 cmap=cm.viridis)

cset = ax.contourf(X, Y, Z, zdir='z', offset=-0.15, cmap=cm.viridis)



# Adjust the limits, ticks and view angle
ax.set_zlim(-0.15,0.2)
ax.set_zticks(jnp.linspace(0,0.2,5))
ax.view_init(27, -21)

plt.show()

In [None]:
Sigma_det = jnp.linalg.det(Sigma)
x = jnp.linspace(-6, 7, 10, endpoint=False)
y = multivariate_normal.pdf(x, mean=0, cov=Sigma_det); 
fig1 = plt.figure()
ax = fig1.add_subplot(111)
ax.plot(x, y)