# Mapping across space and time

This tutorial the standard pipeline for performing mapping across time-pooints using spatial information
using the moscot solver :class:`moscot.solvers.spatio_temporal.SpatioTemporalProblem`.

We exemplify this using a subsample of spatiotemporal transcriptomics atlas of mouse organogenesis
using DNA nanoball-patterned arrays generated by :cite:`chen:22`.

In [None]:
import os
import sys
module_path = os.path.dirname(os.path.abspath(os.path.join("..")))

sys.path.append(module_path)
sys.path.append(os.path.join(module_path, "moscot"))

from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
from moscot.problems.spatio_temporal import SpatioTemporalProblem
from moscot.datasets import mosta

import scanpy as sc
import matplotlib.pyplot as plt

## load the data.

The anndata object includes three time-points with embryo sections E9.5 E2S1, E10.5 E2S1, E11.5 E1S2.
The :attr:`anndata.AnnData.X` entry is based on reprocessing of the counts data consisting of :meth:`scanpy.pp.normalize_total` and :meth:`scanpy.pp.log1p`

In [None]:
adata = mosta()

In [None]:
adata

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8,16))
for i, tp in enumerate(adata.obs["timepoint"].cat.categories):
    sc.pl.spatial(
        adata[adata.obs["timepoint"].isin([tp])], 
        color=["annotation"],
        title = [f"{tp}"],
        spot_size=1.0, 
        frameon=False,
        wspace=0.1, 
        alpha=1.0, 
        show=False,
        ax=axs[i])
    if i < 2:
        axs[i].legend().remove()
    
plt.tight_layout()
plt.show()

## Set `SpatioTemporalProblem`

To set a problem we need to specify the following parameters:

 - `time_key` - Time point key in :attr:`anndata.AnnData.obs`.
 - `spatial_key` - Key in :attr:`anndata.AnnData.obsm` where spatial coordinates are stored.
 - `joint_attr` - The key for the joint space fot the mapping.
 
       * If `None`, a value is computed based on `callback` using :attr:`anndata.AnnData.X`.
           If callback is not specified PCA is computed.
       * If `str`, it must refer to a key in :attr:`anndata.AnnData.obsm`.
       * If `dict`, the dictionary stores `attr` (attribute of :class:`anndata.AnnData`) and `key`
       (key of :class:`anndata.AnnData` ``['{attr}']``)
       
 - `callback` - Custom callback applied to each distribution as preprocessing step.

 ** For the purpose of iilustration we specifically specify the callback but this is not necessary.

In [None]:
stp = SpatioTemporalProblem(adata=adata).prepare(
    time_key="time",
    spatial_key="spatial",
    joint_attr=None,
    callback="local-pca",
)


## Solve the problem

To solve the problem we call :class:`moscot.solvers.spatio_temporal.SpatioTemporalProblem.solve()` and pass:

- `alpha` - Interpolation parameter between quadratic term (spatial coordinates) and linear term (PCA space).
- `epsilon` - Entropic regularisation parameter.
- `rank` – the rank constraint on the coupling to minimize the linear OT problem.
- `gamma` – the (inverse of) gradient stepsize used by mirror descent.

Here we utilize a `Low Rank` approach suggested by suggested by Meyer et al. :cite`scetbon:2021_a` :cite`scetbon:2021_b` which is beneficial when considering large datasets. 


In [None]:
stp =  stp.solve(alpha=0.9, epsilon=1e-3, rank=100, gamma=100)

## Visualize the results

In [None]:
import moscot.plotting as mpl
import numpy as np

We start by plotting the  cell-transition map inferred by `moscot` aggregated according to the tissue annotation. 

As expected we see that cell-types repeating across time points are mainly mapped to themselves.

In [None]:
# TODO: modify after issue # passing `fig` is resolved
# fig, axs = plt.subplots(1, 2)
dfs = {}
timepoints = adata.obs["time"].unique()
for i, tp in enumerate(timepoints[:-1]):
    dfs[i] = stp.cell_transition(
        source=timepoints[i],
        target=timepoints[i+1],
        source_groups="annotation",
        target_groups="annotation",
    )
    
    mpl.cell_transition(stp, figsize=(8,8))

Next, we can visualize the predicted spatial destination of the cells, again aggregating across tissue. 
That is, we focus on a single tissue, e.g. `Liver` cells, at an early time point and look at the spatial locations the cells are mapped to at a later time point

In [None]:
for i, col in enumerate(adata.obs["annotation"].cat.categories):
    adata.obs[f"{col}_mapping"] = 0
    adata.obs[f"{col}_annotation"] = np.asarray(adata.obs["annotation"] == col, dtype=int)
    adata.obs[f"{col}_annotation"] = adata.obs[f"{col}_annotation"].astype("category")
    adata.obs[f"{col}_annotation"] = adata.obs[f"{col}_annotation"].cat.rename_categories({0: "rest", 1: f"{col} cells"})
    adata.uns[f"{col}_annotation_colors"] = ["#808080", adata.uns["annotation_colors"][i]]

for i, tp in enumerate(timepoints[1:]):
    for j, col in enumerate(dfs[i]):
        adata.obs.loc[adata.obs["time"].isin([tp]) ,f"{col}_mapping"] =

In [None]:
tissue = "Liver"

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(8,16))
for i, tp in enumerate(timepoints):
    sc.pl.spatial(adata[adata.obs["time"].isin([tp])], 
                  color=[f"{tissue}_annotation"],
                  title = [f"{tissue} annotation in E{tp}"],
                  spot_size=1.0, frameon=False, wspace=0.1,alpha=1.0, show=False, ax=axs[0, i])
    if i > 0:
        sc.pl.spatial(adata[adata.obs["time"].isin([tp])], 
                      color=[f"{tissue}_mapping"],
                      title = [f"mapping from {tissue} E{timepoints[i-1]} to E{timepoints[i]}"],
                      spot_size=1.0, frameon=False, wspace=0.1, color_map="viridis_r",alpha=1.0, show=False, ax=axs[1, i])
    
plt.tight_layout()
axs[1, 0].remove()
plt.show()