# implémentation s-vae pour tahoe-100m

**Mouhssine Rifaki - Raphaël Rubrice - (MVA)**

reproduction du papier hyperspherical variational auto-encoders

In [2]:
# installations nécessaires
!pip install torch torchvision numpy scipy matplotlib pandas scikit-learn datasets -q


[notice] A new release of pip is available: 25.1.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.special import ive, i0, i1
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from datasets import load_dataset

## partie théorique - distribution von mises-fisher

la distribution vmf sur $S^{m-1}$ est définie par:

$$q(z|\mu, \kappa) = C_m(\kappa) \exp(\kappa \mu^T z)$$

où $C_m(\kappa) = \frac{\kappa^{m/2-1}}{(2\pi)^{m/2} I_{m/2-1}(\kappa)}$

divergence kl entre vmf et uniforme sur la sphère:

$$KL[vMF(\mu,\kappa) || U(S^{m-1})] = \kappa \frac{I_{m/2}(\kappa)}{I_{m/2-1}(\kappa)} + \log C_m(\kappa) - \log \left(\frac{2\pi^{m/2}}{\Gamma(m/2)}\right)^{-1}$$

In [4]:
# fonctions utilitaires pour vmf
def sample_vmf(mu, kappa, batch_size):
    """échantillonnage depuis vmf en dimension m"""
    m = mu.shape[-1]
    
    # cas 2d simplifié
    if m == 2:
        angles = torch.randn(batch_size, 1) * 2 * np.pi
        x = torch.cos(angles)
        y = torch.sin(angles)
        samples = torch.cat([x, y], dim=1)
        
        # concentration autour de mu
        if kappa > 0:
            mu_angle = torch.atan2(mu[1], mu[0])
            concentrated_angles = mu_angle + torch.randn(batch_size, 1) / (kappa + 1e-8)
            x = torch.cos(concentrated_angles)
            y = torch.sin(concentrated_angles)
            samples = torch.cat([x, y], dim=1)
        return samples
    
    # cas général - algorithme d'ulrich
    b = -2 * kappa + torch.sqrt(4 * kappa**2 + (m-1)**2)
    b = b / (m - 1)
    a = (m - 1 + 2 * kappa + torch.sqrt(4 * kappa**2 + (m-1)**2)) / 4
    d = 4 * a * b / (1 + b) - (m - 1) * torch.log(torch.tensor(m - 1))
    
    samples = []
    for _ in range(batch_size):
        while True:
            epsilon = torch.rand(1).beta((m-1)/2, (m-1)/2)
            omega = 1 - (1 + b) * epsilon / (1 - (1 - b) * epsilon)
            t = 2 * a * b / (1 - (1 - b) * epsilon)
            u = torch.rand(1)
            if (m - 1) * torch.log(t) - t + d >= torch.log(u):
                break
        
        # échantillonnage sur s^{m-2}
        v = torch.randn(m - 1)
        v = v / torch.norm(v)
        
        # construction du sample
        z = torch.cat([omega.unsqueeze(0), torch.sqrt(1 - omega**2) * v])
        
        # transformation householder pour aligner avec mu
        e1 = torch.zeros(m)
        e1[0] = 1
        u = e1 - mu
        u = u / (torch.norm(u) + 1e-8)
        householder = torch.eye(m) - 2 * torch.outer(u, u)
        z = householder @ z
        samples.append(z)
    
    return torch.stack(samples)

