<a href="https://colab.research.google.com/github/parvbhargava/SRIP-Parv-Bhargava/blob/main/Question2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Sampling from a Multivariate Normal Distribution
###Multivariate Normal Distribution

Recall that a random vector X  = (X<sub>1</sub>, X<sub>d</sub>) has a multivariate normal (or Gaussian) distribution if every linear combination 

\begin{align}
         \sum_{i=1}^d a_i X_i,\text{ a 𝝐 ℝ}
    \end{align}

is normally distributed.

Warning: The sum of two normally distributed random variables does not need to be normally distributed (see below).

The multivariate normal distribution has a joint probability density given by

\begin{align} 
      p(x|m,K_o,) = (2π)^{-d/2}|K_0|^{-1/2}exp\left(-\frac{1}{2}(x-m)^TK_0^{ -1}(x-m)\right)
\end{align}

where m<sup>d</sup> is the mean vector and K<sub>0</sub>M<sub>d</sub>( ) is the (symmetric, positive definite) covariance matrix.

In [None]:
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(23)  

##Set parameters



In [None]:
# Define dimension. 
d = 10
# Set mean vector. 
m = jnp.array([1,2,3,4,5,6,7,8,9,10])
# Set covariance function.
K_0 = jnp.array([[1,0,0,0,0,0,0,0,0,0],
       [0,1,0,0,0,0,0,0,0,0],
       [0,0,1,0,0,0,0,0,0,0],
       [0,0,0,1,0,0,0,0,0,0],
       [0,0,0,0,1,0,0,0,0,0],
       [0,0,0,0,0,1,0,0,0,0],
       [0,0,0,0,0,0,1,0,0,0],
       [0,0,0,0,0,0,0,1,0,0],
       [0,0,0,0,0,0,0,0,1,0],
       [0,0,0,0,0,0,0,0,0,1]])

K_0 , m.reshape(10,1)

##Sampling Process
###Step 1: Compute the Cholesky Decomposition
We want to compute the Cholesky decomposition of the covariance matrix K<sub>0</sub>.That is, we want to find a lower triangular matrix LM<sub>d</sub>()  such that        
\begin{align}K_0=LL^T\end{align}

“In practice it may be necessary to add a small multiple of the identity matrix I to the covariance matrix for numerical reasons. This is because the eigenvalues of the matrix 
K<sub>0</sub> can decay very rapidly and without this stabilization the Cholesky decomposition fails. The effect on the generated samples is to add additional independent noise of variance . From the context  can usually be chosen to have inconsequential effects on the samples, while ensuring numerical stability.”

In [None]:
# Define epsilon.
epsilon = 0.0001

# Add small pertturbation. 
K = K_0 + epsilon*jnp.identity(d)

#  Cholesky decomposition.
L = jnp.linalg.cholesky(K)
L

Let us verify desired property

In [None]:
jnp.dot(L, jnp.transpose(L))

###Step 2: Generate Independent Samples  𝒰 ∼ N( 0, I )



In [None]:
n = 10000
u = random.uniform(key,shape=(d, n),minval=-3 , maxval=3)
u

### Step 3: Compute x = m + Lu
The variable x = m + Lu  has a multivariate normal distribution since is a linear combination of independent normally distributed variables. Moreover,

\begin{align} 
      𝐸[x] = 𝐸[m + Lu] = m + L𝐸[u] = m
\end{align}

and

\begin{align} 
      𝐸[xx^T] = 𝐸[mm^T]𝐸[mu^TL^T]+𝐸[Lum^T]+𝐸[Luu^TL^T] = ||m||^2 + K
\end{align}

hence,

\begin{align} 
      𝐸[(x-m)(x^T-m^T)] = K
\end{align}



In [None]:
x = m + jnp.dot(L, u).T
x

### Using JAX Sampler
JAX has a build in multivariate normal sampling function:

In [None]:
key = random.PRNGKey(67)
cov = jnp.array([[1.2, 0.4], [0.4, 1.0]])
mean = jnp.array([3,-1])
x1 = random.multivariate_normal(key, mean, cov, (5000,)).T
x1