In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

import numpyro
from numpyro.contrib.einstein import Stein, kernels
from numpyro.contrib.einstein.callbacks import Progbar
from numpyro.distributions import NormalMixture
from numpyro.infer import ELBO, SVI
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_with_noise, init_to_value

In [2]:
rng_key = jax.random.PRNGKey(42)
num_iterations = 1500

In [3]:
def model():
    numpyro.sample('x', NormalMixture(jnp.array([1 / 3, 2 / 3]),
                                      jnp.array([-2., 2.]), jnp.array([1., 1.])))


guide = AutoDelta(model)
init_strategy = init_with_noise(init_to_value(values={'x': -10.}), noise_scale=1.0)

In [4]:
kernels_fns = {'rbf_kernel': kernels.RBFKernel(),
               'linear_kernel': kernels.LinearKernel(),
               'random_kernel': kernels.RandomFeatureKernel(),
               'imq_kernel': kernels.IMQKernel(),
               'matrix_kernel': kernels.MixtureKernel([0.5, 0.5],
                                                      [kernels.LinearKernel(),
                                                       kernels.RandomFeatureKernel()])}

svi = SVI(model, guide, numpyro.optim.Adagrad(step_size=1.0), ELBO(), init_strategy=init_strategy)

for label, kernel_fn in kernels.items():
    svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=1.0), ELBO(),
                 kernels.LinearKernel(), init_strategy=init_strategy, num_particles=100)
    state, _ = svi.train(rng_key, num_iterations, callbacks=[Progbar()])
    sns.kdeplot(svgd.get_params(state)['auto_x'], label=label)
    plt.legend()
    plt.show()
    plt.clf()

SVI 2.0169: 100%|██████████| 1500/1500 [00:00<00:00, 2251.82it/s]


-1.9972894
