# Dirichlet Process: Infinite-Dimensional Dirichlet

The **Dirichlet Process (DP)** is a distribution over probability distributions. It extends the finite Dirichlet to infinite dimensions, allowing the data to determine the number of clusters.

## Why This Matters for Causal ML

1. **Heterogeneous treatment effects** — Discover patient subgroups with different CATEs without pre-specifying K
2. **Flexible outcome modeling** — Nonparametric priors for potential outcomes $Y(0)$, $Y(1)$
3. **Latent confounders** — Model unknown confounding structure
4. **Clustering for stratification** — Data-driven patient stratification

## Table of Contents

1. [From Dirichlet to Dirichlet Process](#1-from-dirichlet-to-dirichlet-process)
2. [Stick-Breaking Construction](#2-stick-breaking-construction)
3. [Chinese Restaurant Process](#3-chinese-restaurant-process)
4. [The Concentration Parameter](#4-the-concentration-parameter)
5. [Posterior Inference](#5-posterior-inference)
6. [Connection to Causal ML](#6-connection-to-causal-ml)
7. [Quick Reference](#7-quick-reference)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from collections import Counter

# Style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 1. From Dirichlet to Dirichlet Process

### The Limitation of Finite Dirichlet

With Dirichlet, we must specify K categories upfront:

$$\boldsymbol{\theta} = (\theta_1, \ldots, \theta_K) \sim \text{Dirichlet}(\alpha_1, \ldots, \alpha_K)$$

**Problem:** What if we don't know K? What if K should grow with data?

### The Dirichlet Process Solution

The DP is a distribution over **probability measures**:

$$G \sim \text{DP}(\alpha, G_0)$$

Where:
- $\alpha > 0$ is the **concentration parameter**
- $G_0$ is the **base measure** (prior guess for G)
- $G$ is a random probability measure (discrete with probability 1)

### Key Property: Discreteness

Even if $G_0$ is continuous, samples from a DP are **discrete** (atomic). This is what enables clustering!

In [None]:
# Illustrate: DP draws are discrete even with continuous base measure
np.random.seed(42)

def sample_dp_naive(alpha, G0_sampler, n_samples):
    """
    Naive DP sampling via stick-breaking (for illustration).
    Returns atoms and their weights.
    """
    # Stick-breaking weights
    K = 100  # Truncation
    betas = np.random.beta(1, alpha, K)
    weights = np.zeros(K)
    remaining = 1.0
    for k in range(K):
        weights[k] = betas[k] * remaining
        remaining *= (1 - betas[k])
    
    # Atoms from base measure
    atoms = G0_sampler(K)
    
    # Sample from the discrete distribution
    indices = np.random.choice(K, size=n_samples, p=weights/weights.sum())
    samples = atoms[indices]
    
    return samples, atoms, weights

# Base measure: standard normal
G0_sampler = lambda n: np.random.randn(n)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, alpha in zip(axes, [0.5, 2, 10]):
    samples, atoms, weights = sample_dp_naive(alpha, G0_sampler, 1000)
    
    # Plot histogram of samples
    ax.hist(samples, bins=50, density=True, alpha=0.7, edgecolor='black')
    
    # Overlay base measure
    x = np.linspace(-4, 4, 200)
    ax.plot(x, stats.norm.pdf(x), 'r--', lw=2, label='Base measure $G_0$')
    
    # Count unique values (clusters)
    n_unique = len(np.unique(np.round(samples, 6)))
    
    ax.set_xlabel('x')
    ax.set_ylabel('Density')
    ax.set_title(f'DP(α={alpha}, $G_0$=Normal)\n{n_unique} unique atoms in 1000 samples')
    ax.legend()

plt.suptitle('DP Samples are Discrete (Clustered)', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

## 2. Stick-Breaking Construction

The most intuitive way to understand DP is through **stick-breaking** (Sethuraman, 1994).

### The Construction

Imagine a stick of length 1. We break it sequentially:

1. Break off fraction $V_1 \sim \text{Beta}(1, \alpha)$ → weight $\pi_1 = V_1$
2. From remaining $(1-V_1)$, break off $V_2 \sim \text{Beta}(1, \alpha)$ → weight $\pi_2 = V_2(1-V_1)$
3. Continue forever...

$$\pi_k = V_k \prod_{j=1}^{k-1}(1 - V_j), \quad V_k \stackrel{iid}{\sim} \text{Beta}(1, \alpha)$$

Each weight $\pi_k$ is associated with an atom $\theta_k \sim G_0$:

$$G = \sum_{k=1}^{\infty} \pi_k \delta_{\theta_k}$$

In [None]:
def stick_breaking(alpha, K=50):
    """
    Generate stick-breaking weights for DP.
    
    Args:
        alpha: Concentration parameter
        K: Truncation level
        
    Returns:
        weights: Array of K weights summing to ~1
    """
    betas = np.random.beta(1, alpha, K)
    weights = np.zeros(K)
    remaining = 1.0
    
    for k in range(K):
        weights[k] = betas[k] * remaining
        remaining *= (1 - betas[k])
    
    return weights

# Visualize stick-breaking process
np.random.seed(42)
alpha = 2

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top left: The stick-breaking process visually
ax = axes[0, 0]
K = 10
betas = np.random.beta(1, alpha, K)
weights = []
remaining = 1.0

colors = plt.cm.tab10(np.arange(K))
left = 0
for k in range(K):
    w = betas[k] * remaining
    weights.append(w)
    ax.barh(0, w, left=left, height=0.5, color=colors[k], edgecolor='black')
    if w > 0.03:
        ax.text(left + w/2, 0, f'π{k+1}', ha='center', va='center', fontsize=10)
    left += w
    remaining *= (1 - betas[k])

ax.set_xlim(0, 1)
ax.set_ylim(-0.5, 0.5)
ax.set_xlabel('Cumulative weight')
ax.set_title(f'Stick-Breaking Visualization (α={alpha})')
ax.set_yticks([])

# Top right: Weight distribution
ax = axes[0, 1]
ax.bar(range(1, K+1), weights, color=colors, edgecolor='black')
ax.set_xlabel('Component k')
ax.set_ylabel('Weight πₖ')
ax.set_title(f'Stick-Breaking Weights (sum = {sum(weights):.4f})')

# Bottom: Multiple draws for different α
for ax, alpha in zip(axes[1], [0.5, 10]):
    for i in range(5):
        weights = stick_breaking(alpha, K=30)
        ax.plot(range(1, 31), weights, 'o-', alpha=0.5, markersize=4)
    
    ax.set_xlabel('Component k')
    ax.set_ylabel('Weight πₖ')
    ax.set_title(f'Multiple Draws: α = {alpha}')
    ax.set_xlim(0, 31)

plt.tight_layout()
plt.show()

print("Key insight:")
print("  - Small α → first few components dominate (few clusters)")
print("  - Large α → weights spread across many components (many clusters)")

In [None]:
# Expected number of components with weight > threshold
np.random.seed(42)

alphas = [0.1, 0.5, 1, 2, 5, 10, 20]
thresholds = [0.01, 0.05, 0.1]

fig, ax = plt.subplots(figsize=(10, 6))

for thresh in thresholds:
    n_components = []
    for alpha in alphas:
        counts = []
        for _ in range(100):
            weights = stick_breaking(alpha, K=100)
            counts.append(np.sum(weights > thresh))
        n_components.append(np.mean(counts))
    
    ax.plot(alphas, n_components, 'o-', lw=2, markersize=8, 
            label=f'Weight > {thresh}')

ax.set_xlabel('Concentration α')
ax.set_ylabel('Expected number of components')
ax.set_title('Number of "Active" Components vs Concentration')
ax.legend()
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

plt.show()

## 3. Chinese Restaurant Process

The **Chinese Restaurant Process (CRP)** is an equivalent way to understand DP, focused on clustering.

### The Metaphor

Imagine customers entering a restaurant with infinite tables:

1. **Customer 1** sits at table 1
2. **Customer n+1** either:
   - Joins existing table k with probability $\frac{n_k}{n + \alpha}$ (where $n_k$ = customers at table k)
   - Starts a new table with probability $\frac{\alpha}{n + \alpha}$

### The "Rich Get Richer" Property

Popular tables attract more customers → **preferential attachment**

This creates power-law-like cluster size distributions.

In [None]:
def chinese_restaurant_process(n_customers, alpha):
    """
    Simulate the Chinese Restaurant Process.
    
    Args:
        n_customers: Number of customers
        alpha: Concentration parameter
        
    Returns:
        assignments: Table assignment for each customer
        table_counts: Number of customers at each table
    """
    assignments = []
    table_counts = []  # Number at each table
    
    for n in range(n_customers):
        if n == 0:
            # First customer starts table 1
            assignments.append(0)
            table_counts.append(1)
        else:
            # Compute probabilities
            probs = np.array(table_counts + [alpha]) / (n + alpha)
            
            # Sample table
            table = np.random.choice(len(probs), p=probs)
            
            if table == len(table_counts):
                # New table
                assignments.append(len(table_counts))
                table_counts.append(1)
            else:
                # Existing table
                assignments.append(table)
                table_counts[table] += 1
    
    return np.array(assignments), np.array(table_counts)

# Visualize CRP
np.random.seed(42)
n_customers = 100

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for ax_row, alpha in zip(axes, [1, 10]):
    assignments, table_counts = chinese_restaurant_process(n_customers, alpha)
    n_tables = len(table_counts)
    
    # Seating arrangement
    ax_row[0].scatter(range(n_customers), assignments, c=assignments, cmap='tab20', s=30)
    ax_row[0].set_xlabel('Customer arrival order')
    ax_row[0].set_ylabel('Table assignment')
    ax_row[0].set_title(f'CRP(α={alpha}): Seating over time\n{n_tables} tables')
    
    # Table sizes
    sorted_counts = np.sort(table_counts)[::-1]
    ax_row[1].bar(range(1, n_tables+1), sorted_counts, color='steelblue', edgecolor='black')
    ax_row[1].set_xlabel('Table (sorted by size)')
    ax_row[1].set_ylabel('Number of customers')
    ax_row[1].set_title(f'Table Size Distribution')
    
    # Multiple runs: number of tables
    n_tables_runs = []
    for _ in range(500):
        _, counts = chinese_restaurant_process(n_customers, alpha)
        n_tables_runs.append(len(counts))
    
    ax_row[2].hist(n_tables_runs, bins=20, density=True, alpha=0.7, edgecolor='black')
    ax_row[2].axvline(np.mean(n_tables_runs), color='red', linestyle='--', 
                      label=f'Mean = {np.mean(n_tables_runs):.1f}')
    ax_row[2].set_xlabel('Number of tables')
    ax_row[2].set_ylabel('Density')
    ax_row[2].set_title(f'Distribution of #Tables (500 runs)')
    ax_row[2].legend()

plt.suptitle('Chinese Restaurant Process', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Expected number of tables: E[K] ≈ α log(1 + n/α)
np.random.seed(42)

n_values = [10, 50, 100, 500, 1000]
alphas = [0.5, 1, 2, 5, 10]

fig, ax = plt.subplots(figsize=(10, 6))

for alpha in alphas:
    empirical = []
    theoretical = []
    
    for n in n_values:
        # Empirical
        n_tables = [len(chinese_restaurant_process(n, alpha)[1]) for _ in range(100)]
        empirical.append(np.mean(n_tables))
        
        # Theoretical approximation
        theoretical.append(alpha * np.log(1 + n/alpha))
    
    ax.plot(n_values, empirical, 'o-', lw=2, label=f'α={alpha} (empirical)')
    ax.plot(n_values, theoretical, '--', alpha=0.5)

ax.set_xlabel('Number of customers (n)')
ax.set_ylabel('Expected number of tables')
ax.set_title('Expected Number of Clusters: E[K] ≈ α log(1 + n/α)')
ax.legend()
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

plt.show()

print("Key result: Number of clusters grows logarithmically with n")
print("This is much slower than linear — DP is parsimonious!")

## 4. The Concentration Parameter

The concentration parameter $\alpha$ is the key hyperparameter:

| α | Effect |
|---|--------|
| Small (< 1) | Few large clusters, strong "rich get richer" |
| α = 1 | Moderate clustering |
| Large (> 10) | Many small clusters, closer to base measure |
| α → ∞ | Approaches iid samples from $G_0$ |

### Interpretation

- $\alpha$ = "pseudo-count" for a new cluster
- Probability of new cluster = $\frac{\alpha}{n + \alpha}$
- Expected number of clusters ≈ $\alpha \log(n)$

In [None]:
# Comprehensive visualization of α effect
np.random.seed(42)
n_customers = 200
alphas = [0.1, 0.5, 1, 2, 5, 10, 20, 50]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for ax, alpha in zip(axes.flat, alphas):
    assignments, table_counts = chinese_restaurant_process(n_customers, alpha)
    n_tables = len(table_counts)
    
    # Plot cluster sizes (sorted)
    sorted_counts = np.sort(table_counts)[::-1]
    ax.bar(range(1, min(21, n_tables+1)), sorted_counts[:20], 
           color='steelblue', edgecolor='black', alpha=0.7)
    
    ax.set_xlabel('Cluster rank')
    ax.set_ylabel('Size')
    ax.set_title(f'α = {alpha}\n{n_tables} clusters')
    ax.set_xlim(0, 21)

plt.suptitle(f'Effect of Concentration on Clustering (n={n_customers})', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 5. Posterior Inference

Given observations, the DP posterior is also a DP (conjugacy!):

$$G | x_1, \ldots, x_n \sim \text{DP}\left(\alpha + n, \frac{\alpha G_0 + \sum_{i=1}^n \delta_{x_i}}{\alpha + n}\right)$$

The posterior base measure is a mixture of:
- Prior base measure $G_0$ (weight $\alpha$)
- Empirical distribution of data (weight $n$)

### Predictive Distribution

For a new observation $x_{n+1}$:

$$x_{n+1} | x_1, \ldots, x_n \sim \frac{\alpha}{\alpha + n} G_0 + \frac{1}{\alpha + n} \sum_{i=1}^n \delta_{x_i}$$

This is the **Pólya urn** scheme: either draw from $G_0$ or copy an existing observation.

In [None]:
# Pólya urn / Blackwell-MacQueen scheme
def polya_urn(alpha, G0_sampler, n_samples):
    """
    Sample from DP using Pólya urn scheme.
    
    Each new sample either:
    - Comes from G0 (prob α/(α+n))
    - Copies an existing sample (prob n/(α+n))
    """
    samples = []
    
    for n in range(n_samples):
        if n == 0 or np.random.rand() < alpha / (alpha + n):
            # Draw from base measure
            samples.append(G0_sampler())
        else:
            # Copy existing sample
            samples.append(np.random.choice(samples))
    
    return np.array(samples)

# Visualize Pólya urn
np.random.seed(42)
G0_sampler = lambda: np.random.randn()

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, alpha in zip(axes, [0.5, 2, 10]):
    samples = polya_urn(alpha, G0_sampler, 500)
    
    ax.hist(samples, bins=50, density=True, alpha=0.7, edgecolor='black')
    
    # Base measure
    x = np.linspace(-4, 4, 200)
    ax.plot(x, stats.norm.pdf(x), 'r--', lw=2, label='$G_0$ = Normal(0,1)')
    
    n_unique = len(np.unique(np.round(samples, 10)))
    ax.set_xlabel('x')
    ax.set_ylabel('Density')
    ax.set_title(f'Pólya Urn (α={alpha})\n{n_unique} unique values')
    ax.legend()

plt.suptitle('Pólya Urn Sampling from DP', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

## 6. Connection to Causal ML

### Why DP Matters for Causal Inference

1. **Heterogeneous Treatment Effects (HTE)**
   - Patients may belong to unknown subgroups with different treatment effects
   - DP mixture models can discover these subgroups
   - No need to pre-specify the number of subgroups

2. **Flexible Outcome Modeling**
   - Potential outcomes $Y(0)$, $Y(1)$ may have complex distributions
   - DP mixtures provide flexible nonparametric priors

3. **Instrumental Variables**
   - Latent compliance types (compliers, always-takers, never-takers)
   - DP can model unknown compliance structure

4. **Confounding**
   - Latent confounders may cluster observations
   - DP can discover confounding structure

In [None]:
# Example: Discovering treatment effect heterogeneity
np.random.seed(42)

# True subgroups (unknown to us)
n_patients = 300
true_subgroups = np.random.choice([0, 1, 2], n_patients, p=[0.5, 0.3, 0.2])

# True treatment effects by subgroup
true_effects = {0: 0, 1: 5, 2: -3}  # Non-responders, responders, adverse

# Generate data
treatment = np.random.binomial(1, 0.5, n_patients)
baseline = np.random.randn(n_patients) * 2
noise = np.random.randn(n_patients) * 1

# Outcome depends on subgroup
individual_effects = np.array([true_effects[s] for s in true_subgroups])
outcome = baseline + treatment * individual_effects + noise

# Naive ATE estimate
naive_ate = outcome[treatment == 1].mean() - outcome[treatment == 0].mean()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# True subgroups
colors = ['#e74c3c', '#2ecc71', '#3498db']
for s in range(3):
    mask = true_subgroups == s
    axes[0].scatter(baseline[mask], outcome[mask], c=colors[s], 
                    alpha=0.5, label=f'Subgroup {s} (effect={true_effects[s]})')
axes[0].set_xlabel('Baseline')
axes[0].set_ylabel('Outcome')
axes[0].set_title('True Subgroups (Unknown)')
axes[0].legend()

# What we observe
axes[1].scatter(baseline[treatment==0], outcome[treatment==0], 
                alpha=0.5, label='Control', c='steelblue')
axes[1].scatter(baseline[treatment==1], outcome[treatment==1], 
                alpha=0.5, label='Treated', c='coral')
axes[1].set_xlabel('Baseline')
axes[1].set_ylabel('Outcome')
axes[1].set_title(f'Observed Data\nNaive ATE = {naive_ate:.2f}')
axes[1].legend()

# Distribution of individual effects
axes[2].hist(individual_effects, bins=20, density=True, alpha=0.7, edgecolor='black')
axes[2].axvline(naive_ate, color='red', linestyle='--', lw=2, label=f'Naive ATE = {naive_ate:.2f}')
axes[2].axvline(0, color='black', linestyle='-', alpha=0.3)
axes[2].set_xlabel('Individual Treatment Effect')
axes[2].set_ylabel('Density')
axes[2].set_title('True Distribution of Effects\n(Heterogeneous!)')
axes[2].legend()

plt.suptitle('Treatment Effect Heterogeneity: Why DP Matters', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Key insight:")
print(f"  - Naive ATE = {naive_ate:.2f} (average across all subgroups)")
print(f"  - But true effects are: {list(true_effects.values())}")
print(f"  - DP mixture models can discover these subgroups!")
print(f"  - This enables personalized treatment decisions.")

## 7. Quick Reference

### Key Concepts

| Concept | Description |
|---------|-------------|
| **DP(α, G₀)** | Distribution over distributions |
| **Concentration α** | Controls number of clusters |
| **Base measure G₀** | Prior for cluster locations |
| **Stick-breaking** | Constructive definition via Beta draws |
| **CRP** | Sequential clustering metaphor |
| **Pólya urn** | Predictive distribution |

### Key Formulas

```
Stick-breaking:
  Vₖ ~ Beta(1, α)
  πₖ = Vₖ ∏ⱼ₌₁ᵏ⁻¹ (1 - Vⱼ)
  θₖ ~ G₀
  G = Σₖ πₖ δ_θₖ

CRP:
  P(new table) = α / (n + α)
  P(table k) = nₖ / (n + α)

Expected clusters:
  E[K] ≈ α log(1 + n/α)
```

### When to Use DP

- Unknown number of clusters
- Want data to determine complexity
- Flexible nonparametric modeling
- Discovering latent subgroups

In [None]:
# Helper functions for future use

def dp_summary(alpha, n):
    """Print summary for DP with given α and n observations."""
    expected_clusters = alpha * np.log(1 + n/alpha)
    prob_new = alpha / (alpha + n)
    
    print(f"DP(α={alpha}) with n={n} observations:")
    print(f"  Expected clusters: {expected_clusters:.2f}")
    print(f"  P(new cluster for n+1): {prob_new:.4f}")
    print(f"  P(join existing): {1-prob_new:.4f}")

# Example
dp_summary(alpha=2, n=100)

---

## Summary

The Dirichlet Process is essential for:

1. **Nonparametric clustering** — let data determine K
2. **Flexible density estimation** — DP mixture models
3. **Bayesian nonparametrics** — infinite-dimensional priors

**For causal ML:**
- Discover patient subgroups with heterogeneous treatment effects
- Model complex outcome distributions
- Handle unknown confounding structure

**Key insight:** The concentration parameter α controls the "rich get richer" dynamics:
- Small α → few dominant clusters
- Large α → many small clusters

**Next:** See `06_dp_mixture_models.ipynb` for practical applications to clustering and causal inference.