**Goal: Variational Inference on parameters of Dirichlet distribution**

Model:
$$
\begin{aligned}
y_i &\sim \text{Dirichlet}(Y\mid\alpha)\\
\log\alpha &\sim \text{MVNormal}(\log\alpha\mid \log(0.5), I_d)
\end{aligned}
$$

In [1]:
import silence_tensorflow.auto
import json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
from numpy.random import gamma

from tfprojgamma import ProjectedGamma
# Set random seeds for reproducibility
np.random.seed(1)
tf.random.set_seed(1)

In [2]:
# Set up shape parameters
alpha_true = gamma(size = 5, shape = 1.5)
print(alpha_true)
# dist1 = ProjectedGamma(alpha_true, 10)
# Yp    = dist1.sample(200)
dist1 = tfd.Dirichlet(concentration = alpha_true)
Y = tf.cast(dist1.sample(200), tf.float32)

[3.94761907 0.62279234 0.68411215 0.34912511 4.2482095 ]


In [3]:
# prior shape parameters
log_alpha_0 = tf.ones(5, dtype = tf.float32) * np.log(0.5)

**Define the Joint Distribution**

In [4]:
# define generator
def generative_model(log_alpha_0, n_samples):
    log_alpha = yield tfd.JointDistributionCoroutine.Root(
        tfd.MultivariateNormalDiag(
            loc = log_alpha_0, scale_diag = tf.ones(5, dtype = tf.float32), name = 'log_alpha'
            ),
        )
    Yp = yield tfd.Sample(
        tfd.Dirichlet(concentration = tf.exp(log_alpha) * tf.ones((n_samples, 5), dtype = tf.float32)),
        name = 'Yp',
        )

model_joint = tfd.JointDistributionCoroutineAutoBatched(
    lambda: generative_model(log_alpha_0, 200),
    )

model_joint_log_prob = lambda log_alpha: model_joint.log_prob(log_alpha, Y)

**Verifying the structure of the joint model**

In [5]:
model_joint

<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  log_alpha=[5],
  Yp=[200, 5]
) dtype=StructTuple(
  log_alpha=float32,
  Yp=float32
)>

**Mean Field Variational Bayes -- Independence between columns**
$$
\log\alpha \sim \prod_{\ell = 1}^d\text{Normal}(\log\alpha_{\ell} \mid \mu_{q\ell}, \sigma_{q\ell})
$$

In [6]:
q_mu = tf.Variable(log_alpha_0, dtype = tf.float32)
q_scale = tfp.util.TransformedVariable(np.ones(5), tfb.Exp(), dtype = tf.float32)

surrogate_posterior = tfd.MultivariateNormalDiag(loc = q_mu, scale_diag = q_scale, name = 'surrogate 1')

with tf.GradientTape() as g:
    samples = surrogate_posterior.sample(100)
    neg_elbo = -tf.reduce_mean(model_joint_log_prob(samples) - surrogate_posterior.log_prob(samples))
print(g.gradient(neg_elbo, surrogate_posterior.trainable_variables)) # exists!

(<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([-161.76694,  399.62607,  284.2843 ,  511.3727 , -201.02748],
      dtype=float32)>, <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([ 85.67609, 698.04956, 525.29767, 713.74994,  65.32199],
      dtype=float32)>)


In [7]:
path = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn = model_joint_log_prob,
    surrogate_posterior = surrogate_posterior,
    optimizer = tf.optimizers.Adam(.2),
    num_steps = 1000,
    sample_size = 500,
    )

print(tf.exp(q_mu)) # This appears to have worked; the values end in *rougly* the right place.

tf.Tensor([4.446111   0.6254502  0.75169945 0.4220706  4.878486  ], shape=(5,), dtype=float32)


**Gaussian Variational Bayes -- Dependence Between Columns**
$$
\log\alpha \sim \text{MVNormal}(\log\alpha \mid \mu_q, \Sigma_q)
$$

In [8]:
# New Style: Make the variational Parameters
q_nu = tf.Variable(tf.zeros(5, dtype = tf.float32), name = 'Mu Surrogate (mean of log alpha)')
cholbijector = tfb.FillScaleTriL(diag_bijector = tfb.Exp())
q_Lu = tfp.util.TransformedVariable(tf.eye(5), bijector = cholbijector)

surrogate_posterior_mvnorm = tfd.MultivariateNormalTriL(loc = q_nu, scale_tril = q_Lu)

with tf.GradientTape() as g:
    samples = surrogate_posterior_mvnorm.sample(100)
    neg_elbo = -tf.reduce_mean(model_joint_log_prob(samples) - surrogate_posterior_mvnorm.log_prob(samples))
print(g.gradient(neg_elbo, surrogate_posterior_mvnorm.trainable_variables)) # Exists!

(<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([-162.54315,  855.37195,  544.6034 , 1385.1965 , -174.64993],
      dtype=float32)>, <tf.Tensor: shape=(15,), dtype=float32, numpy=
array([ 308.8647  , -122.44166 ,   27.030043,  -40.740246, -138.23465 ,
        177.29811 , 2592.4873  , -977.86127 , -904.6997  ,  440.0187  ,
        170.37288 , 1404.0281  , 1037.0505  ,   64.93281 ,  265.7789  ],
      dtype=float32)>)


In [9]:
path_mvnorm = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn = model_joint_log_prob,
    surrogate_posterior = surrogate_posterior_mvnorm,
    optimizer = tf.optimizers.Adam(.2),
    num_steps = 1000,
    sample_size = 500,
    )
print(tf.exp(q_nu)) # This gives the same basic response as previous.

tf.Tensor([4.4074626 0.6281958 0.7440691 0.4203112 4.9172263], shape=(5,), dtype=float32)


In [10]:
(q_Lu.numpy() @ q_Lu.numpy().T) > 0

array([[ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True]])

I guess it makes some sense that the posterior covariance between parameters of the Dirichlet would be positive, despite the covariance between *values* of the Dirichlet being negative.