# Neuron fate prediction from combinatorial morphogen treatment

In this notebook, we show how {class}`~cellflow.model.CellFlow` can be used to predict the outcome of **neuron fate programming experiments**. We use the the dataset from [Lin, Jansen et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. 

## Preparing the data

In [1]:
import os
from functools import partial

import anndata as ad
import flax.linen as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import umap
from ott.solvers import utils as solver_utils
from scipy.sparse import csr_matrix
from sklearn.preprocessing import OneHotEncoder

import cellflow
import cellflow.preprocessing as cfpp

In [2]:
adata = cellflow.datasets.ineurons()
print(adata)

AnnData object with n_obs × n_vars = 178437 × 4000
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_well', 'bc2_well', 'bc3_well', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'plateID', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'n_genes', 'percent_mito', 'n_counts', 'outlier', 'mt_outlier', 'doublet_score', 'predicted_doublet', 'leiden_4', 'leiden_10', 'merged_clusters', 'final_clustering', 'final_clustering_reset', 'merged_clusters_from_10', 'M_XAV', 'M_CHIR', 'M_RA', 'M_FGF8', 'M_BMP4', 'M_SHH', 'M_PM', 'media', 'sample-CellID', 'Neuron_type', 'Division', 'Region', 'FGF8_conc', 'FGF8_start_time', 'FGF8_end_time', 'XAV_conc', 'XAV_start_time', 'XAV_end_time', 'RA_conc', 'RA_start_time', 'RA_end_time', 'CHIR_conc', 'CHIR_start_time', 'CHIR_end_time', 'SHH_conc', 'SHH

### Encoding of morphogen treatment conditions
Now we need to create **representations for the perturbation conditions** to be used by CellFlow. For this, we us a one-hot encoding of the morphogen multiplied by the concentration, which we store in `adata.uns["conditions"]`. We also create columns in the adata object, indicating the treatment conditions for each cell.

In [3]:
morphogens = ["FGF8", "XAV", "RA", "CHIR", "SHH", "BMP4"]
dataset_enc = OneHotEncoder()
dataset_enc.fit(np.array(morphogens).reshape(-1, 1))
condition_dict = {}
condition_keys = []

for mol in morphogens:
    mol_onehot = (
        dataset_enc.transform(np.array([mol]).reshape(-1, 1)).toarray().flatten()
    )
    concs = adata.obs[mol + "_conc"].unique()
    for conc in concs:
        cond_cells = adata.obs[mol + "_conc"] == conc
        cond_id = mol + "_" + str(conc)
        if cond_id in condition_keys:
            continue
        condition_keys.append(cond_id)
        adata.obs[cond_id] = cond_cells
        condition_dict[cond_id] = mol_onehot * np.log1p(float(conc))

adata.uns["conditions"] = condition_dict
print(adata)

AnnData object with n_obs × n_vars = 178437 × 4000
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_well', 'bc2_well', 'bc3_well', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'plateID', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'n_genes', 'percent_mito', 'n_counts', 'outlier', 'mt_outlier', 'doublet_score', 'predicted_doublet', 'leiden_4', 'leiden_10', 'merged_clusters', 'final_clustering', 'final_clustering_reset', 'merged_clusters_from_10', 'M_XAV', 'M_CHIR', 'M_RA', 'M_FGF8', 'M_BMP4', 'M_SHH', 'M_PM', 'media', 'sample-CellID', 'Neuron_type', 'Division', 'Region', 'FGF8_conc', 'FGF8_start_time', 'FGF8_end_time', 'XAV_conc', 'XAV_start_time', 'XAV_end_time', 'RA_conc', 'RA_start_time', 'RA_end_time', 'CHIR_conc', 'CHIR_start_time', 'CHIR_end_time', 'SHH_conc', 'SHH

### Train / test split
Now we split the data into **train in test set**. We know from the original study that RA+BMP4 is a particularly interesting combination, as it resulted in new cell states that were not seen with any individual morphogen treatment. We now want to test wether CellFlow can predict these new cell states by holding out all conditions including the combination of RA+BMP4 from training. 

In [4]:
exclude_combs = ["RA+BMP4", "RA+CHIR+BMP4"]
adata_train = adata[
    ~adata.obs["comb"].isin(exclude_combs)
]  # remove combinations from training
adata_eval = adata[
    adata.obs["comb"].isin(exclude_combs)
]  # evaluation on combinations and control

To prevent any leakage of information by having a latent space computed with the held-out conditions, we also we recompute PCA for the trainign set only and then project the test set into this space.

In [5]:
cfpp.centered_pca(adata_train, n_comps=30, method="rapids")
cfpp.project_pca(adata_eval, adata_train)

### Generating a source distribution
To train the model and generate predictions, we also need a **source distribution** to generate from. In most other use cases, we can just use the control condition for his. However, in this case, each condition (including the lack of morphogens) generates distinct cell state distributions. This means that the control condition cannot be viewed as an "unpertrubed" state and it therefore does not necessarily make sense to use it as a source distribution. Instead, we use a random distribution that is generated from subsamples means of the training data. As a result, the model will essentially act in a fully generative way, generating new cell distributions from a random source.

In [6]:
n_src_cells = 10000
n_samples = 1000
sample_rep = "X_pca"
samples = []
for i in range(n_src_cells):
    sample = adata_train.obsm[sample_rep][
        np.random.choice(adata_train.n_obs, n_samples), :
    ].mean(axis=0)
    samples.append(sample)
samples = np.array(samples)
samples_obs = pd.DataFrame(
    {col: 0.0 for col in [mol + "_conc" for mol in morphogens]},
    index=range(samples.shape[0]),
)
samples_obs["dataset"] = "CTRL"
samples_obs["media"] = "CTRL"
samples_obs["condition"] = "CTRL"
adata_ctrl = sc.AnnData(
    X=csr_matrix(np.zeros((samples.shape[0], adata_train.n_vars))), obs=samples_obs
)
adata_ctrl.obsm[sample_rep] = samples
adata_ctrl.var_names = adata_train.var_names
adata_train_full = ad.concat([adata_train, adata_ctrl], join="outer")
adata_train_full.obs["CTRL"] = adata_train_full.obs["dataset"] == "CTRL"
adata_ctrl.obs["CTRL"] = True
adata_train_full.uns, adata_eval.uns = adata.uns, adata.uns

## Running CellFlow
Now we are ready to set up the `CellFlow` model. We use the default deterministic `otfm` solver for this task.

In [7]:
cf = cellflow.model.CellFlow(adata_train_full, solver="otfm")

### Preparing CellFlow’s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`
We set up the data as follows:
- We use `.obsm["X_pca"]` as the cellular representation (`sample_rep`)
- `"CTRL"` indicated the source distribution we constructed earlier
- We use the previously constructed columns indicating morphogen concentrations as `perturbation_covariates`
- As representations for the perturbation conditions, we use the one-hot encoded morphogen concentrations in `.uns["conditions"]` (`perturbation_covariate_reps`).

In [8]:
cf.prepare_data(
    sample_rep=sample_rep,
    control_key="CTRL",
    perturbation_covariates={"conditions": condition_keys},
    perturbation_covariate_reps={"conditions": "conditions"},
)

100%|██████████| 169/169 [00:03<00:00, 53.24it/s]


### Preparing CellFlow’s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`



In [9]:
layers_before_pool = [{
    "layer_type": "mlp",
    "dims": [2048] * 4,
    "dropout_rate": 0.0,
    "act_fn": nn.relu,
}]
layers_after_pool = []

In [10]:
match_fn = partial(
    solver_utils.match_linear,
    epsilon=0.5,
    scale_cost="mean",
    tau_a=0.99,
    tau_b=0.99,
)

In [11]:
cf.prepare_model(
    condition_embedding_dim=128,
    time_encoder_dims=[1024] * 5,
    time_encoder_dropout=0.1,
    hidden_dims=[2048] * 2 + [128],
    hidden_dropout=0.2,
    decoder_dims=[512] * 2,
    decoder_dropout=0.1,
    pooling="mean",
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=0.3,
    flow={"constant_noise": 0.0},
    match_fn=match_fn
)

In [15]:
cf.train(num_iterations=500000)

100%|██████████| 500000/500000 [3:44:15<00:00, 37.16it/s, loss=0.969]  


In [None]:
cf.save("/home/fleckj/projects/cellflow/results/ineuron_tutorial/", overwrite=True)

In [19]:
adata_train.write_h5ad(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/adata_train.h5ad"
)
adata_eval.write_h5ad(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/adata_eval.h5ad"
)
adata_train_full.write_h5ad(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/adata_train_full.h5ad"
)
adata_ctrl.write_h5ad(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/adata_ctrl.h5ad"
)
adata.write_h5ad(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/adata.h5ad"
)

In [12]:
cf = cellflow.model.CellFlow.load(
    "/home/fleckj/projects/cellflow/results/ineuron_tutorial/"
)

In [13]:
adata_ctrl.uns = adata.uns
obs_pred = adata_eval.obs.drop_duplicates("condition")

x_pred = cf.predict(
    adata_ctrl,
    sample_rep=sample_rep,
    covariate_data=obs_pred,
    condition_id_key="condition",
)

100%|██████████| 24/24 [00:00<00:00, 91.78it/s]


In [None]:
adatas_pred = []
for condition in x_pred.keys():
    adata_pred = ad.AnnData(X=adata_ctrl.X)
    adata_pred.obs["comb"] = adata_eval.obs["comb"][
        adata_eval.obs["condition"] == condition
    ].values[0]
    adata_pred.obs["condition"] = condition
    adata_pred.obsm[sample_rep] = x_pred[condition]
    adatas_pred.append(adata_pred)

adata_pred = ad.concat(adatas_pred, join="outer")

In [None]:
cfpp.reconstruct_pca(adata_pred, use_rep="X_pca", ref_adata=adata_train)
adata_pred.X = csr_matrix(adata_pred.layers["X_recon"])
del adata_pred.layers["X_recon"]

cfpp.centered_pca(adata, n_comps=30, method="rapids")
cfpp.project_pca(adata_pred, adata, obsm_key_added="X_pca_reproj")

In [None]:
umap_model = umap.UMAP(
    n_neighbors=15,
    n_components=2,
    n_epochs=500,
    learning_rate=1.0,
    init="spectral",
    min_dist=0.4,
    spread=1.0,
    negative_sample_rate=5,
    a=None,
    b=None,
    random_state=212,
    n_jobs=-1,
)
adata.obsm["X_umap"] = umap_model.fit_transform(adata_eval.obsm["X_pca"])
adata_pred.obsm["X_umap"] = umap_model.transform(adata_pred.obsm["X_pca_reproj"])