In [1]:
from VariationalBayes import VectorParam, ScalarParam, PosDefMatrixParam, ModelParamsDict
import math

from autograd import grad, hessian, jacobian, hessian_vector_product
from autograd.core import primitive
from autograd.numpy.numpy_grads import unbroadcast

import autograd.numpy as np
import autograd.numpy.random as npr

import copy
import scipy
from scipy import optimize
from scipy import stats

In [5]:
def Log1mInvLogit(u):
    # log(1 - p) = log(1 / (1 + exp(u))) = -log(1 + exp(u))
    return -np.log1p(np.exp(u))

def Logistic(u):
    return np.exp(u) / (1 + np.exp(u))

In [105]:
N = 20000
K = 100

# Build an object to contain a variational approximation to a K-dimensional multivariate normal.
mvn_par = ModelParamsDict()

mvn_par.push_param(VectorParam('e_mu', K))
mvn_par.push_param(VectorParam('var_mu', K, lb=0))

mvn_par['e_mu'].set(np.full(K, 0.1))
mvn_par['var_mu'].set(np.full(K, 2.))

# Generate data
true_mu = np.random.rand(K) - 0.5
x_mat = np.full([N, K], float('nan'))
y_vec = np.full([N], float('nan'))
for n in range(N):
    x_mat[n, :] = np.random.random(K) - 0.5
    y_vec[n] = np.random.random(1) < Logistic(np.dot(x_mat[n, :], true_mu))


In [106]:
def LogLikelihoodVectorized(x_mat, y_vec, e_mu, mu_var, std_draws):
    rho_sd = np.sqrt(np.einsum('ik,ik,k->i', x_mat, x_mat, mu_var))
    rho_mean = np.einsum('ij,j->i', x_mat, e_mu)
    z = np.einsum('i,j->ij', rho_sd, std_draws) + np.expand_dims(rho_mean, 1)

    # The sum is over observations and draws, so dividing by the draws size
    # gives the sum of sample expectations over the draws.
    # log(1 - p) = log(1 / (1 + exp(u))) = -log(1 + exp(u))
    logit_term = -np.sum(np.log1p(np.exp(z))) / std_draws.size
    y_term = np.sum(y_vec * rho_mean)
    return y_term + logit_term


def UnivariateNormalExpectedEntropy(var_mu):
    return 0.5 * np.sum(np.log(var_mu))


def Elbo(y_vec, x_mat, mvn_par_elbo, num_draws=10):
    var_mu = mvn_par_elbo['var_mu'].get()
    e_mu = mvn_par_elbo['e_mu'].get()

    num_draws = 10
    draw_spacing = 1 / float(num_draws + 1)
    target_quantiles = np.linspace(draw_spacing, 1 - draw_spacing, num_draws)
    std_draws = scipy.stats.norm.ppf(target_quantiles)

    assert y_vec.size == x_mat.shape[0]
    assert e_mu.size == x_mat.shape[1]

    ll = LogLikelihoodVectorized(x_mat, y_vec, e_mu, var_mu, std_draws)
    entropy = 0.5 * np.sum(np.log(var_mu))
    #entropy = UnivariateNormalExpectedEntropy(var_mu)

    return ll + entropy


class KLWrapper():
    def __init__(self, mvn_par, x_mat, y_vec, num_draws):
        self.__mvn_par_ad = copy.deepcopy(mvn_par)
        self.x_mat = x_mat
        self.y_vec = y_vec
        self.num_draws = num_draws
        
    def Eval(self, free_par_vec, verbose=False):
        self.__mvn_par_ad.set_free(free_par_vec)
        kl = -Elbo(self.y_vec, self.x_mat, self.__mvn_par_ad, num_draws=self.num_draws)
        if verbose: print kl
        return kl
    
    # Return a posterior moment of interest as a function of
    # unconstrained parameters.  In this case it is a bit silly,
    # but in full generality posterior moments may be a complicated
    # function of moment parameters.
    def GetMu(self, free_par_vec):
        self.__mvn_par_ad.set_free(free_par_vec)
        return self.__mvn_par_ad['e_mu'].get()



