# Dirichlet Distribution (`dirichlet`) — Modeling Random Probability Vectors

The **Dirichlet distribution** is the canonical distribution over **probability vectors**:
random vectors \(X = (X_1,\dots,X_K)\) with \(X_i \ge 0\) and \(\sum_i X_i = 1\).

It appears everywhere you want to express uncertainty about categorical probabilities:
Bayesian smoothing of histograms, mixture-model weights, topic proportions, and more.

**Goals**
- Understand what the Dirichlet models (the probability simplex).
- Work with the PDF, moments, and key properties.
- Interpret parameters and see how the shape changes.
- Sample from a Dirichlet using a **NumPy-only** algorithm.
- Use `scipy.stats.dirichlet` for evaluation/simulation and fit parameters via MLE.

**Prerequisites**
- Gamma/Beta functions and basic multivariate calculus.
- Familiarity with Bayesian updating for categorical/multinomial data.
- NumPy + basic plotting.


## Notebook roadmap

1. Title & classification
2. Intuition & motivation
3. Formal definition (PDF + CDF notes)
4. Moments & properties
5. Parameter interpretation + shape changes
6. Derivations (expectation, variance, likelihood)
7. Sampling & simulation (NumPy-only)
8. Visualization (PDF, CDF, Monte Carlo)
9. SciPy integration
10. Statistical use cases
11. Pitfalls
12. Summary


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

from scipy import stats
from scipy.special import gammaln, psi
from scipy.optimize import minimize

np.set_printoptions(precision=4, suppress=True)

# Reproducibility
rng = np.random.default_rng(0)

# Quick/slow toggle (mirrors patterns used elsewhere in this repo)
FAST_RUN = True
N_SAMPLES = 30_000 if FAST_RUN else 300_000
GRID_N = 35 if FAST_RUN else 90

plt.rcParams.update({
    "figure.figsize": (9, 4.5),
    "axes.grid": True,
})


In [None]:
def validate_alpha(alpha: np.ndarray) -> np.ndarray:
    '''Validate Dirichlet parameters (all strictly positive).'''
    alpha = np.asarray(alpha, dtype=float)
    if alpha.ndim != 1:
        raise ValueError(f"alpha must be 1D, got shape={alpha.shape}")
    if alpha.size < 2:
        raise ValueError("Dirichlet requires K>=2 parameters")
    if np.any(alpha <= 0):
        raise ValueError("All alpha_i must be > 0")
    return alpha


def validate_simplex(X: np.ndarray, *, atol: float = 1e-10, allow_zeros: bool = True) -> np.ndarray:
    '''Validate points on the probability simplex (rows sum to 1).'''
    X = np.asarray(X, dtype=float)
    if X.ndim == 1:
        X = X[None, :]
    if X.ndim != 2:
        raise ValueError(f"X must be 1D or 2D, got shape={X.shape}")
    if np.any(X < 0):
        raise ValueError("Simplex components must be >= 0")
    if not allow_zeros and np.any(X <= 0):
        raise ValueError("All simplex components must be > 0 for log-likelihood computations")
    row_sums = X.sum(axis=1)
    if not np.allclose(row_sums, 1.0, atol=atol):
        raise ValueError("Each row of X must sum to 1")
    return X


def dirichlet_logpdf_numpy(X: np.ndarray, alpha: np.ndarray) -> np.ndarray:
    '''Dirichlet log-PDF implemented with NumPy + SciPy special functions.

    Notes:
    - Supports X as shape (K,) or (N, K).
    - Works on the boundary too (x_i=0), returning +/-inf where appropriate.
    '''
    alpha = validate_alpha(alpha)
    X = validate_simplex(X, allow_zeros=True)

    log_norm = gammaln(alpha.sum()) - gammaln(alpha).sum()
    with np.errstate(divide="ignore", invalid="ignore"):
        log_x = np.log(X)
    return log_norm + ((alpha - 1.0) * log_x).sum(axis=1)


def sample_dirichlet_numpy(alpha: np.ndarray, *, size: int, rng: np.random.Generator) -> np.ndarray:
    '''NumPy-only sampler via Gamma normalization.'''
    alpha = validate_alpha(alpha)
    y = rng.gamma(shape=alpha, scale=1.0, size=(size, alpha.size))
    return y / y.sum(axis=1, keepdims=True)


