In [8]:
# Setting jax to use float64, makes the algorithm more stable
from jax.config import config
config.update("jax_enable_x64", True)

# The lines below are useful to customize RAM-Memory usage
import os
# By default, OTT takes 90% of RAM-memory. Using the line below it uses exactly how much it needs
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# This frees RAM-Memory if no longer needed, but might slow down computations (I didn't observe that)
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'


import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy
import time

from moscot.problems.time._lineage import TemporalProblem
from moscot.backends.ott._solver import SinkhornSolver

In [4]:
Path="/home/icb/manuel.gander/moscotTime_Reproducibility/Data"
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']

# Calculate transport maps

In [9]:
def growth_rates_to_apoptosis_ratio(growth_rates, ts0, ts1):
    
    # I got these cell numbers from http://tome.gs.washington.edu/, and they got them from the experiments or 
    # for E8.5b they estimated it themselves (it was their experiment)
    
    cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
    Cell_number_dict={}
    for i in range(20):
        Cell_number_dict[ts[i]]=cells[i]
    
    cellular_growth_rates=growth_rates*Cell_number_dict[ts1]/Cell_number_dict[ts0]
    apoptotic_cells=cellular_growth_rates[cellular_growth_rates<1]
    sum_apoptotic_cells=(1-apoptotic_cells).sum()
    perc_apoptotic_cells=sum_apoptotic_cells/len(growth_rates)
    
    return(perc_apoptotic_cells)



def lambda_to_growth_rates(adata, epsilon, lam1, score_genes=True):
    time_tuple=tuple(sorted(set(adata.obs['day'])))
    
    tp=TemporalProblem(adata)
    
    if score_genes:
        tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
    tp = tp.prepare('day', joint_attr=f'X_pcaS')
    
    
    if len(adata)<=10**5:
        batch_size=20000
    else:
        batch_size=20000
    
    lam2=lam1*100
    
    tau1=lam1/(lam1+epsilon)
    tau2=lam2/(lam2+epsilon)
    print(batch_size)
    result=tp.solve(batch_size=batch_size, epsilon=epsilon, tau_a=tau1, tau_b=tau2, scale_cost="mean", max_iterations=10**6)
    
    # ToDo: Tell Dominik to fix posterior_growth_rates, .solutions.a corresponds to transport_matrix.sum(1)
    growth_rates=np.array(result[time_tuple].solution.a)
    growth_rates=growth_rates/np.mean(growth_rates)
    return(growth_rates, result)



def given_apoptosis_rate_find_lam(adata, ts0, ts1, epsilon, ap_min, ap_max, posterior_search=False, lam_init=None):
    # For the big data sets, it is faster to subsample to find the right lambda, and then to compute the full map
    # in the end once. Apoptosis rate typically stay approxiamtely the same

    # Lambdas are in log-scale, i.e. lam=10**x, with x in linear scale
    xm=-1
    x_interval_initial=[4, -6]    
    
    # This is only used for the big data sets, where lambda search is done on a subsampled set. The lambdas found in
    # the subsampled case should already correspond to the right apoptosis rates, but in case they do not, find the
    # right lambda in the not subsampled data set
    if posterior_search:
        xm=np.log10(lam)
        x_interval_initial=[xm+2, xm-2]
        
        
    x_interval=x_interval_initial.copy()
        
    
    time0=time.time()
    
    
    
    while(True):
        lam=10**xm
        
        result=0
        
        growth_rates, result=lambda_to_growth_rates(adata, epsilon, lam)
        perc_apop=growth_rates_to_apoptosis_ratio(growth_rates, ts0, ts1)
        print(lam)
        print(perc_apop)
        
        print(time.time()-time0)
        t0=time.time()
        time.sleep(0.1)

        if ap_min<=perc_apop<=ap_max:
            # First entry is whether it was sucessful
            return(lam, result)

        elif xm>max(x_interval_initial)-1 or xm<min(x_interval_initial)+1:
            # Not successful
            return(np.NaN, np.NaN)

        else:
            if ap_max<perc_apop:
                x_interval=[x_interval[0], xm]
            elif ap_min>perc_apop:
                x_interval=[xm, x_interval[1]]
            xm=np.mean(x_interval)

In [6]:
def compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max):
    # For the big data set, estimate lambda on a subsampled set to speed up computations
    if len(adata)>100000:
        adatas=sc.pp.subsample(adata, n_obs=10**5, copy=True)
        print('Subsampling to find lambda')
    else:
        adatas=adata.copy()
        
    lam, result=given_apoptosis_rate_find_lam(adatas, ts0, ts1, epsilon, ap_min, ap_max)
    
    if len(adatas)!=len(adata):
        result=0
        lam, result=given_apoptosis_rate_find_lam(adata, ts0, ts1, epsilon, ap_min, ap_max, posterior_serch=True, lam_init=lam)
    
    return(not np.isnan(lam), result, lam)

# Loop over all time points

In [15]:
epsilon=0.005
ap_min=0.02
ap_max=0.04

for i in range(13,19):
    if i==13:
        continue
    ts0=ts[i]
    ts1=ts[i+1]
    print(ts0)
    adata=sc.read(f"{Path}/anndatas/adata_{ts0}_{ts1}.h5ad")

    del adata.raw
        
    success, result, lam=compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max)
    
    if success:
        result.save(f'{Path}/moscot_maps/', f'{ts0}_{epsilon}_{ap_min}_{ap_max}', overwrite=True)
        np.save(f'{Path}/moscot_maps/{ts0}_{epsilon}_{ap_min}_{ap_max}_used_lam.npy', lam)    
    print('-------------------------------------------------------------------')

E8.5b
Only considering the two last: ['.5', '.h5ad'].
Only considering the two last: ['.5', '.h5ad'].
Subsampling to find lambda


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


20000
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m58177[0m, [1;36m41823[0m[1m)[0m[1m][0m.                                
0.1
0.08542798571977343


UnboundLocalError: local variable 't0' referenced before assignment

In [7]:
epsilon=0.005
ap_min=0.02
ap_max=0.04

i=14
ts0=ts[i]
ts1=ts[i+1]
print(ts0)
adata=sc.read(f"{Path}/anndatas/adata_{ts0}_{ts1}.h5ad")

del adata.raw

success, result, lam=compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max)


E8.5b
Only considering the two last: ['.5', '.h5ad'].
Only considering the two last: ['.5', '.h5ad'].
Subsampling to find lambda


NameError: name 'time' is not defined

In [None]:
success, result, lam=compute_map_given_apoptosis_rate(adata, ts0, ts1, epsilon, ap_min, ap_max)

Subsampling to find lambda


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


20000
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m58177[0m, [1;36m41823[0m[1m)[0m[1m][0m.                                
0.1
0.08542798571977343
191.9025821685791


  for cut in np.unique(obs_cut.loc[gene_list]):




  for cut in np.unique(obs_cut.loc[gene_list]):


20000
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m58177[0m, [1;36m41823[0m[1m)[0m[1m][0m.                                
