In [1]:
import jax
from jax import config
config.update("jax_enable_x64", True)

  from jax.config import config


In [2]:
import scanpy as sc
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import moscot
from moscot.problems.time import TemporalProblem
import moscot.plotting as mpl
import pandas as pd
import os
import muon
from ott.geometry import pointcloud
from sklearn.preprocessing import StandardScaler
import networkx as nx
import itertools
import anndata
from mudata import MuData
import jax.numpy as jnp
from typing import Dict, Tuple
from ott import tools
from tqdm import tqdm
import jax
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=200)
                         
import mplscience

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams['legend.scatterpoints'] = 1 

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


['default', 'despine']


In [3]:
output_dir = "/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/stability_analysis"

In [4]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")

endocrine_celltypes = [
    "Ngn3 low",
    "Ngn3 high",
    "Ngn3 high cycling",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon"
]

mudata = mudata[mudata.obs["cell_type"].isin(endocrine_celltypes)].copy()

In [5]:
mudata.obsm

MuAxisArrays with keys: X_MultiVI, X_umap, atac, rna

In [6]:
RNA_emb = [None, "X_pca", "X_scVI"]
ATAC_emb = [None, "X_lsi", "X_poissonvi"]



In [7]:
EMBEDDINGS = ["X_MultiVI"]

In [8]:
COSTS = [("geodesic", 5)]

In [9]:
EMB = "embedding"

In [10]:
def adapt_time(x):
        if x["stage"]=="E14.5":
            return 14.5
        if x["stage"]=="E15.5":
            return 15.5
        if x["stage"]=="E16.5":
            return 16.5
        raise ValueError
    
def create_adata(mudata: MuData, embedding: str) -> anndata.AnnData:

    adata = mudata["rna"]
    adata.obs["cell_type_refined"] = mudata.obs["cell_type_refined"]
    adata.obs['time'] = adata.obs.apply(adapt_time, axis=1).astype("category")
    if embedding == "X_MultiVI":
        adata.obsm[EMB] = mudata.obsm[embedding].copy()
    elif len(embedding)==2:
        rna_embedding, atac_embedding = embedding
        if rna_embedding is not None:
            rna_emb = mudata["rna"].obsm[rna_embedding]
            rna_emb_scaled = StandardScaler().fit_transform(rna_emb)
        if atac_embedding is not None:
            atac_emb = mudata["atac"].obsm[atac_embedding]
            atac_emb_scaled = StandardScaler().fit_transform(atac_emb)
        if rna_embedding is not None and atac_embedding is not None:
            emb = np.concatenate((rna_emb, atac_emb), axis=1)
        elif rna_embedding is None and atac_embedding is not None:
            emb = atac_emb
        elif rna_embedding is not None and atac_embedding is None:
            emb = rna_emb
        else:
            raise NotImplementedError
    
        adata.obsm[EMB] = emb
        return adata
    raise NotImplementedError


def create_graphs(adata: anndata.AnnData, n_neighbors: int) -> Dict[Tuple, pd.DataFrame]:
    dfs = {}
    batch_column = "time"
    unique_batches = [14.5, 15.5, 16.5]
    for i in range(len(unique_batches) - 1):
        batch1 = unique_batches[i]
        batch2 = unique_batches[i + 1]
    
        indices = np.where((adata.obs[batch_column] == batch1) | (adata.obs[batch_column] == batch2))[0]
        adata_subset = adata[indices]
        sc.pp.neighbors(adata_subset, use_rep=EMB, n_neighbors=n_neighbors)
        G = nx.from_numpy_array(adata_subset.obsp["connectivities"].A)
        assert nx.is_connected(G)
    
        dfs[(batch1, batch2)] = (
            pd.DataFrame(
                index=adata_subset.obs_names,
                columns=adata_subset.obs_names,
                data=adata_subset.obsp["connectivities"].A.astype("float"),
            )
        )
    return dfs

In [11]:
cm = jnp.ones((144, 144)) - jnp.eye(144)