def dirichlet_mean(alpha: np.ndarray) -> np.ndarray:
    alpha = validate_alpha(alpha)
    return alpha / alpha.sum()


def dirichlet_cov(alpha: np.ndarray) -> np.ndarray:
    alpha = validate_alpha(alpha)
    a0 = alpha.sum()
    cov = -np.outer(alpha, alpha) / (a0**2 * (a0 + 1.0))
    np.fill_diagonal(cov, alpha * (a0 - alpha) / (a0**2 * (a0 + 1.0)))
    return cov


def beta_skewness(a: float, b: float) -> float:
    # Standardized third central moment of Beta(a, b)
    return 2 * (b - a) * np.sqrt(a + b + 1) / ((a + b + 2) * np.sqrt(a * b))


def beta_excess_kurtosis(a: float, b: float) -> float:
    # Excess kurtosis (kurtosis minus 3) of Beta(a, b)
    num = 6 * ((a - b) ** 2 * (a + b + 1) - a * b * (a + b + 2))
    den = a * b * (a + b + 2) * (a + b + 3)
    return num / den


def dirichlet_entropy(alpha: np.ndarray) -> float:
    '''Differential entropy H(X) in nats.'''
    alpha = validate_alpha(alpha)
    a0 = alpha.sum()
    k = alpha.size
    log_B = gammaln(alpha).sum() - gammaln(a0)
    return log_B + (a0 - k) * psi(a0) - ((alpha - 1.0) * psi(alpha)).sum()


SQRT3 = float(np.sqrt(3))


def simplex3_grid(n: int, *, min_component: float = 0.0) -> np.ndarray:
    '''Grid of points on the 2-simplex for K=3.

    Returns an array of shape (M, 3) with entries in [0,1] summing to 1.
    '''
    pts = []
    for i in range(n + 1):
        for j in range(n + 1 - i):
            k = n - i - j
            x = np.array([i, j, k], dtype=float) / n
            if x.min() < min_component:
                continue
            pts.append(x)
    return np.vstack(pts)


def simplex3_to_xy(X: np.ndarray) -> np.ndarray:
    '''Map (x1,x2,x3) on the simplex to 2D coordinates inside an equilateral triangle.'''
    X = np.asarray(X, dtype=float)
    if X.ndim == 1:
        X = X[None, :]
    # vertices: e1=(0,0), e2=(1,0), e3=(1/2, sqrt(3)/2)
    x = X[:, 1] + 0.5 * X[:, 2]
    y = (SQRT3 / 2.0) * X[:, 2]
    return np.column_stack([x, y])


def plot_simplex3_outline(ax: plt.Axes) -> None:
    tri = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, SQRT3 / 2.0], [0.0, 0.0]])
    ax.plot(tri[:, 0], tri[:, 1], color="black", lw=1.2)
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)


def plot_dirichlet_simplex3(alpha: np.ndarray, *, ax: plt.Axes, grid_n: int, min_component: float) -> plt.cm.ScalarMappable:
    pts = simplex3_grid(grid_n, min_component=min_component)
    xy = simplex3_to_xy(pts)
    tri = mtri.Triangulation(xy[:, 0], xy[:, 1])
    logpdf = dirichlet_logpdf_numpy(pts, alpha)
    contour = ax.tricontourf(tri, logpdf, levels=35, cmap="viridis")
    plot_simplex3_outline(ax)
    return contour


