In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:
import muon
import ott
import functools
import logging
import typing as t

import anndata as ad
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import scanpy as sc
import scipy.sparse as sp
import scipy.stats as ss
from ott.geometry import costs, geometry, pointcloud
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from sklearn import metrics, model_selection
from ott.geometry import costs as sparse_costs

sc.set_figure_params(scanpy=True, dpi=80, dpi_save=200)

import mplscience

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams['legend.scatterpoints'] = 1 

  from .autonotebook import tqdm as notebook_tqdm


['default', 'despine']


In [3]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")



In [4]:
endocrine_celltypes = [
    "Ngn3 low",
    "Ngn3 high",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon"
]

In [5]:
mudata = mudata[mudata.obs["cell_type"].isin(endocrine_celltypes)]

In [6]:
def adapt_time(x):
    if x["stage"]=="E14.5":
        return 14.5
    if x["stage"]=="E15.5":
        return 15.5
    if x["stage"]=="E16.5":
        return 16.5
    raise ValueError

mudata.obs['time'] = mudata.obs.apply(adapt_time, axis=1).astype("category")

  mudata.obs['time'] = mudata.obs.apply(adapt_time, axis=1).astype("category")


In [7]:
adata=mudata["rna"]


In [8]:
adata.obs["cell_type_refined"] = mudata.obs["cell_type_refined"]
adata.obsm["X_umap"] = mudata.obsm["X_umap"]
adata.obs["time"] = mudata.obs["time"] 

In [9]:
adata.X = adata.layers["raw_counts"]

In [10]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)



In [11]:
sc.pp.highly_variable_genes(adata, inplace=True, subset=True, n_top_genes=5000)

  disp_grouped = df.groupby("mean_bin")["dispersions"]


In [12]:
gex_early = adata[adata.obs["time"]==15.5].X.A
gex_late = adata[adata.obs["time"]==16.5].X.A

In [13]:
x=gex_early
y=gex_late

In [14]:
solver = jax.jit(sinkhorn.Sinkhorn())


def entropic_map(x, y, cost_fn: costs.TICost) -> jnp.ndarray:
    geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, batch_size=1)
    output = solver(linear_problem.LinearProblem(geom))
    dual_potentials = output.to_dual_potentials()
    return dual_potentials.transport

In [15]:
map_l1 = entropic_map(x, y, costs.ElasticL1(scaling_reg=20.0))

In [None]:
from tqdm import tqdm
batch_size=100
push_forward = []
for i in tqdm(range(0, len(x), batch_size)):
    push_forward.append(map_l1(x[i:i+batch_size]))

  0%|                                                                                                | 0/52 [00:00<?, ?it/s]

In [None]:
ls_flattened=[]
for el in push_forward:
    ls_flattened.extend(el)

In [None]:
res=np.asarray(ls_flattened)

In [None]:
gene_mask = (np.abs(res-x) > 1e-6)

In [None]:
adata_15 = adata[adata.obs["time"].isin((15.5, ))]

In [None]:
adata_15.layers["sparse_mask"] = gene_mask

In [None]:
import os
output_dir = "/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/sparse_monge"
adata_15.write(os.path.join(output_dir, "adata_15_with_sparse_mask_balanced_5000_hvg.h5ad"))

In [None]:
adata_15.var["sparse_counts"] = adata_15.layers["sparse_mask"].sum(axis=0)
adata_15.obs["sparse_counts"] = adata_15.layers["sparse_mask"].sum(axis=1)

In [None]:
adata_15.var["sparse_counts"].sort_values()[-5:]/len(adata_15)

In [None]:
sc.pl.umap(adata_15, color='sparse_counts')

In [None]:
res_dict = {}
res_dict_2 = {}
for cell_type in endocrine_celltypes:
    adata_tmp = adata_15[adata_15.obs["cell_type"]==cell_type]
    adata_tmp.var["sparse_counts"] = adata_tmp.layers["sparse_mask"].sum(axis=0)
    adata_tmp.obs["sparse_counts"] = adata_tmp.layers["sparse_mask"].sum(axis=1)
    res_dict[cell_type] = adata_tmp.var["sparse_counts"].sort_values()[-5:]/len(adata_tmp)
    res_dict_2[cell_type] = (adata_tmp.var["sparse_counts"]/len(adata_tmp)).var()

In [None]:
res_dict

In [None]:
res_dict_2

In [None]:
import pandas as pd
df = pd.DataFrame.from_dict(res_dict_2, orient="index")#.sort_values("score")

In [None]:

df.sort_values(0)

In [None]:
sc.pp.pca(adata_15)
sc.pp.neighbors(adata_15, use_rep="X_pca", n_neighbors=50)

In [None]:
def jaccard_similarity(set1, set2):
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union

js = []
for i in tqdm(range(len(adata_15))):
    neighbors = (adata_15.obsp["distances"][i]>0).T
    adata_cell = adata_15[i]
    adata_tmp = adata_15[list(neighbors.A[:,0])]   
    adata_tmp.var["sparse_counts"] = adata_tmp.layers["sparse_mask"].sum(axis=0)
    genes_neighborhood = adata_tmp.var["sparse_counts"][adata_tmp.var["sparse_counts"]>0].index
    genes_cell = adata_cell.var["sparse_counts"][adata_cell.var["sparse_counts"]>0].index
    js.append(jaccard_similarity(set(genes_neighborhood), set(genes_cell)))
    
    

In [None]:
adata_15.obs["js_inv"] = 1-np.asarray(js)

In [None]:
sc.pl.umap(adata_15, color="js_inv")

In [None]:
adata_15.obs[["cell_type_refined", "js_inv"]].groupby("cell_type_refined").mean().sort_values("js_inv")