# Diffusion Models for Gene Expression: DDPM Tutorial

**Goal**: Implement a denoising diffusion probabilistic model (DDPM) for generating gene expression profiles.

This notebook demonstrates:
1. Core DDPM mechanics (forward/reverse diffusion)
2. Training a time-conditional score network on gene expression data
3. Conditional generation (cell type → gene expression)
4. Foundation for drug-response prediction (scPPDM approach)

**Dataset**: PBMC 3k (small subset for fast iteration)

**Next steps**: Extend to perturbation response (baseline + drug → perturbed expression)

---

## Prerequisites

**Environment**: Make sure you're in the `genailab` conda environment:

```bash
mamba activate genailab
```

**Required packages**: torch, scanpy, numpy, matplotlib, tqdm

## Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import scanpy as sc
from tqdm.auto import tqdm

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load and Prepare Gene Expression Data

We'll use a small subset of PBMC 3k for fast iteration. For production, you'd use the full dataset.

In [None]:
# Load PBMC 3k data
data_path = Path("../data/pbmc3k_raw.h5ad")

if data_path.exists():
    adata = sc.read_h5ad(data_path)
    print(f"Loaded data: {adata.shape}")
else:
    # Download if not available
    adata = sc.datasets.pbmc3k()
    print(f"Downloaded PBMC 3k: {adata.shape}")

# Basic preprocessing
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

# Normalize and log-transform for diffusion model
# Note: For count-based models (like NB VAE), we'd use raw counts
# For diffusion, we typically work with normalized continuous data
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# Select highly variable genes for faster training
sc.pp.highly_variable_genes(adata, n_top_genes=500, flavor='seurat_v3')
adata = adata[:, adata.var.highly_variable].copy()

print(f"After preprocessing: {adata.shape}")
print(f"Gene expression range: [{adata.X.min():.2f}, {adata.X.max():.2f}]")

In [None]:
# Annotate cell types for conditional generation
sc.pp.neighbors(adata, n_neighbors=10)
sc.tl.leiden(adata, resolution=0.5)

# Store cell type labels
cell_types = adata.obs['leiden'].values
n_cell_types = len(np.unique(cell_types))

print(f"Found {n_cell_types} cell type clusters")
print(adata.obs['leiden'].value_counts())

## 2. Create PyTorch Dataset

We'll create a dataset that returns:
- Gene expression vector (x)
- Cell type label (condition)

In [None]:
class GeneExpressionDataset(Dataset):
    """Dataset for gene expression with optional conditioning."""
    
    def __init__(self, adata, condition_key=None):
        """
        Args:
            adata: AnnData object with preprocessed gene expression
            condition_key: Key in adata.obs for conditioning (e.g., 'leiden', 'treatment')
        """
        # Convert to dense array if sparse
        if hasattr(adata.X, 'toarray'):
            self.X = adata.X.toarray()
        else:
            self.X = adata.X
        
        self.X = torch.FloatTensor(self.X)
        
        # Extract conditions if provided
        if condition_key is not None:
            conditions = adata.obs[condition_key].values
            # Convert to integer labels
            unique_conditions = np.unique(conditions)
            condition_to_idx = {c: i for i, c in enumerate(unique_conditions)}
            self.conditions = torch.LongTensor([condition_to_idx[c] for c in conditions])
            self.n_conditions = len(unique_conditions)
        else:
            self.conditions = None
            self.n_conditions = 0
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.conditions is not None:
            return self.X[idx], self.conditions[idx]
        return self.X[idx]

# Create dataset
dataset = GeneExpressionDataset(adata, condition_key='leiden')
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

print(f"Dataset size: {len(dataset)}")
print(f"Gene dimension: {dataset.X.shape[1]}")
print(f"Number of conditions: {dataset.n_conditions}")

## 3. Implement DDPM Components

### 3.1 Noise Scheduler

The noise scheduler defines how we add noise in the forward process:
- $\beta_t$: variance schedule (how much noise to add at each step)
- $\alpha_t = 1 - \beta_t$
- $\bar{\alpha}_t = \prod_{i=1}^t \alpha_i$ (cumulative product)