def compute_metrics(df_reference: jax.Array, df: pd.DataFrame, emb: str, cost_0: str, cost_1: str) -> pd.DataFrame:
    
    sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(geometry.Geometry, cost_matrix=(cm,cm,cm), a=df_reference.values.flatten(), b=df.values.flatten(), epsilon=1e-3).divergence
    eps_from_eps_prog = df.loc["Eps. progenitors", "Epsilon"]
    delta_from_fev_delta = df.loc["Delta", "Fev+ Delta"]
    fev_delta_from_eps_prog = df.loc["Eps. progenitors", "Fev+ Delta"]
    eps_from_fev_delta = df.loc["Fev+ Delta", "Epsilon"]
    print(emb, cost_0, cost_1, sink_div)
    print(eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta)
    data = [[str(emb), str(cost_0), str(cost_1), sink_div, eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta]]
    
    return pd.DataFrame(data=data, columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])
                            
                                      

CUDA backend failed to initialize: Found CUDA version 12020, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [11]:
#tp_reference = TemporalProblem.load("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/OT_encodrine_analysis/TemporalProblem.pkl")


In [12]:
metrics_early = pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])
metrics_late= pd.DataFrame(columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])

In [13]:
order_cell_types = list(adata.obs["cell_type"].cat.categories)

NameError: name 'adata' is not defined

In [18]:
emb_1 = None#"X_pca"#arguments[1] if arguments[1] != "None" else None
emb_2 = None#"X_poissonvi" #None#arguments[2] if arguments[2] != "None" else None
cost_1 = "geodesic"#arguments[3] if arguments[3] != "None" else None
cost_2 = 30#int(arguments[4]) if arguments[4] != "None" else None

EMB = "embedding"

def adapt_time(x):
        if x["stage"]=="E14.5":
            return 14.5
        if x["stage"]=="E15.5":
            return 15.5
        if x["stage"]=="E16.5":
            return 16.5
        raise ValueError
    
def create_adata(mudata: MuData, rna_embedding: str, atac_embedding: str) -> anndata.AnnData:
    print("rna_emb", rna_embedding)
    print("atac_emb", atac_embedding)
    adata = mudata["rna"]
    adata.obs["cell_type_refined"] = mudata.obs["cell_type_refined"]
    adata.obs['time'] = adata.obs.apply(adapt_time, axis=1).astype("category")
    if rna_embedding == "X_MultiVI":
        adata.obsm[EMB] = mudata.obsm[rna_embedding].copy()
        return adata
    else:
        if rna_embedding is not None:
            rna_emb = mudata["rna"].obsm[rna_embedding]
            rna_emb_scaled = StandardScaler().fit_transform(rna_emb)
        if atac_embedding is not None:
            atac_emb = mudata["atac"].obsm[atac_embedding]
            atac_emb_scaled = StandardScaler().fit_transform(atac_emb)
        if rna_embedding is not None and atac_embedding is not None:
            emb = np.concatenate((rna_emb_scaled, atac_emb_scaled), axis=1)
        elif rna_embedding is None and atac_embedding is not None:
            emb = atac_emb
        elif rna_embedding is not None and atac_embedding is None:
            emb = rna_emb
        else:
            raise NotImplementedError
    
        adata.obsm[EMB] = emb
        return adata
    raise NotImplementedError


def create_graphs(adata: anndata.AnnData, n_neighbors: int) -> Dict[Tuple, pd.DataFrame]:
    dfs = {}
    batch_column = "time"
    unique_batches = [14.5, 15.5, 16.5]
    for i in range(len(unique_batches) - 1):
        batch1 = unique_batches[i]
        batch2 = unique_batches[i + 1]
    
        indices = np.where((adata.obs[batch_column] == batch1) | (adata.obs[batch_column] == batch2))[0]
        adata_subset = adata[indices]
        sc.pp.neighbors(adata_subset, use_rep=EMB, n_neighbors=n_neighbors)
        G = nx.from_numpy_array(adata_subset.obsp["connectivities"].A)
        assert nx.is_connected(G)
    
        dfs[(batch1, batch2)] = (
            pd.DataFrame(
                index=adata_subset.obs_names,
                columns=adata_subset.obs_names,
                data=adata_subset.obsp["connectivities"].A.astype("float"),
            )
        )
    return dfs

cm = jnp.ones((144, 144)) - jnp.eye(144)

