In [1]:
import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy

from moscot.problems.time._lineage import TemporalProblem
from moscot.backends.ott._solver import SinkhornSolver

In [2]:
from importlib import reload

import sys
sys.path.append('/home/icb/manuel.gander/moscotTime_Reproducibility/Notebooks/Python_notebooks')
import Utils

Utils=reload(Utils)

In [3]:
Path="/home/icb/manuel.gander/moscotTime_Reproducibility/Data"
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']

In [4]:
def growth_rates_to_apoptosis_ratio(growth_rates, ts0, ts1):
    
    # I got these cell numbers from http://tome.gs.washington.edu/, and they got them from the experiments or 
    # for E8.5b they estimated it themselves (it was their experiment)
    
    cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
    Cell_number_dict={}
    for i in range(20):
        Cell_number_dict[ts[i]]=cells[i]
    
    cellular_growth_rates=growth_rates*Cell_number_dict[ts1]/Cell_number_dict[ts0]
    apoptotic_cells=cellular_growth_rates[cellular_growth_rates<1]
    sum_apoptotic_cells=(1-apoptotic_cells).sum()
    perc_apoptotic_cells=sum_apoptotic_cells/len(growth_rates)
    
    return(perc_apoptotic_cells)



def lambda_to_growth_rates(adata, epsilon, lam1, score_genes=True):
    time_tuple=tuple(sorted(set(adata.obs['day'])))
    
    tp=TemporalProblem(adata)
    
    # This corresponds to the map from E8.5a to E8.5b, where no actuall time passes
    if time_tuple!=(0,1) and score_genes:
        tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
    tp = tp.prepare('day', joint_attr=f'X_pcaS')
    
    
    if time_tuple[0]>=8.5:
        batch_size=10**5
    else:
        batch_size=10**6

    eps=0.05
    
    lam2=lam1*100
    
    tau1=lam1/(lam1+eps)
    tau2=lam2/(lam2+eps)
    result=tp.solve(batch_size=batch_size, epsilon=eps, tau_a=tau1, tau_b=tau2, scale_cost="mean", max_iterations=10**6)
    
    # ToDo: Tell Dominik to fix posterior_growth_rates, .solutions.a corresponds to transport_matrix.sum(1)
    growth_rates=np.array(result[time_tuple].solution.a)
    growth_rates=growth_rates/np.mean(growth_rates)
    return(growth_rates, result)



def given_apoptosis_rate_find_lam(adata, ts0, ts1, epsilon, ap_min, ap_max, posterior_search=False, lam_init=None):
    # For the big data sets, it is faster to subsample to find the right lambda, and then to compute the full map
    # in the end once. Apoptosis rate typically stay approxiamtely the same

    # Lambdas are in log-scale, i.e. lam=10**x, with x in linear scale
    x_interval=[4, -6]
    xm=-1
    
    # This is only used for the big data sets, where lambda search is done on a subsampled set. The lambdas found in
    # the subsampled case should already correspond to the right apoptosis rates, but in case they do not, find the
    # right lambda in the not subsampled data set
    if posterior_search:
        xm=np.log10(lam)
        x_interval=[xm-2, xm+2]
        

    while(True):
        lam=10**xm

        growth_rates, result=lambda_to_growth_rates(adata, epsilon, lam)
        perc_apop=growth_rates_to_apoptosis_ratio(growth_rates, ts0, ts1)

        if ap_min<=perc_apop<=ap_max:
            # First entry is whether it was sucessful
            return(lam, result)

        elif xm>3 or xm<-5:
            # Not successful
            return(np.NaN)

        else:
            if ap_max<perc_apop:
                x_interval=[x_interval[0], xm]
            elif ap_min>perc_apop:
                x_interval=[xm, x_interval[1]]
            xm=np.mean(x_interval)

In [5]:
def compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max):
    # For the big data set, estimate lambda on a subsampled set to speed up computations
    if len(adata)>100000:
        adatas=sc.pp.subsample(adata, fraction=0.3, copy=True)
    else:
        adatas=adata.copy()
        
    lam, result=given_apoptosis_rate_find_lam(adatas, ts0, ts1, epsilon, ap_min, ap_max)
    
    if len(adatas)!=len(adata):
        lam, result=given_apoptosis_rate_find_lam(adata, ts0, ts1, epsilon, ap_min, ap_max, posterior_serch=True, lam_init=lam)
    
    return(np.isnan(lam), result)

In [None]:
epsilon=0.05
ap_min=0.04
ap_max=0.07

i=8

ts0=ts[i]
ts1=ts[i+1]
print(ts0)
adata=sc.read(f"{Path}/anndatas/adata_{ts0}_{ts1}.h5ad")
del adata.raw
success, result=compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max)

E7.25
Only considering the two last: ['.5', '.h5ad'].
Only considering the two last: ['.5', '.h5ad'].


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m13537[0m, [1;36m10994[0m[1m)[0m[1m][0m.                                


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m13537[0m, [1;36m10994[0m[1m)[0m[1m][0m.                                


In [16]:
# For the big data set, estimate lambda on a subsampled set to speed up computations
if len(adata)>100000:
    adatas=sc.pp.subsample(adata, fraction=0.3, copy=True)
else:
    adatas=adata.copy()

lam, result=given_apoptosis_rate_find_lam(adatas, ts0, ts1, epsilon, ap_min, ap_max)

if len(adatas)!=len(adata):
    lam, result=given_apoptosis_rate_find_lam(adata, ts0, ts1, epsilon, ap_min, ap_max, posterior_serch=True, lam_init=lam)



ValueError: No valid genes were passed for scoring.

In [17]:
adatas

AnnData object with n_obs × n_vars = 24531 × 29452
    obs: 'cellID', 'day', 'cell_state', 'cell_type', 'group', 'sample', 'origin'
    var: 'features', 'gene_names'
    obsm: 'X_pcaS', 'X_umap3'

In [18]:
adata

AnnData object with n_obs × n_vars = 24531 × 29452
    obs: 'cellID', 'day', 'cell_state', 'cell_type', 'group', 'sample', 'origin'
    var: 'features', 'gene_names'
    obsm: 'X_pcaS', 'X_umap3'

In [None]:
epsilon=0.005
ap_min=0.04
ap_max=0.07

for i in range(13,19):
    ts0=ts[i]
    ts1=ts[i+1]
    print(ts0)
    adata=sc.read(f"{Path}/anndatas/adata_{ts0}_{ts1}.h5ad")
    # adata.raw causes problems in gene scoring, and is not needed in any other way
    del adata.raw
        
    success, result=compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max)
    
    if success:
        result.save(f'{Path}/moscot_maps/', f'{ts0}_{epsilon}_{ap_min}_{ap_max}')
    print('-------------------------------------------------------------------')

E8.5a
Only considering the two last: ['.5b', '.h5ad'].
Only considering the two last: ['.5b', '.h5ad'].


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m16909[0m, [1;36m154313[0m[1m)[0m[1m][0m.                               


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m16909[0m, [1;36m154313[0m[1m)[0m[1m][0m.                               
