In [None]:
from random import randint

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import Trace_ELBO
from numpyro.contrib.einstein import kernels, Stein
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_with_noise, init_to_value
from numpyro.contrib.einstein.callbacks import Progbar

In [None]:
rng_key = jax.random.PRNGKey(randint(0, int(1e6)))
num_iterations = 6000

In [None]:
def model():
    numpyro.sample('x', dist.MultivariateNormal(loc=jnp.array([5., 10.]), covariance_matrix=[[3., 5.], 
                                                                                             [5., 10.]]))

In [None]:
guide = AutoDelta(model)
svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=1.0), Trace_ELBO(),
             kernels.RBFKernel(mode='vector') ,
             init_strategy=init_with_noise(init_to_value(values={'x': jnp.array([-10., 30.])}), noise_scale=1.0),
             num_particles=100)
svgd_state = svgd.init(rng_key)

In [None]:
sns.kdeplot(x=svgd.get_params(svgd_state)['x_auto_loc'][:, 0], y=svgd.get_params(svgd_state)['x_auto_loc'][:, 1])


In [None]:
svgd_state, loss = svgd.run(rng_key, num_iterations, callbacks=[Progbar()])

In [None]:
plt.clf()
sns.kdeplot(x=svgd.get_params(svgd_state)['x_auto_loc'][:, 0], y=svgd.get_params(svgd_state)['x_auto_loc'][:, 1])

In [None]:
svgd.get_params(svgd_state)['x_auto_loc']

In [None]:
guide = AutoDelta(model)
svgd = Stein(model, guide, numpyro.optim.Adagrad(step_size=1.0), Trace_ELBO(), kernels.GraphicalKernel(),
             init_strategy=init_with_noise(init_to_value(values={'x': jnp.array([-10., 30.])}), noise_scale=1.0),
             num_particles=100)
svgd_state = svgd.init(rng_key)

In [None]:
svgd_state, loss = svgd.run(rng_key, num_iterations, callbacks=[Progbar()])

In [None]:
plt.clf()
sns.kdeplot(x=svgd.get_params(svgd_state)['x_auto_loc'][:, 0], y=svgd.get_params(svgd_state)['x_auto_loc'][:, 1])