def compute_metrics(df_reference: jax.Array, df: pd.DataFrame, emb_0: str, emb_1: str, cost_0: str, cost_1: str) -> pd.DataFrame:
    
    sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(geometry.Geometry, cost_matrix=(cm,cm,cm), a=df_reference.values.flatten(), b=df.values.flatten(), epsilon=1e-3).divergence
    eps_from_eps_prog = df.loc["Eps. progenitors", "Epsilon"]
    delta_from_fev_delta = df.loc["Fev+ Delta", "Delta"]
    fev_delta_from_eps_prog = df.loc["Eps. progenitors", "Fev+ Delta"]
    eps_from_fev_delta = df.loc["Fev+ Delta", "Epsilon"]
    beta_from_fev_beta = df.loc["Fev+ Beta", "Beta"]
    print(emb_0 , emb_1, cost_0, cost_1, sink_div)
    print(eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta)
    data = [[str(emb_0),str(emb_0), str(cost_0), str(cost_1), sink_div, eps_from_eps_prog, delta_from_fev_delta, fev_delta_from_eps_prog, eps_from_fev_delta, beta_from_fev_beta]]
    
    return pd.DataFrame(data=data, columns=["emb_0", "emb_1", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta"])
                            
order_cell_types = endocrine_celltypes


#tp_reference = TemporalProblem.load("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/plots/OT_encodrine_analysis/TemporalProblem.pkl")
metrics_early = pd.DataFrame(columns=["emb_0", "emb_1", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta"])
metrics_late= pd.DataFrame(columns=["emb_0", "emb_1", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_fev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta", "beta_from_fev_beta"])
            

#reference_tmap_early = tp_reference.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
#reference_tmap_late = tp_reference.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)

adata = create_adata(mudata, emb_1, emb_2)

tp = TemporalProblem(adata)
tp = tp.prepare("time", joint_attr=EMB, cost=cost_1 if cost_1!="geodesic" else "sq_euclidean")

if cost_1=="geodesic":
    dfs = create_graphs(adata, cost_2)
    tp[14.5, 15.5].set_graph_xy((dfs[14.5, 15.5]).astype("float"), t=100.0)
    tp[15.5, 16.5].set_graph_xy((dfs[15.5, 16.5]).astype("float"), t=100.0)

tp[15.5, 16.5].solve(max_iterations=1e7, threshold=5e-3, lse_mode=False if cost_1=="geodesic" else True)

#df_early = tp.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
#df_late = tp.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
#metrics_early = pd.concat((metrics_early, compute_metrics(reference_tmap_early, df_early,emb_1, emb_2,  str(cost_1), str(cost_2))))
#metrics_late = pd.concat((metrics_late, compute_metrics(reference_tmap_late, df_late, emb_1, emb_2, str(cost_1), str(cost_2))))

#metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early_{emb_1}_{emb_2}_{cost_1}_{cost_2}.csv"))
#metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late_{emb_1}_{emb_2}_{cost_1}_{cost_2}.csv"))


rna_emb None
atac_emb None


NotImplementedError: 

In [15]:
tp[15.5, 16.5].solution.transport_matrix.sum(axis=0)

Array([0.00058688, 0.00058828, 0.00059004, ..., 0.00058786, 0.00059012,
       0.0005872 ], dtype=float64)

In [17]:
tp.cell_transition(15.5, 16.5, "cell_type", "cell_type")

Unnamed: 0,Alpha,Beta,Delta,Eps. progenitors,Epsilon,Fev+,Fev+ Alpha,Fev+ Beta,Fev+ Delta,Ngn3 high,Ngn3 high cycling,Ngn3 low
Alpha,0.420203,0.088592,0.057062,0.001652,0.126432,0.027327,0.142663,0.088995,0.010205,0.001095,0.000159,1.880891e-06
Beta,0.06001,0.355942,0.056873,9.7e-05,0.010282,0.017772,0.037434,0.154694,0.00276,0.000411,3.8e-05,3.048524e-07
Delta,0.005068,0.004137,0.368122,0.000193,0.002461,0.001505,0.002508,0.00483,0.004495,0.000215,3.2e-05,9.658101e-07
Eps. progenitors,0.005698,0.001217,0.008777,0.064053,0.072119,0.030769,0.013935,0.007578,0.074012,0.043329,0.035688,0.007123854
Epsilon,0.072347,0.008999,0.024386,0.01981,0.162071,0.007469,0.027521,0.013687,0.037053,0.002321,0.000804,5.064402e-05
Fev+,0.122321,0.117201,0.11889,0.012501,0.054899,0.1688,0.225395,0.204165,0.041209,0.027845,0.012048,0.0004190892
Fev+ Alpha,0.058518,0.024189,0.020325,0.000892,0.020315,0.023791,0.05806,0.038678,0.005335,0.002208,0.000693,1.564537e-05
Fev+ Beta,0.142342,0.344473,0.154926,0.001014,0.037775,0.069493,0.134014,0.259702,0.012986,0.003532,0.000718,1.016737e-05
Fev+ Delta,0.024009,0.006082,0.049992,0.055124,0.14323,0.030655,0.027246,0.016896,0.085861,0.025766,0.017354,0.002088319
Ngn3 high,0.08351,0.046702,0.130534,0.551027,0.296181,0.533752,0.302945,0.194555,0.543523,0.558853,0.490277,0.1992481


In [31]:
tp[15.5, 16.5].solution._output.errors.max()

Array(0.99999994, dtype=float32)

In [14]:
#reference_tmap_early = tp_reference.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
#reference_tmap_late = tp_reference.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
i=0
for emb in tqdm(EMBEDDINGS):
    if emb == (None, None):
        continue
    adata = create_adata(mudata, emb)
    for cost in COSTS:
    
        tp = TemporalProblem(adata)
        tp = tp.prepare("time", joint_attr=EMB, cost=cost[0] if cost[0]!="geodesic" else "sq_euclidean")

        if cost[0]=="geodesic":
            dfs = create_graphs(adata, cost[1])
            tp[14.5, 15.5].set_graph_xy((dfs[14.5, 15.5]).astype("float"), t=100.0)
            tp[15.5, 16.5].set_graph_xy((dfs[15.5, 16.5]).astype("float"), t=100.0)
        
        tp = tp.solve(lse_mode=False if cost[0]=="geodesic" else True, max_iterations=1e7, device='cpu')

        #df_early = tp.cell_transition(14.5, 15.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
        #df_late = tp.cell_transition(15.5, 16.5, {"cell_type": order_cell_types}, {"cell_type": order_cell_types}, forward=False)
        #metrics_early = pd.concat((metrics_early, compute_metrics(reference_tmap_early, df_early, str(cost[0]), str(cost[1]), str(emb))))
        #metrics_late = pd.concat((metrics_late, compute_metrics(reference_tmap_late, df_late, str(cost[0]), str(cost[1]), str(emb))))

        i+=1

  0%|          | 0/1 [00:00<?, ?it/s]


NotImplementedError: 

In [None]:
#metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early.csv"))
#metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late.csv"))

In [None]:
cost[0], cost[1], str(emb)

In [None]:
data = np.ones((8,1)).T

In [None]:
pd.DataFrame(data=data, columns=["emb", "cost_0", "cost_1", "sink_div", "eps_from_eps_prog", "delta_from_dev_delta", "fev_delta_from_eps_prog", "eps_from_fev_delta"])

In [None]:
out

In [None]:
reference_tmap_early.shape

In [None]:
reference_tmap_early.values.flatten()[:,None].shape

In [None]:
df_early.values.flatten()[:,None].shape

In [None]:
a = np.array([[1,2], [23,3]])
b = np.array([[13,1], [21,3]])

In [None]:
sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud, x=a, y=b, epsilon=1e-1)
    

In [None]:
sink_div = tools.sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud, pointcloud.PointCloud(reference_tmap_early.values.flatten()[:,None]), pointcloud.PointCloud(df_early.values.flatten()[:,None]), epsilon=1e-1)
    

In [None]:
metrics_early.to_csv(os.path.join(output_dir, f"stability_metrics_early.csv"))
metrics_late.to_csv(os.path.join(output_dir, f"stability_metrics_late.csv"))