In [1]:
import scanpy as sc
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import moscot.plotting as mpl
import pandas as pd
import os
import muon
from ott.geometry import pointcloud
from sklearn.preprocessing import StandardScaler
import networkx as nx
import itertools
import anndata
from mudata import MuData
import jax.numpy as jnp
from typing import Dict, Tuple
from ott import tools
from tqdm import tqdm
import jax
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=200)
                         
import mplscience

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams['legend.scatterpoints'] = 1 

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


['default', 'despine']


In [2]:
output_dir = "/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/stability_analysis"

In [3]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")

endocrine_celltypes = [
    "Ngn3 low",
    "Ngn3 high",
    "Ngn3 high cycling",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon"
]

mudata = mudata[mudata.obs["cell_type"].isin(endocrine_celltypes)].copy()

In [4]:
mudata.obsm

MuAxisArrays with keys: X_MultiVI, X_umap, atac, rna

In [5]:
RNA_emb = [None, "X_pca", "X_scVI"]
ATAC_emb = [None, "X_lsi", "X_poissonvi"]



In [6]:
EMBEDDINGS = list(itertools.product(RNA_emb, ATAC_emb)) + ["X_MultiVI"]

In [7]:
COSTS = [("sq_euclidean", None), ("cosine", None), ("geodesic", 5), ("geodesic", 10), ("geodesic", 30)]

In [8]:
EMB = "embedding"

In [9]:
def adapt_time(x):
        if x["stage"]=="E14.5":
            return 14.5
        if x["stage"]=="E15.5":
            return 15.5
        if x["stage"]=="E16.5":
            return 16.5
        raise ValueError
    
def create_adata(mudata: MuData, embedding: str) -> anndata.AnnData:

    adata = mudata["rna"]
    adata.obs["cell_type_refined"] = mudata.obs["cell_type_refined"]
    adata.obs['time'] = adata.obs.apply(adapt_time, axis=1).astype("category")
    if embedding == "X_MultiVI":
        adata.obsm[EMB] = mudata.obsm[embedding].copy()
    elif len(embedding)==2:
        rna_embedding, atac_embedding = embedding
        if rna_embedding is not None:
            rna_emb = mudata["rna"].obsm[rna_embedding]
            rna_emb_scaled = StandardScaler().fit_transform(rna_emb)
        if atac_embedding is not None:
            atac_emb = mudata["atac"].obsm[atac_embedding]
            atac_emb_scaled = StandardScaler().fit_transform(atac_emb)
        if rna_embedding is not None and atac_embedding is not None:
            emb = np.concatenate((rna_emb, atac_emb), axis=1)
        elif rna_embedding is None and atac_embedding is not None:
            emb = atac_emb
        elif rna_embedding is not None and atac_embedding is None:
            emb = rna_emb
        else:
            raise NotImplementedError
    
        adata.obsm[EMB] = emb
        return adata
    raise NotImplementedError


def create_graphs(adata: anndata.AnnData, n_neighbors: int) -> Dict[Tuple, pd.DataFrame]:
    dfs = {}
    batch_column = "time"
    unique_batches = [14.5, 15.5, 16.5]
    for i in range(len(unique_batches) - 1):
        batch1 = unique_batches[i]
        batch2 = unique_batches[i + 1]
    
        indices = np.where((adata.obs[batch_column] == batch1) | (adata.obs[batch_column] == batch2))[0]
        adata_subset = adata[indices]
        sc.pp.neighbors(adata_subset, use_rep=EMB, n_neighbors=n_neighbors)
        G = nx.from_numpy_array(adata_subset.obsp["connectivities"].A)
        assert nx.is_connected(G)
    
        dfs[(batch1, batch2)] = (
            pd.DataFrame(
                index=adata_subset.obs_names,
                columns=adata_subset.obs_names,
                data=adata_subset.obsp["connectivities"].A.astype("float"),
            )
        )
    return dfs

In [10]:
cm = jnp.ones((144, 144)) - jnp.eye(144)

def compute_metrics(df_reference: jax.Array, df: pd.DataFrame, emb: str, cost_0: str, cost_1: str) -> pd.DataFrame:
    
    sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(geometry.Geometry, cost_matrix=(cm,cm,cm), a=df_reference.values.flatten(), b=df.values.flatten(), epsilon=1e-3).divergence
    eps_from_eps_prog = df.loc["Eps. progenitors", "Epsilon"]
    delta_from_fev_delta = df.loc["Delta", "Fev+ Delta"]
    fev_delta_from_eps_prog = df.loc["Eps. progenitors", "Fev+ Delta"]
    eps_from_fev_delta = df.loc["Fev+ Delta", "Epsilon"]
    print(emb, cost_0, cost_1, sink_div)
    print(eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta)
    data = [[str(emb), str(cost_0), str(cost_1), sink_div, eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta]]
    
    return pd.DataFrame(data=data, columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])
                            
                                      