def dirichlet_fit_mle(X: np.ndarray, *, alpha_init: np.ndarray | None = None) -> tuple[np.ndarray, object]:
    '''Fit alpha by maximum likelihood.

    SciPy's `scipy.stats.dirichlet` does not provide `.fit()` (as of SciPy 1.15),
    so we optimize the log-likelihood ourselves.

    We optimize over theta = log(alpha) to enforce alpha_i > 0.
    '''
    X = validate_simplex(X, allow_zeros=False)
    n, k = X.shape
    sum_log_x = np.log(X).sum(axis=0)

    if alpha_init is None:
        # Method-of-moments-inspired initialization
        mean = X.mean(axis=0)
        var = X.var(axis=0, ddof=1)
        eps = 1e-8
        a0_est = np.mean(mean * (1.0 - mean) / np.maximum(var, eps) - 1.0)
        a0_est = float(np.clip(a0_est, 1e-2, 1e6))
        alpha_init = np.clip(mean * a0_est, 1e-2, None)

    def nll(theta: np.ndarray) -> float:
        alpha = np.exp(theta)
        a0 = alpha.sum()
        ll = n * (gammaln(a0) - gammaln(alpha).sum()) + ((alpha - 1.0) * sum_log_x).sum()
        return -float(ll)

    def grad(theta: np.ndarray) -> np.ndarray:
        alpha = np.exp(theta)
        a0 = alpha.sum()
        grad_alpha = -(n * (psi(a0) - psi(alpha)) + sum_log_x)
        return grad_alpha * alpha

    res = minimize(nll, x0=np.log(alpha_init), jac=grad, method="L-BFGS-B")
    alpha_hat = np.exp(res.x)
    return alpha_hat, res


## 1) Title & Classification

- **Name**: `dirichlet`
- **Type**: **continuous** (a \((K-1)\)-dimensional distribution embedded in \(\mathbb{R}^K\))
- **Support** (the probability simplex):

\[
\Delta^{K-1} = \left\{x \in \mathbb{R}^K : x_i \ge 0, \; \sum_{i=1}^K x_i = 1\right\}
\]

- **Parameter space**:

\[
\alpha = (\alpha_1,\dots,\alpha_K), \qquad \alpha_i > 0
\]

Useful shorthand:

\[
\alpha_0 = \sum_{i=1}^K \alpha_i \quad \text{(total concentration)}
\]

Interpretation preview:
- The mean is \(\mathbb{E}[X_i] = \alpha_i / \alpha_0\).
- \(\alpha_0\) controls how *concentrated* draws are around the mean.


## 2) Intuition & Motivation

### What it models
A Dirichlet random vector is a **random categorical probability vector**.
For \(K\) categories, a draw \(X\) can be used as parameters of a categorical/multinomial distribution:

\[
X \sim \text{Dirichlet}(\alpha)\quad \Rightarrow \quad Y \mid X \sim \text{Multinomial}(n, X)
\]

### Typical real-world use cases
- **Bayesian smoothing** of empirical category frequencies (avoids zero probabilities).
- **Mixture models**: prior over mixture weights (e.g., Gaussian mixture weights).
- **Topic models** (LDA): prior over per-document topic proportions.
- **Compositional data** (parts of a whole): proportions of time, budget, species, etc.

### Relations to other distributions
- **Beta distribution**: \(K=2\) gives \(X_1 \sim \text{Beta}(\alpha_1, \alpha_2)\) and \(X_2 = 1 - X_1\).
- **Gamma normalization**: if \(Y_i \sim \text{Gamma}(\alpha_i, 1)\) i.i.d., then
  \(X_i = Y_i / \sum_j Y_j\) is Dirichlet.
- **Conjugacy**: Dirichlet is conjugate to categorical/multinomial likelihoods.
- **Dirichlet–multinomial**: integrating out \(X\) yields an overdispersed count model.


## 3) Formal Definition

### PDF
For \(x \in \Delta^{K-1}\) and \(\alpha_i > 0\):

\[
 f(x;\alpha) = \frac{1}{B(\alpha)} \prod_{i=1}^K x_i^{\alpha_i - 1}
\]

where the multivariate Beta function \(B(\alpha)\) is

\[
B(\alpha) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma(\alpha_0)},
\qquad \alpha_0 = \sum_{i=1}^K \alpha_i.
\]

A numerically stable log-form is

\[
\log f(x;\alpha) = \log \Gamma(\alpha_0) - \sum_i \log \Gamma(\alpha_i) + \sum_i (\alpha_i - 1)\log x_i.
\]

### CDF
A multivariate CDF can be defined in \((K-1)\) free coordinates, e.g. \(x_K = 1 - \sum_{i=1}^{K-1} x_i\):

\[
F(x_1,\dots,x_{K-1}) = \mathbb{P}(X_1 \le x_1, \dots, X_{K-1} \le x_{K-1}).
\]

