In [1]:
from random import randint

import jax
import jax.numpy as jnp
import scipy.io

import numpyro
import numpyro.distributions as dist
import numpyro.infer.kernels as kernels
from numpyro import handlers
from numpyro.callbacks import Progbar
from numpyro.contrib.autoguide import AutoDelta
from numpyro.infer import ELBO, Stein
from numpyro.infer.initialization import init_to_value, init_with_noise



In [2]:
# From http://theoval.cmp.uea.ac.uk/matlab/default.html
data = scipy.io.loadmat('data/benchmarks.mat')

In [3]:
datasets = {k: {'train': {'input': v['x'][0, 0][v['train'][0, 0][13, :] - 1], 'class': (v['t'][0, 0][v['train'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]}, 'test': {'input': v['x'][0, 0][v['test'][0, 0][13, :] - 1], 'class': (v['t'][0, 0][v['test'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]}} for k, v in data.items() if not str.startswith(k, "__") and not k == 'benchmarks' and v['x'][0, 0].shape[0] > 500}

In [4]:
rng_key = jax.random.PRNGKey(randint(0, int(1e6)))
num_iterations = 3000
num_particles = 100

In [5]:
def model(data, classes=None):
    alpha = numpyro.sample('alpha', dist.InverseGamma(concentration=1.0, rate=0.01))
    w = numpyro.sample('w', dist.Normal(loc=jnp.zeros(data.shape[1]+1), scale=alpha))
    with numpyro.plate('data', data.shape[0]):
        biased_data = jnp.concatenate((jnp.ones((data.shape[0],1)), data), axis=1)
        return numpyro.sample('x', dist.Bernoulli(logits=biased_data @ w), obs=classes)

In [6]:
def test_accuracy(model, guide, rng_key, testset, params, num_pred=100):
    def single_test_accuracy(rng_key, testset, params):
        guide_trace = handlers.trace(handlers.substitute(guide, params)).get_trace(testset['input'])
        model_trace = handlers.trace(handlers.replay(handlers.seed(model, rng_key), guide_trace)).get_trace(testset['input'])
        accuracy = jnp.count_nonzero(model_trace['x']['value'] == testset['class']) / testset['input'].shape[0] * 100
        return accuracy
    accs = []
    for i in range(num_particles):
        ps = {k: param[i] for k, param in params.items()}
        accs.append(jax.vmap(lambda rnk: single_test_accuracy(rnk, testset, ps))(jax.random.split(rng_key, num_pred)))
    return jnp.mean(jnp.stack(accs))

In [7]:
for name, dataset in datasets.items():
    print(name)
    guide = AutoDelta(model, init_strategy=init_with_noise(init_to_value(values={'x': -10.}), noise_scale=1.0))
    svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=.05), ELBO(),
                 kernels.RBFKernel(), num_particles=num_particles,
                 repulsion_temperature=dataset['train']['input'].shape[0] ** -1)
    svgd_state, loss = svgd.train(rng_key, num_iterations, dataset['train']['input'], dataset['train']['class'],
                                  callbacks=[Progbar()])
    print(test_accuracy(model, guide, svgd_state.rng_key, dataset['train'], svgd.get_params(svgd_state)))

banana
49.755676
diabetis
69.31567
german
66.303566
image
73.10204
ringnorm
67.07942
splice
80.4721
twonorm
95.4193
waveform
82.7091


Stein 267.33: 100%|██████████| 3000/3000 [00:10<00:00, 273.56it/s]
Stein 226.89: 100%|██████████| 3000/3000 [00:11<00:00, 250.18it/s]
Stein 355.48: 100%|██████████| 3000/3000 [00:13<00:00, 218.92it/s]
Stein 585.72: 100%|██████████| 3000/3000 [00:15<00:00, 198.52it/s]
Stein 199.2: 100%|██████████| 3000/3000 [00:12<00:00, 232.91it/s] 
Stein 673.16: 100%|██████████| 3000/3000 [00:19<00:00, 155.06it/s]  
Stein 84.314: 100%|██████████| 3000/3000 [00:13<00:00, 229.32it/s]
Stein 144.22: 100%|██████████| 3000/3000 [00:13<00:00, 226.83it/s]
