In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import itertools
import pandas as pd

In [3]:
print(moscot.__version__)

0.1.0


In [4]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/moscot_paper/pancreas/pancreas_multiome_2022_processed.h5ad")

In [5]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [6]:
adata = adata[:, adata.var["modality"] == "GEX"].copy()

In [7]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [8]:
adata

AnnData object with n_obs × n_vars = 16918 × 14663
    obs: 'n_counts', 'sample', 'n_genes', 'log_genes', 'mt_frac', 'rp_frac', 'ambi_frac', 'nCount_RNA', 'nFeature_RNA', 'nCount_ATAC', 'nFeature_ATAC', 'nucleosome_signal', 'nucleosome_percentile', 'TSS.enrichment', 'TSS.percentile', 'S_score', 'G2M_score', 'phase', 'proliferation', 'celltype', 'nCount_peaks', 'nFeature_peaks', 'a', 'time'
    var: 'modality'
    uns: 'celltype_colors', 'neighbors'
    obsm: 'X_pca', 'X_pca_wsnn', 'X_spca_wsnn', 'X_umap', 'X_umap_ATAC', 'X_umap_GEX', 'X_umap_wsnn', 'lsi_full', 'lsi_red', 'umap', 'umap_ATAC', 'umap_GEX'
    obsp: 'connectivities', 'connectivities_wnn', 'distances', 'distances_wnn'

In [9]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [10]:
self_transitions = [(el, el) for el in adata.obs["celltype"].cat.categories]

correct_transitions =self_transitions

def compute_score(df):
    score=0
    for pair in correct_transitions:
        score += df.loc[pair].sum()
    return score/len(correct_transitions)
        

In [11]:
epsilon = (1e-4, 1e-3, 1e-2)
tau_a = (1, 0.99, 0.95, 0.9)
tau_b = (1, 0.99, 0.95, 0.9)

configs = itertools.product(epsilon, tau_a, tau_b)

# PCA on GEX space

In [12]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [13]:
tp0 = TemporalProblem(adata)
tp0 = tp0.prepare("time", a="a")

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` using `adata.X`                                     


In [14]:
scores = {}
for config in list(configs):
    tp0 = tp0.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp0.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores[config] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [15]:
df_pca = pd.DataFrame.from_dict([scores]).transpose()

In [16]:
df_pca.to_csv("scores_pca_self.csv")

# scVI on GEX space

In [17]:
adata = sc.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas/embeddings/pancreas_GEX_embeddings.h5ad")

In [18]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [19]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [20]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [21]:
tp1 = TemporalProblem(adata)
tp1 = tp1.prepare("time", joint_attr="X_scVI", a="a")

In [22]:
scores_scvi = {}
for config in list(configs):
    tp1 = tp1.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp1.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_scvi[config] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [23]:
df_scvi = pd.DataFrame.from_dict([scores_scvi]).transpose()

In [24]:
df_scvi.to_csv("scores_scvi_self.csv")

In [25]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [26]:
tp2 = TemporalProblem(adata)
tp2 = tp2.prepare("time", joint_attr="X_scVI_batch_corrected")

In [27]:
scores_scvi_bc = {}
for config in list(configs):
    tp2 = tp2.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp2.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_scvi_bc[config] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [28]:
df_scvi_bc = pd.DataFrame.from_dict([scores_scvi_bc]).transpose()

In [29]:
df_scvi_bc.to_csv("scores_scvi_bc_self.csv")