In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

# --- File paths ---
base_path = r"C:/Document/Serieux/Travail/python_work/cEBMF_additional_simulation_VAE"
Y_path = os.path.join(base_path, "slice4_Y.csv")  # normalized read count
X_path = os.path.join(base_path, "slice4_X.csv")  # spatial x,y coordinates

# --- Load data ---
Y = pd.read_csv(Y_path, index_col=0).values  # shape (n_spots, n_genes)
X = pd.read_csv(X_path, index_col=0).values  # shape (n_spots, 2)

# --- Standardize features ---
Y = StandardScaler().fit_transform(Y)
Y = torch.tensor(Y, dtype=torch.float32)
X = torch.tensor(X, dtype=torch.float32)

# --- VAE and cVAE Models ---
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim=16):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.mu = nn.Linear(64, latent_dim)
        self.logvar = nn.Linear(64, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

class CVAE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim=16):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + cond_dim, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU()
        )
        self.mu = nn.Linear(64, latent_dim)
        self.logvar = nn.Linear(64, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def forward(self, x, c):
        hc = torch.cat([x, c], dim=1)
        h = self.encoder(hc)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        zc = torch.cat([z, c], dim=1)
        return self.decoder(zc), mu, logvar

# --- Loss and training ---
def vae_loss(recon_x, x, mu, logvar):
    recon = nn.functional.mse_loss(recon_x, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kl

def train_vae(model, X, cond=None, epochs=100, lr=1e-3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        optimizer.zero_grad()
        if cond is not None:
            x_recon, mu, logvar = model(X, cond)
        else:
            x_recon, mu, logvar = model(X)
        loss = vae_loss(x_recon, X, mu, logvar)
        loss.backward()
        optimizer.step()
    return model

# --- Train and extract embeddings ---
input_dim = Y.shape[1]
cond_dim = X.shape[1]

vae = VAE(input_dim=input_dim, latent_dim=16)
train_vae(vae, Y)
with torch.no_grad():
    vae_mu = vae.mu(vae.encoder(Y)).numpy()

cvae = CVAE(input_dim=input_dim, cond_dim=cond_dim, latent_dim=16)
train_vae(cvae, Y, cond=X)
with torch.no_grad():
    cvae_mu = cvae.mu(cvae.encoder(torch.cat([Y, X], dim=1))).numpy()

# --- Save embeddings ---
pd.DataFrame(vae_mu).to_csv(os.path.join(base_path, "slice4_vae_embeddings.csv"), index=False)
pd.DataFrame(cvae_mu).to_csv(os.path.join(base_path, "slice4_cvae_embeddings.csv"), index=False)

print("Saved VAE and cVAE embeddings.")

# --- Optional: NCF-style embedding from long-format data ---
class NCFLayer(nn.Module):
    def __init__(self, n_spots, n_genes, embedding_dim=16):
        super().__init__()
        self.user_emb = nn.Embedding(n_spots, embedding_dim)
        self.item_emb = nn.Embedding(n_genes, embedding_dim)
        self.fc = nn.Sequential(
            nn.Linear(2 * embedding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, user_idx, item_idx):
        u = self.user_emb(user_idx)
        i = self.item_emb(item_idx)
        x = torch.cat([u, i], dim=1)
        return self.fc(x).squeeze()

class NCFDataset(Dataset):
    def __init__(self, Y):
        self.rows, self.cols = torch.nonzero(Y, as_tuple=True)
        self.values = Y[self.rows, self.cols]

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        return self.rows[idx], self.cols[idx], self.values[idx]

# Prepare long-form gene expression for NCF
Y_long = Y.clone()
ncf_data = NCFDataset(Y_long)
ncf_model = NCFLayer(n_spots=Y.shape[0], n_genes=Y.shape[1], embedding_dim=16)
optimizer = optim.Adam(ncf_model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# Train NCF
ncf_model.train()
loader = DataLoader(ncf_data, batch_size=512, shuffle=True)
for epoch in range(10):
    for spot, gene, value in loader:
        optimizer.zero_grad()
        pred = ncf_model(spot, gene)
        loss = loss_fn(pred, value)
        loss.backward()
        optimizer.step()

# Save spot embeddings
ncf_model.eval()
with torch.no_grad():
    spot_embeddings = ncf_model.user_emb.weight.data.numpy()
pd.DataFrame(spot_embeddings).to_csv(os.path.join(base_path, "slice4_ncf_embeddings.csv"), index=False)

print("Saved NCF embeddings.")


Saved VAE and cVAE embeddings.
Saved NCF embeddings.
