In [5]:
from functools import partial

from matplotlib import pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

# Run the following before any XLA modules such as JAX:
import chex

chex.set_n_cpu_devices(2)

import sys
sys.path.append("../src")

# Import the remaining JAX related 
from gabenet.mcmc import sample_markov_chain
from gabenet.nets import MultinomialDirichletBelieve
from gabenet.utils import freeze_trainable_states, holdout_split, perplexity

import arviz as az
import haiku as hk
import jax
from jax import random
from jax.tree_util import tree_map
import jax.numpy as jnp
from scipy.stats import entropy

To illustrate how to use the multinomial-Dirichlet believe network, we will train the
model on the MNIST dataset, containing handwritten digits.

The dataset can directly be loaded from scikit-learn. As preprocessing step, we reshape the
digits from a 8x8 square matrix to a flat array of size 64.

In [2]:
digits = datasets.load_digits()
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
X_train, X_test = train_test_split(X, test_size=0.2, random_state=0)

Next, we define the model. We use a simple decoder network with two hidden layers. In
total, the size of the network: 2 x 10 x 64.

```
n_hidden_units = (2, 10)
model = MultinomialDirichletBelieve(n_hidden_units, n_features)
```

This function has to be defined in a [haiku](https://github.com/deepmind/dm-haiku) context to transform the network in a pure state for JAX.

Since the network is a Bayesian model, we don't train the model by
minimising a loss. Rather, we infer the distribution $p(\boldsymbol{\theta}|\boldsymbol{X}_{\mathrm{train}})$ of the model's parameters $\boldsymbol{\theta}$ given the training data $\boldsymbol{X}_{\mathrm{train}}$ that we observe. This probability distribution is called the [posterior](https://en.wikipedia.org/wiki/Posterior_probability).

Unfortunately, we don't know what this distribution is. However, we do know a way how to sample it: using Markov chain Monte Carlo (MCMC). This simulation method samples the distributions by taking small steps that depend on its previous state. In theory, when we have take enough steps, the state converges to the true (posterior) distribution.

First, initialise the chain using training data:

```python
model.init(X_train)
```

This method takes samples from the prior as a starting point. After that, keep taking steps from your current to your next state. To take one step, you simply call your model using
the training data:

```python
model(X_train)
```

This function call does one Gibbs sampling step, which updates all the parameters one-by-one.

Now, lets put all elements together.

In [3]:
# Pseudo-random number generator sequence.
key_seq = hk.PRNGSequence(42)

m_samples, n_features = X_train.shape
X_train = jnp.array(X_train)

@hk.transform_with_state
def kernel(X=X_train):
    """Advance the Markov chain by one step."""
    n_hidden_units = (10,)
    model = MultinomialDirichletBelieve(n_hidden_units, n_features)
    # Do one Gibbs sampling step.
    model(X)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Here, we defined a function that proposes a new state based on its current configuration. This is called a _kernel_. The `hk.transform_with_state` decorator uses haiku to purify the function into something that is stateless.

Finally, we draw samples from the Markov chain. We first take 100 burn-in steps, in the
hope that the chain converges to the true distribution. After throwing away these first 100 samples, we collect a new set
of 50 samples (25 in each chain) to estimate the posterior distribution.

Note that `sample_markov_chain` (below) automatically takes care of distributing your
computation across multiple devices. For simplicity, we assume you are running on a CPU and split the CPU up in two virtual devices. (See above, at the import section, where we've used
`chex` set the number of devices to 2.)

The following cell, that collects statistics from the Markov chain takes, about `10 minutes` to run on a CPU.

In [4]:
n_chains = jax.device_count()

params, states = sample_markov_chain(
    next(key_seq),
    kernel=kernel,
    n_samples=800,
    n_burnin_steps=1_000,
    n_chains=n_chains,
)

_ = states["multinomial_dirichlet_believe/~/multinomial_layer"][
    "phi"
].block_until_ready()



In [6]:
def make_idata(states):
    entr_theta = entropy(
        states["multinomial_dirichlet_believe/~/cap_layer"]["theta"], axis=-1
    )
    r_entr = entropy(states["multinomial_dirichlet_believe/~/cap_layer"]["r"], axis=-1)

    return az.convert_to_dataset(
        {
            "c": states["multinomial_dirichlet_believe/~/cap_layer"]["c"],
            "s[theta(1)]": entr_theta,
            "s[r]": r_entr,
        }
    )

In [8]:
idata = make_idata(states)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
c,0.838,0.037,0.770,0.911,0.021,0.016,3.0,12.0,1.74
s[theta(1)][0],0.988,0.091,0.809,1.144,0.032,0.023,8.0,148.0,1.22
s[theta(1)][1],0.991,0.205,0.654,1.364,0.118,0.093,3.0,20.0,1.73
s[theta(1)][2],0.887,0.228,0.536,1.199,0.146,0.120,3.0,43.0,1.82
s[theta(1)][3],0.970,0.143,0.684,1.157,0.077,0.060,4.0,44.0,1.51
...,...,...,...,...,...,...,...,...,...
s[theta(1)][1433],0.837,0.105,0.660,1.012,0.047,0.036,5.0,101.0,1.40
s[theta(1)][1434],1.051,0.076,0.902,1.175,0.041,0.032,4.0,61.0,1.59
s[theta(1)][1435],1.061,0.099,0.908,1.255,0.052,0.041,4.0,52.0,1.46
s[theta(1)][1436],1.284,0.144,1.050,1.563,0.034,0.024,17.0,41.0,1.10


In [None]:
r_hats = az.summary(idata, var_names=['c', 's[r]'])
r_hats

In [None]:
r_s = az.summary(idata, var_names=['s[theta(1)]'])
r_s['r_hat'].describe()

In [9]:
# Make a dilution series.
X_train_sets = [X_train]
X_rest = X_train
for i in range(9):
    X_train_i, X_rest = holdout_split(next(key_seq), X_rest)
    X_train_sets.append(X_train_i)
X_train_sets = X_train_sets[::-1]

In [10]:
X_subset = X_train_sets[0]
kernel_i = hk.TransformedWithState(
    init=partial(kernel.init, X=X_subset), apply=partial(kernel.apply, X=X_subset)
)
params, states = sample_markov_chain(
    next(key_seq), 
    kernel=kernel_i, 
    n_samples=n_chains, 
    n_burnin_steps=100, 
    n_chains=n_chains,
)
last_state = tree_map(lambda x: x[:, -1], states)



In [11]:
for i in range(1, 10):
    X_subset = X_train_sets[i]
    kernel_i = hk.TransformedWithState(
        init=partial(kernel.init, X=X_subset), apply=partial(kernel.apply, X=X_subset)
    )
    params, states = sample_markov_chain(
        next(key_seq), 
        kernel=kernel_i, 
        n_samples=n_chains, 
        n_burnin_steps=100, 
        n_chains=n_chains,
        params=params,
        initial_state=last_state,
    )
    last_state = tree_map(lambda x: x[:, -1], states)

In [13]:
kernel = hk.TransformedWithState(
    init=partial(kernel.init, X=X_train), apply=partial(kernel.apply, X=X_train)
)
params, states = sample_markov_chain(
    next(key_seq), 
    kernel=kernel, 
    n_samples=800, 
    n_burnin_steps=0, 
    n_chains=n_chains,
    params=params,
    initial_state=last_state,
)

In [14]:
idata2 = make_idata(states)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
c,0.624,0.414,0.132,1.113,0.287,0.241,3.0,17.0,1.92
s[theta(1)][0],0.338,0.358,0.000,0.810,0.243,0.202,2.0,4.0,2.49
s[theta(1)][1],0.051,0.082,0.000,0.226,0.037,0.028,3.0,3.0,2.03
s[theta(1)][2],0.316,0.318,0.000,0.691,0.222,0.187,3.0,4.0,2.34
s[theta(1)][3],0.038,0.090,0.000,0.141,0.015,0.010,4.0,3.0,1.44
...,...,...,...,...,...,...,...,...,...
s[theta(1)][1433],0.140,0.161,0.000,0.426,0.103,0.083,3.0,4.0,2.19
s[theta(1)][1434],0.303,0.305,0.000,0.658,0.213,0.179,3.0,4.0,2.11
s[theta(1)][1435],0.241,0.245,0.000,0.585,0.137,0.108,4.0,4.0,1.50
s[theta(1)][1436],0.270,0.318,0.000,0.684,0.201,0.163,3.0,3.0,2.09


In [26]:
r_hats2 = az.summary(idata2, var_names=['c', 's[r]'])
r_hats2

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
c,0.624,0.414,0.132,1.113,0.287,0.241,3.0,17.0,1.92
s[r],0.422,0.285,0.042,0.722,0.199,0.168,2.0,11.0,2.85


In [22]:
r_s2 = az.summary(idata2, var_names=['s[theta(1)]'])
r_s2['r_hat'].describe()