# Dirichlet Process Mixture Models

**DP Mixture Models (DPMMs)** combine the Dirichlet Process with parametric mixture components for flexible, nonparametric clustering.

## Why This Matters for Causal ML

1. **Heterogeneous Treatment Effects** — Discover patient subgroups with different CATEs
2. **Flexible Outcome Modeling** — Nonparametric priors for potential outcomes
3. **Clustering for Stratification** — Data-driven patient stratification

## Table of Contents

1. [DPMM Overview](#1-dpmm-overview)
2. [Gaussian DPMM Implementation](#2-gaussian-dpmm-implementation)
3. [Application: Treatment Effect Heterogeneity](#3-application-treatment-effect-heterogeneity)
4. [Using sklearn](#4-using-sklearn)
5. [Quick Reference](#5-quick-reference)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import logsumexp
from collections import Counter

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
np.random.seed(42)

## 1. DPMM Overview

### The Model

$$G \sim \text{DP}(\alpha, G_0)$$
$$\theta_i | G \sim G$$
$$x_i | \theta_i \sim F(\theta_i)$$

Where:
- $G_0$ is the base measure (prior over component parameters)
- $F(\theta)$ is the likelihood (e.g., Gaussian)
- Multiple observations share the same $\theta$ → clustering

In [None]:
# Generate mixture data
true_means = [-3, 0, 2, 5]
true_stds = [0.5, 0.8, 0.4, 0.6]
true_weights = [0.3, 0.25, 0.25, 0.2]

n_samples = 300
true_assignments = np.random.choice(4, n_samples, p=true_weights)
data = np.array([np.random.randn() * true_stds[k] + true_means[k] 
                 for k in true_assignments])

# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(data, bins=40, density=True, alpha=0.5, edgecolor='black')

x = np.linspace(-6, 8, 200)
for k, (mu, std, w) in enumerate(zip(true_means, true_stds, true_weights)):
    ax.plot(x, w * stats.norm.pdf(x, mu, std), '--', lw=2, label=f'Component {k}')

ax.set_xlabel('x')
ax.set_ylabel('Density')
ax.set_title('True Mixture (K=4)')
ax.legend()
plt.show()

## 2. Gaussian DPMM Implementation

A simple collapsed Gibbs sampler for 1D Gaussian DPMM.

In [None]:
class GaussianDPMM:
    """Simple Gaussian DPMM with collapsed Gibbs sampling."""
    
    def __init__(self, alpha=1.0, mu0=0.0, sigma0=3.0, sigma=1.0):
        self.alpha = alpha
        self.mu0 = mu0
        self.sigma0 = sigma0
        self.sigma = sigma
        
    def _posterior_params(self, data_k):
        n = len(data_k)
        if n == 0:
            return self.mu0, self.sigma0
        precision0 = 1 / self.sigma0**2
        precision_data = n / self.sigma**2
        precision_post = precision0 + precision_data
        sigma_post = np.sqrt(1 / precision_post)
        mu_post = (precision0 * self.mu0 + precision_data * np.mean(data_k)) / precision_post
        return mu_post, sigma_post
    
    def _marginal_likelihood(self, x, data_k):
        mu_post, sigma_post = self._posterior_params(data_k)
        pred_var = sigma_post**2 + self.sigma**2
        return stats.norm.pdf(x, mu_post, np.sqrt(pred_var))
    
    def fit(self, data, n_iter=100):
        n = len(data)
        assignments = np.arange(n)
        history = []
        
        for iteration in range(n_iter):
            for i in np.random.permutation(n):
                temp_assignments = assignments.copy()
                temp_assignments[i] = -1
                
                log_probs, cluster_list = [], []
                for k in set(temp_assignments) - {-1}:
                    mask = temp_assignments == k
                    n_k = np.sum(mask)
                    log_prob = np.log(n_k) + np.log(self._marginal_likelihood(data[i], data[mask]))
                    log_probs.append(log_prob)
                    cluster_list.append(k)
                
                log_prob_new = np.log(self.alpha) + np.log(self._marginal_likelihood(data[i], []))
                log_probs.append(log_prob_new)
                cluster_list.append(max(cluster_list) + 1 if cluster_list else 0)
                
                log_probs = np.array(log_probs)
                probs = np.exp(log_probs - logsumexp(log_probs))
                assignments[i] = np.random.choice(cluster_list, p=probs)
            
            # Relabel
            unique = sorted(set(assignments))
            mapping = {old: new for new, old in enumerate(unique)}
            assignments = np.array([mapping[k] for k in assignments])
            history.append(len(set(assignments)))
            
            if (iteration + 1) % 25 == 0:
                print(f"Iter {iteration+1}: {len(set(assignments))} clusters")
        
        self.assignments_ = assignments
        self.history_ = history
        return assignments

In [None]:
# Fit DPMM
model = GaussianDPMM(alpha=2.0, mu0=0.0, sigma0=5.0, sigma=0.6)
assignments = model.fit(data, n_iter=100)

n_clusters = len(set(assignments))
print(f"\nFound {n_clusters} clusters (true: 4)")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Cluster count over iterations
axes[0].plot(model.history_, 'b-', lw=2)
axes[0].axhline(4, color='red', linestyle='--', label='True K=4')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Number of clusters')
axes[0].set_title('Convergence')
axes[0].legend()

# Inferred clusters
colors = plt.cm.tab10(np.arange(n_clusters))
for k in range(n_clusters):
    mask = assignments == k
    axes[1].scatter(np.arange(len(data))[mask], data[mask], c=[colors[k]], alpha=0.6, s=20)
axes[1].set_xlabel('Index')
axes[1].set_ylabel('x')
axes[1].set_title(f'Inferred Clusters (K={n_clusters})')

# Compare means
inferred_means = sorted([data[assignments == k].mean() for k in range(n_clusters)])
axes[2].bar(np.arange(4) - 0.2, sorted(true_means), width=0.4, alpha=0.7, label='True')
axes[2].bar(np.arange(len(inferred_means)) + 0.2, inferred_means, width=0.4, alpha=0.7, label='Inferred')
axes[2].set_xlabel('Cluster')
axes[2].set_ylabel('Mean')
axes[2].set_title('Cluster Means')
axes[2].legend()

plt.tight_layout()
plt.show()

## 3. Application: Treatment Effect Heterogeneity

Discovering patient subgroups with different treatment effects.

In [None]:
# Simulate heterogeneous treatment effects
np.random.seed(42)
n_patients = 400

# True subgroups
subgroup_effects = {0: 0.0, 1: 3.0, 2: 8.0, 3: -2.0}
subgroup_probs = [0.35, 0.30, 0.20, 0.15]

true_subgroups = np.random.choice(4, n_patients, p=subgroup_probs)
true_effects = np.array([subgroup_effects[s] for s in true_subgroups])

# Treatment and outcomes
treatment = np.random.binomial(1, 0.5, n_patients)
baseline = np.random.randn(n_patients) * 2 + 10
noise = np.random.randn(n_patients) * 1.5

Y0 = baseline + noise
Y1 = baseline + true_effects + noise
Y_obs = treatment * Y1 + (1 - treatment) * Y0

naive_ate = Y_obs[treatment == 1].mean() - Y_obs[treatment == 0].mean()
true_ate = true_effects.mean()

print(f"True ATE: {true_ate:.2f}")
print(f"Naive ATE: {naive_ate:.2f}")
print(f"True subgroup effects: {list(subgroup_effects.values())}")

In [None]:
# Visualize heterogeneity
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Effect distribution
axes[0].hist(true_effects, bins=20, density=True, alpha=0.7, edgecolor='black')
axes[0].axvline(true_ate, color='red', linestyle='--', lw=2, label=f'ATE = {true_ate:.2f}')
axes[0].set_xlabel('Treatment Effect')
axes[0].set_ylabel('Density')
axes[0].set_title('True Effect Distribution (Heterogeneous)')
axes[0].legend()

# By subgroup
colors = plt.cm.tab10(np.arange(4))
for s in range(4):
    mask = true_subgroups == s
    axes[1].scatter(baseline[mask], true_effects[mask], c=[colors[s]], 
                    alpha=0.5, s=30, label=f'Subgroup {s} (τ={subgroup_effects[s]})')
axes[1].set_xlabel('Baseline')
axes[1].set_ylabel('Treatment Effect')
axes[1].set_title('Effects by Subgroup')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Use DPMM to discover subgroups
model_hte = GaussianDPMM(alpha=2.0, mu0=0.0, sigma0=5.0, sigma=0.5)
inferred = model_hte.fit(true_effects, n_iter=100)

n_found = len(set(inferred))
print(f"\nFound {n_found} subgroups")

for k in range(n_found):
    mask = inferred == k
    print(f"  Subgroup {k}: n={np.sum(mask)}, effect={true_effects[mask].mean():.2f}")

## 4. Using sklearn

For practical use, sklearn provides `BayesianGaussianMixture` with a DP-like prior.

In [None]:
from sklearn.mixture import BayesianGaussianMixture

# Fit with sklearn
bgm = BayesianGaussianMixture(
    n_components=10,  # Upper bound
    weight_concentration_prior_type='dirichlet_process',
    weight_concentration_prior=2.0,  # Concentration
    random_state=42
)

bgm.fit(data.reshape(-1, 1))
sklearn_assignments = bgm.predict(data.reshape(-1, 1))

# Count active components
weights = bgm.weights_
active = np.sum(weights > 0.01)
print(f"sklearn found {active} active components")
print(f"Weights: {weights[weights > 0.01].round(3)}")

In [None]:
# Compare our implementation vs sklearn
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Our implementation
for k in range(n_clusters):
    mask = assignments == k
    axes[0].scatter(np.arange(len(data))[mask], data[mask], alpha=0.6, s=20)
axes[0].set_title(f'Our DPMM: {n_clusters} clusters')
axes[0].set_xlabel('Index')
axes[0].set_ylabel('x')

# sklearn
for k in np.unique(sklearn_assignments):
    mask = sklearn_assignments == k
    axes[1].scatter(np.arange(len(data))[mask], data[mask], alpha=0.6, s=20)
axes[1].set_title(f'sklearn BGM: {len(np.unique(sklearn_assignments))} clusters')
axes[1].set_xlabel('Index')
axes[1].set_ylabel('x')

plt.tight_layout()
plt.show()

## 5. Quick Reference

### Key Concepts

| Concept | Description |
|---------|-------------|
| **DPMM** | DP prior + parametric likelihood |
| **Concentration α** | Controls expected number of clusters |
| **Collapsed Gibbs** | Integrate out component parameters |
| **sklearn BGM** | Practical implementation with DP prior |

### When to Use DPMM

- Unknown number of clusters
- Heterogeneous treatment effects
- Flexible density estimation
- Patient stratification

### sklearn Usage

```python
from sklearn.mixture import BayesianGaussianMixture

model = BayesianGaussianMixture(
    n_components=20,  # Upper bound
    weight_concentration_prior_type='dirichlet_process',
    weight_concentration_prior=alpha
)
model.fit(X)
labels = model.predict(X)
```

---

## Summary

DPMMs are powerful for:

1. **Nonparametric clustering** — data determines K
2. **Treatment effect heterogeneity** — discover patient subgroups
3. **Flexible modeling** — complex outcome distributions

**For causal ML:** Use DPMMs to discover latent subgroups with different treatment effects, enabling personalized treatment decisions.

**Next steps:** Return to core causal ML topics (ATE, CATE, causal graphs) in the main project.