For \(K>2\), there is **no simple closed form in general**.
A key special case is \(K=2\), where Dirichlet reduces to **Beta**, and the CDF is the regularized incomplete beta function.

In practice, common workarounds are:
- use **marginal CDFs** (each \(X_i\) marginal is Beta), or
- estimate multivariate probabilities by **Monte Carlo**.


In [None]:
# Quick consistency check: our logpdf vs SciPy
alpha = np.array([2.5, 1.2, 3.0])
x = np.array([0.2, 0.5, 0.3])

our = dirichlet_logpdf_numpy(x, alpha)[0]
scipy = stats.dirichlet(alpha).logpdf(x)

our, scipy, float(our - scipy)


## 4) Moments & Properties

Let \(X \sim \text{Dirichlet}(\alpha)\) with \(\alpha_0 = \sum_i \alpha_i\).

### Mean
\[
\mathbb{E}[X_i] = \frac{\alpha_i}{\alpha_0}.
\]

### Variance and covariance
\[
\mathrm{Var}(X_i) = \frac{\alpha_i(\alpha_0 - \alpha_i)}{\alpha_0^2(\alpha_0+1)}
= \frac{m_i(1-m_i)}{\alpha_0+1},\quad m_i=\alpha_i/\alpha_0
\]

\[
\mathrm{Cov}(X_i, X_j) = -\frac{\alpha_i\alpha_j}{\alpha_0^2(\alpha_0+1)}\quad (i\ne j).
\]

### Marginals
Each component has a Beta marginal:

\[
X_i \sim \mathrm{Beta}(\alpha_i,\; \alpha_0 - \alpha_i).
\]

This is very useful: you can get **univariate PDFs/CDFs** and skewness/kurtosis per component.

### Skewness and kurtosis (component-wise)
For \(X_i\), with \(a=\alpha_i\) and \(b=\alpha_0-\alpha_i\):

\[
\mathrm{skew}(X_i) = \frac{2(b-a)\sqrt{a+b+1}}{(a+b+2)\sqrt{ab}}
\]

\[
\mathrm{excess\ kurt}(X_i) = \frac{6\big((a-b)^2(a+b+1)-ab(a+b+2)\big)}{ab(a+b+2)(a+b+3)}.
\]

### Other useful properties
- **Mode** (if all \(\alpha_i>1\)):
  \(\;\mathrm{mode}_i = \frac{\alpha_i-1}{\alpha_0-K}\).
- **Conjugacy** (multinomial counts \(c_i\)):
  \(\alpha \mapsto \alpha + c\).
- **Additivity**: merging categories keeps Dirichlet form; e.g. \((X_1+X_2, X_3,\dots)\) is Dirichlet with \((\alpha_1+\alpha_2, \alpha_3,\dots)\).

### MGF / characteristic function
They exist (support is bounded), but are not usually expressed in elementary functions for general \(K\).
They can be written using multivariate hypergeometric functions (e.g. Lauricella functions).

### Entropy
Differential entropy (nats):

\[
H(X) = \log B(\alpha) + (\alpha_0-K)\,\psi(\alpha_0) - \sum_i (\alpha_i-1)\,\psi(\alpha_i)
\]

where \(\psi\) is the digamma function.


In [None]:
alpha = np.array([2.0, 3.0, 5.0])
rv = stats.dirichlet(alpha)

samples = sample_dirichlet_numpy(alpha, size=N_SAMPLES, rng=rng)

mean_theory = dirichlet_mean(alpha)
mean_mc = samples.mean(axis=0)

cov_theory = dirichlet_cov(alpha)
cov_mc = np.cov(samples, rowvar=False)

# Component-wise skewness/kurtosis via Beta marginals
alpha0 = alpha.sum()
skew_theory = np.array([beta_skewness(a, alpha0 - a) for a in alpha])
exkurt_theory = np.array([beta_excess_kurtosis(a, alpha0 - a) for a in alpha])

skew_mc = stats.skew(samples, axis=0, bias=False)
exkurt_mc = stats.kurtosis(samples, axis=0, fisher=True, bias=False)

print('Mean (theory):', mean_theory)
print('Mean (MC):    ', mean_mc)

print()
print('Cov (theory):')
print(cov_theory)

