In [2]:
import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
import os
import time
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
from moscot.problems.time._lineage import TemporalProblem
import warnings
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
import sys,os
sys.path.append('/home/icb/manuel.gander/mouse_atlas/notebook')
import scripts as scr
import c2
warnings.simplefilter(action='ignore', category=FutureWarning)
sc.settings.verbosity = 0
from tqdm import tqdm

# Load and Aggregate based on SEACells

In [3]:
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']
cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
Cell_number_dict={ts[i]:cells[i] for i in range(20)}

In [4]:
adata = sc.read_h5ad('E105_E115_metacell_aggregated_anndata.h5ad')
adata.obs['day']=adata.obs['day'].astype('category')

  utils.warn_names_duplicates("obs")


# Run moscot

In [47]:
ts0 = 'E10.5'
ts1 = 'E11.5'

tau1 = 0.8
tau2=0.99995
epsilon = 0.005

In [48]:
day0,day1=sorted(set(adata.obs['day']))
inds0=list(adata[adata.obs['day']==day0].obs.index)
inds1=list(adata[adata.obs['day']==day1].obs.index)

In [49]:
# we need to adjust the weights of the metacells based on the number of cells that belong to the metacell

ob0 = pd.read_pickle(f'E105_metacells.pkl')
ob1 = pd.read_pickle(f'E115_metacells.pkl')

a_mult = np.array([np.sum(ob0['metacell']==a) for a in sorted(set(ob0['metacell']))])
b_mult = np.array([np.sum(ob1['metacell']==a) for a in sorted(set(ob1['metacell']))])

In [50]:
tp=TemporalProblem(adata)
tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
tp = tp.prepare('day', joint_attr=f'X_pcaS')

In [51]:
tp[(10.5, 11.5)]._a = tp[(10.5, 11.5)].a*a_mult/a_mult.sum()
tp[(10.5, 11.5)]._b = tp[(10.5, 11.5)].b*b_mult/b_mult.sum()

In [52]:
result=tp.solve(epsilon=epsilon, tau_a=tau1, tau_b=tau2)
iterations=float(tp[(day0, day1)].solution._output.n_iters)

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m4573[0m, [1;36m7468[0m[1m)[0m[1m][0m.                                  


In [53]:
gr=tp[(day0, day1)].solution.a
gr=gr/gr.mean()*Cell_number_dict[ts1]/Cell_number_dict[ts0]
cell_dying=np.sum((1-gr[gr<1]))
apoptosis_rate=float(cell_dying/len(gr))
apoptosis_rate

0.023928280919790268

In [59]:
M = np.array(tp.solutions[(10.5, 11.5)].transport_matrix)
T = pd.DataFrame(data=M, index=adata[adata.obs['day']==10.5].obs.index, columns=adata[adata.obs['day']==11.5].obs.index)
T.to_pickle('Metacell_T.pkl')

array([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 9.7563513e-10],
       [2.2820572e-22, 0.0000000e+00, 0.0000000e+00, ..., 1.8278701e-11,
        6.2529975e-23, 0.0000000e+00],
       [2.0755801e-21, 0.0000000e+00, 0.0000000e+00, ..., 1.0607004e-15,
        8.6965156e-29, 0.0000000e+00],
       ...,
       [1.1972049e-21, 0.0000000e+00, 0.0000000e+00, ..., 1.1185957e-11,
        3.5382883e-19, 0.0000000e+00],
       [4.9574332e-21, 0.0000000e+00, 0.0000000e+00, ..., 3.0868795e-12,
        1.0455959e-21, 0.0000000e+00],
       [9.7072443e-16, 1.1818551e-40, 1.8809112e-29, ..., 9.1501665e-22,
        3.7228707e-15, 6.6415752e-14]], dtype=float32)