In [None]:
import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
from moscot.problems.time._lineage import TemporalProblem

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
sc.settings.verbosity = 0

In [None]:
Path="/home/icb/manuel.gander/mouse_atlas/data"
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']

In [None]:
def basic_PCA_preprocessing(adata):
    adata.X=adata.raw.X.copy()
    sc.pp.filter_genes(adata, min_cells=3)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    sc.tl.pca(adata, svd_solver='arpack', use_highly_variable=True, n_comps=30)
    return(adata.obsm['X_pca'])

In [None]:
cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
Cell_number_dict={}
for i in range(20):
    Cell_number_dict[ts[i]]=cells[i]

In [None]:
for i in range(19):
    eps=0.005
    tau1s=[0.975,0.985,0.97,0.97,0.98,0.9955,0.94,0.96,0.972,0.95,0.975,0.99,0.985,0.98,0.985,0.895,0.935,0.95,0.94]
    print(tau1s[i])
    lam1s=[t*eps/(1-t) for t in tau1s]
    lam2=100
    lam1=lam1s[i]

    tau1=lam1/(lam1+eps)
    tau2=lam2/(lam2+eps)
    
    
    ts0=ts[i]
    ts1=ts[i+1]
    print(f'{ts0}_{ts1}')

    adata=sc.read(f"{Path}/Comb_anndatas/adata_{ts0}_{ts1}.h5ad")
    del adata.raw
    adata.obs['day']=adata.obs['day'].astype('category')
    
    # Load scVI-representation
    scVI_repr=pd.read_pickle(f'{Path}/scVI_Representations/{ts0}_{ts1}_scVI.pkl')
    if list(scVI_repr.index)==list(adata.obs.index):
        adata.obsm['X_scVI']=scVI_repr.values

    
    tp=TemporalProblem(adata)
    if i!=13:
        tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
    tp = tp.prepare('day', joint_attr=f'X_pca')

    result=tp.solve(batch_size=3*10**3, epsilon=eps, tau_a=tau1, tau_b=tau2)
    
    
    # Check apoptosis rate
    gr=tp[(day0, day1)].solution.a
    gr=gr/gr.mean()*Cell_number_dict[ts1]/Cell_number_dict[ts0]
    cell_dying=np.sum((1-gr[gr<1]))
    apoptosis_rate=float(cell_dying/len(gr))    
    print(f'{ts0}:{apoptosis_rate}')
    
    
    time_tuple=list(tp.solutions.keys())[0]
    day0=time_tuple[0]
    day1=time_tuple[1]
    A0=adata[adata.obs['day']==day0].copy()
    A1=adata[adata.obs['day']==day1].copy()
    cell_states0={'cell_state': list(set(A0.obs['cell_state']))}
    cell_states1={'cell_state': list(set(A1.obs['cell_state']))}
    CT=tp.cell_transition(day0, day1, cell_states0, cell_states1)
    CT.to_pickle(f'{Path}/moscot_maps/scVI/CTs/{ts0}_{ts1}_cell_type_transitions.pkl')