# Variational Inference: Multivariate Gaussian

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats

## Utilities

In [2]:
def genGauss(n=100,loc=np.array([0,0]), scale=np.array([1,1])):
    dim = loc.shape[0]
    X = np.random.normal(loc=loc, scale=scale, size=(n,dim))
    return X

In [3]:
def plot_data(X):
    plt.figure(figsize=(10,7))
    plt.scatter(X[:,0], X[:,1])
#     plt.legend()
    plt.show()

In [22]:
from scipy.special import gamma as gamma_func

def elbo(λ, a, b):
    return -0.5*np.log(λ) + np.log(gamma_func(a)) - a*np.log(b) 

In [23]:
def var_inference(data, num_iter=10):
    r_0 = np.ones(data.shape[1])
    μ_0 = np.zeros(data.shape[1])
    a_0 = np.ones(data.shape[1])
    b_0 = np.ones(data.shape[1])
    N = len(data)
    μ_N = (r_0*μ_0 + N * np.mean(data, axis=0)) / (r_0 + N)
    a_N = a_0 + (N+1)/2
    r_N = np.ones(data.shape[1])
    
    ELBO_prev = -1e9
    for i in range(num_iter):
        E_mu_2 = 1/r_N + np.power(μ_N,2)
        b_N = b_0 + r_0*(E_mu_2 + μ_0**2 - 2*μ_N*μ_0) + 0.5*np.sum(np.power(data,2) + E_mu_2 - 2*μ_N*data)
        r_N = (r_0 + N)*(a_N / b_N)
        
        ELBO_cur = elbo(r_N, a_N, b_N)
#         assert ELBO_cur >= ELBO_prev
#         ELBO_prev = ELBO_cur
        print('ELBO at step %3i/%3i is'%(i+1, num_iter), ELBO_cur)
        print('a_N=', a_N)
        print('b_N=', b_N)
        print('r_N=', r_N)
    return μ_N

## Run test

In [33]:
num_samples = 300
true_mu = np.array([2,3,5])
true_sigma = np.array([1,2,3])
X = genGauss(num_samples, true_mu, true_sigma)

In [36]:
μ_estimated = var_inference(X, num_iter=8)
print(μ_estimated)
print(true_mu)
print(true_mu - μ_estimated)

ELBO at step   1/  8 is [-576.7192618  -577.06716827 -577.98615983]
a_N= [151.5 151.5 151.5]
b_N= [2458.17475213 2463.84495657 2478.88573068]
r_N= [18.55095939 18.50826688 18.39596696]
ELBO at step   2/  8 is [-547.93733855 -548.35820799 -549.4689558 ]
a_N= [151.5 151.5 151.5]
b_N= [2031.57294058 2037.24326936 2052.2843733 ]
r_N= [22.44640056 22.38392473 22.21987391]
ELBO at step   3/  8 is [-547.62342085 -548.04516494 -549.15820939]
a_N= [151.5 151.5 151.5]
b_N= [2027.35383967 2033.02416845 2048.0652724 ]
r_N= [22.49311349 22.43037771 22.26564779]
ELBO at step   4/  8 is [-547.62031294 -548.04206569 -549.1551329 ]
a_N= [151.5 151.5 151.5]
b_N= [2027.31211265 2032.98244144 2048.02354539]
r_N= [22.49357645 22.4308381  22.26610143]
ELBO at step   5/  8 is [-547.6202822  -548.04203504 -549.15510248]
a_N= [151.5 151.5 151.5]
b_N= [2027.31169997 2032.98202876 2048.02313271]
r_N= [22.49358103 22.43084265 22.26610592]
ELBO at step   6/  8 is [-547.62028189 -548.04203473 -549.15510217]
a_N= [1