In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp
from jax import lax

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
normal = dist.Normal(loc=0., scale=1.)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
samples = normal.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = normal.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
loc = jnp.zeros(k)
print("loc: ", loc.shape)
scale = jnp.ones(k)
print("scale: ", scale.shape)
print('')
normal = dist.Normal(loc=loc, scale=scale)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
samples = normal.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = normal.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
s = 4
loc = jnp.vstack([jnp.zeros(k) / k for _ in range(s)])
print("loc: ", loc.shape)
scale =jnp.vstack([jnp.ones(k) / k for _ in range(s)])
print("scale: ", scale.shape)
print('')
normal = dist.Normal(loc=loc, scale=scale)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
samples = normal.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = normal.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
d = 2
loc = jnp.zeros((d,))
print("loc: ", loc.shape)
cov_matrix = jnp.eye(d, d)
print("cov_matrix: ", cov_matrix.shape)
print('')
normal = dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
samples = normal.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = normal.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
d = 2
loc = jnp.zeros((k, d))
print("loc: ", loc.shape)
cov_matrix = jnp.repeat(jnp.expand_dims(jnp.eye(d, d), 0), k, axis=0)
print("cov_matrix: ", cov_matrix.shape)
print('')
normal = dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
samples = normal.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = normal.log_prob(samples)
print('lp: ', lp.shape)