In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from muon import atac as ac
import numpy as np
import itertools
import pandas as pd

In [3]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/moscot_paper/pancreas/pancreas_multiome_2022_processed.h5ad")

In [4]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [5]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [6]:
adata_gex = adata[:, adata.var["modality"]=="GEX"].copy()
adata = adata[:, adata.var["modality"]=="ATAC"].copy()

In [7]:
known_transitions = [
    ("Imm. Acinar", ["Mat. Acinar"]), 
    ("Fev+ Alpha", ["Alpha"]),
    ("Fev+ Beta", ["Beta"]),
    ("Ngn3 high", ["Alpha", "Beta", "Delta", "Epsilon", "Fev+ Alpha", "Fev+ Beta", "Eps. progenitors", "Fev+ Delta", "Fev+"]),
    ("Ngn3 low", ["Alpha", "Beta", "Delta", "Epsilon", "Fev+ Alpha", "Fev+ Beta", "Eps. progenitors", "Fev+ Delta", "Fev+", "Ngn3 high", "Ngn3 high cycling"]),
]


correct_transitions = known_transitions 

def compute_score(df):
    score=0
    for pair in correct_transitions:
        score += df.loc[pair].sum()
    return score/len(correct_transitions)

In [8]:
epsilon = (1e-4, 1e-3, 1e-2)
tau_a = (1, 0.99, 0.95, 0.9)
tau_b = (1, 0.99, 0.95, 0.9)

configs = itertools.product(epsilon, tau_a, tau_b)


# LSI ATAC space

In [9]:
ac.pp.tfidf(adata, scale_factor=1e4)

In [10]:
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)

In [11]:
ac.tl.lsi(adata, n_comps=30)

In [12]:
depth_corr = [np.corrcoef(adata.obsm["X_lsi"][:,i], adata.obs["nCount_ATAC"])[1,0] for i in range(adata.obsm["X_lsi"].shape[1])]

In [13]:
comps_to_remove = np.where(np.abs(depth_corr) > 0.4)[0]
comps_to_remove

array([0, 4])

In [14]:
subset = list(set(range(adata.obsm["X_lsi"].shape[1])) - set(comps_to_remove))

In [15]:
adata.obsm['X_lsi'] = adata.obsm['X_lsi'][:,subset]

In [16]:
tp0 = TemporalProblem(adata)
tp0 = tp0.prepare("time", joint_attr="X_lsi", a="a")

In [17]:
scores = {}
for config in list(configs):
    tp0 = tp0.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp0.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores[str(config)] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [18]:
df_lsi = pd.DataFrame.from_dict([scores]).transpose()

In [19]:
df_lsi.to_csv("lsi_scores_forward.csv")

# poissVI on ATAC space

In [20]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/moscot_paper/pancreas/embeddings/pancreas_ATAC_embeddings.h5ad")

In [21]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [22]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [23]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [24]:
tp1 = TemporalProblem(adata)
tp1 = tp1.prepare("time", joint_attr="X_patac", a="a")

In [25]:
scores_patac = {}
for config in list(configs):
    tp1 = tp1.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp1.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_patac[str(config)] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [26]:
df_patac = pd.DataFrame.from_dict([scores_patac]).transpose()

In [27]:
df_patac.to_csv("patac_scores_forward.csv")

In [28]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [29]:
tp2 = TemporalProblem(adata)
tp2 = tp2.prepare("time", joint_attr="X_patac_batch_corrected", a="a")

In [30]:
scores_patac_bc = {}
for config in list(configs):
    tp2 = tp2.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp2.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_patac_bc[str(config)] = compute_score(ct_desc)

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.            
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m

[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m9811[0m, [1;36m7107[0m[1m)[0m[1m][0m.              


In [31]:
df_patac_bc = pd.DataFrame.from_dict([scores_patac_bc]).transpose()

In [32]:
df_patac_bc.to_csv("patac_bc_scores_forward.csv")