In [None]:
class NoiseScheduler:
    """Linear noise schedule for DDPM."""
    
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        """
        Args:
            num_timesteps: Number of diffusion steps (T)
            beta_start: Starting noise variance
            beta_end: Ending noise variance
        """
        self.num_timesteps = num_timesteps
        
        # Linear schedule for beta
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        
        # Precompute useful quantities
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # For sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # For posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def add_noise(self, x_0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
        
        Args:
            x_0: Original data [batch_size, dim]
            t: Timestep [batch_size]
            noise: Optional noise to add (for reproducibility)
        
        Returns:
            x_t: Noisy data at timestep t
            noise: The noise that was added
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Get coefficients for this timestep
        sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].reshape(-1, 1)
        sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1)
        
        # x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
        x_t = sqrt_alpha_prod * x_0 + sqrt_one_minus_alpha_prod * noise
        
        return x_t, noise

# Test the scheduler
scheduler = NoiseScheduler(num_timesteps=1000)

# Visualize noise schedule
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(scheduler.betas.numpy())
axes[0].set_title('Beta Schedule')
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('Beta')

axes[1].plot(scheduler.alphas_cumprod.numpy())
axes[1].set_title('Cumulative Alpha')
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('Alpha_bar')

axes[2].plot(scheduler.sqrt_one_minus_alphas_cumprod.numpy())
axes[2].set_title('Noise Coefficient')
axes[2].set_xlabel('Timestep')
axes[2].set_ylabel('sqrt(1 - alpha_bar)')

plt.tight_layout()
plt.show()

### 3.2 Visualize Forward Diffusion Process

Let's see how a gene expression vector gets progressively noisier.

In [None]:
# Take a sample gene expression vector
x_0 = dataset.X[0:1]  # Shape: [1, n_genes]

# Add noise at different timesteps
timesteps = [0, 100, 250, 500, 750, 999]
noisy_samples = []

for t in timesteps:
    t_tensor = torch.tensor([t])
    x_t, _ = scheduler.add_noise(x_0, t_tensor)
    noisy_samples.append(x_t[0].numpy())

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for i, (t, x_t) in enumerate(zip(timesteps, noisy_samples)):
    axes[i].hist(x_t, bins=50, alpha=0.7)
    axes[i].set_title(f'Timestep t={t}')
    axes[i].set_xlabel('Expression value')
    axes[i].set_ylabel('Frequency')
    axes[i].axvline(x=0, color='r', linestyle='--', alpha=0.5)

plt.suptitle('Forward Diffusion: Gene Expression → Noise', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print(f"Original data mean: {x_0.mean():.3f}, std: {x_0.std():.3f}")
print(f"Final noise mean: {noisy_samples[-1].mean():.3f}, std: {noisy_samples[-1].std():.3f}")

### 3.3 Time-Conditional Score Network

The core of DDPM: a neural network that predicts the noise $\epsilon_\theta(x_t, t, c)$ given:
- Noisy data $x_t$
- Timestep $t$
- Condition $c$ (e.g., cell type, drug)

For gene expression (tabular data), we use an MLP with:
- Sinusoidal time embeddings
- Conditional embeddings
- Residual connections

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    """Sinusoidal embeddings for timesteps (like in Transformers)."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        return embeddings


class MLPBlock(nn.Module):
    """MLP block with residual connection."""
    
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout),
        )
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x):
        return self.norm(x + self.net(x))


class ConditionalScoreNetwork(nn.Module):
    """Time and condition-conditional score network for gene expression."""
    
    def __init__(
        self,
        input_dim,
        hidden_dim=256,
        time_dim=64,
        n_conditions=0,
        condition_dim=32,
        n_layers=4,
        dropout=0.1,
    ):
        """
        Args:
            input_dim: Gene expression dimension
            hidden_dim: Hidden layer dimension
            time_dim: Time embedding dimension
            n_conditions: Number of condition classes (0 for unconditional)
            condition_dim: Condition embedding dimension
            n_layers: Number of MLP blocks
            dropout: Dropout rate
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.n_conditions = n_conditions
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.GELU(),
            nn.Linear(time_dim * 2, time_dim),
        )
        
        # Condition embedding (if conditional)
        if n_conditions > 0:
            self.condition_embed = nn.Embedding(n_conditions, condition_dim)
            total_input_dim = input_dim + time_dim + condition_dim
        else:
            self.condition_embed = None
            total_input_dim = input_dim + time_dim
        
        # Input projection
        self.input_proj = nn.Linear(total_input_dim, hidden_dim)
        
        # MLP blocks
        self.blocks = nn.ModuleList([
            MLPBlock(hidden_dim, dropout) for _ in range(n_layers)
        ])
        
        # Output projection (predict noise)
        self.output_proj = nn.Linear(hidden_dim, input_dim)
    
    def forward(self, x, t, condition=None):
        """
        Args:
            x: Noisy gene expression [batch_size, input_dim]
            t: Timestep [batch_size]
            condition: Condition labels [batch_size] (optional)
        
        Returns:
            Predicted noise [batch_size, input_dim]
        """
        # Time embedding
        t_emb = self.time_mlp(t)
        
        # Concatenate inputs
        if self.condition_embed is not None and condition is not None:
            c_emb = self.condition_embed(condition)
            h = torch.cat([x, t_emb, c_emb], dim=-1)
        else:
            h = torch.cat([x, t_emb], dim=-1)
        
        # Project to hidden dimension
        h = self.input_proj(h)
        
        # Apply MLP blocks
        for block in self.blocks:
            h = block(h)
        
        # Predict noise
        noise_pred = self.output_proj(h)
        
        return noise_pred