In [11]:
tp_reference = TemporalProblem.load("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/OT_encodrine_analysis/TemporalProblem.pkl")


In [12]:
metrics_early = pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])
metrics_late= pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])

In [None]:
order_cell_types = list(adata.obs["cell_type"].cat.categories)

In [None]:
reference_tmap_early = tp_reference.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
reference_tmap_late = tp_reference.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
i=0
for emb in tqdm(EMBEDDINGS):
    if emb == (None, None):
        continue
    adata = create_adata(mudata, emb)
    for cost in COSTS:
    
        tp = TemporalProblem(adata)
        tp = tp.prepare("time", joint_attr=EMB, cost=cost[0] if cost[0]!="geodesic" else "sq_euclidean")

        if cost[0]=="geodesic":
            dfs = create_graphs(adata, cost[1])
            tp[14.5, 15.5].set_graph_xy((dfs[14.5, 15.5]).astype("float"), t=100.0)
            tp[15.5, 16.5].set_graph_xy((dfs[15.5, 16.5]).astype("float"), t=100.0)
        
        tp = tp.solve(lse_mode=False if cost[0]=="geodesic" else True, max_iterations=1e7, device='cpu')

        df_early = tp.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
        df_late = tp.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
        metrics_early = pd.concat((metrics_early, compute_metrics(reference_tmap_early, df_early, str(cost[0]), str(cost[1]), str(emb))))
        metrics_late = pd.concat((metrics_late, compute_metrics(reference_tmap_late, df_late, str(cost[0]), str(cost[1]), str(emb))))

        i+=1

  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(ann

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m5185[0m, [1;36m1699[0m[1m)[0m[1m][0m.                                  
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m4761[0m, [1;36m5185[0m[1m)[0m[1m][0m.                                  


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(ann

sq_euclidean None (None, 'X_lsi') 0.00049286336
0.07388929205126517 0.0011761732505029291 0.24813968657113114 0.03838017574168388


  metrics_early = pd.concat((metrics_early, compute_metrics(reference_tmap_early, df_early, str(cost[0]), str(cost[1]), str(emb))))


sq_euclidean None (None, 'X_lsi') 0.00027394667
0.05067166622634431 0.005210214984703463 0.11008996752062566 0.09194515461326355
[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      


  metrics_late = pd.concat((metrics_late, compute_metrics(reference_tmap_late, df_late, str(cost[0]), str(cost[1]), str(emb))))
  if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)):
  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m5185[0m, [1;36m1699[0m[1m)[0m[1m][0m.                                  
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m4761[0m, [1;36m5185[0m[1m)[0m[1m][0m.                                  


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if no

cosine None (None, 'X_lsi') 0.00087735057
0.056742096841788964 1.162978191893988e-15 0.20367795731499314 0.06458515945300955
cosine None (None, 'X_lsi') 0.0007263534
0.06757517501245093 0.007442131127813969 0.1219223581324375 0.1550180074557546


  if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m5185[0m, [1;36m1699[0m[1m)[0m[1m][0m.                                  


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):




  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m4761[0m, [1;36m5185[0m[1m)[0m[1m][0m.                                  


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True)
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(ann

geodesic 5 (None, 'X_lsi') nan
nan nan nan nan
geodesic 5 (None, 'X_lsi') nan
nan nan nan nan


  if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m5185[0m, [1;36m1699[0m[1m)[0m[1m][0m.                                  


  if not is_categorical_dtype(df_full[k]):
  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m4761[0m, [1;36m5185[0m[1m)[0m[1m][0m.                                  


In [None]:
metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early.csv"))
metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late.csv"))

In [None]:
cost[0], cost[1], str(emb)

In [None]:
data = np.ones((8,1)).T

In [None]:
pd.DataFrame(data=data, columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])

In [None]:
out

In [None]:
reference_tmap_early.shape

In [None]:
reference_tmap_early.values.flatten()[:,None].shape

In [None]:
df_early.values.flatten()[:,None].shape

In [None]:
a = np.array([[1,2], [23,3]])
b = np.array([[13,1], [21,3]])

In [None]:
sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud, x=a, y=b, epsilon=1e-1)
    

In [None]:
sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud, pointcloud.PointCloud(reference_tmap_early.values.flatten()[:,None]), pointcloud.PointCloud(df_early.values.flatten()[:,None]), epsilon=1e-1)
    

In [None]:
metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early.csv"))
metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late.csv"))