In [None]:
import numpy as np
import elboflow as ef
import tensorflow as tf
import scipy.stats
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# Generate some data for linear regression
np.random.seed(1)
num_samples = 100
num_dims = 3

x = np.random.normal(0, 1, (num_samples, num_dims))
theta = np.random.normal(0, 1, num_dims)
predictor = np.dot(x, theta)
tau = np.random.gamma(1)
y = predictor + np.random.normal(0, 1 / np.sqrt(tau), num_samples)

In [None]:
# Construct a graph
with tf.Graph().as_default() as graph:
    # Define the factors
    q_theta = ef.NormalDistribution(
        tf.get_variable('theta_mean', num_dims, initializer=tf.random_normal_initializer()),
        tf.exp(tf.get_variable('theta_log_precision', num_dims, initializer=tf.random_normal_initializer()))
    )
    q_tau = ef.GammaDistribution(
        tf.exp(tf.get_variable('tau_shape', [], initializer=tf.random_normal_initializer())),
        tf.exp(tf.get_variable('tau_scale', [], initializer=tf.random_normal_initializer())),
    )
    
    # Evaluate the expected log joint distribution
    log_likelihood = ef.NormalDistribution.linear_log_likelihood(y, x, q_theta, q_tau, True)
    log_joint = log_likelihood + ef.NormalDistribution(0, 1e-3).log_pdf(q_theta, True) + \
        ef.GammaDistribution(1e-3, 1e-3).log_pdf(q_tau, True)
    # Add the entropy
    elbo = log_joint + tf.reduce_sum(q_theta.entropy) + q_tau.entropy
    
    # Add a training operation
    train_op = tf.train.AdamOptimizer(0.1).minimize(-elbo)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
sess.run(elbo)

In [None]:
# Maximize the ELBO
elbos = []

for _ in tqdm_notebook(range(1000)):
    _, _elbo = sess.run([train_op, elbo])
    elbos.append(_elbo)
    
plt.plot(-np.asarray(elbos))
plt.yscale('log')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

with graph.as_default():
    ef.plot_pdf(sess, q_theta, reference=theta, ax=ax1)
    ef.plot_pdf(sess, q_tau, reference=tau, ax=ax2)

In [None]:
with graph.as_default():
    foo = ef.plot_comparison(sess, q_theta, theta, ax=plt.gca())