# Cell-Wise Variational Autoencoder

The following model is trained on 20-day-old mice.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from itertools import cycle

from models import Cell_Wise_VAE, VAE_Single_Dataset
import pickle
import umap.umap_ as umap
import time

## Data Preparation

In [None]:
mod1 = pd.read_csv("./cellwiseVAE_workshop.csv")
mod1.head()

In [None]:
# MinMaxScler rescales features to the range [0,1]
# It matches sigmoid() activation from model definition

scaler_mod1 = MinMaxScaler()
X = mod1.drop(columns=['sample_id','cell_type'])
X = scaler_mod1.fit_transform(X)
X = pd.DataFrame(X, columns=mod1.drop(columns=['sample_id','cell_type']).columns)
X['sample_id'] = mod1['sample_id'].values
X['cell_type'] = mod1['cell_type'].values

In [None]:
# Save the scaler to disk so that the same scaling operation can be applied later
with open('./scaler_mod1.pkl', 'wb') as f:
    pickle.dump(scaler_mod1, f)

In [None]:
# Extracting final features and labels into arrays
features_1 = X.drop(columns=["sample_id", "cell_type"]).values
cell_types_1 = X["cell_type"].values
sample_ids_1 = X["sample_id"].values

# Number of features going into the VAE
num_features = features_1.shape[1]
print(num_features)

In [None]:
# Choose GPU, if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Main loss function for the VAE
def mmvae_loss_function(recons, truths, kl_loss, beta=1.0, device="cuda"):

    recon_losses = {} # Per-Modality Reconstruction Losses
    recon_loss_total = torch.tensor(0.0, device=device)

    # Looping through the modalities (trivial for single-modality)
    for mod in truths.keys():

        # Ensuring computations proceed only for existing modalities
        if mod in recons:
            # Ensuring both tensors have the same batch size
            n = min(recons[mod].shape[0], truths[mod].shape[0])
            recon = recons[mod][:n]
            truth = truths[mod][:n]

            # MSE reconstruction loss (averaged per batch)
            loss_mod = F.mse_loss(recon, truth, reduction='mean')
            recon_losses[mod] = loss_mod
            recon_loss_total += loss_mod

    # Total VAE loss = reconstruction + beta * KL
    total_loss = recon_loss_total + beta * kl_loss

    return {
        'total_loss': total_loss,
        'kl_loss': kl_loss,
        'recon_loss_total': recon_loss_total,
        'recon_losses': recon_losses
    }

In [None]:
# Stratified train/validation split

# Create a combined label that ensures each (cell_type, sample_id) group is proportionally represented
# The same logic applies to both train and validation sets
combined_strata_x1 = [f"{cell}_{sample}" for cell, sample in zip(cell_types_1, sample_ids_1)]

# Stratified split (preserving the distribution)
train_idx_x1, val_idx_x1 = train_test_split(np.arange(len(features_1)), test_size=0.2, stratify=combined_strata_x1, shuffle=True)

train_x1 = features_1[train_idx_x1]
val_x1 = features_1[val_idx_x1]

In [None]:
# Training loader (cycled for infinite iteration during training loops)
# This helps extend the logic to multiple modalities and prevents the StopIteration Error
train_loader_x1 = DataLoader(VAE_Single_Dataset(train_x1), batch_size=128, shuffle=True)
# Number of steps per epoch
num_steps = len(train_loader_x1)
train_loader_x1 = cycle(train_loader_x1)

# Validation loader
val_loader_x1 = DataLoader(VAE_Single_Dataset(val_x1), batch_size=128, shuffle=True)

## VAE Training

In [None]:
# Instantiate the Cell_Wise_VAE model and move it to the selected device (GPU/CPU)
model = Cell_Wise_VAE(input_dim=15, latent_dim=3, use_mean=True).to(device)
print("using device : ", device)
print("latent dim: ", model.latent_dim)

# Training Hyperparameters
num_epochs = 100
BETA = 0.001 # KL weight (beta-penalty)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=150, gamma=0.5)

# Lists for storing performance metrics over epochs
train_losses, train_recon_totals, train_kl_losses = [], [], []
val_losses, val_recon_totals, val_kl_losses = [], [], []

# Dictionaries for storing reconstruction losses per modality
train_recon_losses_mod = {"mod1": []}
val_recon_losses_mod = {"mod1": []}