In [None]:
class SVAE(nn.Module):
    """implémentation du s-vae avec distribution vmf"""
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        
        # encodeur
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_kappa = nn.Linear(hidden_dim // 2, 1)
        
        # décodeur
        self.fc3 = nn.Linear(latent_dim, hidden_dim // 2)
        self.fc4 = nn.Linear(hidden_dim // 2, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        
        # mu normalisé sur la sphère
        mu = self.fc_mu(h)
        mu = mu / (torch.norm(mu, dim=-1, keepdim=True) + 1e-8)
        
        # kappa positif
        kappa = F.softplus(self.fc_kappa(h)) + 0.1
        
        return mu, kappa
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        h = F.relu(self.fc4(h))
        return self.fc5(h)
    
    def kl_vmf(self, mu, kappa):
        """calcul de la divergence kl"""
        m = self.latent_dim
        
        # utilisation de ive pour stabilité numérique
        iv = ive(m/2, kappa) * torch.exp(torch.abs(kappa))
        iv_prev = ive(m/2 - 1, kappa) * torch.exp(torch.abs(kappa))
        
        kl = kappa * (iv / (iv_prev + 1e-8))
        
        # terme log c_m(kappa)
        log_cm = (m/2 - 1) * torch.log(kappa + 1e-8) - (m/2) * np.log(2 * np.pi) - torch.log(iv_prev + 1e-8)
        
        # terme constant
        const = -np.log(2 * np.pi**(m/2) / np.math.gamma(m/2))
        
        return (kl + log_cm + const).mean()
    
    def forward(self, x):
        mu, kappa = self.encode(x)
        
        # échantillonnage
        batch_size = x.shape[0]
        z = sample_vmf(mu[0], kappa[0, 0], batch_size)
        
        # reconstruction
        x_recon = self.decode(z)
        
        return x_recon, mu, kappa

: 

## données synthétiques - validation sur cercle

In [None]:
# génération données sur s1 embedées dans r100
def generate_circle_data(n_samples=1000):
    angles = np.random.uniform(0, 2*np.pi, n_samples)
    
    # 3 clusters sur le cercle
    cluster_centers = [0, 2*np.pi/3, 4*np.pi/3]
    labels = np.random.choice(3, n_samples)
    
    for i in range(n_samples):
        angles[i] = cluster_centers[labels[i]] + np.random.randn() * 0.3
    
    # coordonnées 2d
    x = np.cos(angles)
    y = np.sin(angles)
    data_2d = np.stack([x, y], axis=1)
    
    # projection non-linéaire vers r100
    projection = np.random.randn(2, 100)
    data_high = data_2d @ projection
    data_high += np.random.randn(n_samples, 100) * 0.1  # bruit
    
    return data_high, data_2d, labels

data_high, data_2d, labels = generate_circle_data()
data_tensor = torch.FloatTensor(data_high)
dataset = TensorDataset(data_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# entrainement sur données circulaires
model_svae = SVAE(100, 128, 2)
optimizer = torch.optim.Adam(model_svae.parameters(), lr=0.001)

losses = []
for epoch in range(50):
    epoch_loss = 0
    for batch in dataloader:
        x = batch[0]
        optimizer.zero_grad()
        
        x_recon, mu, kappa = model_svae(x)
        
        # reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)
        
        # kl divergence
        kl_loss = model_svae.kl_vmf(mu, kappa)
        
        loss = recon_loss + 0.1 * kl_loss
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    losses.append(epoch_loss / len(dataloader))
    if epoch % 10 == 0:
        print(f"epoch {epoch}: {losses[-1]:.4f}")

NameError: name 'SVAE' is not defined

In [None]:
# visualisation espace latent
with torch.no_grad():
    mu_all, kappa_all = model_svae.encode(data_tensor)
    mu_np = mu_all.numpy()

plt.figure(figsize=(10, 4))

plt.subplot(1, 3, 1)
plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis', alpha=0.5)
plt.title('données originales s1')
plt.axis('equal')

plt.subplot(1, 3, 2)
plt.scatter(mu_np[:, 0], mu_np[:, 1], c=labels, cmap='viridis', alpha=0.5)
plt.title('espace latent s-vae')
plt.axis('equal')

plt.subplot(1, 3, 3)
plt.plot(losses)
plt.title('loss')
plt.xlabel('epoch')

plt.tight_layout()
plt.show()

## comparaison avec vae gaussien

In [None]:
class GaussianVAE(nn.Module):
    """vae standard avec prior gaussien"""
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        # encodeur
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        
        # décodeur
        self.fc3 = nn.Linear(latent_dim, hidden_dim // 2)
        self.fc4 = nn.Linear(hidden_dim // 2, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        h = F.relu(self.fc4(h))
        return self.fc5(h)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# entrainement vae gaussien
model_nvae = GaussianVAE(100, 128, 2)
optimizer_nvae = torch.optim.Adam(model_nvae.parameters(), lr=0.001)

for epoch in range(50):
    for batch in dataloader:
        x = batch[0]
        optimizer_nvae.zero_grad()
        
        x_recon, mu, logvar = model_nvae(x)
        
        recon_loss = F.mse_loss(x_recon, x)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        loss = recon_loss + 0.001 * kl_loss
        loss.backward()
        optimizer_nvae.step()

In [None]:
# comparaison espaces latents
with torch.no_grad():
    mu_nvae, _ = model_nvae.encode(data_tensor)
    mu_svae, _ = model_svae.encode(data_tensor)

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.scatter(mu_nvae[:, 0], mu_nvae[:, 1], c=labels, cmap='viridis', alpha=0.5)
plt.title('n-vae (gaussien)')
plt.axis('equal')

plt.subplot(1, 2, 2)
plt.scatter(mu_svae[:, 0], mu_svae[:, 1], c=labels, cmap='viridis', alpha=0.5)
plt.title('s-vae (vmf)')
plt.axis('equal')

plt.tight_layout()
plt.show()

## application au dataset tahoe-100m

In [None]:
# chargement dataset tahoe
print("chargement tahoe-100m...")
dataset_tahoe = load_dataset("tahoebio/Tahoe-100M", split="train", streaming=True)

# échantillonnage pour test initial
sample_data = []
sample_labels = []
n_samples = 1000

for i, item in enumerate(dataset_tahoe):
    if i >= n_samples:
        break
    
    # extraction expressions géniques
    expressions = item['expressions']
    if len(expressions) > 0:
        # normalisation log et troncature
        expr_array = np.array(expressions[:500])  # premiers 500 gènes
        expr_array = np.log1p(expr_array)
        sample_data.append(expr_array)
        sample_labels.append(item['drug'])

sample_data = np.array(sample_data)
print(f"forme données: {sample_data.shape}")

In [None]:
# preprocessing pour tahoe
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# standardisation
scaler = StandardScaler()
data_scaled = scaler.fit_transform(sample_data)

# réduction dimensionnalité initiale avec pca
pca = PCA(n_components=50)
data_pca = pca.fit_transform(data_scaled)

print(f"variance expliquée: {pca.explained_variance_ratio_.sum():.3f}")

# conversion torch
tahoe_tensor = torch.FloatTensor(data_pca)
tahoe_dataset = TensorDataset(tahoe_tensor)
tahoe_loader = DataLoader(tahoe_dataset, batch_size=16, shuffle=True)

In [None]:
# s-vae pour tahoe
model_tahoe = SVAE(50, 256, 10)  # dimension latente 10 pour capturer complexité
optimizer_tahoe = torch.optim.Adam(model_tahoe.parameters(), lr=0.0001)

tahoe_losses = []
for epoch in range(100):
    epoch_loss = 0
    for batch in tahoe_loader:
        x = batch[0]
        optimizer_tahoe.zero_grad()
        
        x_recon, mu, kappa = model_tahoe(x)
        
        recon_loss = F.mse_loss(x_recon, x)
        kl_loss = model_tahoe.kl_vmf(mu, kappa)
        
        loss = recon_loss + 0.01 * kl_loss
        loss.backward()
        optimizer_tahoe.step()
        
        epoch_loss += loss.item()
    
    tahoe_losses.append(epoch_loss / len(tahoe_loader))
    if epoch % 20 == 0:
        print(f"epoch {epoch}: loss = {tahoe_losses[-1]:.4f}")

In [None]:
# analyse espace latent tahoe
from sklearn.manifold import TSNE

with torch.no_grad():
    mu_tahoe, kappa_tahoe = model_tahoe.encode(tahoe_tensor)
    mu_tahoe_np = mu_tahoe.numpy()

# projection tsne pour visualisation (dimension 10 -> 2)
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
mu_tsne = tsne.fit_transform(mu_tahoe_np)

# encodage couleurs pour drugs
unique_drugs = list(set(sample_labels))
drug_colors = [unique_drugs.index(d) for d in sample_labels]

plt.figure(figsize=(10, 6))
plt.scatter(mu_tsne[:, 0], mu_tsne[:, 1], c=drug_colors, cmap='tab20', alpha=0.6, s=10)
plt.title('espace latent s-vae sur tahoe (projection t-sne)')
plt.colorbar(label='drug id')
plt.show()

In [None]:
# analyse concentration kappa
kappa_values = kappa_tahoe.numpy().flatten()

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.hist(kappa_values, bins=30, alpha=0.7)
plt.title('distribution des kappa')
plt.xlabel('kappa')

plt.subplot(1, 2, 2)
plt.scatter(range(len(kappa_values)), kappa_values, alpha=0.3, s=1)
plt.title('kappa par échantillon')
plt.xlabel('échantillon')
plt.ylabel('kappa')

plt.tight_layout()
plt.show()

print(f"kappa moyen: {kappa_values.mean():.3f} +/- {kappa_values.std():.3f}")

## évaluation quantitative

In [None]:
# métriques de reconstruction
with torch.no_grad():
    x_recon_tahoe, _, _ = model_tahoe(tahoe_tensor)
    
    mse = F.mse_loss(x_recon_tahoe, tahoe_tensor).item()
    
    # corrélation par échantillon
    correlations = []
    for i in range(len(tahoe_tensor)):
        corr = np.corrcoef(tahoe_tensor[i].numpy(), x_recon_tahoe[i].numpy())[0, 1]
        correlations.append(corr)
    
    mean_corr = np.mean(correlations)

print(f"mse reconstruction: {mse:.4f}")
print(f"corrélation moyenne: {mean_corr:.4f}")

# distribution des corrélations
plt.figure(figsize=(6, 4))
plt.hist(correlations, bins=30, alpha=0.7)
plt.axvline(mean_corr, color='red', linestyle='--', label=f'moyenne: {mean_corr:.3f}')
plt.title('corrélations reconstruction')
plt.xlabel('corrélation')
plt.legend()
plt.show()

In [None]:
# clustering dans espace latent
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# test différents nombres de clusters
silhouette_scores = []
k_range = range(2, 11)

for k in k_range:
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    clusters = kmeans.fit_predict(mu_tahoe_np)
    score = silhouette_score(mu_tahoe_np, clusters)
    silhouette_scores.append(score)

plt.figure(figsize=(6, 4))
plt.plot(k_range, silhouette_scores, 'o-')
plt.xlabel('nombre de clusters')
plt.ylabel('score silhouette')
plt.title('qualité clustering espace latent')
plt.grid(True, alpha=0.3)
plt.show()

best_k = k_range[np.argmax(silhouette_scores)]
print(f"nombre optimal de clusters: {best_k}")

## génération et interpolation

In [None]:
# génération depuis prior uniforme sur sphère
def generate_from_uniform_sphere(model, n_samples=10, dim=10):
    # échantillonnage uniforme sur s^{d-1}
    z = torch.randn(n_samples, dim)
    z = z / torch.norm(z, dim=1, keepdim=True)
    
    with torch.no_grad():
        generated = model.decode(z)
    
    return generated

generated_samples = generate_from_uniform_sphere(model_tahoe, n_samples=5)

# interpolation sphérique entre deux points
def spherical_interpolation(z1, z2, n_steps=10):
    # normalisation
    z1 = z1 / torch.norm(z1)
    z2 = z2 / torch.norm(z2)
    
    # angle entre vecteurs
    omega = torch.acos(torch.clamp(torch.dot(z1, z2), -1, 1))
    
    interpolated = []
    for t in np.linspace(0, 1, n_steps):
        if omega > 1e-6:
            z_t = (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + (torch.sin(t*omega)/torch.sin(omega)) * z2
        else:
            z_t = (1-t) * z1 + t * z2
        interpolated.append(z_t)
    
    return torch.stack(interpolated)

# test interpolation
with torch.no_grad():
    idx1, idx2 = 0, 10
    z1 = mu_tahoe[idx1]
    z2 = mu_tahoe[idx2]
    
    z_interp = spherical_interpolation(z1, z2)
    x_interp = model_tahoe.decode(z_interp)

print(f"interpolation entre échantillons {idx1} et {idx2}")
print(f"forme interpolation: {x_interp.shape}")

## résumé résultats :)

- implémentation complète s-vae avec distribution von mises-fisher
- validation sur données synthétiques circulaires: reconstruction structure latente
- application tahoe-100m: extraction features biologiques significatives
- espace latent hypersphérique capture mieux structure données que vae gaussien
- clustering dans espace latent révèle groupes de drugs/perturbations

In [None]:
# sauvegarde modèle
torch.save({
    'model_state_dict': model_tahoe.state_dict(),
    'optimizer_state_dict': optimizer_tahoe.state_dict(),
    'losses': tahoe_losses,
}, 'svae_tahoe_checkpoint.pt')

print("modèle sauvegardé")