In [1]:
import scanpy as sc
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import moscot.plotting as mpl
import pandas as pd
import os
import muon
from ott.geometry import pointcloud, geometry
from sklearn.preprocessing import StandardScaler
import networkx as nx
import itertools
import anndata
from mudata import MuData
import jax.numpy as jnp
from typing import Dict, Tuple
from ott import tools
from tqdm import tqdm
import jax
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=200)
                         
import mplscience

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams['legend.scatterpoints'] = 1 

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


['default', 'despine']


In [2]:
output_dir = "/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/stability_analysis"

In [3]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")

endocrine_celltypes = [
    "Ngn3 low",
    "Ngn3 high",
    "Ngn3 high cycling",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon"
]

mudata = mudata[mudata.obs["cell_type"].isin(endocrine_celltypes)].copy()

In [4]:
mudata.obsm

MuAxisArrays with keys: X_MultiVI, X_umap, atac, rna

In [5]:
cm = jnp.ones((144, 144)) - jnp.eye(144)

def compute_metrics(df_reference: jax.Array, df: pd.DataFrame, emb_0: str, emb_1: str, cost_0: str, cost_1: str) -> pd.DataFrame:
    
    sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(geometry.Geometry, cost_matrix=(cm,cm,cm), a=df_reference.values.flatten(), b=df.values.flatten(), epsilon=1e-3).divergence
    eps_from_eps_prog = df.loc["Eps. progenitors", "Epsilon"]
    delta_from_fev_delta = df.loc["Fev+ Delta", "Delta"]
    fev_delta_from_eps_prog = df.loc["Eps. progenitors", "Fev+ Delta"]
    eps_from_fev_delta = df.loc["Fev+ Delta", "Epsilon"]
    beta_from_fev_beta = df.loc["Fev+ Beta", "Beta"]
    delta_from_ngn3_low = df.loc["Ngn3 low", "Delta"]
    print(emb_0 , emb_1, cost_0, cost_1, sink_div)
    print(eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta)
    data = [[str(emb_0),str(emb_0), str(cost_0), str(cost_1), sink_div, eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta, beta_from_fev_beta, delta_from_ngn3_low]]
    
    return pd.DataFrame(data=data, columns=["emb_0", "emb_1", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta", "delta_from_ngn3_low"])
                              
                                      

CUDA backend failed to initialize: Found CUDA version 12020, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [6]:
tp_reference = TemporalProblem.load("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/OT_encodrine_analysis/TemporalProblem.pkl")


In [7]:
metrics_early = pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta", "delta_from_ngn3_low"])
metrics_late= pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta", "delta_from_ngn3_low"])

In [8]:
order_cell_types = list(tp_reference.adata.obs["cell_type"].cat.categories)

In [9]:
outer_tmat_early = np.ones(tp_reference[14.5, 15.5].shape)
outer_tmat_late = np.ones(tp_reference[15.5, 16.5].shape)

outer_tmat_early /= (outer_tmat_early.shape[0]*outer_tmat_early.shape[1])
outer_tmat_late /= (outer_tmat_late.shape[0]*outer_tmat_late.shape[1])

In [10]:
reference_tmap_early = tp_reference.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
reference_tmap_late = tp_reference.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)

tp = TemporalProblem(tp_reference.adata)
tp = tp.prepare("time", joint_attr="X_pca")

tp[14.5, 15.5].set_solution(outer_tmat_early)
tp[15.5, 16.5].set_solution(outer_tmat_late)


df_early = tp.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
df_late = tp.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
metrics_early = pd.concat((metrics_early, compute_metrics(reference_tmap_early, df_early,"ref", "ref", "ref", "ref")))
metrics_late = pd.concat((metrics_late, compute_metrics(reference_tmap_late, df_late, "ref", "ref", "ref", "ref")))

ref ref ref ref 6.8868113
0.034236504935937835 0.025204788909892875 0.03423650493593782 0.02520478890989288
ref ref ref ref 7.550257
0.019093539054966253 0.025458052073288332 0.01909353905496625 0.025458052073288336


In [11]:
metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early_outer.csv"))
metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late_outer.csv"))