In [1]:
from random import randint

import jax
import jax.numpy as jnp
import scipy.io

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.einstein.callbacks import Progbar
from numpyro.infer import Trace_ELBO
from numpyro.contrib.einstein import kernels, Stein
from numpyro.infer.initialization import init_to_value, init_with_noise
from numpyro.infer.autoguide import AutoDelta
from numpyro.examples.datasets import LR_BANANA, LR_DIABETIS, LR_GERMAN, LR_IMAGE,\
                                      LR_RINGNORM, LR_SPLICE, LR_TWONORM, LR_WAVEFORM, load_dataset

ModuleNotFoundError: No module named 'numpyro.contrib.einstein.callbacks'

In [None]:
datasets = [LR_BANANA, LR_DIABETIS, LR_GERMAN, LR_IMAGE,
            LR_RINGNORM, LR_SPLICE, LR_TWONORM, LR_WAVEFORM]

In [None]:
rng_key = jax.random.PRNGKey(randint(0, int(1e6)))
num_iterations = 3000
num_particles = 100

In [None]:
def model(data, classes=None):
    alpha = numpyro.sample('alpha', dist.InverseGamma(concentration=1.0, rate=0.01))
    w = numpyro.sample('w', dist.Normal(loc=jnp.zeros(data.shape[1]+1), scale=alpha))
    with numpyro.plate('data', data.shape[0]):
        biased_data = jnp.concatenate((jnp.ones((data.shape[0],1)), data), axis=1)
        return numpyro.sample('x', dist.Bernoulli(logits=biased_data @ w), obs=classes)

In [None]:
def test_accuracy(model, guide, rng_key, testset, params, num_pred=100):
    test_inp, test_clz = testset
    def single_test_accuracy(rng_key, testset, params):
        guide_trace = handlers.trace(handlers.substitute(guide, params)).get_trace(test_inp)
        model_trace = handlers.trace(handlers.replay(handlers.seed(model, rng_key),
                                                     guide_trace)).get_trace(test_inp)
        accuracy = jnp.count_nonzero(model_trace['x']['value'] == test_clz) / test_inp.shape[0] * 100
        return accuracy
    accs = []
    for i in range(num_particles):
        ps = {k: param[i] for k, param in params.items()}
        accs.append(jax.vmap(lambda rnk: single_test_accuracy(rnk, testset, ps))(jax.random.split(rng_key, num_pred)))
    return jnp.mean(jnp.stack(accs))

In [None]:
for dataset in datasets:
    _, get_train_batch = load_dataset(dataset, split='train')
    train_inp, train_clz = get_train_batch()
    print(dataset.name)
    guide = AutoDelta(model)
    svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=.05), Trace_ELBO(),
                 kernels.RBFKernel(),
                 init_strategy=init_with_noise(init_to_value(values={'x': -10.}), noise_scale=1.0),
                 num_particles=num_particles,
                 repulsion_temperature=train_inp.shape[0] ** -1)
    svgd_state, loss = svgd.run(rng_key, num_iterations, train_inp, train_clz, callbacks=[Progbar()])
    _, get_test_batch = load_dataset(dataset, split='test')
    test_data = get_train_batch()
    print(test_accuracy(model, guide, svgd_state.rng_key, test_data, svgd.get_params(svgd_state)))