#################
# Training Loop #
#################

for epoch in range(1, num_epochs + 1):
    start = time.time()
    model.train() # Set model to training mode

    # Running sums for the epoch
    epoch_total_loss, epoch_recon_loss, epoch_kl_loss = 0.0, 0.0, 0.0

    # Tracking per-modality averages
    train_mod_sums = {m: 0.0 for m in train_recon_losses_mod}
    train_mod_counts = {m: 0 for m in train_recon_losses_mod}

    ##################
    # Training Steps #
    ##################
    for step in range(num_steps):
        # Get the next training batch from the infinite iterator
        x1 = next(train_loader_x1)[0].to(device)

        # Forward pass through the VAE
        optimizer.zero_grad()
        # Reconstructed outputs for each modality and KL divergence term for the latent distribution
        recons, kl_loss, _, _ = model(x1=x1)
        # Ground-Truth tensors for reconstruction
        truths = {"mod1": x1}

        # Computing the VAE loss (reconstruction + beta * KL)
        losses = mmvae_loss_function(recons, truths, kl_loss, beta=BETA, device=device)
        total_loss = losses["total_loss"]

        # Backpropagation
        total_loss.backward()
        optimizer.step()

        # Storing the step losses
        epoch_total_loss += total_loss.item()
        epoch_recon_loss += losses["recon_loss_total"].item()
        epoch_kl_loss += losses["kl_loss"].item()

        # Accumulating per-modality reconstruction loss
        for mod, val in losses["recon_losses"].items():
            train_mod_sums[mod] += val.item()
            train_mod_counts[mod] += 1

    # Averaging over all the steps executed in an epoch
    epoch_total_loss /= num_steps
    epoch_recon_loss /= num_steps
    epoch_kl_loss /= num_steps

    # Saving the epoch metrics
    train_losses.append(epoch_total_loss)
    train_recon_totals.append(epoch_recon_loss)
    train_kl_losses.append(epoch_kl_loss)

    ####################
    # Validation Steps #
    ####################
    model.eval()
    val_mod_sums = {"mod1": 0.0}
    val_mod_counts = {"mod1": 0}
    val_kl_sums = {"mod1": 0.0}
    val_kl_counts = {"mod1": 0}

    with torch.no_grad():

        # Iterating through all the modalities (only one in this example)
        for mod_name, (x_name, loader) in zip(["mod1"], [("x1", val_loader_x1)]):
            for batch in loader:
                batch = batch[0].to(device)

                # Forward pass depending on the modality considered
                if x_name == "x1":
                    recons, kl_loss, _, _ = model(x1=batch)

                truths = {mod_name: batch}
                # Computing the validation loss
                losses = mmvae_loss_function(recons=recons,truths=truths,kl_loss=kl_loss,beta=BETA,device=device)

                val_mod_sums[mod_name] += losses["recon_loss_total"].item()
                val_mod_counts[mod_name] += 1
                val_kl_sums[mod_name] += losses["kl_loss"].item()
                val_kl_counts[mod_name] += 1

    # Computing the average validation metrics per modality
    avg_per_mod_recon = {}
    for mod in val_mod_sums:
        if val_mod_counts[mod] > 0:
            avg_per_mod_recon[mod] = val_mod_sums[mod] / val_mod_counts[mod]
        else:
            avg_per_mod_recon[mod] = np.nan

    avg_per_mod_kl = {}
    for mod in val_kl_sums:
        if val_kl_counts[mod] > 0:
            avg_per_mod_kl[mod] = val_kl_sums[mod] / val_kl_counts[mod]
        else:
            avg_per_mod_kl[mod] = np.nan

    # Computing the total validation loss for this epoch
    val_recon_total_epoch = np.sum(list(avg_per_mod_recon.values()))
    val_kl_total_epoch = np.sum(list(avg_per_mod_kl.values()))
    val_total_loss_epoch = val_recon_total_epoch + BETA * val_kl_total_epoch

    # Storing the metrics
    val_losses.append(val_total_loss_epoch)
    val_recon_totals.append(val_recon_total_epoch)
    val_kl_losses.append(val_kl_total_epoch)

    # Step learning-rate scheduler
    scheduler.step()
    stop = time.time()

    # Printing out the results for this epoch
    print(f"Epoch {epoch}/{num_epochs} | time {stop - start:.2f}s")
    print(f"Train: Total={epoch_total_loss:.6f} | Recon={epoch_recon_loss:.6f} | KL={epoch_kl_loss:.6f}")
    print(f" Val : Total={val_total_loss_epoch:.6f} | Recon={val_recon_total_epoch:.6f} | KL={val_kl_total_epoch:.6f}")