print()
print('Cov (MC):')
print(cov_mc)

print()
print('Skewness (theory, per component):', skew_theory)
print('Skewness (MC):                   ', skew_mc)

print()
print('Excess kurtosis (theory, per component):', exkurt_theory)
print('Excess kurtosis (MC):                   ', exkurt_mc)

print()
print('Entropy (theory):', dirichlet_entropy(alpha))
print('Entropy (SciPy): ', rv.entropy())


## 5) Parameter Interpretation

A very useful reparameterization is:

\[
\alpha = \alpha_0\, m, \qquad m \in \Delta^{K-1}, \quad \alpha_0 > 0.
\]

- \(m\) is the **mean direction** (since \(\mathbb{E}[X]=m\)).
- \(\alpha_0\) is the **concentration** (larger means less variance).

Heuristics (for symmetric \(\alpha_1=\cdots=\alpha_K=a\)):
- \(a=1\): uniform over the simplex.
- \(a>1\): mass near the center (balanced proportions).
- \(0<a<1\): mass near corners/edges (sparse proportions).

For asymmetric \(\alpha\), the mean shifts toward larger \(\alpha_i\) components.


In [None]:
# Shape changes on the 2-simplex (K=3)

alphas = [
    np.array([1.0, 1.0, 1.0]),     # uniform
    np.array([5.0, 5.0, 5.0]),     # concentrated around center
    np.array([0.35, 0.35, 0.35]),  # corners/edges (divergent at boundaries)
    np.array([2.0, 6.0, 1.2]),     # asymmetric
]

fig, axes = plt.subplots(2, 2, figsize=(10, 8), constrained_layout=True)
axes = axes.ravel()

mappables = []
for ax, a in zip(axes, alphas):
    min_comp = 1.0 / GRID_N
    m = plot_dirichlet_simplex3(a, ax=ax, grid_n=GRID_N, min_component=min_comp)
    ax.set_title(f"alpha={a}")
    mappables.append(m)

# One shared colorbar (log-density scale)
cb = fig.colorbar(mappables[0], ax=axes, shrink=0.85)
cb.set_label("log PDF")

plt.show()


## 6) Derivations

### Expectation (sketch)
Start from the definition:

\[
\mathbb{E}[X_i] = \int_{\Delta^{K-1}} x_i\, f(x;\alpha)\, dx.
\]

Using the normalization constant, notice that multiplying by \(x_i\) increases the exponent of \(x_i\) by 1:

\[
\mathbb{E}[X_i] = \frac{B(\alpha + e_i)}{B(\alpha)}
\]

where \(e_i\) is the unit vector.
Since

\[
\frac{B(\alpha + e_i)}{B(\alpha)} = \frac{\alpha_i}{\alpha_0},
\]

we get \(\mathbb{E}[X_i]=\alpha_i/\alpha_0\).

### Variance / covariance (sketch)
Similarly, for second moments:

\[
\mathbb{E}[X_i^2] = \frac{B(\alpha + 2e_i)}{B(\alpha)} = \frac{\alpha_i(\alpha_i+1)}{\alpha_0(\alpha_0+1)}
\]

\[
\mathbb{E}[X_iX_j] = \frac{B(\alpha + e_i + e_j)}{B(\alpha)} = \frac{\alpha_i\alpha_j}{\alpha_0(\alpha_0+1)}\quad (i\ne j)
\]

Combine these with \(\mathrm{Var}(X_i)=\mathbb{E}[X_i^2]-\mathbb{E}[X_i]^2\) to get the usual formulas.

### Likelihood (iid observations)
Given samples \(x^{(1)},\dots,x^{(N)}\) on the simplex:

\[
\ell(\alpha) = \sum_{n=1}^N \log f(x^{(n)};\alpha)
= N\Big(\log\Gamma(\alpha_0) - \sum_i \log\Gamma(\alpha_i)\Big) + \sum_i (\alpha_i-1)\sum_{n=1}^N \log x_i^{(n)}.
\]

The gradient is

\[
\frac{\partial \ell}{\partial \alpha_i} = N\big(\psi(\alpha_0) - \psi(\alpha_i)\big) + \sum_{n=1}^N \log x_i^{(n)}.
\]

