In [1]:
import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy
from moscot.problems.time._lineage import TemporalProblem
import sys
sys.path.append('/home/mgander/mouse_atlas/Utils')
import c2


Path="/home/mgander/mouse_atlas/data/"
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']

# These cell numbers where determined by each lab that was performing the sequencing of the embryos
# I took these numbers from http://tome.gs.washington.edu/ 

cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
Cell_number_dict={}
for i in range(20):
    Cell_number_dict[ts[i]]=cells[i]

In [2]:
# For some we also want to have the pull
D_pull_population={7:'E7.25:Definitive endoderm', 10:'E8:Allantois',
                   11:'E8.25:First heart field', 16:'E11.5:Pancreatic epithelium'}

In [3]:
# These are the tau1's I use
taus=[0.99, 0.98, 0.99, 0.9, 0.95, 0.95, 0.6, 0.8, 0.92, 0.92, 0.87, 0.95, 0.93, 0.95, 0.98, 0.65, 0.8, 0.87, 0.88]

In [10]:
i=11
ts0=ts[i]
ts1=ts[i+1]
print(f'{ts0}_{ts1}')
print('------------------------')

adata=sc.read(f"{Path}/Comb_anndatas/adata_{ts0}_{ts1}.h5ad")
# Raw count matrix not needed here, but causes problems in "score_genes_for_marginals"
del adata.raw
adata.obs['day']=adata.obs['day'].astype('category')
day0,day1=sorted(set(adata.obs['day']))


if 14>i>4:
    ExE_cell_types=['Embryonic visceral endoderm', 'Extraembryonic visceral endoderm', 'Parietal endoderm', 'Extraembryonic ectoderm', 'Primitive erythroid cells', 'Blood progenitors']
elif i>4:
    ExE_cell_types=['Extraembryonic visceral endoderm']
else:
    ExE_cell_types=[]
adata=adata[~adata.obs['cell_type'].isin(ExE_cell_types)].copy()


tp=TemporalProblem(adata)
#tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
tp = tp.prepare('day', joint_attr=f'X_pcaS')

batch_size=25*10**2
eps=0.005
tau1=taus[i]
tau2=0.99995

tp.solve(batch_size=batch_size, epsilon=eps, tau_a=tau1, tau_b=tau2, scale_cost="mean", max_iterations=10**5)

E8.0_E8.25
------------------------
Only considering the two last: ['.25', '.h5ad'].
Only considering the two last: ['.25', '.h5ad'].




[34mINFO    [0m Ordering [1;35mIndex[0m[1m([0m[1m[[0m[32m'cell_30635'[0m, [32m'cell_30636'[0m, [32m'cell_30638'[0m, [32m'cell_30639'[0m, [32m'cell_30642'[0m,                     
                [32m'cell_30647'[0m, [32m'cell_30652'[0m, [32m'cell_30654'[0m, [32m'cell_30655'[0m, [32m'cell_30656'[0m,                              
                [33m...[0m                                                                                                
                [32m'cell_95713'[0m, [32m'cell_95715'[0m, [32m'cell_95716'[0m, [32m'cell_95717'[0m, [32m'cell_95718'[0m,                              
                [32m'cell_95719'[0m, [32m'cell_95721'[0m, [32m'cell_95722'[0m, [32m'cell_95723'[0m, [32m'cell_95725'[0m[1m][0m,                             
               [33mdtype[0m=[32m'object'[0m, [33mlength[0m=[1;36m25041[0m[1m)[0m in ascending order.                                                   


  if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)):


[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m11674[0m, [1;36m13367[0m[1m)[0m[1m][0m.                                


TemporalProblem[(8.0, 8.25)]

In [11]:
gr=tp[(day0, day1)].solution.a 
gr=gr/gr.mean()*Cell_number_dict[ts1]/Cell_number_dict[ts0]
cell_dying=np.sum((1-gr[gr<1]))
apoptosis_rate=float(cell_dying/len(gr))
df_gr=pd.DataFrame(data=gr, index=adata.obs[adata.obs['day']==day0].index, columns=['trained_growth_rate'])
df_gr.to_pickle(f'{Path}/moscot_maps/growth_rates_and_pulls/{ts0}_{ts1}_growth_rates.pkl')


#cell_states0={'cell_state': list(set(adata[adata.obs['day']==day0].obs['cell_state']))}
#cell_states1={'cell_state': list(set(adata[adata.obs['day']==day1].obs['cell_state']))}
#CT=tp.cell_transition(day0, day1, cell_states0, cell_states1)
#df_curated=c2.evaluate_using_curated_transitions(CT)
#df_germ=c2.evaluate_using_germ_layers(CT)
print(apoptosis_rate)

0.04827921465039253


In [12]:
if i in list(D_pull_population.keys()):
    tp.pull(day0, day1, data='cell_state', subset=D_pull_population[i])
    df_pull=adata.obs['pull']
    df_pull.to_pickle(f'{Path}/moscot_maps/growth_rates_and_pulls/{ts0}_{D_pull_population[i]}_pulls.pkl')