# Mixed Dirichlet


A mixed Dirichlet random variable $Y$ takes on values in the probability simplex $\Delta_{K-1}$, an assignment $Y=y$ has probability density given by 

\begin{align}
P_{Y}(y|\alpha, w) &= \sum_{f} \mathrm{Gibbs}(f|w) \times \mathrm{Dirichlet}(y|\alpha \odot f)
\end{align}

where $w \in \mathbb R^K$, $\alpha \in \mathbb R^K_{>0}$, $f$ is one of the non-empty faces of the simplex,  by $\alpha \odot f$ we mean the sub-vector of $\alpha$ whose coordinates are associated with the vertices in $f$. 

The distribution over proper faces has probability mass function:
\begin{align}
\mathrm{Gibbs}(f|w) = \frac{\exp(w^\top \phi(f))}{\sum_{f'} \exp(w^\top \phi(f'))}
\end{align}
where $\phi(f) \in \mathbb {-1, 1}^K$ is such that $\phi_k(f) = 1$ if the vertex $\mathbf e_k$ is in the face, and $-1$, otherwise. 


In [None]:
import probabll.distributions as pd

In [None]:
import torch
import torch.distributions as td

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_marginals(samples, bins=100):
    D = samples.shape[-1]
    fig, ax = plt.subplots(D, 1, figsize=(4, 2*D), sharex=True)
    for d in range(D):
        _ = ax[d].hist(samples[...,d].flatten().numpy(), bins=bins, density=True)
    return fig, ax

# Uniform F and Uniform Y|f

In [None]:
p3d = pd.MixedDirichlet(concentration=torch.ones(3), scores=torch.zeros(3))

In [None]:
p3d.sample([10])

In [None]:
p3d.entropy(), p3d.cross_entropy(p3d), td.kl_divergence(p3d, p3d)

In [None]:
_ = plot_marginals(p3d.sample([1000]), bins=100)

In [None]:
_p = p3d.expand([2, 1])
_p.sample().shape

In [None]:
_p.entropy(), _p.cross_entropy(_p), td.kl_divergence(_p, _p)

In [None]:
_p.faces.cross_entropy(_p.faces).shape

# Max-Ent F and Uniform Y|f

In [None]:
pm3d = pd.MixedDirichlet(concentration=torch.ones(3), pmf_n=pd.MaxEntropyFaces.pmf_n(3, 1))

In [None]:
_ = plot_marginals(pm3d.sample(torch.Size([1000])), bins=100)

In [None]:
pm3d.entropy(), pm3d.cross_entropy(pm3d), td.kl_divergence(pm3d, pm3d)

In [None]:
_pm = pm3d.expand([2, 1])
_pm.sample().shape

In [None]:
_pm.entropy(), _pm.cross_entropy(_pm), td.kl_divergence(_pm, _pm)

# VI

In [None]:
p = pd.MixedDirichlet(concentration=torch.ones(5), pmf_n=pd.MaxEntropyFaces.pmf_n(5, 1))
q = pd.MixedDirichlet(concentration=torch.ones(5)/10, scores=torch.zeros(5))

In [None]:
p.batch_shape, p.event_shape

In [None]:
p.sample(torch.Size([10]))

In [None]:
f = p.faces.enumerate_support()

In [None]:
p.faces.log_prob(f).exp(), f.sum(-1)

In [None]:
p.cross_entropy(q)

In [None]:
p.Y(f).cross_entropy(q.Y(f))

In [None]:
p.Y(f).entropy()

In [None]:
p.cross_entropy(q)

In [None]:
td.kl_divergence(p, q)

In [None]:
td.kl_divergence(p.faces, q.faces)