This is the basis for MLE algorithms.


In [None]:
# MLE demo on synthetic data
alpha_true = np.array([1.8, 3.2, 5.0])
X = sample_dirichlet_numpy(alpha_true, size=8_000 if FAST_RUN else 40_000, rng=rng)

alpha_hat, opt = dirichlet_fit_mle(X)

print('alpha_true:', alpha_true)
print('alpha_hat: ', np.round(alpha_hat, 4))
print('optimizer success:', opt.success)
print('final nll:', opt.fun)


## 7) Sampling & Simulation (NumPy-only)

A standard and efficient sampling method uses the **Gamma normalization** representation.

**Algorithm** for \(X \sim \mathrm{Dirichlet}(\alpha)\):
1. For each component \(i\), draw \(Y_i \sim \mathrm{Gamma}(\alpha_i, 1)\) independently.
2. Normalize: \(X_i = \frac{Y_i}{\sum_j Y_j}\).

Why it works (high-level idea):
- The joint density of independent Gammas factorizes nicely.
- The change of variables from \((Y_1,\dots,Y_K)\) to \((X_1,\dots,X_{K-1}, S)\) with \(S=\sum_i Y_i\)
  yields a Jacobian of \(S^{K-1}\).
- After integrating out \(S\), the remaining density over \(X\) matches the Dirichlet PDF.

This approach is the practical default in most libraries.


In [None]:
alpha = np.array([0.7, 1.5, 2.2])
X = sample_dirichlet_numpy(alpha, size=5, rng=rng)

print('samples:')
print(X)
print()
print('row sums:', X.sum(axis=1))


## 8) Visualization

Because the Dirichlet is multivariate, the cleanest visualization is for \(K=3\):
a density over a **triangle** (the 2-simplex).

We also visualize a **univariate CDF** via a marginal \(X_i\), since each component is Beta-distributed.


In [None]:
alpha = np.array([2.0, 3.0, 5.0])
rv = stats.dirichlet(alpha)

# 1) PDF over the 2-simplex (K=3)
fig, axes = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)

m = plot_dirichlet_simplex3(alpha, ax=axes[0], grid_n=GRID_N, min_component=1.0 / GRID_N)
axes[0].set_title('Dirichlet log PDF on simplex (K=3)')
cb = fig.colorbar(m, ax=axes[0], fraction=0.046, pad=0.04)
cb.set_label('log PDF')

# 2) Monte Carlo samples
samps = rv.rvs(size=4_000 if FAST_RUN else 20_000, random_state=rng)
xy = simplex3_to_xy(samps)
axes[1].scatter(xy[:, 0], xy[:, 1], s=6, alpha=0.35)
plot_simplex3_outline(axes[1])
axes[1].set_title('Monte Carlo samples')

# 3) Marginal PDF and CDF (component 1)
a = alpha[0]
b = alpha.sum() - alpha[0]
xs = np.linspace(0, 1, 400)

axes[2].plot(xs, stats.beta(a, b).pdf(xs), label='marginal PDF (Beta)')
axes[2].set_xlabel('x')
axes[2].set_ylabel('pdf')

ax2 = axes[2].twinx()
ax2.plot(xs, stats.beta(a, b).cdf(xs), color='tab:orange', label='marginal CDF (Beta)')
ax2.set_ylabel('cdf')

lines1, labels1 = axes[2].get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
axes[2].legend(lines1 + lines2, labels1 + labels2, loc='center right')
axes[2].set_title('Univariate marginal (X1)')

plt.show()


## 9) SciPy Integration (`scipy.stats.dirichlet`)

SciPy provides a frozen Dirichlet distribution via:

```python
rv = scipy.stats.dirichlet(alpha)
```

Available methods (SciPy 1.15):
- `pdf`, `logpdf`
- `rvs`
- `mean`, `var`, `entropy`

**Notably missing**:
- `cdf` (multivariate CDF is nontrivial)
- `fit` (no built-in MLE)

Workarounds:
- For CDF-like quantities, use **marginal Beta CDFs** or Monte Carlo.
- For fitting, optimize the log-likelihood (we implemented `dirichlet_fit_mle`).


