In [1]:
from random import randint

import jax
import jax.numpy as jnp
import scipy.io

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.einstein.callbacks import Progbar
from numpyro.infer import Trace_ELBO
from numpyro.contrib.einstein import kernels, Stein
from numpyro.infer.initialization import init_to_value, init_with_noise
from numpyro.infer.autoguide import AutoDelta
from numpyro.examples.datasets import LR_BENCHMARKS, load_dataset

In [2]:
data = load_dataset(LR_BENCHMARKS)

Downloading - https://github.com/pyro-ppl/datasets/raw/master/benchmarks.mat.
Download complete.


TypeError: byte indices must be integers or slices, not str

In [3]:
datasets = {k: {'train': {'input': v['x'][0, 0][v['train'][0, 0][13, :] - 1],
                          'class': (v['t'][0, 0][v['train'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]},
                'test': {'input': v['x'][0, 0][v['test'][0, 0][13, :] - 1],
                         'class': (v['t'][0, 0][v['test'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]}}
            for k, v in data.items()
            if not str.startswith(k, "__") and not k == 'benchmarks' and v['x'][0, 0].shape[0] > 500}

In [4]:
rng_key = jax.random.PRNGKey(randint(0, int(1e6)))
num_iterations = 3000
num_particles = 100

In [5]:
def model(data, classes=None):
    alpha = numpyro.sample('alpha', dist.InverseGamma(concentration=1.0, rate=0.01))
    w = numpyro.sample('w', dist.Normal(loc=jnp.zeros(data.shape[1]+1), scale=alpha))
    with numpyro.plate('data', data.shape[0]):
        biased_data = jnp.concatenate((jnp.ones((data.shape[0],1)), data), axis=1)
        return numpyro.sample('x', dist.Bernoulli(logits=biased_data @ w), obs=classes)

In [6]:
def test_accuracy(model, guide, rng_key, testset, params, num_pred=100):
    def single_test_accuracy(rng_key, testset, params):
        guide_trace = handlers.trace(handlers.substitute(guide, params)).get_trace(testset['input'])
        model_trace = handlers.trace(handlers.replay(handlers.seed(model, rng_key),
                                                     guide_trace)).get_trace(testset['input'])
        accuracy = jnp.count_nonzero(model_trace['x']['value'] == testset['class']) / testset['input'].shape[0] * 100
        return accuracy
    accs = []
    for i in range(num_particles):
        ps = {k: param[i] for k, param in params.items()}
        accs.append(jax.vmap(lambda rnk: single_test_accuracy(rnk, testset, ps))(jax.random.split(rng_key, num_pred)))
    return jnp.mean(jnp.stack(accs))

In [7]:
for name, dataset in datasets.items():
    guide = AutoDelta(model)
    svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=.05), ELBO(),
                 kernels.RBFKernel(),
                 init_strategy=init_with_noise(init_to_value(values={'x': -10.}), noise_scale=1.0),
                 num_particles=num_particles,
                 repulsion_temperature=dataset['train']['input'].shape[0] ** -1)
    svgd_state, loss = svgd.train(rng_key, num_iterations, dataset['train']['input'], dataset['train']['class'],
                                  callbacks=[Progbar()])
    print(test_accuracy(model, guide, svgd_state.rng_key, dataset['train'], svgd.get_params(svgd_state)))

banana
50.0388
diabetis
69.037
german
66.606766
image
73.13408
ringnorm
67.8514
splice
80.49021
twonorm
94.92745
waveform
82.20442


Stein 269.13: 100%|██████████| 3000/3000 [00:26<00:00, 113.97it/s]
Stein 228.21: 100%|██████████| 3000/3000 [00:20<00:00, 145.42it/s]
Stein 362.52: 100%|██████████| 3000/3000 [00:24<00:00, 122.35it/s]
Stein 595.68: 100%|██████████| 3000/3000 [00:28<00:00, 103.68it/s]
Stein 204.49: 100%|██████████| 3000/3000 [00:23<00:00, 128.35it/s]
Stein 764.99: 100%|██████████| 3000/3000 [00:41<00:00, 72.62it/s]   
Stein 92.452: 100%|██████████| 3000/3000 [00:23<00:00, 130.10it/s]
Stein 155.47: 100%|██████████| 3000/3000 [00:24<00:00, 122.68it/s]
