In [29]:
from VariationalBayes import VectorParam, ScalarParam, PosDefMatrixParam, ModelParamsDict
import math

from autograd import grad, hessian, jacobian, hessian_vector_product
from autograd.core import primitive
from autograd.numpy.numpy_grads import unbroadcast

import autograd.numpy as np
import autograd.numpy.random as npr

import copy
import scipy
from scipy import optimize
from scipy import stats

In [71]:
# Build an object to contain a variational approximation to a K-dimensional multivariate normal.

K = 5
mvn_par = ModelParamsDict()

mvn_par.push_param(VectorParam('e_mu', K))
mvn_par.push_param(VectorParam('var_mu', K, lb=0))

mvn_par['e_mu'].set(np.full(K, 0.1))
mvn_par['var_mu'].set(np.full(K, 2.))



In [60]:

def Log1mInvLogit(u):
    # log(1 + exp(-u)) = u + log(1 + exp(u))
    return -np.log1p(np.exp(u))
    



    
    

In [72]:
# Generate data

N = 200000
true_mu = np.random.rand(K).T - 0.5
x_mat = np.full([N, K], float('nan'))
y_vec = np.full([N], float('nan'))
for n in range(N):
    x_mat[n, :] = np.random.random(K) - 0.5
    y_vec[n] = np.random.random(1) < Logistic(np.dot(x_mat[n, :], true_mu))


In [73]:
# Define the variational objective
def LogLikelihood(x_row, y, e_mu, mu_var, std_draws):
    # logit(rho) is the probability of y being 1, which has a normal distribution under q().
    rho_mean = np.dot(x_row, e_mu)
    rho_sd = np.sqrt(np.sum(x_row * x_row * mu_var))
    e_log_1mrho = 0.
    for std_draw in std_draws:
        e_log_1mrho += Log1mInvLogit(std_draw * rho_sd + rho_mean)
    e_log_1mrho /= len(std_draws)
    # e_log_1mrho = np.mean(Log1mInvLogit(std_draws * rho_sd + rho_mean))
    return y * rho_mean + e_log_1mrho

# Just to see how much faster a simpler function is
def LogLikelihoodJunk(x_row, y, e_mu, mu_var, std_draws):
    # logit(rho) is the probability of y being 1, which has a normal distribution under q().
    rho_mean = np.dot(x_row, e_mu)
    rho_sd = np.sum(mu_var)
    e_log_1mrho = 0.
    for std_draw in std_draws:
        e_log_1mrho += np.exp(std_draw * rho_sd + rho_mean)
    e_log_1mrho /= len(std_draws)
    # e_log_1mrho = np.mean(Log1mInvLogit(std_draws * rho_sd + rho_mean))
    return y * rho_mean + e_log_1mrho


def UnivariateNormalExpectedEntropy(var_mu):
    return 0.5 * np.log(var_mu)


def Elbo(y_vec, x_mat, mvn_par_elbo, num_draws=10):
    var_mu = mvn_par_elbo['var_mu'].get()
    e_mu = mvn_par_elbo['e_mu'].get()

    num_draws = 10
    draw_spacing = 1 / float(num_draws + 1)
    target_quantiles = np.linspace(draw_spacing, 1 - draw_spacing, num_draws)
    std_draws = scipy.stats.norm.ppf(target_quantiles)

    assert y_vec.size == x_mat.shape[0]
    assert e_mu.size == x_mat.shape[1]

    ll = 0
    for n in range(y_vec.size):
        #ll += LogLikelihood(x_mat[n, :], y_vec[n], e_mu, var_mu, std_draws)
        ll += LogLikelihoodJunk(x_mat[n, :], y_vec[n], e_mu, var_mu, std_draws)

    entropy = sum([ UnivariateNormalExpectedEntropy(var_mu_k) for var_mu_k in var_mu])

    return ll + entropy


class KLWrapper():
    def __init__(self, mvn_par, x_mat, y_vec, num_draws):
        self.__mvn_par_ad = copy.deepcopy(mvn_par)
        self.x_mat = x_mat
        self.y_vec = y_vec
        self.num_draws = num_draws
        
    def Eval(self, free_par_vec, verbose=False):
        self.__mvn_par_ad.set_free(free_par_vec)
        kl = -Elbo(self.y_vec, self.x_mat, self.__mvn_par_ad, num_draws=self.num_draws)
        if verbose: print kl
        return kl
    
    # Return a posterior moment of interest as a function of
    # unconstrained parameters.  In this case it is a bit silly,
    # but in full generality posterior moments may be a complicated
    # function of moment parameters.
    def GetMu(self, free_par_vec):
        self.__mvn_par_ad.set_free(free_par_vec)
        return self.__mvn_par_ad['e_mu'].get()

    
kl_wrapper = KLWrapper(mvn_par, x_mat, y_vec, 10)
KLGrad = grad(kl_wrapper.Eval)
KLHess = hessian(kl_wrapper.Eval)
MomentJacobian = jacobian(kl_wrapper.GetMu)
KLHessVecProd = hessian_vector_product(kl_wrapper.Eval)  

In [74]:
# Check that the AD functions are working:
mvn_par['e_mu'].set(true_mu)
mvn_par['var_mu'].set(np.abs(true_mu) * 0.1)
free_par_vec = mvn_par.get_free()
print kl_wrapper.Eval(free_par_vec)
if K < 10:
    print 'Grad:'
    print KLGrad(free_par_vec)
    print 'Hess:'
    print KLHess(free_par_vec)
    print 'Jac:'
    print MomentJacobian(free_par_vec)
    print 'Hess vector product:'
    print KLHessVecProd(free_par_vec, free_par_vec + 1)

