# Disentangling lineage relationships of delta and epsilon cells in pancreatic development with the {class}`~moscot.problems.time.TemporalProblem`

In this tutorial, we showcase ???. The method builds upon {cite}`klein:23`.


:::{seealso}
- See {doc}`500_spatiotemporal` on how to incorporate spatial information as additional prior.
- See {doc}`100_lineage` on how to incorporate lineage information as additional prior. 
:::

In [1]:
import os

import moscot as mt
import moscot.plotting as mpl
from moscot.problems.time import TemporalProblem

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc

# import muon
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=200)

import mplscience

import jax

mplscience.available_styles()
mplscience.set_style(reset_current=True)
plt.rcParams["legend.scatterpoints"] = 1

['default', 'despine']


In [2]:
adata = mt.datasets.pancreas_multiome(rna_only=True)

In [3]:
# mudata = mt.datasets.pancreas_multiome(rna_only=False)

In [4]:
endocrine_celltypes = [
    "Ngn3 low",
    "Ngn3 high",
    "Ngn3 high cycling",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon",
]

In [5]:
adata = adata[adata.obs["cell_type"].isin(endocrine_celltypes)].copy()

In [6]:
def adapt_time(x):
    if x["stage"] == "E14.5":
        return 14.5
    if x["stage"] == "E15.5":
        return 15.5
    if x["stage"] == "E16.5":
        return 16.5
    raise ValueError


adata.obs["time"] = adata.obs.apply(adapt_time, axis=1).astype("category")

In [7]:
import networkx as nx

dfs = {}
batch_column = "time"
unique_batches = [14.5, 15.5, 16.5]
for i in range(len(unique_batches) - 1):
    batch1 = unique_batches[i]
    batch2 = unique_batches[i + 1]

    indices = np.where(
        (adata.obs[batch_column] == batch1) | (adata.obs[batch_column] == batch2)
    )[0]
    adata_subset = adata[indices]
    sc.pp.neighbors(adata_subset, use_rep="X_MultiVI", n_neighbors=30)
    G = nx.from_numpy_array(adata_subset.obsp["connectivities"].A)
    assert nx.is_connected(G)

    dfs[(batch1, batch2)] = pd.DataFrame(
        index=adata_subset.obs_names,
        columns=adata_subset.obs_names,
        data=adata_subset.obsp["connectivities"].A.astype("float"),
    )

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
tp0 = TemporalProblem(adata)
tp0 = tp0.prepare("time", joint_attr="X_MultiVI")

In [9]:
tp0[14.5, 15.5].set_graph_xy((dfs[14.5, 15.5]).astype("float"), t=100.0)
tp0[15.5, 16.5].set_graph_xy((dfs[15.5, 16.5]).astype("float"), t=100.0)

In [None]:
tp0 = tp0.solve(max_iterations=5, device="CPU")

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem BirthDeathProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m5185[0m, [1;36m1699[0m[1m)[0m[1m][0m.                                  
