In [61]:
%matplotlib inline
import matplotlib.pyplot as plt
import palettes
import numpyro

from jax import numpy as jnp
from jax import random as jr
from numpyro import distributions as dist
from numpyro.infer import Predictive, NUTS, MCMC, SVI, autoguide, Trace_ELBO
from numpyro.distributions import constraints
from ramsey.data import sample_from_sine_function
from ramsey.experimental import ARMA, Autoregressive

palettes.set_theme()

In [62]:
numpyro.set_host_device_count(4)

In [None]:
## Bayesian inference with NumPyro

In [2]:
def model(y=None):
    loc = numpyro.sample("loc", dist.Normal(0.0, 1.0))
    scale = numpyro.sample("scale", dist.HalfNormal(1.0))
    ar_coefficients = numpyro.sample(
        "ar_coefficients", dist.Normal(jnp.zeros(3), 1.0)
    )
    numpyro.sample(
        "y", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y
    )

In [3]:
D = Predictive(model, num_samples=1)(rng_key=jr.PRNGKey(3))
D

{'ar_coefficients': Array([[-0.69692624,  1.3743659 , -0.31461632]], dtype=float32),
 'loc': Array([-0.05887505], dtype=float32),
 'scale': Array([1.0093267], dtype=float32),
 'y': Array([[  0.791172  ,  -0.16440308,   0.84213233,  -0.60780644,
           0.9093818 ,  -3.276204  ,   3.1482418 ,  -6.157694  ,
           9.625948  , -17.264713  ]], dtype=float32)}

In [4]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(jr.PRNGKey(1), y=D["y"].flatten())

  mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000, num_chains=4)
sample: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:01<00:00, 2640.73it/s, 15 steps of size 1.74e-01. acc. prob=0.90]
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 10705.34it/s, 31 steps of size 1.58e-01. acc. prob=0.92]
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 10503.90it/s, 31 steps of size 1.56e-01. acc. prob=0.93]
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 10549.45it/s, 15 steps of size 1.66e-01. acc. prob=0.90]


In [7]:
mcmc.print_summary()


                        mean       std    median      5.0%     95.0%     n_eff     r_hat
ar_coefficients[0]     -0.79      0.34     -0.79     -1.34     -0.24   3765.03      1.00
ar_coefficients[1]      1.13      0.42      1.14      0.43      1.80   4097.92      1.00
ar_coefficients[2]     -0.53      0.53     -0.54     -1.40      0.32   4913.12      1.00
               loc     -0.19      0.31     -0.20     -0.69      0.31   4951.99      1.00
             scale      0.99      0.27      0.95      0.57      1.37   3056.92      1.00

Number of divergences: 0


## Maximum likelihood inference with NumPyro

In [45]:
D

{'ar_coefficients': Array([[-0.69692624,  1.3743659 , -0.31461632]], dtype=float32),
 'loc': Array([-0.05887505], dtype=float32),
 'scale': Array([1.0093267], dtype=float32),
 'y': Array([[  0.791172  ,  -0.16440308,   0.84213233,  -0.60780644,
           0.9093818 ,  -3.276204  ,   3.1482418 ,  -6.157694  ,
           9.625948  , -17.264713  ]], dtype=float32)}

In [15]:
def model(y=None):
    loc = numpyro.param("loc", 0.0)
    scale = numpyro.param("scale", 1.0, constraints=constraints.positive)
    ar_coefficients = numpyro.param(
        "ar_coefficients", jnp.array([-1.0, 0.0, 1.0])
    )
    numpyro.sample(
        "y", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y
    )

In [30]:
def guide(y=None):
    pass

In [43]:
svi = SVI(model, guide, optim=Adam(0.01), loss=Trace_ELBO())
svi_res = svi.run(jr.PRNGKey(1), y=D["y"].flatten(), num_steps=1000)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 6641.08it/s, init loss: 104.0581, avg. loss [951-1000]: 11.3399]


In [44]:
svi_res.params

{'ar_coefficients': Array([-0.8599633 ,  1.103057  , -0.41942602], dtype=float32),
 'loc': Array(-0.20917863, dtype=float32),
 'scale': Array(0.75362915, dtype=float32)}

In [53]:
## MAP with NumPyro

In [57]:
def model(y=None):
    loc = numpyro.sample("loc", dist.Normal(0.0, 1.0))
    scale = numpyro.sample("scale", dist.HalfNormal(1.0))
    ar_coefficients = numpyro.sample(
        "ar_coefficients", dist.Normal(jnp.zeros(3), 1.0)
    )
    numpyro.sample(
        "y", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y
    )

In [58]:
guide = autoguide.AutoDelta(model)

In [59]:
svi = SVI(model, guide, optim=Adam(0.01), loss=Trace_ELBO())
svi_res = svi.run(jr.PRNGKey(1), y=D["y"].flatten(), num_steps=1000)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 6383.21it/s, init loss: 1647.7804, avg. loss [951-1000]: 17.6015]


In [60]:
svi_res.params

{'ar_coefficients_auto_loc': Array([-1.0999439,  0.7032318, -0.4249696], dtype=float32),
 'loc_auto_loc': Array(-0.25446877, dtype=float32),
 'scale_auto_loc': Array(0.7408141, dtype=float32)}

## Maximum likelihood inference in Flax

In [128]:
from flax import linen as nn
import optax
import jax
from flax.linen import initializers
import numpy as np
from flax.training.train_state import TrainState

In [113]:
y = D["y"]
y = y.flatten()

In [107]:
class ARModel(nn.Module):
    order: int

    def setup(self):
        self.loc = self.param(
            "loc", initializers.glorot_normal(), (1, 1), jnp.float32
        )
        self.log_scale = self.param(
            "log_scale", initializers.glorot_normal(), (1, 1), jnp.float32
        )
        self.ar_coefficients = self.param(
            "ar_coefficients",
            initializers.glorot_normal(),
            (self.order, 1),
            jnp.float32,
        )

    def __call__(self, inputs):
        return self.log_prob(inputs)

    def _get_distr(self):
        ar = Autoregressive(
            self.loc.flatten(),
            self.ar_coefficients.flatten(),
            jnp.exp(self.log_scale.flatten()),
        )
        return ar

    def log_prob(self, inputs):
        ar = self._get_distr()
        return ar.log_prob(inputs)

    def sample(self, initial_state, length, shape=()):
        ar = self._get_distr()
        return ar.sample(self.make_rng("sample"), length, initial_state, shape)

In [131]:
def create_train_state(rng, model, optimizer, **init_data):
    init_key, sample_key = jr.split(rng)
    params = model.init({"sample": sample_key, "params": init_key}, **init_data)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
    return state

In [None]:
def train(seed, y, state, n_iter=1000):
    objectives = np.zeros(n_iter)
    for i in range(n_iter):

        def step(rngs, state, **batch):
            current_step = state.step
            rngs = {
                name: jr.fold_in(rng, current_step)
                for name, rng in rngs.items()
            }

            def obj_fn(params):
                obj = state.apply_fn(variables=params, rngs=rngs, **batch)
                return -jnp.sum(obj)

            obj, grads = jax.value_and_grad(obj_fn)(state.params)
            new_state = state.apply_gradients(grads=grads)
            return new_state, obj

        sample_rng_key, seed = jr.split(seed)
        state, obj = step({"sample": sample_rng_key}, state, inputs=y)
        objectives[i] = obj
    return state.params, objectives


state = create_train_state(
    jr.PRNGKey(123), ARModel(3), optax.adam(0.01), inputs=y
)
params, objectives = train(jr.PRNGKey(2), y, state, 2000)
params

In [None]:
D