In [1]:
import matplotlib.pyplot as plt
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import ELBO
from numpyro.infer.util import Predictive
from numpyro.infer.initialization import init_to_value, init_with_noise
from numpyro.contrib.autoguide import AutoDelta
from numpyro.infer.stein import SVGD
from numpyro.distributions import NormalMixture
import numpyro.infer.kernels as kernels
from numpyro.infer.kernels import SteinKernel
from tqdm import tqdm
import seaborn as sns
import os
import jax
import jax.numpy as np
from random import randint
import scipy.io

In [2]:
# From http://theoval.cmp.uea.ac.uk/matlab/default.html
data = scipy.io.loadmat('data/benchmarks.mat')

In [3]:
datasets = {k: {'train': {'input': v['x'][0, 0][v['train'][0, 0][13, :] - 1], 'class': (v['t'][0, 0][v['train'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]}, 'test': {'input': v['x'][0, 0][v['test'][0, 0][13, :] - 1], 'class': (v['t'][0, 0][v['test'][0, 0][13, :] - 1] == 1).astype('float')[:, 0]}} for k, v in data.items() if not str.startswith(k, "__") and not k == 'benchmarks' and v['x'][0, 0].shape[0] > 500}

In [4]:
rng_key = jax.random.PRNGKey(randint(0, int(1e6)))
num_iterations = 3000
num_particles = 100

In [5]:
def model(data, classes=None):
    alpha = numpyro.sample('alpha', dist.InverseGamma(concentration=1.0, rate=0.01))
    w = numpyro.sample('w', dist.Normal(loc=np.zeros(data.shape[1]+1), scale=alpha))
    with numpyro.plate('data', data.shape[0]):
        biased_data = np.concatenate((np.ones((data.shape[0],1)), data), axis=1)
        return numpyro.sample('x', dist.Bernoulli(logits=biased_data @ w), obs=classes)

In [6]:
def test_accuracy(model, guide, rng_key, testset, params, num_pred=100):
    def single_test_accuracy(rng_key, testset, params):
        guide_trace = handlers.trace(handlers.substitute(guide, params)).get_trace(testset['input'])
        model_trace = handlers.trace(handlers.replay(handlers.seed(model, rng_key), guide_trace)).get_trace(testset['input'])
        accuracy = np.count_nonzero(model_trace['x']['value'] == testset['class']) / testset['input'].shape[0] * 100
        return accuracy
    accs = []
    for i in range(num_particles):
        ps = {k: param[i] for k, param in params.items()}
        accs.append(jax.vmap(lambda rnk: single_test_accuracy(rnk, testset, ps))(jax.random.split(rng_key, num_pred)))
    return np.mean(np.stack(accs))

In [7]:
for name, dataset in datasets.items():
    print(name)
    guide = AutoDelta(model, init_strategy=init_with_noise(init_to_value(values={'x': -10.}), noise_scale=1.0))
    svgd = SVGD(model, guide, numpyro.optim.Adagrad(step_size=.05), ELBO(),
            kernels.RBFKernel(), num_particles=num_particles,
            repulsion_temperature=dataset['train']['input'].shape[0] ** -1)
    svgd_state, loss = svgd.run(rng_key, num_iterations, dataset['train']['input'], dataset['train']['class'])
    print(test_accuracy(model, guide, svgd_state.rng_key, dataset['train'], svgd.get_params(svgd_state)))

banana
SVGD 261.65: 100%|██████████| 3000/3000 [00:12<00:00, 245.49it/s]
49.92515
diabetis
SVGD 224.71: 100%|██████████| 3000/3000 [00:13<00:00, 226.05it/s]
69.00015
german
SVGD 353.76: 100%|██████████| 3000/3000 [00:14<00:00, 210.01it/s]
66.16984
image
SVGD 581.25: 100%|██████████| 3000/3000 [00:16<00:00, 186.11it/s]
72.2445
ringnorm
SVGD 196.58: 100%|██████████| 3000/3000 [00:14<00:00, 204.00it/s]
66.83965
splice
SVGD 587.06: 100%|██████████| 3000/3000 [00:19<00:00, 151.22it/s]
79.54571
twonorm
SVGD 63.764: 100%|██████████| 3000/3000 [00:15<00:00, 188.51it/s]
94.11233
waveform
SVGD 134.37: 100%|██████████| 3000/3000 [00:15<00:00, 198.09it/s]
81.30287
