# Diffusion Models for Gene Expression Data

This notebook extends the SDE-based diffusion framework to generate **realistic gene expression data**.

**Motivation:**
- Companies like Synthesize Bio (GEM-1), Insilico Medicine (Precious3GPT), and scGPT are building generative models for gene expression
- Applications: drug target discovery, clinical trial acceleration, in-silico perturbation experiments

**Learning objectives:**
1. Understand challenges of applying diffusion to gene expression data
2. Implement latent diffusion for high-dimensional biological data
3. Add conditional generation (cell type, tissue, disease)
4. Evaluate generated samples with biological metrics

**Prerequisites:** `02_sde_formulation.ipynb`

---

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.decomposition import PCA

sys.path.insert(0, str(Path('../../../src').resolve()))

from genailab.diffusion import VPSDE, train_score_network, sample_reverse_sde
from genailab.data import ToyBulkDataset

sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')

## 1. Load Synthetic Gene Expression Data

In [None]:
dataset = ToyBulkDataset(n=5000, n_genes=500, n_tissues=5, n_diseases=3, n_batches=4, seed=42)
print(f"Dataset: {len(dataset)} samples, {dataset.n_genes} genes")
print(f"Conditions: {list(dataset.cond.keys())}")

In [None]:
# Visualize data
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
pca = PCA(n_components=2)
x_pca = pca.fit_transform(dataset.x.numpy())
scatter = axes[0].scatter(x_pca[:, 0], x_pca[:, 1], c=dataset.cond['tissue'].numpy(), cmap='tab10', alpha=0.5, s=5)
axes[0].set_title('PCA by Tissue')
plt.colorbar(scatter, ax=axes[0])
axes[1].hist(dataset.x.numpy().flatten(), bins=50, density=True)
axes[1].set_title('Expression Distribution')
plt.tight_layout()
plt.show()

## 2. Latent Diffusion Approach

Key insight: Run diffusion in a learned latent space (like Stable Diffusion, scPPDM).

```
Genes (500) → VAE Encoder → Latent (32) → Diffusion → VAE Decoder → Genes (500)
```

In [None]:
class GeneVAE(nn.Module):
    def __init__(self, n_genes, latent_dim=32, hidden_dim=256):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(nn.Linear(n_genes, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU())
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, n_genes))
    
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        return self.decode(z), mu, logvar

