# Basics of Numpyro for Bayesian Inference with MCMC 

In [10]:
import math
import os

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from causalgraphicalmodels import CausalGraphicalModel
from IPython.display import Image, set_matplotlib_formats
from matplotlib.patches import Ellipse, transforms

import jax
import jax.numpy as jnp  # numpy, superfast
from jax import ops, random, vmap
from jax.scipy.special import expit

import numpy as onp       # the numpy, original

import numpyro as numpyro
import numpyro as npr
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size, print_summary
from numpyro.infer import MCMC, NUTS, Predictive

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_host_device_count(4)

## Distributions

In [2]:
b = dist.Bernoulli(0.3)

random_samples = b.sample(random.PRNGKey(0), (1000,))

In [3]:
with numpyro.handlers.seed(rng_seed=0):
    x = npr.sample('x', dist.Bernoulli(.3))

In [4]:
x

DeviceArray(1, dtype=int32)

In [5]:
x.item()

1

In [6]:
x * 100

DeviceArray(100, dtype=int32)

In [7]:
type(x)

jax.interpreters.xla._DeviceArray

In [8]:
with numpyro.handlers.seed(rng_seed=0):
    x = npr.sample('x', dist.Normal(0, jnp.array([1,2,4])))
x

DeviceArray([ 1.1378783, -2.44191  , -2.3661458], dtype=float32)

In [34]:
nYlevels=4
cuts = jnp.array([i+.5 for i in range(nYlevels-1)])
print(cuts, cuts.shape)
mu = jnp.array([0]*7)
sigma = 1
cdfs = jax.scipy.stats.norm.cdf(cuts.reshape(-1,1), mu, sigma)
print('cdfs', cdfs.shape, cdfs)

[0.5 1.5 2.5] (3,)
cdfs (3, 7) [[0.69146246 0.69146246 0.69146246 0.69146246 0.69146246 0.69146246
  0.69146246]
 [0.9331928  0.9331928  0.9331928  0.9331928  0.9331928  0.9331928
  0.9331928 ]
 [0.9937903  0.9937903  0.9937903  0.9937903  0.9937903  0.9937903
  0.9937903 ]]


In [37]:
cdfs1 = jnp.concatenate( (cdfs, jnp.array([[1.]*mu.shape[0]])) )
cdfs1.shape, cdfs1

((4, 7),
 DeviceArray([[0.69146246, 0.69146246, 0.69146246, 0.69146246, 0.69146246,
               0.69146246, 0.69146246],
              [0.9331928 , 0.9331928 , 0.9331928 , 0.9331928 , 0.9331928 ,
               0.9331928 , 0.9331928 ],
              [0.9937903 , 0.9937903 , 0.9937903 , 0.9937903 , 0.9937903 ,
               0.9937903 , 0.9937903 ],
              [1.        , 1.        , 1.        , 1.        , 1.        ,
               1.        , 1.        ]], dtype=float32))

In [38]:
import numpy as np
def getAmat(nYlevels):
    a = np.eye(nYlevels)
    for j in range(a.shape[0]-1):
        a[j+1,j] = -1 
    return np.array(a)
Amat = getAmat(nYlevels)
Amat

array([[ 1.,  0.,  0.,  0.],
       [-1.,  1.,  0.,  0.],
       [ 0., -1.,  1.,  0.],
       [ 0.,  0., -1.,  1.]])

In [48]:
diff = jnp.dot(Amat, cdfs1)
diff

DeviceArray([[0.69146246, 0.69146246, 0.69146246, 0.69146246, 0.69146246,
              0.69146246, 0.69146246],
             [0.24173033, 0.24173033, 0.24173033, 0.24173033, 0.24173033,
              0.24173033, 0.24173033],
             [0.06059754, 0.06059754, 0.06059754, 0.06059754, 0.06059754,
              0.06059754, 0.06059754],
             [0.00620967, 0.00620967, 0.00620967, 0.00620967, 0.00620967,
              0.00620967, 0.00620967]], dtype=float32)

In [49]:
diff.sum(axis=0)

DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [50]:
max0 = jnp.maximum(0, diff)
max0.sum(axis=0)

DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [51]:
probs = max0 / max0.sum(axis=0)

In [58]:
probs, probs.shape

(DeviceArray([[0.69146246, 0.69146246, 0.69146246, 0.69146246, 0.69146246,
               0.69146246, 0.69146246],
              [0.24173033, 0.24173033, 0.24173033, 0.24173033, 0.24173033,
               0.24173033, 0.24173033],
              [0.06059754, 0.06059754, 0.06059754, 0.06059754, 0.06059754,
               0.06059754, 0.06059754],
              [0.00620967, 0.00620967, 0.00620967, 0.00620967, 0.00620967,
               0.00620967, 0.00620967]], dtype=float32),
 (4, 7))

In [60]:
with numpyro.handlers.seed(rng_seed=0):
    s = numpyro.sample('obs', dist.Categorical(probs.T))
    print(s)

[0 0 0 2 0 0 0]
