# Dirichlet Distribution: Multivariate Beta

The **Dirichlet distribution** is the multivariate generalization of the Beta distribution. While Beta models a single proportion in [0,1], Dirichlet models a **probability vector** that sums to 1.

## Why This Matters for Causal ML

1. **Cell type proportions** — Model composition of cell populations
2. **Treatment assignment probabilities** — Prior for multinomial treatments
3. **Heterogeneous treatment effects** — Discover patient subgroups via Dirichlet Process mixtures
4. **Topic models / latent states** — LDA-style models for gene programs

## Table of Contents

1. [From Beta to Dirichlet](#1-from-beta-to-dirichlet)
2. [Visualization](#2-visualization)
3. [The Concentration Parameter](#3-the-concentration-parameter)
4. [Dirichlet-Multinomial Conjugacy](#4-dirichlet-multinomial-conjugacy)
5. [Applications in Biology](#5-applications-in-biology)
6. [Connection to Dirichlet Process](#6-connection-to-dirichlet-process)
7. [Quick Reference](#7-quick-reference)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.tri as tri

# Style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 1. From Beta to Dirichlet

### Beta Distribution (K=2)

The Beta distribution models a single probability $\theta \in [0, 1]$:

$$\theta \sim \text{Beta}(\alpha, \beta)$$

We can think of this as modeling a 2-category probability vector $(\theta, 1-\theta)$.

### Dirichlet Distribution (K≥2)

The Dirichlet distribution generalizes to K categories:

$$\boldsymbol{\theta} = (\theta_1, \theta_2, \ldots, \theta_K) \sim \text{Dirichlet}(\boldsymbol{\alpha})$$

Where:
- $\boldsymbol{\alpha} = (\alpha_1, \alpha_2, \ldots, \alpha_K)$ are concentration parameters
- $\theta_k \geq 0$ for all k
- $\sum_{k=1}^K \theta_k = 1$ (probability simplex)

### PDF

$$p(\boldsymbol{\theta} | \boldsymbol{\alpha}) = \frac{\Gamma(\sum_k \alpha_k)}{\prod_k \Gamma(\alpha_k)} \prod_{k=1}^K \theta_k^{\alpha_k - 1}$$

In [None]:
# Beta is Dirichlet with K=2
np.random.seed(42)

alpha, beta_param = 2, 5

# Sample from Beta
beta_samples = np.random.beta(alpha, beta_param, 5000)

# Sample from Dirichlet with K=2
dirichlet_samples = np.random.dirichlet([alpha, beta_param], 5000)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Beta samples
axes[0].hist(beta_samples, bins=50, density=True, alpha=0.7, edgecolor='black')
x = np.linspace(0, 1, 200)
axes[0].plot(x, stats.beta.pdf(x, alpha, beta_param), 'r-', lw=2, label='Beta PDF')
axes[0].set_xlabel('θ')
axes[0].set_ylabel('Density')
axes[0].set_title(f'Beta({alpha}, {beta_param})')
axes[0].legend()

# Dirichlet samples (first component)
axes[1].hist(dirichlet_samples[:, 0], bins=50, density=True, alpha=0.7, edgecolor='black')
axes[1].plot(x, stats.beta.pdf(x, alpha, beta_param), 'r-', lw=2, label='Beta PDF')
axes[1].set_xlabel('θ₁')
axes[1].set_ylabel('Density')
axes[1].set_title(f'Dirichlet([{alpha}, {beta_param}]) - First component')
axes[1].legend()

plt.suptitle('Beta is Dirichlet with K=2', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Beta samples:      mean = {beta_samples.mean():.4f}")
print(f"Dirichlet θ₁:      mean = {dirichlet_samples[:, 0].mean():.4f}")
print(f"Theory:            mean = {alpha / (alpha + beta_param):.4f}")

## 2. Visualization

For K=3, the Dirichlet lives on a 2D simplex (triangle). Let's visualize different parameter settings.

In [None]:
def plot_dirichlet_simplex(alpha, ax, n_samples=5000, title=None):
    """
    Plot Dirichlet samples on a 2D simplex (ternary plot).
    
    For K=3, we project the 3D simplex onto 2D using:
    x = θ₂ + θ₃/2
    y = θ₃ * sqrt(3)/2
    """
    # Sample from Dirichlet
    samples = np.random.dirichlet(alpha, n_samples)
    
    # Project to 2D (barycentric to Cartesian)
    # Vertices of equilateral triangle: (0,0), (1,0), (0.5, sqrt(3)/2)
    x = samples[:, 1] + samples[:, 2] / 2
    y = samples[:, 2] * np.sqrt(3) / 2
    
    # Plot
    ax.scatter(x, y, alpha=0.3, s=5, c='steelblue')
    
    # Draw simplex boundary
    triangle = plt.Polygon([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]], 
                           fill=False, edgecolor='black', linewidth=2)
    ax.add_patch(triangle)
    
    # Label vertices
    ax.text(-0.05, -0.05, 'θ₁=1', fontsize=10, ha='center')
    ax.text(1.05, -0.05, 'θ₂=1', fontsize=10, ha='center')
    ax.text(0.5, np.sqrt(3)/2 + 0.05, 'θ₃=1', fontsize=10, ha='center')
    
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.0)
    ax.set_aspect('equal')
    ax.axis('off')
    
    if title:
        ax.set_title(title, fontsize=12)
    else:
        ax.set_title(f'Dirichlet({alpha})', fontsize=12)

# Different parameter settings
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

params = [
    ([1, 1, 1], 'Uniform (α = [1,1,1])'),
    ([5, 5, 5], 'Symmetric concentrated (α = [5,5,5])'),
    ([0.5, 0.5, 0.5], 'Sparse (α = [0.5,0.5,0.5])'),
    ([10, 1, 1], 'Favor θ₁ (α = [10,1,1])'),
    ([1, 10, 1], 'Favor θ₂ (α = [1,10,1])'),
    ([2, 3, 5], 'Asymmetric (α = [2,3,5])'),
]

np.random.seed(42)
for ax, (alpha, title) in zip(axes.flat, params):
    plot_dirichlet_simplex(alpha, ax, title=title)

plt.suptitle('Dirichlet Distribution on 2-Simplex (K=3)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### Key Observations

| Parameters | Behavior |
|------------|----------|
| α = [1,1,1] | Uniform over simplex |
| α = [5,5,5] | Concentrated at center (equal proportions) |
| α = [0.5,0.5,0.5] | Sparse — samples near vertices/edges |
| α = [10,1,1] | Concentrated near θ₁=1 vertex |
| Asymmetric α | Concentrated toward higher-α components |

## 3. The Concentration Parameter

Just like Beta, Dirichlet has a **concentration** interpretation:

$$\alpha_0 = \sum_{k=1}^K \alpha_k \quad \text{(total concentration)}$$

We can decompose:
- **Base measure**: $\boldsymbol{\mu} = \frac{\boldsymbol{\alpha}}{\alpha_0}$ (expected proportions)
- **Concentration**: $\alpha_0$ (how tightly concentrated around $\boldsymbol{\mu}$)

This gives the alternative parameterization:
$$\text{Dirichlet}(\alpha_0 \cdot \boldsymbol{\mu})$$

In [None]:
# Effect of concentration with fixed base measure
base_measure = np.array([0.5, 0.3, 0.2])  # Expected proportions
concentrations = [1, 5, 20, 100]

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

np.random.seed(42)
for ax, conc in zip(axes.flat, concentrations):
    alpha = conc * base_measure
    plot_dirichlet_simplex(alpha, ax, 
                          title=f'Concentration = {conc}\nα = [{alpha[0]:.1f}, {alpha[1]:.1f}, {alpha[2]:.1f}]')
    
    # Mark the expected value
    exp_x = base_measure[1] + base_measure[2] / 2
    exp_y = base_measure[2] * np.sqrt(3) / 2
    ax.scatter([exp_x], [exp_y], c='red', s=100, marker='*', zorder=5, label='E[θ]')

plt.suptitle(f'Effect of Concentration (Base measure = {base_measure})', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("As concentration increases:")
print("  - Samples concentrate around the expected value")
print("  - Variance decreases")
print("  - Distribution becomes more 'certain'")

In [None]:
# Quantify the effect of concentration on variance
base_measure = np.array([0.5, 0.3, 0.2])
concentrations = [1, 2, 5, 10, 20, 50, 100, 200]

variances = []
for conc in concentrations:
    alpha = conc * base_measure
    samples = np.random.dirichlet(alpha, 10000)
    # Variance of first component
    variances.append(samples[:, 0].var())

# Theoretical variance: Var(θ_k) = μ_k(1-μ_k) / (α_0 + 1)
theoretical_var = [base_measure[0] * (1 - base_measure[0]) / (c + 1) for c in concentrations]

fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(concentrations, variances, 'bo-', lw=2, markersize=8, label='Empirical')
ax.plot(concentrations, theoretical_var, 'r--', lw=2, label='Theoretical')

ax.set_xlabel('Concentration (α₀)')
ax.set_ylabel('Variance of θ₁')
ax.set_title('Variance Decreases with Concentration\nVar(θₖ) = μₖ(1-μₖ) / (α₀ + 1)')
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.show()

### Dirichlet Moments

**Mean:**
$$\mathbb{E}[\theta_k] = \frac{\alpha_k}{\alpha_0} = \mu_k$$

**Variance:**
$$\text{Var}(\theta_k) = \frac{\mu_k(1 - \mu_k)}{\alpha_0 + 1}$$

**Covariance:**
$$\text{Cov}(\theta_j, \theta_k) = \frac{-\mu_j \mu_k}{\alpha_0 + 1} \quad (j \neq k)$$

Note: Components are **negatively correlated** (if one goes up, others must go down to sum to 1).

## 4. Dirichlet-Multinomial Conjugacy

Just as Beta is conjugate to Binomial, **Dirichlet is conjugate to Multinomial**:

$$\text{Dirichlet prior} + \text{Multinomial data} = \text{Dirichlet posterior}$$

### Update Rule

If we observe counts $\mathbf{n} = (n_1, n_2, \ldots, n_K)$:

$$\boldsymbol{\alpha}_{\text{posterior}} = \boldsymbol{\alpha}_{\text{prior}} + \mathbf{n}$$

That's it! Just add the counts to each α.

In [None]:
# Bayesian updating example: Cell type composition

# Prior: We expect roughly equal proportions of 3 cell types
prior_alpha = np.array([2, 2, 2])  # Weak prior

# Data: We observe cell type counts
observed_counts = np.array([45, 30, 25])  # 100 cells total

# Posterior
posterior_alpha = prior_alpha + observed_counts

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

np.random.seed(42)

# Prior
plot_dirichlet_simplex(prior_alpha, axes[0], 
                       title=f'Prior: Dirichlet({list(prior_alpha)})')

# Data visualization (as a point)
data_prop = observed_counts / observed_counts.sum()
data_x = data_prop[1] + data_prop[2] / 2
data_y = data_prop[2] * np.sqrt(3) / 2

axes[1].scatter([data_x], [data_y], c='green', s=200, marker='X', zorder=5)
triangle = plt.Polygon([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]], 
                       fill=False, edgecolor='black', linewidth=2)
axes[1].add_patch(triangle)
axes[1].text(-0.05, -0.05, 'Type A', fontsize=10, ha='center')
axes[1].text(1.05, -0.05, 'Type B', fontsize=10, ha='center')
axes[1].text(0.5, np.sqrt(3)/2 + 0.05, 'Type C', fontsize=10, ha='center')
axes[1].set_xlim(-0.1, 1.1)
axes[1].set_ylim(-0.1, 1.0)
axes[1].set_aspect('equal')
axes[1].axis('off')
axes[1].set_title(f'Data: {observed_counts[0]}A, {observed_counts[1]}B, {observed_counts[2]}C\n'
                  f'Proportions: [{data_prop[0]:.2f}, {data_prop[1]:.2f}, {data_prop[2]:.2f}]')

# Posterior
plot_dirichlet_simplex(posterior_alpha, axes[2],
                       title=f'Posterior: Dirichlet({list(posterior_alpha)})')
# Mark posterior mean
post_mean = posterior_alpha / posterior_alpha.sum()
post_x = post_mean[1] + post_mean[2] / 2
post_y = post_mean[2] * np.sqrt(3) / 2
axes[2].scatter([post_x], [post_y], c='red', s=100, marker='*', zorder=5)

plt.suptitle('Dirichlet-Multinomial Conjugacy: Cell Type Composition', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Prior:")
print(f"  α = {prior_alpha}")
print(f"  E[θ] = {prior_alpha / prior_alpha.sum()}")

print(f"\nData: {observed_counts} (n={observed_counts.sum()})")

print(f"\nPosterior:")
print(f"  α = {posterior_alpha}")
print(f"  E[θ] = {posterior_alpha / posterior_alpha.sum()}")

## 5. Applications in Biology

### Application 1: Cell Type Deconvolution

Given bulk RNA-seq, estimate the proportion of each cell type.

In [None]:
# Simulating cell type deconvolution uncertainty
np.random.seed(42)

# True proportions (unknown in practice)
true_proportions = np.array([0.50, 0.30, 0.15, 0.05])  # 4 cell types
cell_types = ['T cells', 'B cells', 'Macrophages', 'NK cells']

# After deconvolution, we have posterior Dirichlet
# (In practice, this comes from the algorithm)
# Higher concentration = more confident estimate
concentration = 50
posterior_alpha = concentration * true_proportions + np.random.randn(4) * 2
posterior_alpha = np.maximum(posterior_alpha, 1)  # Ensure positive

# Sample from posterior to get uncertainty
n_samples = 10000
proportion_samples = np.random.dirichlet(posterior_alpha, n_samples)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for i, (ax, cell_type) in enumerate(zip(axes.flat, cell_types)):
    samples = proportion_samples[:, i]
    
    ax.hist(samples, bins=50, density=True, alpha=0.7, edgecolor='black')
    ax.axvline(true_proportions[i], color='red', linestyle='--', lw=2, 
               label=f'True = {true_proportions[i]:.2f}')
    ax.axvline(samples.mean(), color='green', linestyle='-', lw=2,
               label=f'Mean = {samples.mean():.3f}')
    
    # 95% CI
    ci_low, ci_high = np.percentile(samples, [2.5, 97.5])
    ax.axvspan(ci_low, ci_high, alpha=0.2, color='green',
               label=f'95% CI: [{ci_low:.3f}, {ci_high:.3f}]')
    
    ax.set_xlabel('Proportion')
    ax.set_ylabel('Density')
    ax.set_title(f'{cell_type}')
    ax.legend(fontsize=9)

plt.suptitle('Cell Type Deconvolution: Posterior Uncertainty', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### Application 2: Treatment Response Subgroups

In causal ML, we might model the probability of belonging to different response subgroups.

In [None]:
# Treatment response subgroups
np.random.seed(42)

# Prior: We expect most patients to be moderate responders
# Subgroups: Non-responder, Low, Moderate, High responder
subgroups = ['Non-responder', 'Low', 'Moderate', 'High']
prior_alpha = np.array([1, 2, 5, 2])  # Prior favors moderate

# After observing trial data
observed_counts = np.array([15, 25, 40, 20])  # 100 patients
posterior_alpha = prior_alpha + observed_counts

# Sample from prior and posterior
n_samples = 10000
prior_samples = np.random.dirichlet(prior_alpha, n_samples)
posterior_samples = np.random.dirichlet(posterior_alpha, n_samples)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Prior
bp1 = axes[0].boxplot([prior_samples[:, i] for i in range(4)], 
                       labels=subgroups, patch_artist=True)
colors = ['#e74c3c', '#f39c12', '#2ecc71', '#3498db']
for patch, color in zip(bp1['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
axes[0].set_ylabel('Proportion')
axes[0].set_title(f'Prior: Dirichlet({list(prior_alpha)})')
axes[0].set_ylim(0, 0.8)

# Posterior
bp2 = axes[1].boxplot([posterior_samples[:, i] for i in range(4)],
                       labels=subgroups, patch_artist=True)
for patch, color in zip(bp2['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
axes[1].set_ylabel('Proportion')
axes[1].set_title(f'Posterior: Dirichlet({list(posterior_alpha)})')
axes[1].set_ylim(0, 0.8)

# Add observed proportions
obs_prop = observed_counts / observed_counts.sum()
axes[1].scatter(range(1, 5), obs_prop, color='red', s=100, marker='*', 
                zorder=5, label='Observed')
axes[1].legend()

plt.suptitle('Treatment Response Subgroups: Bayesian Inference', fontsize=14)
plt.tight_layout()
plt.show()

print("Posterior summary:")
for i, subgroup in enumerate(subgroups):
    mean = posterior_samples[:, i].mean()
    ci_low, ci_high = np.percentile(posterior_samples[:, i], [2.5, 97.5])
    print(f"  {subgroup:15s}: {mean:.3f} [{ci_low:.3f}, {ci_high:.3f}]")

## 6. Connection to Dirichlet Process

The Dirichlet distribution is finite-dimensional. The **Dirichlet Process (DP)** extends this to infinite dimensions:

| Dirichlet Distribution | Dirichlet Process |
|------------------------|-------------------|
| Fixed K categories | Infinite categories |
| $\boldsymbol{\theta} \sim \text{Dir}(\boldsymbol{\alpha})$ | $G \sim \text{DP}(\alpha_0, G_0)$ |
| Concentration: $\alpha_0 = \sum_k \alpha_k$ | Concentration: $\alpha_0$ |
| Base measure: $\boldsymbol{\mu} = \boldsymbol{\alpha}/\alpha_0$ | Base measure: $G_0$ |

### Why DP Matters for Causal ML

1. **Unknown number of subgroups** — DP lets data determine K
2. **Heterogeneous treatment effects** — Different subgroups have different CATEs
3. **Flexible outcome modeling** — DP mixture models for $Y(0)$ and $Y(1)$
4. **Clustering patients** — Discover latent subpopulations

### Stick-Breaking Construction (Preview)

The DP can be constructed via "stick-breaking":

$$\pi_k = V_k \prod_{j=1}^{k-1}(1 - V_j), \quad V_k \sim \text{Beta}(1, \alpha_0)$$

This connects back to Beta! The concentration $\alpha_0$ controls how quickly the stick "breaks" (how many clusters we get).

In [None]:
# Preview: Stick-breaking construction
def stick_breaking(alpha, K=20):
    """Generate stick-breaking weights."""
    betas = np.random.beta(1, alpha, K)
    weights = np.zeros(K)
    remaining = 1.0
    for k in range(K):
        weights[k] = betas[k] * remaining
        remaining *= (1 - betas[k])
    return weights

np.random.seed(42)
concentrations = [0.5, 1, 5, 20]

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for ax, alpha in zip(axes.flat, concentrations):
    # Multiple draws
    for _ in range(5):
        weights = stick_breaking(alpha, K=20)
        ax.bar(range(1, 21), weights, alpha=0.3, width=0.8)
    
    ax.set_xlabel('Component k')
    ax.set_ylabel('Weight πₖ')
    ax.set_title(f'Concentration α₀ = {alpha}')
    ax.set_xlim(0, 21)
    ax.set_ylim(0, 1)

plt.suptitle('Stick-Breaking: Effect of Concentration\n(Low α → few clusters, High α → many clusters)', 
             fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Concentration controls cluster structure:")
print("  Low α₀ (e.g., 0.5): Few dominant clusters")
print("  High α₀ (e.g., 20): Many small clusters")

## 7. Quick Reference

### Key Formulas

```
Dirichlet(α₁, α₂, ..., αₖ)

Concentration: α₀ = Σαₖ
Base measure:  μₖ = αₖ / α₀

Mean:     E[θₖ] = αₖ / α₀
Variance: Var(θₖ) = μₖ(1-μₖ) / (α₀ + 1)

Bayesian update (Multinomial likelihood):
  Prior: Dirichlet(α)
  Data: counts n = (n₁, n₂, ..., nₖ)
  Posterior: Dirichlet(α + n)
```

### Special Cases

| α | Distribution |
|---|-------------|
| [1, 1, ..., 1] | Uniform on simplex |
| [α, α, ..., α] | Symmetric Dirichlet |
| K=2 | Beta distribution |

### Connections

- **Beta** = Dirichlet with K=2
- **Dirichlet Process** = infinite-dimensional Dirichlet
- **LDA** = Dirichlet prior on topic proportions

In [None]:
# Helper function
def dirichlet_summary(alpha):
    """Print summary statistics for a Dirichlet distribution."""
    alpha = np.array(alpha)
    alpha_0 = alpha.sum()
    mu = alpha / alpha_0
    
    print(f"Dirichlet({list(alpha)}) Summary:")
    print(f"  Concentration (α₀): {alpha_0:.2f}")
    print(f"  Base measure (μ):   {mu}")
    print(f"\n  Component statistics:")
    for k in range(len(alpha)):
        var_k = mu[k] * (1 - mu[k]) / (alpha_0 + 1)
        print(f"    θ_{k+1}: E={mu[k]:.4f}, Var={var_k:.6f}")

# Example
dirichlet_summary([2, 3, 5])

---

## Summary

The Dirichlet distribution is essential for:

1. **Modeling probability vectors** (compositions that sum to 1)
2. **Bayesian inference** as conjugate prior for multinomial data
3. **Understanding concentration** — the key parameter controlling uncertainty
4. **Foundation for Dirichlet Process** — nonparametric Bayesian methods

**For causal ML:**
- Model cell type proportions in deconvolution
- Prior for treatment response subgroup proportions
- Foundation for DP mixture models that discover heterogeneous treatment effects

**Next:** See `05_dirichlet_process.ipynb` for the infinite-dimensional extension and its applications to clustering and causal inference.