In [None]:
# Set the model to evaluation mode, ensuring deterministic behavior
model.eval()

# Plotting the total loss curves over epochs for both training and validation
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()

In [None]:
# Plotting the KL divergence component over epochs
plt.plot(train_kl_losses, label="Train KLD Loss", color="black")
plt.plot(val_kl_losses, label="Val KLD Loss", color="gray")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("KL: Training Losse")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.show()

## Latent-Space Encoding

In [None]:
def encode_latent_space_by_mod_column_mod1(model, sampled_df, feature_columns):
    """
    Helper function running the encoder and returning a dataframe of latent means for the considered modality.
    """
    model.eval()
    latent_dfs = []

    # Mapping the modality name to the corresponding encoder component
    encoder_map = {"mod1": (model.encoder_body_mod1, model.encoder_head_mod1)}

    # Looping over each modality (only 1 considered in the example)
    for _, (body, head) in encoder_map.items():
        subset = sampled_df

        # Converting the selected feature columns to a torch tensor
        x = torch.tensor(subset[feature_columns].values, dtype=torch.float32).to(device)

        with torch.no_grad():
            # Forward pass through the encoder layers
            h = body(x)
            stats = head(h)
            mu = stats[:, :model.latent_dim]

        # Converting latent means to a dataframe
        df_mu = pd.DataFrame(mu.cpu().numpy(), index=subset.index)
        df_mu.columns = [f"z{i+1}" for i in range(model.latent_dim)]
        latent_dfs.append(df_mu)

    # Combine (if multiple modalities are used) and preserve ordering
    latent_df = pd.concat(latent_dfs).sort_index()
    return latent_df

In [None]:
# Filtering unwanted cell types before encoding the latent space
# "Unknown" cell type act like extra noise added to the model
X = X[X["cell_type"] != "unknown"]

# Extracting all features except for the sample_id and cell_type
feature_cols = list(X.columns[:-2])

# Computing the latent representation for the chosen modality
latent_df_mod1 = encode_latent_space_by_mod_column_mod1(model, X, feature_cols)

In [None]:
# Load the same scaler used during training
with open('./scaler_mod1.pkl', 'rb') as f:
    scaler_mod1 = pickle.load(f)

In [None]:
# Undo the MinMax scaling to return to the original range
data_temp_mod1 = X.drop(columns=["sample_id", "cell_type",]).values
data_temp_mod1 = scaler_mod1.inverse_transform(data_temp_mod1)
data_temp_mod1 = pd.DataFrame(data_temp_mod1, columns=X.drop(columns=["sample_id", "cell_type"]).columns, index=X.index)
data_temp_mod1["sample_id"] = X["sample_id"].values
data_temp_mod1["cell_type"] = X["cell_type"].values

In [None]:
# Final merging of the original (unscaled) data with its latent representation
merged_mod1 = pd.concat([data_temp_mod1.loc[latent_df_mod1.index], latent_df_mod1], axis=1)
merged_mod1

## Plotting

In [None]:
def plot_latent_umap(latent_vecs, labels, sample_ids, color_by="cell_type"):
    reducer = umap.UMAP(random_state=3624)
    Z_umap = reducer.fit_transform(latent_vecs)

    cell_type_colors = {
        'T1': '#D32F2F',
        'T2': '#FF9800',
        'MZ': '#2196F3',
        'FM': '#9C27B0'
    }

    # Plot 1: UMAP colored by pre-set labels
    plt.figure(figsize=(10, 6))
    if color_by == "cell_type" and isinstance(labels[0], str):
        unique_labels = np.unique(labels)
        for lbl in unique_labels:
            mask = np.array(labels) == lbl
            color = cell_type_colors.get(lbl, "#999999")
            plt.scatter(Z_umap[mask, 0], Z_umap[mask, 1], label=lbl, c=color, s=10, alpha=0.7, edgecolors='none')
        plt.legend(title="Cell Type", bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.title("Latent Space UMAP")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Plot 2: UMAPs faceted by Sample ID
    unique_samples = np.unique(sample_ids)
    ncols = 2
    nrows = 1

    fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows), sharex=True, sharey=True)
    axes = axes.flatten()

    for ax, sample in zip(axes, unique_samples):
        sample_mask = np.array(sample_ids) == sample
        for lbl in np.unique(labels):
            mask = (np.array(labels) == lbl) & sample_mask
            color = cell_type_colors.get(lbl, "#999999")
            ax.scatter(Z_umap[mask, 0], Z_umap[mask, 1], label=lbl, c=color, s=10, alpha=0.7, edgecolors='none')
        ax.set_title(f"Sample: {sample}")
        ax.grid(True)

    for j in range(len(unique_samples), len(axes)):
        axes[j].axis("off")

    handles, labels_legend = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels_legend, title="Cell Type", bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.suptitle("Latent Space UMAP, faceted by Sample ID")
    plt.tight_layout()
    plt.show()

