# Neuron fate prediction from combinatorial morphogen treatment

In this notebook, we show how CellFlow can be used to predict the outcome of neuron fate programming experiments. We use the the dataset from [Lin, Jansen et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. 

## Preliminaries

In [1]:
import warnings
from functools import partial, reduce
from typing import Literal

import anndata as ad
import cellflow
import flax.linen as nn
import numpy as np
import optax
import pandas as pd
import scanpy as sc
import yaml
from ott.solvers import utils as solver_utils
from scipy.sparse import csr_matrix

In [4]:
adata = cellflow.datasets.ineurons()
print(adata)

AnnData object with n_obs × n_vars = 178437 × 4000
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_well', 'bc2_well', 'bc3_well', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'plateID', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'n_genes', 'percent_mito', 'n_counts', 'outlier', 'mt_outlier', 'doublet_score', 'predicted_doublet', 'leiden_4', 'leiden_10', 'merged_clusters', 'final_clustering', 'final_clustering_reset', 'merged_clusters_from_10', 'M_XAV', 'M_CHIR', 'M_RA', 'M_FGF8', 'M_BMP4', 'M_SHH', 'M_PM', 'media', 'sample-CellID', 'Neuron_type', 'Division', 'Region', 'FGF8_conc', 'FGF8_start_time', 'FGF8_end_time', 'XAV_conc', 'XAV_start_time', 'XAV_end_time', 'RA_conc', 'RA_start_time', 'RA_end_time', 'CHIR_conc', 'CHIR_start_time', 'CHIR_end_time', 'SHH_conc', 'SHH

Now we need to create representations for the perturbation conditions to be used by CellFlow. For this, we us a one-hot encoding of the morphogen multiplied by the concentration, which we store in `adata.uns["conditions"]`