-916375487.479
Grad:
[  2.56821327e+08   2.14038365e+08   2.22173668e+08   2.65241202e+08
   2.59800995e+08  -1.03331548e+09  -1.73837905e+09  -1.62902433e+09
  -1.04581870e+09  -1.34448324e+09]
Hess:
[[ -1.08601594e+08  -5.97832492e+07  -6.20320320e+07  -7.46762397e+07
   -7.26689854e+07   2.89591835e+08   4.87189427e+08   4.56542221e+08
    2.93095923e+08   3.76798155e+08]
 [ -5.97832492e+07  -9.81902660e+07  -5.17464564e+07  -6.23715383e+07
   -6.05022931e+07   2.41351910e+08   4.06033889e+08   3.80491864e+08
    2.44272291e+08   3.14031487e+08]
 [ -6.20320320e+07  -5.17464564e+07  -1.00303898e+08  -6.47180963e+07
   -6.28799236e+07   2.50525114e+08   4.21466257e+08   3.94953442e+08
    2.53556491e+08   3.25967066e+08]
 [ -7.46762397e+07  -6.23715383e+07  -6.47180963e+07  -1.12241665e+08
   -7.54802529e+07   2.99086183e+08   5.03162067e+08   4.71510084e+08
    3.02705154e+08   3.89151587e+08]
 [ -7.26689854e+07  -6.05022931e+07  -6.28799236e+07  -7.54802529e+07
   -1.09826192e+08   

In [70]:
import timeit

time_num = 1

print 'Function time:'
print timeit.timeit(lambda: kl_wrapper.Eval(free_par_vec), number=time_num) / time_num

print 'Grad time:'
print timeit.timeit(lambda: KLGrad(free_par_vec), number=time_num) / time_num

print 'Hessian vector product time:'
print timeit.timeit(lambda: KLHessVecProd(free_par_vec, free_par_vec + 1), number=time_num) / time_num

if K < 10:
    print 'Hessian time:'
    print timeit.timeit(lambda: KLHess(free_par_vec), number=time_num) / time_num


Function time:
0.0106568336487
Grad time:
0.308406114578
Hessian vector product time:
0.609016895294


In [45]:
import cProfile

cProfile.run('kl_wrapper.Eval(free_par_vec)', '/tmp/cprofilestats_func.prof')
cProfile.run('KLHess(free_par_vec)', '/tmp/cprofilestats_hess.prof')
cProfile.run('KLGrad(free_par_vec)', '/tmp/cprofilestats_grad.prof')

In [44]:
import pstats
p = pstats.Stats('/tmp/cprofilestats')
p.strip_dirs().sort_stats('cumulative').print_stats(100)
# p.strip_dirs().sort_stats('cumulative').print_callers(100)

Wed Feb 22 18:17:23 2017    /tmp/cprofilestats

         12676517 function calls (12415228 primitive calls) in 7.331 seconds

   Ordered by: cumulative time
   List reduced from 183 to 100 due to restriction <100>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.010    0.010    7.331    7.331 <string>:1(<module>)
      2/1    0.000    0.000    7.321    7.321 convenience_wrappers.py:53(jacfun)
214857/14    0.124    0.000    7.088    0.506 {map}
       11    0.036    0.003    7.088    0.644 core.py:18(<lambda>)
       11    0.886    0.081    7.051    0.641 core.py:31(backward_pass)
   221916    0.164    0.000    2.397    0.000 core.py:76(vjp)
222222/175779    0.888    0.000    1.606    0.000 core.py:59(__call__)
   179768    0.650    0.000    1.515    0.000 core.py:166(toposort)
   265957    0.276    0.000    1.147    0.000 core.py:253(assert_vspace_match)
   179757    0.153    0.000    0.854    0.000 core.py:127(vsum)
      2/1    0.000    0.000   

<pstats.Stats instance at 0x7fde9c733e18>

In [None]:
# Set initial values.

# Is there not a better way than reduce?
true_means = reduce(lambda x, y: x + y, x_draws) / N

mvn_par['e_mu'].set(np.full(K, 1.0))
init_par_vec = mvn_par.get_free()

In [None]:
# Optimize.

print 'Running BFGS'
vb_opt_bfgs = optimize.minimize(
    lambda par: kl_wrapper.Eval(par, verbose=True), init_par_vec,
    method='bfgs', jac=KLGrad, tol=1e-6)
print 'Running Newton Trust Region'
vb_opt = optimize.minimize(
    lambda par: kl_wrapper.Eval(par, verbose=True),
    vb_opt_bfgs.x, method='trust-ncg', jac=KLGrad, hess=KLHess)
mvn_par_opt = copy.deepcopy(mvn_par)
mvn_par_opt.set_free(vb_opt.x)
print 'Done.'

In [None]:
# The mean parameters match, as expected.
print mvn_par_opt['e_mu']
print true_means

In [None]:
# LRVB
moment_jac = MomentJacobian(vb_opt.x)
opt_hess = KLHess(vb_opt.x)
mu_cov = np.matmul(moment_jac, np.linalg.solve(opt_hess, moment_jac.T))

# The VB variance is underestimated.
print np.diag(mu_cov)
print mvn_par_opt['var_mu']