In [None]:
import variational_bayes as vb
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
mean = np.asarray((5, -2))
precision = np.asarray([[4, -1], [-1, 2]])
n = 100

x = np.random.multivariate_normal(mean, np.linalg.inv(precision), n)

In [None]:
q_mean = vb.MultiNormalDistribution(np.zeros(2), 1e-3 * np.eye(2))
q_precision = vb.WishartDistribution(np.asarray(3), np.eye(2))
likelihoods = [
    vb.MultiNormalLikelihood(x, q_mean, q_precision),
    vb.MultiNormalLikelihood(q_mean, 0, np.eye(2) * 1e-3),
    vb.WishartLikelihood(q_precision, 2, np.eye(2) * 1e-3)
]

model = vb.Model({'mean': q_mean, 'precision': q_precision}, likelihoods)
elbo = model.update(10)

In [None]:
plt.plot(elbo)
print("Mean     : %s +- %s" % (model['mean'].mean, model['mean'].std))
print("Precision: %s +- %s" % (model['precision'].mean, model['precision'].std))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
vb.plot_comparison(q_mean, mean, ax=ax1)
vb.plot_comparison(q_precision, precision, ax=ax2)