diff --git a/pymc3/variational/svgd.py b/pymc3/variational/svgd.py index 8d4663bd29..871c796011 100644 --- a/pymc3/variational/svgd.py +++ b/pymc3/variational/svgd.py @@ -7,7 +7,7 @@ import theano import theano.tensor as tt from tqdm import tqdm -from .updates import adagrad +from .updates import adagrad, apply_momentum import pymc3 as pm from pymc3.model import modelcontext @@ -91,7 +91,8 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01, logp_grad_vec = _make_vectorized_logp_grad(vars, model, theta) svgd_grad = -1 * _svgd_gradient(vars, model, theta, logp_grad_vec) # maximize - svgd_updates = optimizer([svgd_grad], [theta], learning_rate=1e-3) + svgd_updates = optimizer([svgd_grad], [theta], learning_rate=1.) + svgd_updates = apply_momentum(svgd_updates, [theta], momentum=0.9) i = tt.iscalar('i') svgd_step = theano.function([i], [i],