In [None]:
#from jax.config import config
#config.update("jax_enable_x64", True)
import numpy as np
import anndata
import pandas as pd
import scanpy as sc
import scipy 
import seaborn as sns
import matplotlib.pyplot as plt
import os
import time
from moscot.problems.time._lineage import TemporalProblem
import warnings
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
import sys,os
sys.path.append('/home/icb/manuel.gander/mouse_atlas/notebook')
import c2
warnings.simplefilter(action='ignore', category=FutureWarning) 
sc.settings.verbosity = 0
import wandb

In [None]:
Path="/home/icb/manuel.gander/mouse_atlas/data"


In [None]:
ts=['E3.5', 'E4.5', 'E5.25', 'E5.5', 'E6.25', 'E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0', 'E8.25', 'E8.5a', 'E8.5b', 'E9.5', 'E10.5', 'E11.5', 'E12.5', 'E13.5']
cells=[32, 80, 100, 120, 400, 660, 1720, 4500, 8200, 15000, 30000, 60000, 73000, 90000, 90000, 200000, 1100000, 2600000, 6000000, 13000000]
Cell_number_dict={ts[i]:cells[i] for i in range(20)}

In [None]:
def solve_LR_moscot(epsilon, rank, gamma, iterations, i=18, tau1=0.07, batch_size=10**5):
    ts0=ts[i]
    ts1=ts[i+1]
    print(f'{ts0}_{ts1}')

    adata=sc.read(f"{Path}/Comb_anndatas/adata_{ts0}_{ts1}.h5ad")
    del adata.raw
    adata.obs['day']=adata.obs['day'].astype('category')

    day0,day1=sorted(set(adata.obs['day']))
    inds0=list(adata[adata.obs['day']==day0].obs.index)
    inds1=list(adata[adata.obs['day']==day1].obs.index)

    tp=TemporalProblem(adata)
    if i!=13:
        tp.score_genes_for_marginals(gene_set_proliferation='mouse',  gene_set_apoptosis='mouse')
    tp = tp.prepare('day', joint_attr=f'X_pcaS')
    
    tau2=0.99995

    t0=time.time()
    if rank==-1:
        result=tp.solve(batch_size=batch_size, epsilon=epsilon, tau_a=tau1, tau_b=tau2, rank=rank)
        iterations=float(tp[(day0, day1)].solution._output.n_iters)
    else:
        inners=max(int(iterations/20),1)
        result=tp.solve(batch_size=batch_size, epsilon=epsilon, tau_a=tau1, tau_b=tau2, max_iterations=iterations, rank=rank, threshold=0, inner_iterations=inners, gamma=gamma)
    t1=time.time()-t0
    print('Sinkhorn done')
    time0=time.time()
    gr=tp[(day0, day1)].solution.a
    gr=gr/gr.mean()*Cell_number_dict[ts1]/Cell_number_dict[ts0]
    cell_dying=np.sum((1-gr[gr<1]))
    apoptosis_rate=float(cell_dying/len(gr))

    cell_states0={'cell_state': list(set(adata[adata.obs['day']==day0].obs['cell_state']))}
    cell_states1={'cell_state': list(set(adata[adata.obs['day']==day1].obs['cell_state']))}
    CT=tp.cell_transition(day0, day1, cell_states0, cell_states1)

    ev=c2.evaluate_using_curated_transitions(CT)
    acc0=list(ev['Accuracy'])[0]

    ev=c2.evaluate_using_germ_layers(CT)
    acc1=list(ev['Accuracy'])[0]
    time1=time.time()-time0

    return(acc0, acc1, apoptosis_rate, t1/3600, time1/3600, iterations)

In [None]:
def solve_and_save_to_wandb(epsilon, rank, gamma, iterations, i, tau1, batch_size): 
    wandb.init(project=f"Low_Rank_benchmark")
    acc0, acc1, ap_rate, t1, t2, iterations=solve_LR_moscot(epsilon, rank, gamma, iterations, i, tau1, batch_size)
    wandb.log({"Accuracy_Curated": acc0, 'Accuracy_Germ':acc1, 'Iteration':iterations, 
               'Gamma':gamma, 'Rank':rank, 'Time_Sinkhorn':t1, 'Time_Evaluation':t2, 
               'Epsilon':epsilon,"Apoptosis_rate":ap_rate, 'Gamma*Epsilon':gamma*epsilon, 
               'tau1':tau1, 'i':i, 'batch_size':batch_size})
    wandb.finish()

In [None]:
# This dictionary is for the tau1s, such that the apoptitc range is in the approprate range
D={}
D[-1]=[0.99,0.98,0.99,0.92,0.985,0.9,0.3,0.65,0.9,0.82,0.93,0.996,0.96,0.99,0.98,0.68,0.8,0.9,0.9]
D[10]=[0.4,0.23,0.35,0.083,0.215,0.15,0.115,0.21,0.12,0.15,0.3,0.65,0.4,0.6,0.45,0.25,0.4,0.4,0.35]
D[100]=[0.33,0.2,0.34,0.062,0.215,0.12,0.07,0.065,0.07,0.077,0.25,0.65,0.16,0.5,0.22,0.08,0.15,0.15,0.12]
D[1000]=[0.33,0.218,0.34,0.058,0.155,0.1,0.035,0.045,0.065,0.04,0.2,0.62,0.13,0.5,0.2,0.045,0.065,0.085,0.09]
D[2000]=D[1000]
D[2000][5]=0.08
D[2000][6]=0.02
D[2000][8]=0.06
D[2000][15]=0.035
D[2000][16]=0.06

In [None]:
for rank in [-1, 10, 100, 1000, 2000]:
    for i in range(19):
        tau1=D[rank][i]
        
        if rank==-1:
            eps=0.005
            gamma=np.NaN
            iters=np.NaN
        else:
            eps=0.0001
            gamma=500
            iters=1000
        
        # I choose batch-size ~as big as possible without breaking GPU-memory (40Gb in this case)
        if i<13:
            batch_size=10**6
        elif i==13:
            batch_size=10**5
        else:
            batch_size=3*10**3
        
        solve_and_save_to_wandb(epsilon, rank, gamma, iterations, i, tau1, batch_size)