In [107]:
kl_wrapper = KLWrapper(mvn_par, x_mat, y_vec, 10)
KLGrad = grad(kl_wrapper.Eval)
KLHess = hessian(kl_wrapper.Eval)
MomentJacobian = jacobian(kl_wrapper.GetMu)
KLHessVecProd = hessian_vector_product(kl_wrapper.Eval)  

# Check that the AD functions are working:
mvn_par['e_mu'].set(true_mu)
mvn_par['var_mu'].set(np.abs(true_mu) * 0.1)
free_par_vec = mvn_par.get_free()
print kl_wrapper.Eval(free_par_vec)
if K < 10:
    print 'Grad:'
    print KLGrad(free_par_vec)
    print 'Hess:'
    print KLHess(free_par_vec)
    print 'Jac:'
    print MomentJacobian(free_par_vec)
    print 'Hess vector product:'
    print KLHessVecProd(free_par_vec, free_par_vec + 1)

102702.375647


In [108]:
import timeit

time_num = 10

print 'Function time:'
print timeit.timeit(lambda: kl_wrapper.Eval(free_par_vec), number=time_num) / time_num

print 'Grad time:'
print timeit.timeit(lambda: KLGrad(free_par_vec), number=time_num) / time_num

print 'Hessian vector product time:'
print timeit.timeit(lambda: KLHessVecProd(free_par_vec, free_par_vec + 1), number=time_num) / time_num

print 'Hessian time:'
print timeit.timeit(lambda: KLHess(free_par_vec), number=time_num) / time_num


Function time:
0.0119942188263
Grad time:
0.0181173086166
Hessian vector product time:
0.0369627952576
Hessian time:
2.93001360893


In [67]:
import cProfile

profile = False
if profile:
    cProfile.run('kl_wrapper.Eval(free_par_vec)', '/tmp/cprofilestats_func.prof')
    cProfile.run('KLHess(free_par_vec)', '/tmp/cprofilestats_hess.prof')
    cProfile.run('KLGrad(free_par_vec)', '/tmp/cprofilestats_grad.prof')

# A better way to visualize this https://jiffyclub.github.io/snakeviz/
# snakeviz /tmp/cprofilestats_hess.prof

In [110]:
# Set initial values.
xtx = np.matmul(x_mat.T, x_mat)
mu_reg = np.linalg.solve(xtx, np.matmul(x_mat.T, y_vec))

mvn_par['e_mu'].set(mu_reg)
mvn_par['var_mu'].set(np.full(K, 1.))
init_par_vec = mvn_par.get_free()

In [None]:
# Optimize.
print 'Running BFGS'
vb_opt_bfgs = optimize.minimize(
    lambda par: kl_wrapper.Eval(par, verbose=True), init_par_vec,
    method='bfgs', jac=KLGrad, tol=1e-6)
print 'Running Newton Trust Region'
vb_opt = optimize.minimize(
    lambda par: kl_wrapper.Eval(par, verbose=True),
    vb_opt_bfgs.x, method='trust-ncg', jac=KLGrad, hess=KLHess)
mvn_par_opt = copy.deepcopy(mvn_par)
mvn_par_opt.set_free(vb_opt.x)
print 'Done.'

Running BFGS
23585.1071997
22324.0309127
20312.3601863
30440.7773204
20047.1742141
19725.4863251
19158.0105684
18221.1566033
16764.0732477
23359.3288387
16570.9366342
16278.9191442
17751.1322485
16238.307753
16160.6396242
16026.4170607
15771.6683809
15449.2118817
15481.1045371
15296.3545072
15164.7785416
15153.7789333
15132.3843558
15062.7782711
14926.0734808
14684.9332012
14613.4472738
14480.2571743
14506.8516664
14419.2310976
14314.2303931
14347.729729
14268.5837817
14188.405458
14115.5666034
14004.5787093
13843.8378821
13988.6873791
13788.5400356
13688.0289203
13818.377073
13657.2017786
13599.5591292
13544.5384318
13478.2035614
13393.1661773
13621.2939628
13374.8023686
13341.536331
13334.6989842
13321.1575415
13272.1622116
13228.2042251
13245.0807301
13209.561758
13173.6481262
13061.9778072
13882.727208
13042.4808753
13037.2751121
13027.0080077
12990.7160911
13119.8916185
12980.8379181
12962.5717675
12933.0589833
13068.2559523
12928.5113124
12919.5187754
12887.3024672
12890.2561884


