**Goal: Variational Inference on parameters of Dirichlet distribution**

Model:
$$
\begin{aligned}
y_i &\sim \mathcal{PG}_p(Y\mid\alpha_i)\\
\alpha_i &\sim G\\
G &\sim \mathcal{PY}(\eta, d, G_0)\\
\end{aligned}
~\hspace{1cm}~
\begin{aligned}
G_0 &= \prod_{\ell = 1}^d\mathcal{G}(\alpha_{\ell}\mid\xi,\tau)\\
\xi &\sim \mathcal{G}(\xi\mid a, b)\\
\tau &\sim \mathcal{G}(\tau\mid c,d)\\
\end{aligned}
$$

In [3]:
import silence_tensorflow.auto
import json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
from numpy.random import gamma

from tfprojgamma import ProjectedGamma
# Set random seeds for reproducibility
np.random.seed(1)
tf.random.set_seed(1)

*Declare Random Sample*

In [4]:
alpha_true = gamma(size = (3,5), shape = 1.5)
pi_true = (0.3, 0.5, 0.2)

MixProjectedGamma = tfp.distributions.MixtureSameFamily(
    mixture_distribution = tfd.Categorical(
        probs = pi_true,
        ),
    components_distribution = ProjectedGamma(
        concentration = alpha_true,
        )
    )
print(alpha_true)
print(pi_true)
Yp = tf.cast(MixProjectedGamma.sample(1000), tf.float32)

[[3.94761907 0.62279234 0.68411215 0.34912511 4.2482095 ]
 [0.52249041 3.56568716 0.05637186 2.86972854 0.33627595]
 [0.99015762 0.45215243 0.33593299 2.89134277 2.43528465]]
(0.3, 0.5, 0.2)


*Specifications*

In [5]:
N, D = Yp.shape; J = 20 # N = nobs, D = ncols, J = nclust

*Prior Parameters*

In [6]:
a = 0.5; b = 0.5    # strength (inherently unstable)
c = 2.0; d = 2.0    # rate (biased towards 1)
eta, dis = 0.1, 0.1 # PY Strength / Discount Parameters

*Define the Joint Distribution*

In [9]:
def create_model(N, J, D, eta, discount, dtype = np.float64):
    model = tfd.JointDistributionNamed(dict(
        xi = tfd.Independent(
            tfd.Gamma(
                concentration = np.full(D, a, dtype),
                rate = np.full(D, b, dtype),
                ),
            reinterpreted_batch_ndims = 1,
            ),
        tau = tfd.Independent(
            tfd.Gamma(
                concentration = np.full(D, c, dtype),
                rate = np.full(D, d, dtype),
                ),
            reinterpreted_batch_ndims = 1,
            ),
        nu = tfd.Independent(
            tfd.Beta(np.ones(K - 1, dtype) - discount, eta + np.arange(1, K) * discount),
            reinterpreted_batch_ndims = 1,
            ),
        alpha = lambda xi, tau: tfd.Independent(
            tfd.Gamma(
                concentration = np.ones((J, D), dtype) * xi,
                rate = np.ones((J, D), dtype) * tau,
                ),
            reinterpreted_batch_ndims = 2,
            ),        
        obs = lambda alpha, nu: tfd.Sample(tfd.MixtureSameFamily(
            mixture_distribution = tfd.Categorical(probs = stickbreak(nu)),
            components_distribution = ProjectedGamma(alpha, np.ones((K, D), dtype)),
            sample_shape = (N, D),
            )),
        ))
    return(model)

**Variational Parameters**

In [None]:
q_nu_mu    = tf.Variable(tf.random.normal([J],   dtype = np.float64), name = 'q_nu_mu')
q_nu_sd    = tf.Variable(tf.random.normal([J],   dtype = np.tloat64), name = 'q_nu_sd')
q_alpha_mu = tf.Variable(tf.random.normal([J,l], dtype = np.float64), name = 'q_alpha_mu')
q_alpha_sd = tf.Variable(tf.random.normal([J,l], dtype = np.float64), name = 'q_alpha_sd')
q_xi_mu    = tf.Variable(tf.random.normal([l],   dtype = np.float64), name = 'q_xi_mu')
q_xi_sd    = tf.Variable(tf.random.normal([l],   dtype = np.float64), name = 'q_xi_sd')
q_tau_mu   = tf.Variable(tf.random.normal([l],   dtype = np.float64), name = 'q_tau_mu')
q_tau_sd   = tf.variable(tf.random.normal([l],   dtype = np.float64), name = 'q_tau_sd')

**Verifying the structure of the joint model**

In [None]:
model_joint

**Mean Field Variational Bayes -- Independence between columns**
$$
\log\alpha \sim \prod_{\ell = 1}^d\text{Normal}(\log\alpha_{\ell} \mid \mu_{q\ell}, \sigma_{q\ell})
$$

In [None]:
q_mu = tf.Variable(log_alpha_0, dtype = tf.float32)
q_scale = tfp.util.TransformedVariable(np.ones(5), tfb.Exp(), dtype = tf.float32)

surrogate_posterior = tfd.MultivariateNormalDiag(loc = q_mu, scale_diag = q_scale, name = 'surrogate 1')

with tf.GradientTape() as g:
    samples = surrogate_posterior.sample(100)
    neg_elbo = -tf.reduce_mean(model_joint_log_prob(samples) - surrogate_posterior.log_prob(samples))
print(g.gradient(neg_elbo, surrogate_posterior.trainable_variables)) # exists!

In [None]:
path = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn = model_joint_log_prob,
    surrogate_posterior = surrogate_posterior,
    optimizer = tf.optimizers.Adam(.2),
    num_steps = 1000,
    sample_size = 500,
    )

print(tf.exp(q_mu)) # This appears to have worked; the values end in *rougly* the right place.
print(q_scale**2)

**Gaussian Variational Bayes -- Dependence Between Columns**
$$
\log\alpha \sim \text{MVNormal}(\log\alpha \mid \mu_q, \Sigma_q)
$$

In [None]:
# New Style: Make the variational Parameters
q_nu = tf.Variable(tf.zeros(5, dtype = tf.float32), name = 'Mu Surrogate (mean of log alpha)')
cholbijector = tfb.FillScaleTriL(diag_bijector = tfb.Exp())
q_Lu = tfp.util.TransformedVariable(tf.eye(5), bijector = cholbijector)

surrogate_posterior_mvnorm = tfd.MultivariateNormalTriL(loc = q_nu, scale_tril = q_Lu)

with tf.GradientTape() as g:
    samples = surrogate_posterior_mvnorm.sample(100)
    neg_elbo = -tf.reduce_mean(model_joint_log_prob(samples) - surrogate_posterior_mvnorm.log_prob(samples))
print(g.gradient(neg_elbo, surrogate_posterior_mvnorm.trainable_variables)) # Exists!

In [None]:
path_mvnorm = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn = model_joint_log_prob,
    surrogate_posterior = surrogate_posterior_mvnorm,
    optimizer = tf.optimizers.Adam(.2),
    num_steps = 1000,
    sample_size = 500,
    )
print(tf.exp(q_nu)) # This gives the same basic response as previous.

In [None]:
(q_Lu.numpy() @ q_Lu.numpy().T > 0)

In [None]:
(q_Lu.numpy() @ q_Lu.numpy().T)

I guess it makes some sense that the posterior covariance between parameters of the Dirichlet would be positive, despite the covariance between *values* of the Dirichlet being negative.