In [ ]:
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from itertools import permutations

# Dirichlet Distribution and Dirichlet Process

## Dirichlet Distribution
Both the DD and the DP can be considered distributions of distributions. The DD can be thought of bag of dice made in the dark ages.
Those dice would all have different production errors and another probability of throwing a certain value. Each sample of the bag, we get a dice which on its turn represent a distribution of 6 discrete values.

## Dirichlet Process
With the DD we define op front how many discrete probabilities are sampled from the DD. In the example above, it is 6 discrete probabilites. The Dirichlet Process is method which allows us to take a sample of infinite discrete values. Note that the values still are discrete!

The internet describes the DP often with anologies like a Chinese Restaurant Process and a Stick Breaking Process. These analogies explain how one could sample from a DP.

# Dirichlet distribution
The Dirichlet distribution is a probability distribution that returns a vector $\theta$ containing probabilities. This can be seen as a distribution vector over $n$ values. 

$ \sum_{i=1}^{k}{\theta_i} = 1 $

The Dirichlet distribution is thus a distribution of distributions (a probability vector).

$ \theta \sim Dir(\alpha)$

The probability density function is 

$f(x) = \frac{1}{B(\alpha)}\prod_{i=1}^{k}{x_i^{\alpha_i-1}}$

where 

$B(\alpha) = \frac{\prod_{i=1}^{k}{\Gamma(\alpha_i)}}
{\Gamma(\sum_{i=1}^{n}{\alpha_i})}$

$\alpha = (\alpha_i,\dots, \alpha_k)$

$\alpha_0 = \sum_{i=1}^{k}{\alpha_i}$

* $\alpha$ = the concentration parameters of the Dirichlet distribution.
* $k$ = the dimension of the probability space of the samples of the distribution.

The probability of a certain event $\theta_i$ is

$\mathbb{E} \theta_i = \frac{\alpha_i}{\alpha_0}$

In [ ]:
for alpha in [(1, 3, 4), (1, 1, 1), (10, 0.2, 0.2), (0.1, 0.1, 0.1)]:
    
    d = stats.dirichlet(alpha)
    theta = d.rvs(100)
    fig = plt.figure()
   
    ax = plt.gca(projection='3d')
    plt.title(f'alpha = {alpha}')
    ax.scatter(theta[:, 0], theta[:, 1], theta[:, 2])
    ax.view_init(azim=30)
    plt.show()

Other properties of a Dirichlet distribution.

$ G \sim Dir(\alpha G_0)$ 

where 

$\alpha$ is a scaling factor of a discrete probability distribution $G_0$
By scaling $G_0$ we change the variance of the samples drawn from $Dir(\alpha G_0)$.

$\mathbb{E} G = G_0$

In [ ]:
# Set probability distribution G0
G0 = np.array([5, 5, 8], dtype=np.float32)

# A probability distribution should sum to one.
G0 /= G0.sum()

print('G0 = ', G0)

scale = [.1, 1, 10, 100, 1000]

for s in scale:
    print('\nscale:', s)
    theta = stats.dirichlet(alpha=s * G0).rvs(10000)
    print('elementwise mean: {}'.format(theta.mean(axis=0).round(3)))
    print('elementwise sd: {}'.format(theta.std(axis=0).round(3)))

# Dirichlet process

$H \sim DP(\alpha H0) $

A sample $H$ (probability distribution) is created from a Dirichlet process by drawing a infinite number of samples $\theta_k$ from $H_0$:

$H = \sum_{k=1}^{\infty}{\pi_k \cdot \delta(\theta_k)}$

where
* $\pi_k$ are chosen weights that sum to 1.
* $\delta$ is the Dirac delta function (indicator function).
* $\theta_k$ are samples of $H_0$

$\pi_k = \pi'_k \cdot \prod_{i=i}^{k-1}{1 - \pi'_i}$

$where \pi' \sim Beta(1, \alpha)$ 

$H$ is a discrete distribution that takes the value $\theta_k$ with probability $\pi_k$.

In [ ]:
def dirichlet_process(h_0, alpha, n):
    """
    Truncated dirichlet process.
    :param h_0: (scipy distribution)
    :param alpha: (flt)
    :param n: (int) Truncate value.
    """
    pi = stats.beta(1, alpha).rvs(size=n)
    pi[1:] = pi[1:] * (1 - pi[:-1]).cumprod()
    theta = h_0(size=n)
    i = np.argmin(np.abs(np.cumsum(pi) - 1))
    return pi[:i], theta[:i]
    
    
def plot_normal_dp_approximation(alpha):
    
    pis, thetas = dirichlet_process(stats.norm.rvs, alpha, 5000)
    x = np.linspace(-4, 4, 100)
    
    plt.figure(figsize=(14, 4))
    plt.title('Dirichlet Process Sample with N(0,1)')
    plt.suptitle('alpha = {}'.format(alpha))
    plt.subplot(121)
    plt.vlines(thetas, 0, pis)
    plt.plot(x, stats.norm.pdf(x))
    
    plt.subplot(122)
    pis = pis * (stats.norm.pdf(0) / pis.max())
    plt.vlines(thetas, 0, pis)
    plt.ylim(0, 1)
    plt.plot(x, stats.norm.pdf(x))


for alpha in [.1, 1, 10, 1000]:
    plot_normal_dp_approximation(alpha)
        