# Diffusion Models for Gene Expression Data

This notebook extends the SDE-based diffusion framework to generate **realistic gene expression data**.

**Motivation:**
- Companies like Synthesize Bio (GEM-1), Insilico Medicine (Precious3GPT), and scGPT are building generative models for gene expression
- Applications: drug target discovery, clinical trial acceleration, in-silico perturbation experiments

**Learning objectives:**
1. Understand challenges of applying diffusion to gene expression data
2. Implement latent diffusion for high-dimensional biological data
3. Add conditional generation (cell type, tissue, disease)
4. Evaluate generated samples with biological metrics

**Prerequisites:** `02_sde_formulation.ipynb`

---

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.decomposition import PCA

sys.path.insert(0, str(Path('../../../src').resolve()))

from genailab.diffusion import VPSDE, train_score_network, sample_reverse_sde
from genailab.data import ToyBulkDataset

sns.set_style('whitegrid')
np.random.seed(42)
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')

## 1. Load Synthetic Gene Expression Data

In [None]:
dataset = ToyBulkDataset(n=5000, n_genes=500, n_tissues=5, n_diseases=3, n_batches=4, seed=42)
print(f"Dataset: {len(dataset)} samples, {dataset.n_genes} genes")
print(f"Conditions: {list(dataset.cond.keys())}")

In [None]:
# Visualize data
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
pca = PCA(n_components=2)
x_pca = pca.fit_transform(dataset.x.numpy())
scatter = axes[0].scatter(x_pca[:, 0], x_pca[:, 1], c=dataset.cond['tissue'].numpy(), cmap='tab10', alpha=0.5, s=5)
axes[0].set_title('PCA by Tissue')
plt.colorbar(scatter, ax=axes[0])
axes[1].hist(dataset.x.numpy().flatten(), bins=50, density=True)
axes[1].set_title('Expression Distribution')
plt.tight_layout()
plt.show()

## 2. Latent Diffusion Approach

Key insight: Run diffusion in a learned latent space (like Stable Diffusion, scPPDM).

```
Genes (500) → VAE Encoder → Latent (32) → Diffusion → VAE Decoder → Genes (500)
```

In [None]:
class GeneVAE(nn.Module):
    def __init__(self, n_genes, latent_dim=32, hidden_dim=256):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(nn.Linear(n_genes, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU())
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(),
                                     nn.Linear(hidden_dim, n_genes))
    
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        return self.decode(z), mu, logvar

In [None]:
# Train VAE
latent_dim = 32
vae = GeneVAE(n_genes=dataset.n_genes, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
data = dataset.x.to(device)

for epoch in tqdm(range(2000), desc="Training VAE"):
    idx = np.random.choice(len(data), 128)
    x = data[idx]
    recon, mu, logvar = vae(x)
    loss = F.mse_loss(recon, x) + 0.01 * (-0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 500 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

In [None]:
# Get latent representations
vae.eval()
with torch.no_grad():
    mu, _ = vae.encode(data)
    latent_data = mu.cpu().numpy()
print(f"Latent shape: {latent_data.shape} (compression: {dataset.n_genes/latent_dim:.0f}x)")

In [None]:
# Train diffusion in latent space
from genailab.diffusion import SimpleScoreNetwork

score_net = SimpleScoreNetwork(data_dim=latent_dim, hidden_dim=256, num_layers=4).to(device)
sde = VPSDE(beta_min=0.1, beta_max=20.0, T=1.0)

losses = train_score_network(score_net, latent_data, sde, num_epochs=5000, batch_size=128, lr=1e-3, device=device)

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Latent Diffusion Training')
plt.show()

In [None]:
# Generate samples
latent_samples, _ = sample_reverse_sde(score_net, sde, n_samples=1000, num_steps=500, data_dim=latent_dim, device=device)

with torch.no_grad():
    gene_samples = vae.decode(torch.FloatTensor(latent_samples).to(device)).cpu().numpy()

print(f"Generated {gene_samples.shape[0]} samples with {gene_samples.shape[1]} genes")

In [None]:
# Evaluate
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Distribution
axes[0].hist(dataset.x.numpy().flatten(), bins=50, density=True, alpha=0.5, label='Real')
axes[0].hist(gene_samples.flatten(), bins=50, density=True, alpha=0.5, label='Generated')
axes[0].legend()
axes[0].set_title('Expression Distribution')

# Gene means
real_means = dataset.x.numpy().mean(axis=0)
gen_means = gene_samples.mean(axis=0)
axes[1].scatter(real_means, gen_means, alpha=0.3, s=5)
axes[1].plot([-2, 2], [-2, 2], 'r--')
corr = np.corrcoef(real_means, gen_means)[0, 1]
axes[1].set_title(f'Gene Means (r={corr:.3f})')

# PCA
pca = PCA(n_components=2)
real_pca = pca.fit_transform(dataset.x.numpy())
gen_pca = pca.transform(gene_samples)
axes[2].scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.3, s=5, label='Real')
axes[2].scatter(gen_pca[:, 0], gen_pca[:, 1], alpha=0.3, s=5, label='Generated')
axes[2].legend()
axes[2].set_title('PCA Overlay')

plt.tight_layout()
plt.show()

## 3. Summary & Next Steps

**What we learned:**
- Latent diffusion is more efficient than direct diffusion for high-dimensional data
- VAE provides a compressed, structured representation
- Generated samples capture gene-gene correlations

**Next steps:**
1. Add conditional generation (tissue, disease)
2. Apply to real single-cell data (PBMC3k)
3. Evaluate with biological metrics (pathway enrichment, DE analysis)
4. Connect to scPPDM for perturbation prediction