In [1]:
import scanpy as sc
import pandas as pd 
import numpy as np
from anndata import AnnData, concat
import cloudpickle

In [2]:
import jax.numpy as jnp

In [3]:
import cellrank as cr
import scvelo as scv

In [4]:
from moscot.problems.time import TemporalNeuralProblem

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
org_adata = sc.read("/lustre/groups/ml01/workspace/monge_velo/data/benchmarks/gastrulation/adata_gastrulation.h5ad")
org_adata.obs['celltype'].unique()

leave_out = ['E6.5', 'E6.75', 'E7.0', 'E7.25', 'E7.5', 'E7.75', 'E8.0']
leave_in_cell = ['Blood progenitors 2', 'Erythroid1', 'Erythroid2', 'Erythroid3']
adata = org_adata[~org_adata.obs["stage"].isin(leave_out)].copy()
adata = adata[adata.obs["celltype"].isin(leave_in_cell)].copy()

map_dict = {"E6.5": 6.5, "E6.75": 6.75, "E7.0": 7.0, "E7.25": 7.25, "E7.5": 7.5, "E7.75": 7.75, "E8.0": 8.0,
            "E8.25": 8.25, "E8.5": 8.5}
adata.obs["stage"] = adata.obs["stage"].copy().map(map_dict)

adata.obs = adata.obs.loc[:, ['stage', 'celltype']]
adata.var = adata.var[[]]

sc.pp.log1p(adata)

In [6]:
adata

AnnData object with n_obs × n_vars = 4645 × 53801
    obs: 'stage', 'celltype'
    uns: 'log1p'
    obsm: 'X_pca', 'X_umap'
    layers: 'spliced', 'unspliced'

In [7]:
tnp = TemporalNeuralProblem(adata)
tnp = tnp.score_genes_for_marginals("mouse", "mouse")



In [None]:
tnp = tnp.prepare("stage", joint_attr="X_pca")
tnp = tnp.solve(tau_a=0.95, tau_b=0.95)

[34mINFO    [0m Solving problem BirthDeathNeuralProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m1661[0m, [1;36m2984[0m[1m)[0m[1m][0m.                            


 16%|███████████▎                                                             | 3888/25000 [2:16:54<11:53:41,  2.03s/it]

In [None]:
with open("/lustre/groups/ml01/workspace/monge_velo/data/benchmarks/gastrulation/test_save_branch_tnp", "wb") as f:
    cloudpickle.dump(tnp, f)

tnp

In [None]:
with open("/lustre/groups/ml01/workspace/monge_velo/data/benchmarks/gastrulation/test_save_branch_tnp", "rb") as f:
    tnp = cloudpickle.load(f)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(tnp[8.25, 8.5].solution._training_logs["valid_logs"]["sinkhorn_loss_forward"], label="forward")
ax.plot(tnp[8.25, 8.5].solution._training_logs["valid_logs"]["sinkhorn_loss_inverse"], label="inverse")
ax.set_ylabel("sinkhorn loss")
ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(tnp[8.25, 8.5].solution._training_logs["valid_logs"]["valid_w_dist"])
ax.set_ylabel("W2-distance")
plt.show()

In [None]:
source = jnp.array(adata.obsm["X_pca"].copy())

In [None]:
velocity = tnp[8.25, 8.5].solution.push(source) - source

In [None]:
adata_result = AnnData(shape=velocity.shape)
adata_result.obs = adata.obs.copy()

In [None]:
adata_result.layers["GEX_velocity"] = np.asarray(velocity)
adata_result.layers["X_pca"] = np.asarray(adata.obsm["X_pca"])

In [None]:
adata_result

In [None]:
adata_result.write("/lustre/groups/ml01/workspace/monge_velo/data/benchmarks/gastrulation/adata_gex_velocities_branch.h5ad")