In [None]:
import numpy as np
import elboflow as ef
import tensorflow as tf
import scipy.stats
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# Generate some data from a multivariate normal distribution
np.random.seed(2)
num_samples = 1000
num_dims = 2

mean = np.random.normal(0, 1, num_dims) + (3, -2)
precision = scipy.stats.wishart.rvs(num_dims, np.eye(num_dims))
cov = np.linalg.inv(precision)

x = np.random.multivariate_normal(mean, cov, num_samples)

fig, ax = plt.subplots(1, 1)
ax.scatter(*x.T)
ax.axhline(mean[1], ls=':')
ax.axvline(mean[0], ls=':')
ax.set_aspect('equal')

In [None]:
# Define the factors
q_mu = ef.MultiNormalDistribution(
    ef.get_variable('mu_mean', num_dims),
    ef.get_positive_definite_variable('mu_precision', (num_dims, num_dims))
)
q_tau = ef.WishartDistribution(
    ef.get_positive_variable('tau_dof', []) + float(num_dims - 1),
    ef.get_positive_definite_variable('tau_scale', (num_dims, num_dims)),
)

prior_mu = ef.NormalDistribution(0.0, 1e-3)
prior_tau = ef.WishartDistribution(2.0, 2.0 * np.eye(num_dims))

# This term evaluates the likelihood of all the data points for all possible community assignments
log_likelihood = ef.MultiNormalDistribution.log_likelihood(x, q_mu, q_tau)
log_joint = tf.reduce_sum(log_likelihood) + \
    tf.reduce_sum(prior_mu.log_proba(q_mu)) + \
    tf.reduce_sum(prior_tau.log_proba(q_tau))
entropy = tf.reduce_sum(q_mu.entropy) + tf.reduce_sum(q_tau.entropy)
elbo = log_joint + entropy

# Add a training operation
train_op = tf.train.AdamOptimizer(1).minimize(-elbo)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

sess.run(elbo)

In [None]:
# Maximize the ELBO
elbos = []
precisions = []

for _ in tqdm_notebook(range(2000)):
    _, _elbo, _precision = sess.run([train_op, elbo, q_tau.statistic(1)])
    elbos.append(_elbo)
    precisions.append(_precision)
    
plt.plot(-np.asarray(elbos))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ef.plot_comparison(sess, q_mu, mean, ax=ax1)
ef.plot_comparison(sess, q_tau, precision, ax=ax2)