In [None]:
alpha = np.array([1.4, 2.1, 3.7])
rv = stats.dirichlet(alpha)

x = np.array([0.25, 0.5, 0.25])

print('pdf:', rv.pdf(x))
print('logpdf:', rv.logpdf(x))
print('mean:', rv.mean())
print('var:', rv.var())
print('entropy:', rv.entropy())

# Fit alpha on synthetic samples
X = rv.rvs(size=6_000 if FAST_RUN else 30_000, random_state=rng)
alpha_hat, opt = dirichlet_fit_mle(X)
print()
print('alpha true:', alpha)
print('alpha hat :', np.round(alpha_hat, 4))


## 10) Statistical Use Cases

### A) Hypothesis testing (Bayesian style)
With a Dirichlet posterior over probabilities \(p\), you can ask questions like:

\[
\mathbb{P}(p_1 > p_2 \mid \text{data})
\]

and decide using a threshold (e.g. > 0.95).

### B) Bayesian modeling (conjugate updating)
For multinomial counts \(c\) and prior \(\alpha\):

\[
\alpha_{\text{post}} = \alpha + c.
\]

This gives closed-form posterior mean/variance and easy simulation.

### C) Generative modeling
Dirichlet priors are commonly used for:
- **mixture weights** (e.g. Dirichlet prior on component proportions),
- **topic proportions** (LDA),
- any model where a latent probability simplex vector is required.

A simple generative story:
1. Sample probabilities \(p \sim \text{Dirichlet}(\alpha)\)
2. Sample data \(y \mid p \sim \text{Multinomial}(n, p)\)


In [None]:
# Example: Bayesian updating + a posterior probability "test"

prior_alpha = np.array([1.0, 1.0, 1.0])  # uniform prior on simplex
counts = np.array([12, 5, 3])
posterior_alpha = prior_alpha + counts

rv_post = stats.dirichlet(posterior_alpha)

print('prior alpha:    ', prior_alpha)
print('observed counts:', counts)
print('posterior alpha:', posterior_alpha)

print()
print('posterior mean:', np.round(rv_post.mean(), 4))
print('posterior var :', np.round(rv_post.var(), 6))

# Posterior probability that category 1 is more likely than category 2
post_samples = rv_post.rvs(size=50_000 if FAST_RUN else 250_000, random_state=rng)
p_gt = (post_samples[:, 0] > post_samples[:, 1]).mean()
print()
print('P(p1 > p2 | data) ≈', round(float(p_gt), 4))

# Posterior predictive simulation for 20 future trials
n_future = 20
p_draw = rv_post.rvs(size=1, random_state=rng)[0]
future_counts = rng.multinomial(n_future, p_draw)
print()
print('p_draw (one posterior draw):', np.round(p_draw, 4))
print('one predictive sample (n=20):', future_counts)


## 11) Pitfalls

- **Invalid parameters**: all \(\alpha_i\) must be strictly positive.
- **Zeros in data**: real datasets often contain exact zeros in proportions; Dirichlet assigns probability 0 to exact zeros (continuous support), and the MLE log-likelihood uses \(\log x_i\).
  Common fixes include adding small pseudocounts / smoothing, or using models designed for zero inflation.
- **Numerical issues near boundaries**:
  - when \(\alpha_i < 1\), the density diverges as \(x_i \to 0\);
  - evaluate in **log space** (`logpdf`) and avoid plotting exactly at \(x_i=0\) for such parameters.
- **Interpreting scale**: scaling all \(\alpha\) by a constant keeps the mean the same but changes concentration.
- **Fitting can be tricky**: MLE is well-defined but may require good initialization and care with constraints.


## 12) Summary

- Dirichlet is a **continuous distribution on the probability simplex**, modeling random categorical probabilities.
- Parameters \(lpha\) encode both a **mean direction** \(lpha/lpha_0\) and a **concentration** \(lpha_0\).
- Each component has a **Beta marginal**, which provides easy univariate PDFs/CDFs and skewness/kurtosis.
- Sampling is simple and efficient via **Gamma normalization**.
- `scipy.stats.dirichlet` supports `pdf/logpdf/rvs/mean/var/entropy`; multivariate `cdf` and `fit` are not provided, but MLE fitting is straightforward via likelihood optimization.