# Test the network
model = ConditionalScoreNetwork(
    input_dim=dataset.X.shape[1],
    hidden_dim=256,
    n_conditions=dataset.n_conditions,
    n_layers=4,
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
x_test, c_test = next(iter(dataloader))
x_test, c_test = x_test.to(device), c_test.to(device)
t_test = torch.randint(0, 1000, (x_test.shape[0],), device=device)

noise_pred = model(x_test, t_test, c_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {noise_pred.shape}")

## 4. Training Loop

DDPM training is simple:
1. Sample a batch of data $x_0$
2. Sample random timesteps $t$
3. Add noise: $x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon$
4. Predict noise: $\epsilon_\theta(x_t, t, c)$
5. Compute MSE loss: $\|\epsilon - \epsilon_\theta(x_t, t, c)\|^2$

In [None]:
def train_ddpm(model, dataloader, scheduler, num_epochs=100, lr=1e-4):
    """Train DDPM model."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
            if len(batch) == 2:
                x_0, condition = batch
                x_0 = x_0.to(device)
                condition = condition.to(device)
            else:
                x_0 = batch.to(device)
                condition = None
            
            batch_size = x_0.shape[0]
            
            # Sample random timesteps
            t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)
            
            # Add noise
            noise = torch.randn_like(x_0)
            x_t, _ = scheduler.add_noise(x_0, t, noise)
            
            # Predict noise
            noise_pred = model(x_t, t, condition)
            
            # Compute loss
            loss = F.mse_loss(noise_pred, noise)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    return losses

In [None]:
# Train the model (start with fewer epochs for testing)
losses = train_ddpm(
    model=model,
    dataloader=dataloader,
    scheduler=scheduler,
    num_epochs=50,  # Increase to 200-500 for better results
    lr=1e-4,
)

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('DDPM Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 5. Sampling (Reverse Diffusion)

Generate new gene expression profiles by:
1. Start with pure noise $x_T \sim \mathcal{N}(0, I)$
2. Iteratively denoise: $x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t, c) \right) + \sigma_t z$

In [None]:
@torch.no_grad()
def sample_ddpm(model, scheduler, n_samples, condition=None, device='cpu'):
    """Sample from DDPM model.
    
    Args:
        model: Trained score network
        scheduler: Noise scheduler
        n_samples: Number of samples to generate
        condition: Condition labels [n_samples] (optional)
        device: Device to run on
    
    Returns:
        Generated samples [n_samples, input_dim]
    """
    model.eval()
    
    # Start from pure noise
    x = torch.randn(n_samples, model.input_dim, device=device)
    
    if condition is not None:
        condition = condition.to(device)
    
    # Reverse diffusion
    for t in tqdm(reversed(range(scheduler.num_timesteps)), desc="Sampling", total=scheduler.num_timesteps):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        
        # Predict noise
        noise_pred = model(x, t_batch, condition)
        
        # Get scheduler coefficients
        alpha_t = scheduler.alphas[t]
        alpha_bar_t = scheduler.alphas_cumprod[t]
        beta_t = scheduler.betas[t]
        
        # Compute mean
        mean = (1 / torch.sqrt(alpha_t)) * (
            x - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
        )
        
        # Add noise (except at t=0)
        if t > 0:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(scheduler.posterior_variance[t])
            x = mean + sigma_t * noise
        else:
            x = mean
    
    return x

In [None]:
# Generate samples for each cell type
n_samples_per_type = 50
generated_samples = []
generated_labels = []

for cell_type_idx in range(dataset.n_conditions):
    condition = torch.full((n_samples_per_type,), cell_type_idx, dtype=torch.long)
    samples = sample_ddpm(model, scheduler, n_samples_per_type, condition, device=device)
    generated_samples.append(samples.cpu())
    generated_labels.extend([cell_type_idx] * n_samples_per_type)

generated_samples = torch.cat(generated_samples, dim=0).numpy()
generated_labels = np.array(generated_labels)

print(f"Generated {generated_samples.shape[0]} samples")
print(f"Sample shape: {generated_samples.shape}")

## 6. Evaluation

Compare generated vs real gene expression distributions.

In [None]:
# Compare distributions
real_data = dataset.X.numpy()

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

# Overall distribution
axes[0].hist(real_data.flatten(), bins=50, alpha=0.5, label='Real', density=True)
axes[0].hist(generated_samples.flatten(), bins=50, alpha=0.5, label='Generated', density=True)
axes[0].set_title('Overall Distribution')
axes[0].legend()

# Mean expression per gene
axes[1].scatter(real_data.mean(axis=0), generated_samples.mean(axis=0), alpha=0.3)
axes[1].plot([real_data.mean(axis=0).min(), real_data.mean(axis=0).max()],
             [real_data.mean(axis=0).min(), real_data.mean(axis=0).max()],
             'r--', alpha=0.5)
axes[1].set_xlabel('Real mean expression')
axes[1].set_ylabel('Generated mean expression')
axes[1].set_title('Mean Expression per Gene')

# Std expression per gene
axes[2].scatter(real_data.std(axis=0), generated_samples.std(axis=0), alpha=0.3)
axes[2].plot([real_data.std(axis=0).min(), real_data.std(axis=0).max()],
             [real_data.std(axis=0).min(), real_data.std(axis=0).max()],
             'r--', alpha=0.5)
axes[2].set_xlabel('Real std expression')
axes[2].set_ylabel('Generated std expression')
axes[2].set_title('Std Expression per Gene')

# Sample a few genes and compare distributions
for i, gene_idx in enumerate([0, 10, 50]):
    axes[3 + i].hist(real_data[:, gene_idx], bins=30, alpha=0.5, label='Real', density=True)
    axes[3 + i].hist(generated_samples[:, gene_idx], bins=30, alpha=0.5, label='Generated', density=True)
    axes[3 + i].set_title(f'Gene {gene_idx}')
    axes[3 + i].legend()

plt.tight_layout()
plt.show()

## 7. Next Steps: Extending to Drug-Response Prediction

To implement the scPPDM approach for perturbation response:

### Architecture Changes:
1. **Input**: Concatenate baseline expression + drug embedding
2. **Output**: Predict perturbed expression (not noise)
3. **Conditioning**: Drug type + dose

### Data Requirements:
- Paired samples: (baseline, drug, dose, perturbed_expression)
- Examples: Sci-Plex, LINCS L1000, Replogle et al. Perturb-seq

### Modified Forward Process:
```python
# Instead of: x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
# Use: x_t = sqrt(alpha_bar) * x_perturbed + sqrt(1 - alpha_bar) * noise
# Condition on: [x_baseline, drug_embedding, dose]
```

### Training:
```python
# Predict perturbed expression from baseline + drug
def forward(x_baseline, drug, dose, t):
    # Encode drug
    drug_emb = drug_encoder(drug, dose)
    
    # Concatenate baseline + drug info
    condition = torch.cat([x_baseline, drug_emb], dim=-1)
    
    # Predict noise for perturbed expression
    noise_pred = score_network(x_t, t, condition)
    
    return noise_pred
```

### Sampling:
```python
# Generate counterfactual response
x_perturbed = sample_ddpm(
    model,
    scheduler,
    condition={'baseline': x_baseline, 'drug': drug_id, 'dose': dose_value}
)
```

## Summary

We've implemented:
1. ✅ Noise scheduler (linear beta schedule)
2. ✅ Forward diffusion (adding noise)
3. ✅ Time-conditional score network (MLP for tabular data)
4. ✅ Training loop (simple MSE loss)
5. ✅ Sampling (reverse diffusion)
6. ✅ Conditional generation (cell type → expression)

**Next notebook**: Implement full scPPDM for drug-response prediction with perturbation datasets.