In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from muon import atac as ac
import numpy as np
import itertools
import pandas as pd

2022-11-24 10:03:06.908029: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [3]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/moscot_paper/pancreas/pancreas_multiome_2022_processed.h5ad")

In [4]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [5]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [None]:
adata_gex = adata[:, adata.var["modality"]=="GEX"].copy()
adata_atac = adata[:, adata.var["modality"]=="ATAC"].copy()

In [None]:
known_transitions = [
    ("Imm. Acinar", ["Mat. Acinar"]), 
    ("Fev+ Alpha", ["Alpha"]),
    ("Fev+ Beta", ["Beta"]),
    ("Ngn3 high", ["Alpha", "Beta", "Delta", "Epsilon", "Fev+ Alpha", "Fev+ Beta", "Eps. progenitors", "Fev+ Delta", "Fev+"]),
    ("Ngn3 low", ["Alpha", "Beta", "Delta", "Epsilon", "Fev+ Alpha", "Fev+ Beta", "Eps. progenitors", "Fev+ Delta", "Fev+", "Ngn3 high", "Ngn3 high cycling"]),
]

self_transitions = [(el, el) for el in adata.obs["celltype"].cat.categories]

correct_transitions = known_transitions + self_transitions

def compute_score(df):
    score=0
    for pair in correct_transitions:
        score += df.loc[pair].sum()
    return score/len(correct_transitions)

In [None]:
epsilon = (1e-4, 1e-3, 1e-2)
tau_a = (1, 0.99, 0.95, 0.9)
tau_b = (1, 0.99, 0.95, 0.9)

configs = itertools.product(epsilon, tau_a, tau_b)


# LSI ATAC space

In [None]:
ac.pp.tfidf(adata_atac, scale_factor=1e4)

In [None]:
sc.pp.normalize_per_cell(adata_atac, counts_per_cell_after=1e4)
sc.pp.log1p(adata_atac)

In [None]:
ac.tl.lsi(adata_atac, n_comps=15)

In [None]:
depth_corr = [np.corrcoef(adata_atac.obsm["X_lsi"][:,i], adata_atac.obs["nCount_ATAC"])[1,0] for i in range(adata_atac.obsm["X_lsi"].shape[1])]

In [None]:
comps_to_remove = np.where(np.abs(depth_corr) > 0.4)[0]
comps_to_remove

In [None]:
subset = list(set(range(adata_atac.obsm["X_lsi"].shape[1])) - set(comps_to_remove))

In [None]:
adata_atac.obsm['X_lsi'] = adata_atac.obsm['X_lsi'][:,subset]

# PCA GEX space

In [None]:
sc.pp.normalize_total(adata_gex, target_sum=1e4)
sc.pp.log1p(adata_gex)

In [None]:
sc.pp.pca(adata_gex, n_comps=15)

In [None]:
adata_gex.obsm["X_pca_scaled"] = StandardScaler().fit_transform(adata_gex.obsm["X_pca"])

# Concatenate

In [None]:
adata.obsm["X_joint"] = np.concatenate((adata_gex.obsm["X_pca_scaled"], adata_atac.obsm["X_lsi"]), axis=1)

In [None]:
tp0 = TemporalProblem(adata)
tp0 = tp0.prepare("time", joint_attr="X_joint", a="a")

In [None]:
scores = {}
for config in list(configs):
    tp0 = tp0.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp0.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores[config] = compute_score(ct_desc)

In [None]:
df_joint = pd.DataFrame.from_dict([scores]).transpose()

In [None]:
df_joint.to_csv("joint_pca_lsi_scores.csv")

# MultiVI spaces

In [None]:
adata = sc.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas/embeddings/pancreas_shared_embeddings.h5ad")

In [None]:
adata.obs['time'] = adata.obs.apply(lambda x: 14.5 if x["sample"]=="E14.5" else 15.5, axis=1)

In [None]:
marginals = pd.read_csv("marginals.csv", index_col="Unnamed: 0")
adata.obs["a"] = marginals

In [None]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [None]:
tp1 = TemporalProblem(adata)
tp1 = tp1.prepare("time", joint_attr="X_multi_vi", a="a")

In [None]:
scores_multivi = {}
for config in list(configs):
    tp1 = tp1.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp1.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_multivi[config] = compute_score(ct_desc)

In [None]:
df_multivi = pd.DataFrame.from_dict([scores_multivi]).transpose()

In [None]:
df_multivi.to_csv("multivi_scores.csv")

In [None]:
configs = itertools.product(epsilon, tau_a, tau_b)

In [None]:
tp2 = TemporalProblem(adata)
tp2 = tp2.prepare("time", joint_attr="X_multi_vi_batch_corrected", a="a")

In [None]:
scores_multivi_bc = {}
for config in list(configs):
    tp2 = tp2.solve(epsilon=config[0], tau_a=config[1], tau_b=config[2], max_iterations=1e8)
    ct_desc = tp2.cell_transition(source=14.5, target=15.5, source_groups="celltype", target_groups="celltype", forward=True)
    scores_multivi_bc[config] = compute_score(ct_desc)

In [None]:
df_multivi_bc = pd.DataFrame.from_dict([scores_multivi_bc]).transpose()

In [None]:
df_multivi_bc.to_csv("multivi_bc_scores.csv")