In [99]:
# The mean parameters match, as expected.
print np.vstack((mvn_par_opt['e_mu'].get(), true_mu)).T

[[ 0.03972175  0.03602481]
 [ 0.34474658  0.28744943]
 [-0.15526871 -0.10642935]
 [-0.10595918 -0.10104007]
 [ 0.34608339  0.2459426 ]
 [ 0.24852655  0.3611111 ]
 [-0.32025061 -0.34503632]
 [ 0.42270831  0.40722197]
 [ 0.33635276  0.30287017]
 [-0.18016719 -0.15084181]
 [-0.26960135 -0.35976978]
 [ 0.2433167   0.22876347]
 [ 0.37648733  0.33610404]
 [ 0.18495355  0.21945538]
 [-0.12999183 -0.17547628]
 [ 0.50363108  0.49322671]
 [-0.40239219 -0.43670232]
 [ 0.3802229   0.42390039]
 [ 0.40510501  0.35875045]
 [-0.37619394 -0.34651961]
 [ 0.50476477  0.4536665 ]
 [-0.32689877 -0.33462315]
 [-0.1130267  -0.15972893]
 [-0.20568697 -0.13699205]
 [-0.01741392 -0.02381676]
 [ 0.37405925  0.4464989 ]
 [ 0.08545523  0.09282559]
 [ 0.1532413   0.13079413]
 [ 0.46029478  0.44468218]
 [-0.35001903 -0.31313023]
 [-0.14182729 -0.21457137]
 [ 0.45501067  0.4720647 ]
 [ 0.1565683   0.14717246]
 [ 0.04646432  0.04107799]
 [-0.37272817 -0.3582593 ]
 [-0.27828877 -0.23237042]
 [-0.0248029  -0.03265384]
 

In [100]:
# LRVB
moment_jac = MomentJacobian(vb_opt.x)
opt_hess = KLHess(vb_opt.x)
mu_cov = np.matmul(moment_jac, np.linalg.solve(opt_hess, moment_jac.T))

# The VB variance is underestimated.
print np.diag(mu_cov)
print mvn_par_opt['var_mu']

[ 0.00261606  0.00262803  0.00263939  0.00263788  0.00264194  0.0026511
  0.00261329  0.00262633  0.00262754  0.00264534  0.00262042  0.00262479
  0.00263479  0.00263222  0.00261616  0.00265636  0.00262059  0.00261051
  0.0026039   0.00264844  0.0026166   0.00264532  0.00264977  0.00262649
  0.00261183  0.00268524  0.00262318  0.00262324  0.00263058  0.00262665
  0.00258853  0.00264293  0.0026046   0.00260108  0.00267234  0.00264285
  0.00260006  0.00263511  0.00261834  0.0026435   0.00265169  0.00264181
  0.00260559  0.00262355  0.00264402  0.00264773  0.00264964  0.00265635
  0.00264481  0.00265818]
var_mu: [ 0.00419352  0.00420885  0.00423482  0.00423043  0.00423667  0.00425047
  0.00418842  0.004209    0.00421243  0.00424299  0.00420174  0.0042097
  0.00422118  0.0042229   0.00419474  0.00425071  0.00419174  0.00418174
  0.00417598  0.00424355  0.00418947  0.00423973  0.00424825  0.00421205
  0.00418598  0.00430262  0.00420812  0.0042025   0.00421254  0.00420603
  0.00415222  0.004