# **Data pipeline**
<a target="_blank" href="https://colab.research.google.com/github/raphaelrubrice/scVAE_mva2025/blob/raph/data_pipeline.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## **Colab setup**

In [12]:
!git clone https://github.com/raphaelrubrice/scVAE_mva2025.git
%cd scVAE_mva2025
!python -m pip install -r requirements.txt
!git checkout raph

Cloning into 'scVAE_mva2025'...
remote: Enumerating objects: 426, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 426 (delta 60), reused 66 (delta 42), pack-reused 319 (from 1)[K
Receiving objects: 100% (426/426), 1.68 MiB | 21.24 MiB/s, done.
Resolving deltas: 100% (238/238), done.
/content/scVAE_mva2025/scVAE_mva2025
Branch 'raph' set up to track remote branch 'raph' from 'origin'.
Switched to a new branch 'raph'


### Test Downloads

In [13]:
from data_pipeline.src.downloader import run_downloads

run_downloads()


Saving extracted PBMC datasets to: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw

✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD34
------------------------------------------------------------
✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD19_B
------------------------------------------------------------
✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD56_NK
------------------------------------------------------------
✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD4_helper
------------------------------------------------------------
✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD4_CD25
------------------------------------------------------------
✓ Already downloaded/extracted: /content/scVAE_mva2025/data_pipeline/data/pbmc_raw/CD4_CD45RA_CD25neg
--------------------------------------------------

In [14]:
from data_pipeline.src.config import DATASETS
from data_pipeline.src.load_anndata import load_anndata
from pathlib import Path

folder = Path("data_pipeline/data/pbmc_raw/CD4_CD45RO")
meta = DATASETS["CD4_CD45RO"]

adata = load_anndata(folder, meta)
adata

FileNotFoundError: Could not locate matrix.mtx below data_pipeline/data/pbmc_raw/CD4_CD45RO

In [None]:
from data_pipeline.src.combine import run_combine

combined, collection = run_combine(
    do_write_shards=True,
    write_combined=True,
    harmonize_var=False
)

In [None]:
import scanpy as sc
adata_combined = sc.read_h5ad("data_pipeline/data/pbmc_processed/pbmc_combined.h5ad")
adata_combined

In [None]:
from data_pipeline.src.dataloader import build_cv_dataloaders

kept_idx, folds, test_loader = build_cv_dataloaders(
    shard_dir="data_pipeline/data/pbmc_processed/shards",
    label_maps_path="data_pipeline/data/pbmc_processed/label_maps.json",
    batch_size=256,
    one_hot=True,
    pin_m=True
    filter_genes=True,
    max_genes=5000,
)

# **Cross-Validation of architectures**

In [None]:
import torch
from mixture_vae.distributions import NormalDistribution, UniformDistribution, NegativeBinomial

## **Cross-Validated scVAE**

In [None]:
from mixture_vae.mvae import MixtureVAE

In [None]:
# Problem setup
input_dim = kept_idx # genes
hidden_dim = 100 # hidden neurons per layer
n_components = 9 # 9 clusters are assumed
latent_dim = 100 # dimension latent space
n_layers = 2 # number of encoding and decoding layers

# Prior on latent: Standard Gaussian in R2
mu = torch.zeros((1,latent_dim))
std = torch.ones((1,latent_dim))
prior_latent = NormalDistribution({"mu":mu,
                                    "std":std})

# Prior on input gene counts: NB for each gene
p = 0.5 * torch.ones((1,input_dim)) # 50/50 chance of expression
r = torch.mean(X, dim=0).reshape(1,-1) # prior = average count in train data
prior_input = NegativeBinomial({"p":p,
                                "r":r})

# Prior on cluster repartitions (mixture): Assume balanced
# cluster classes = Uniform on [0,1]
a = torch.zeros((1,n_components))
b = torch.ones((1,n_components))
prior_categorical = UniformDistribution({"a":a,
                                          "b":b})

# Posterior on latent: Gaussian on R2
# (here assumed posterior = assumed prior
# but it could have been differnet)
mu = torch.zeros((1,latent_dim))
std = torch.ones((1,latent_dim))
posterior_latent = NormalDistribution({"mu":mu,
                                        "std":std})

In [None]:
cv_scVAE = []
for fold in tqdm(list(range(len(folds)))):
    print(f"\nInstantiating scVAE..")
    scVAE = MixtureVAE(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            n_components=n_components,
            n_layers=n_layers,
            prior_latent=prior_latent,
            prior_input=prior_input,
            prior_categorical=prior_categorical,
            posterior_latent=posterior_latent
        )

    train_loader, val_loader = folds[fold]

    optimizer = torch.optim.Adam(scVAE.parameters(), lr=1e-4)

    EPOCHS = 500
    BETA_KL = 1.0
    WARMUP_BETA = 200
    PATIENCE = 20
    TOL = 5e-3

    scVAE, losses, parts, clusters, all_betas = training_mvae(
        train_loader,
        val_loader,
        scVAE,
        optimizer,
        epochs=EPOCHS,
        beta_kl=BETA_KL,
        warmup=WARMUP_BETA,
        patience=PATIENCE,
        tol=TOL,
        show_loss_every=10,
        model_type=0,
        track_clusters=True,
    )

    cv_scVAE.append(scVAE)
    # plot training and validation losses
    plot_loss_components(parts["train"],
                         parts["val"],
                         all_betas,
                         title="Loss Breakdown",
                         save_path=f"./scVAE_{fold}_losses.pdf")



## **Cross-Validated independent Mixture of Mixtures**

In [None]:
from mixture_vae.mvae import ind_MoMVAE

In [None]:
cv_IndMoM = []
for fold in tqdm(list(range(len(folds)))):
    print(f"\nInstantiating IndMoM..")
    IndMoM = ind_MoMVAE(
            PARAMS = [
            {"input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "n_components": n_components,
            "n_layers": n_layers,
            "prior_latent": prior_latent,
            "prior_input": prior_input,
            "prior_categorical": prior_categorical,
            "posterior_latent": posterior_latent}
            for n_components in [2, 3, 5, 9]]
        )

    train_loader, val_loader = folds[fold]

    optimizer = torch.optim.Adam(IndMoM.parameters(), lr=1e-3)

    EPOCHS = 500
    BETA_KL = 1.0
    WARMUP_BETA = 200
    PATIENCE = 20
    TOL = 5e-3

    IndMoM, losses, parts, clusters, all_betas = training_mvae(
        train_loader,
        val_loader,
        IndMoM,
        optimizer,
        epochs=EPOCHS,
        beta_kl=BETA_KL,
        warmup=WARMUP_BETA,
        patience=PATIENCE,
        tol=TOL,
        show_loss_every=10,
        model_type=1,
        track_clusters=True,
    )
    cv_IndMoM.append(IndMoM)
    # plot training and validation losses
    plot_loss_components(parts["train"],
                         parts["val"],
                         all_betas,
                         title="Loss Breakdown",
                         save_path=f"./IndMoM_{fold}_losses.pdf")



## **Cross-Validated Mixture of Mixtures**

In [None]:
from mixture_vae.mvae import MoMixVAE
from mixture_vae.training import training_momixvae

In [None]:
# Prior on cluster repartitions (mixture): Assume balanced
# cluster classes = Uniform on [0,1]
# currently we assume the same for all levels
all_prior_categorical = []
for n_components in hierarchy_components:
    a = torch.zeros((1,n_components))
    b = torch.ones((1,n_components))
    prior_categorical = UniformDistribution({"a":a,
                                            "b":b})
    all_prior_categorical.append(prior_categorical)

# Posterior on latent: Gaussian on R2
# (here assumed posterior = assumed prior
# but it could have been differnet)
# currently we assume the same for all levels
all_posterior_latent = []
for n_components in hierarchy_components:
    mu = torch.zeros((1,latent_dim))
    std = torch.ones((1,latent_dim))
    posterior_latent = NormalDistribution({"mu":mu,
                                        "std":std})
    all_posterior_latent.append(posterior_latent)

In [None]:
cv_MoMix = []
for fold in tqdm(list(range(len(folds)))):
    print(f"\nInstantiating MoMix..")
    MoMix = MoMixVAE(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        hierarchy_components=hierarchy_components,
        n_layers=n_layers,
        prior_latent=prior_latent,
        prior_input=prior_input,
        all_prior_categorical=all_prior_categorical,
        all_posterior_latent=all_posterior_latent
    )

    train_loader, val_loader = folds[fold]

    optimizer = torch.optim.Adam(MoMix.parameters(), lr=1e-3)

    EPOCHS = 500
    BETA_KL = 1.0
    WARMUP_BETA = 200
    PATIENCE = 20
    TOL = 5e-3

    MoMix, losses, parts, clusters, all_betas = training_momixvae(
        dataloader,
        val_dataloader,
        model,
        optimizer,
        epochs=EPOCHS,
        beta_kl=BETA_KL,
        warmup=WARMUP,
        patience=PATIENCE,
        tol=TOL,
        show_loss_every=10,
        track_clusters=True,
    )
    cv_MoMix.append(MoMix)
    # plot training and validation losses
    plot_loss_components(parts["train"],
                         parts["val"],
                         all_betas,
                         title="Loss Breakdown",
                         save_path=f"./MoMix_{fold}_losses.pdf")



# **Figure 1: Generative and Clustering performances**

In [None]:
from mixture_vae.figure1 import plot_figure1

cv_dico = {"$\mathcal{N}$(0,I)": [cv_scVAE, cv_IndMoM, cv_MoMix]}

plot_figure1(cv_dico, test_loader, save_path="./Figure1.pdf")

In [None]:
# from data_pipeline.src.dataloader import build_dataloaders

# train_loader, val_loader, test_loader = build_dataloaders(
#     shard_dir="data_pipeline/data/pbmc_processed/shards",
#     label_maps_path="data_pipeline/data/pbmc_processed/label_maps.json",
#     batch_size=512,
#     one_hot=True,
# )

# for i, batch in enumerate(train_loader):
#         print("Example batch shapes:")
#         print("X:", batch["X"].shape, "y1:", batch["y1"].shape, "y4:", batch["y4"].shape)
#         break