In [None]:
def plot_latent_umap_continuous(latent_vecs, values, sample_ids, color_label="GFP-A Rag2"):
    reducer = umap.UMAP(random_state=3624)
    Z_umap = reducer.fit_transform(latent_vecs)

    # Plot 1: UMAP colored by a continuous label, here GFP
    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(Z_umap[:, 0], Z_umap[:, 1], c=values, s=10, alpha=0.7, cmap="viridis")
    cbar = plt.colorbar(scatter)
    cbar.set_label(color_label)
    plt.title(f"Latent Space UMAP, by {color_label}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Plot 2: UMAPs faceted by Sample ID
    unique_samples = np.unique(sample_ids)
    ncols = 2
    nrows = 1

    fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows), sharex=True, sharey=True)
    axes = axes.flatten()

    vmin, vmax = np.min(values), np.max(values)

    for ax, sample in zip(axes, unique_samples):
        sample_mask = np.array(sample_ids) == sample
        scatter = ax.scatter(Z_umap[sample_mask, 0], Z_umap[sample_mask, 1],
                             c=np.array(values)[sample_mask], s=10,
                             alpha=0.7, cmap="viridis", vmin=vmin, vmax=vmax)
        ax.set_title(f"Sample: {sample}")
        ax.grid(True)

    for j in range(len(unique_samples), len(axes)):
        axes[j].axis("off")

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(scatter, cax=cbar_ax, label=color_label)

    plt.suptitle(f"Latent Space UMAP by {color_label}, faceted by Sample ID")
    plt.show()


In [None]:
# 2D UMAP Representation of the Latent Space
plot_latent_umap(
    latent_vecs=latent_df_mod1.values,
    labels=merged_mod1.loc[latent_df_mod1.index, "cell_type"].values,
    sample_ids=merged_mod1.loc[latent_df_mod1.index, "sample_id"].values,
    color_by="cell_type"
)

In [None]:
plot_latent_umap_continuous(
    latent_vecs=latent_df_mod1.values,
    values=merged_mod1["GFP-A Rag2"].values,
    sample_ids=merged_mod1["sample_id"].values,
    color_label="GFP-A Rag2"
)

In [None]:
# 3D Rendering: Latent Space colored by cell type
cell_type_colors = {
        'T1': '#D32F2F',
        'T2': '#FF9800',
        'MZ': '#2196F3',
        'FM': '#9C27B0'
}

fig = px.scatter_3d(merged_mod1,
                    x=merged_mod1['z1'],
                    y=merged_mod1['z2'],
                    z=merged_mod1['z3'],
                    color='cell_type',
                    title=f'3D Representation of Latent Space, sample of {len(X)} cells',
                    color_discrete_map=cell_type_colors)

fig.update_traces(marker=dict(size=2, opacity=1))

fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemwidth=30,
        itemclick='toggleothers'
    )
)

fig.show(renderer='browser')

In [None]:
# 3D Rendering: Latent Space colored by GFP-A
fig = px.scatter_3d(merged_mod1,
                    x=merged_mod1['z1'],
                    y=merged_mod1['z2'],
                    z=merged_mod1['z3'],
                    color='GFP-A Rag2',
                    title=f'3D Representation of Latent Space, sample of {len(X)} cells',
                    color_continuous_scale='viridis')

fig.update_traces(marker=dict(size=2, opacity=1))

fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemwidth=30,
        itemclick='toggleothers'
    )
)

fig.show(renderer='browser')