In [None]:
# Train VAE
latent_dim = 32
vae = GeneVAE(n_genes=dataset.n_genes, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
data = dataset.x.to(device)

for epoch in tqdm(range(2000), desc="Training VAE"):
    idx = np.random.choice(len(data), 128)
    x = data[idx]
    recon, mu, logvar = vae(x)
    loss = F.mse_loss(recon, x) + 0.01 * (-0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 500 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

In [None]:
# Get latent representations
vae.eval()
with torch.no_grad():
    mu, _ = vae.encode(data)
    latent_data = mu.cpu().numpy()
print(f"Latent shape: {latent_data.shape} (compression: {dataset.n_genes/latent_dim:.0f}x)")

In [None]:
# Train diffusion in latent space
from genailab.diffusion import SimpleScoreNetwork

score_net = SimpleScoreNetwork(data_dim=latent_dim, hidden_dim=256, num_layers=4).to(device)
sde = VPSDE(beta_min=0.1, beta_max=20.0, T=1.0)

losses = train_score_network(score_net, latent_data, sde, num_epochs=5000, batch_size=128, lr=1e-3, device=device)

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Latent Diffusion Training')
plt.show()

In [None]:
# Generate samples
latent_samples, _ = sample_reverse_sde(score_net, sde, n_samples=1000, num_steps=500, data_dim=latent_dim, device=device)

with torch.no_grad():
    gene_samples = vae.decode(torch.FloatTensor(latent_samples).to(device)).cpu().numpy()

print(f"Generated {gene_samples.shape[0]} samples with {gene_samples.shape[1]} genes")

In [None]:
# Evaluate
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Distribution
axes[0].hist(dataset.x.numpy().flatten(), bins=50, density=True, alpha=0.5, label='Real')
axes[0].hist(gene_samples.flatten(), bins=50, density=True, alpha=0.5, label='Generated')
axes[0].legend()
axes[0].set_title('Expression Distribution')

# Gene means
real_means = dataset.x.numpy().mean(axis=0)
gen_means = gene_samples.mean(axis=0)
axes[1].scatter(real_means, gen_means, alpha=0.3, s=5)
axes[1].plot([-2, 2], [-2, 2], 'r--')
corr = np.corrcoef(real_means, gen_means)[0, 1]
axes[1].set_title(f'Gene Means (r={corr:.3f})')

# PCA
pca = PCA(n_components=2)
real_pca = pca.fit_transform(dataset.x.numpy())
gen_pca = pca.transform(gene_samples)
axes[2].scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.3, s=5, label='Real')
axes[2].scatter(gen_pca[:, 0], gen_pca[:, 1], alpha=0.3, s=5, label='Generated')
axes[2].legend()
axes[2].set_title('PCA Overlay')

plt.tight_layout()
plt.show()

## 3. Handling Count Data: The Core Challenge

The approach above uses MSE loss and Gaussian outputs, which works for **log-normalized** expression data. But real gene expression is **count data**:

- **UMI counts** (scRNA-seq): integers with many zeros
- **TPM/FPKM** (bulk RNA-seq): continuous but count-derived
- **Heavy-tailed**: few highly expressed genes, many low/zero

**Problem**: Adding Gaussian noise to counts doesn't have clear biological meaning.

**Solution**: Use count-aware decoders (Negative Binomial, Zero-Inflated NB) that output distribution parameters instead of point estimates.

In [None]:
# Import count-aware decoders and losses from genailab
from genailab.model.decoders import NegativeBinomialDecoder, ZINBDecoder
from genailab.objectives.losses import nb_loss, zinb_loss, elbo_loss_nb, elbo_loss_zinb

print("Available count-aware components:")
print("  Decoders: NegativeBinomialDecoder, ZINBDecoder")
print("  Losses: nb_loss, zinb_loss, elbo_loss_nb, elbo_loss_zinb")

### 3.1 VAE with Negative Binomial Decoder

The key insight: **run diffusion in continuous latent space, but decode to count distributions**.

```
Counts → Encoder → z (continuous) → Diffusion → z' → NB Decoder → NB(μ, θ) → Sample counts
```

The NB decoder outputs:
- **μ (mu)**: Expected count per gene
- **θ (theta)**: Dispersion parameter (inverse overdispersion)

In [None]:
class GeneVAE_NB(nn.Module):
    """VAE with Negative Binomial decoder for count data.
    
    Architecture:
        Encoder: expression → latent (mu, logvar)
        Decoder: latent → NB parameters (mu, theta)
    """
    
    def __init__(self, n_genes, latent_dim=32, hidden_dim=256):
        super().__init__()
        self.n_genes = n_genes
        self.latent_dim = latent_dim
        
        # Encoder (same as before)
        self.encoder = nn.Sequential(
            nn.Linear(n_genes, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # NB Decoder: outputs rate parameters
        self.decoder_net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )
        self.rho_head = nn.Linear(hidden_dim, n_genes)  # Rate (before library scaling)
        
        # Gene-specific dispersion (learned parameter)
        self.log_theta = nn.Parameter(torch.zeros(n_genes))
    
    def encode(self, x):
        """Encode expression to latent distribution."""
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        """Sample z from q(z|x) using reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, library_size=None):
        """Decode latent to NB parameters.
        
        Returns:
            mu: Expected counts (n_samples, n_genes)
            theta: Dispersion (n_genes,) broadcast to (n_samples, n_genes)
        """
        h = self.decoder_net(z)
        
        # Rate: softmax ensures non-negative and sums to 1
        rho = F.softmax(self.rho_head(h), dim=-1)
        
        # Scale by library size (total counts per sample)
        if library_size is not None:
            if library_size.dim() == 1:
                library_size = library_size.unsqueeze(-1)
            mu = rho * library_size
        else:
            # Default: assume library size = n_genes (normalized)
            mu = rho * self.n_genes
        
        # Dispersion: exp to ensure positive
        theta = torch.exp(self.log_theta).unsqueeze(0).expand(z.shape[0], -1)
        
        return mu, theta
    
    def forward(self, x, library_size=None):
        """Full forward pass."""
        # Encode
        enc_mu, enc_logvar = self.encode(x)
        z = self.reparameterize(enc_mu, enc_logvar)
        
        # Decode to NB parameters
        dec_mu, dec_theta = self.decode(z, library_size)
        
        return dec_mu, dec_theta, enc_mu, enc_logvar

print("GeneVAE_NB defined with Negative Binomial decoder")

In [None]:
# Generate synthetic count data for demonstration
# (Our ToyBulkDataset uses log-normalized data; let's create count-like data)

def generate_synthetic_counts(n_samples=5000, n_genes=500, seed=42):
    """Generate synthetic count data with NB-like properties."""
    np.random.seed(seed)
    
    # Base expression rates (log-normal distributed)
    log_base_rates = np.random.normal(2, 2, n_genes)
    base_rates = np.exp(log_base_rates)
    
    # Sample-specific library sizes (total counts)
    library_sizes = np.random.lognormal(10, 0.5, n_samples)
    
    # Gene-specific dispersion (smaller = more overdispersion)
    dispersions = np.random.uniform(0.1, 10, n_genes)
    
    # Generate counts using Negative Binomial
    counts = np.zeros((n_samples, n_genes))
    for i in range(n_samples):
        # Rates for this sample
        rates = base_rates * (library_sizes[i] / base_rates.sum())
        for j in range(n_genes):
            # NB parameterization: mean=mu, var=mu + mu^2/theta
            mu = rates[j]
            theta = dispersions[j]
            # Convert to scipy's NB parameterization
            p = theta / (theta + mu)
            counts[i, j] = np.random.negative_binomial(theta, p) if p < 1 else 0
    
    return counts.astype(np.float32), library_sizes.astype(np.float32)

# Generate count data
count_data, library_sizes = generate_synthetic_counts(n_samples=5000, n_genes=500)
print(f"Count data shape: {count_data.shape}")
print(f"Count range: [{count_data.min():.0f}, {count_data.max():.0f}]")
print(f"Sparsity (zeros): {(count_data == 0).mean()*100:.1f}%")
print(f"Library sizes: mean={library_sizes.mean():.0f}, std={library_sizes.std():.0f}")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(count_data.flatten(), bins=50, density=True, log=True)
axes[0].set_xlabel('Count')
axes[0].set_ylabel('Density (log)')
axes[0].set_title('Count Distribution (heavy-tailed)')

axes[1].hist(np.log1p(count_data).flatten(), bins=50, density=True)
axes[1].set_xlabel('log1p(Count)')
axes[1].set_ylabel('Density')
axes[1].set_title('Log-transformed Distribution')

axes[2].hist(library_sizes, bins=30, density=True)
axes[2].set_xlabel('Library Size')
axes[2].set_ylabel('Density')
axes[2].set_title('Library Size Distribution')

plt.tight_layout()
plt.show()

In [None]:
# Train VAE with NB decoder on count data
# Note: We log-transform input for encoder stability, but use NB loss on original counts

# Prepare data
count_tensor = torch.FloatTensor(count_data).to(device)
library_tensor = torch.FloatTensor(library_sizes).to(device)

# Log-transform for encoder input (standard practice in scVI, scGen)
log_counts = torch.log1p(count_tensor)

# Initialize model
vae_nb = GeneVAE_NB(n_genes=500, latent_dim=32, hidden_dim=256).to(device)
optimizer_nb = torch.optim.Adam(vae_nb.parameters(), lr=1e-3)

# Training loop
losses_nb = []
for epoch in tqdm(range(3000), desc="Training VAE-NB"):
    idx = np.random.choice(len(count_tensor), 128)
    x_counts = count_tensor[idx]
    x_log = log_counts[idx]
    lib_size = library_tensor[idx]
    
    # Forward pass (encode log-transformed, decode to NB params)
    dec_mu, dec_theta, enc_mu, enc_logvar = vae_nb(x_log, lib_size)
    
    # Loss: NB reconstruction + KL divergence
    loss, loss_dict = elbo_loss_nb(
        x=x_counts,           # Original counts for NB loss
        mu=dec_mu,            # Predicted mean
        theta=dec_theta,      # Predicted dispersion
        enc_mu=enc_mu,        # Encoder mean
        enc_logvar=enc_logvar,# Encoder logvar
        beta=0.01             # KL weight (low to avoid posterior collapse)
    )
    
    optimizer_nb.zero_grad()
    loss.backward()
    optimizer_nb.step()
    
    losses_nb.append(loss.item())
    
    if (epoch + 1) % 1000 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, "
              f"Recon: {loss_dict['recon'].item():.4f}, KL: {loss_dict['kl'].item():.4f}")

# Plot training curve
plt.figure(figsize=(10, 4))
plt.plot(losses_nb)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('VAE-NB Training (NB Reconstruction Loss)')
plt.grid(True, alpha=0.3)
plt.show()

### 3.2 Latent Diffusion with NB Decoder

Now we combine the trained VAE-NB with diffusion in latent space:

1. **Extract latent representations** from VAE-NB encoder
2. **Train diffusion** in the continuous latent space
3. **Sample**: noise → diffusion → latent → NB decoder → sample from NB distribution

In [None]:
# Step 1: Extract latent representations from VAE-NB
vae_nb.eval()
with torch.no_grad():
    enc_mu, _ = vae_nb.encode(log_counts)
    latent_data_nb = enc_mu.cpu().numpy()

print(f"Latent shape: {latent_data_nb.shape}")
print(f"Latent range: [{latent_data_nb.min():.2f}, {latent_data_nb.max():.2f}]")

# Step 2: Train diffusion in latent space
score_net_nb = SimpleScoreNetwork(data_dim=32, hidden_dim=256, num_layers=4).to(device)
sde_nb = VPSDE(beta_min=0.1, beta_max=20.0, T=1.0)

losses_diff_nb = train_score_network(
    score_net_nb, latent_data_nb, sde_nb, 
    num_epochs=5000, batch_size=128, lr=1e-3, device=device
)

plt.figure(figsize=(10, 4))
plt.plot(losses_diff_nb)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Latent Diffusion Training (for NB-VAE)')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Step 3: Generate samples using latent diffusion + NB decoder

# Sample latent vectors from diffusion
latent_samples_nb, _ = sample_reverse_sde(
    score_net_nb, sde_nb, 
    n_samples=500, num_steps=500, data_dim=32, device=device
)

# Decode to NB parameters
vae_nb.eval()
with torch.no_grad():
    z_tensor = torch.FloatTensor(latent_samples_nb).to(device)
    # Use median library size for generation
    gen_lib_size = torch.full((500,), library_sizes.mean(), device=device)
    gen_mu, gen_theta = vae_nb.decode(z_tensor, gen_lib_size)

# Sample from NB distribution to get actual counts
def sample_from_nb(mu, theta):
    """Sample counts from Negative Binomial distribution."""
    mu_np = mu.cpu().numpy()
    theta_np = theta.cpu().numpy()
    
    # NB parameterization: p = theta / (theta + mu)
    p = theta_np / (theta_np + mu_np + 1e-8)
    p = np.clip(p, 1e-8, 1 - 1e-8)
    
    # Sample
    samples = np.zeros_like(mu_np)
    for i in range(mu_np.shape[0]):
        for j in range(mu_np.shape[1]):
            if p[i, j] < 1:
                samples[i, j] = np.random.negative_binomial(theta_np[i, j], p[i, j])
    
    return samples

# Generate count samples
gen_counts = sample_from_nb(gen_mu, gen_theta)
print(f"Generated {gen_counts.shape[0]} count samples")
print(f"Count range: [{gen_counts.min():.0f}, {gen_counts.max():.0f}]")
print(f"Sparsity (zeros): {(gen_counts == 0).mean()*100:.1f}%")

In [None]:
# Evaluate: Compare real vs generated count distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Row 1: Count distributions
axes[0, 0].hist(count_data.flatten(), bins=50, density=True, alpha=0.5, label='Real', log=True)
axes[0, 0].hist(gen_counts.flatten(), bins=50, density=True, alpha=0.5, label='Generated', log=True)
axes[0, 0].set_xlabel('Count')
axes[0, 0].set_ylabel('Density (log)')
axes[0, 0].set_title('Count Distribution')
axes[0, 0].legend()

# Log-transformed comparison
axes[0, 1].hist(np.log1p(count_data).flatten(), bins=50, density=True, alpha=0.5, label='Real')
axes[0, 1].hist(np.log1p(gen_counts).flatten(), bins=50, density=True, alpha=0.5, label='Generated')
axes[0, 1].set_xlabel('log1p(Count)')
axes[0, 1].set_ylabel('Density')
axes[0, 1].set_title('Log-transformed Distribution')
axes[0, 1].legend()

# Gene means correlation
real_means = count_data.mean(axis=0)
gen_means = gen_counts.mean(axis=0)
axes[0, 2].scatter(real_means, gen_means, alpha=0.3, s=5)
max_val = max(real_means.max(), gen_means.max())
axes[0, 2].plot([0, max_val], [0, max_val], 'r--')
corr = np.corrcoef(real_means, gen_means)[0, 1]
axes[0, 2].set_xlabel('Real Gene Mean')
axes[0, 2].set_ylabel('Generated Gene Mean')
axes[0, 2].set_title(f'Gene Means (r={corr:.3f})')

# Row 2: More detailed comparisons
# Gene variances
real_vars = count_data.var(axis=0)
gen_vars = gen_counts.var(axis=0)
axes[1, 0].scatter(np.log1p(real_vars), np.log1p(gen_vars), alpha=0.3, s=5)
max_var = max(np.log1p(real_vars).max(), np.log1p(gen_vars).max())
axes[1, 0].plot([0, max_var], [0, max_var], 'r--')
corr_var = np.corrcoef(real_vars, gen_vars)[0, 1]
axes[1, 0].set_xlabel('Real Gene Variance (log)')
axes[1, 0].set_ylabel('Generated Gene Variance (log)')
axes[1, 0].set_title(f'Gene Variances (r={corr_var:.3f})')

# Sparsity per gene
real_sparsity = (count_data == 0).mean(axis=0)
gen_sparsity = (gen_counts == 0).mean(axis=0)
axes[1, 1].scatter(real_sparsity, gen_sparsity, alpha=0.3, s=5)
axes[1, 1].plot([0, 1], [0, 1], 'r--')
corr_sparse = np.corrcoef(real_sparsity, gen_sparsity)[0, 1]
axes[1, 1].set_xlabel('Real Sparsity')
axes[1, 1].set_ylabel('Generated Sparsity')
axes[1, 1].set_title(f'Gene Sparsity (r={corr_sparse:.3f})')

# PCA comparison
pca = PCA(n_components=2)
real_pca = pca.fit_transform(np.log1p(count_data[:500]))  # Subsample for speed
gen_pca = pca.transform(np.log1p(gen_counts))
axes[1, 2].scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.3, s=5, label='Real')
axes[1, 2].scatter(gen_pca[:, 0], gen_pca[:, 1], alpha=0.3, s=5, label='Generated')
axes[1, 2].set_xlabel('PC1')
axes[1, 2].set_ylabel('PC2')
axes[1, 2].set_title('PCA Overlay (log-transformed)')
axes[1, 2].legend()

plt.tight_layout()
plt.suptitle('Latent Diffusion + NB Decoder: Real vs Generated Counts', y=1.02, fontsize=14)
plt.show()

print("\nSummary Statistics:")
print(f"  Real counts - Mean: {count_data.mean():.1f}, Sparsity: {(count_data==0).mean()*100:.1f}%")
print(f"  Generated   - Mean: {gen_counts.mean():.1f}, Sparsity: {(gen_counts==0).mean()*100:.1f}%")

## 4. Summary & Key Takeaways

### What We Learned

**The Count Data Challenge:**
- Gene expression is fundamentally count data (UMI counts, TPM)
- Standard diffusion assumes continuous data with Gaussian noise
- Adding noise to counts doesn't have clear biological meaning

**Solutions Implemented:**

| Approach | How It Works | Pros | Cons |
|----------|--------------|------|------|
| **Latent Diffusion** | Diffusion in VAE latent space | Well-defined, flexible | Requires VAE training |
| **NB Decoder** | Output NB(μ, θ) parameters | Proper count model | More complex |
| **ZINB Decoder** | NB + dropout probability π | Handles sparsity | Even more complex |

**The Recommended Pipeline:**
```
Counts → log1p → Encoder → z (continuous) → Diffusion → z' → NB Decoder → NB(μ,θ) → Sample
```

### Implementation in genai-lab

- `src/genailab/model/decoders.py`: `NegativeBinomialDecoder`, `ZINBDecoder`
- `src/genailab/objectives/losses.py`: `nb_loss()`, `zinb_loss()`, `elbo_loss_nb()`, `elbo_loss_zinb()`

### Next Steps

1. **Add conditioning** - Condition on tissue, disease, perturbation
2. **Apply to real data** - PBMC3k, GTEx, scPerturb
3. **Implement ZINB** - For sparse scRNA-seq with dropout
4. **Benchmark** - Compare with scVI, scGen on standard tasks
5. **Connect to scPPDM** - Full perturbation prediction pipeline

### References

- Lopez et al. (2018) - "Deep generative modeling for single-cell transcriptomics" (scVI)
- Lotfollahi et al. (2020) - "scGen predicts single-cell perturbation responses"
- See also: `docs/incubation/generative-ai-for-gene-expression-prediction.md`