# Passing callbacks in {meth}`~moscot.problems.space.MappingProblem.prepare`

In this example, we show how to use different callbacks.

The `callback` argument states which computation should be run on {attr}`~anndata.AnnData.X` to get the joint cost when preparing the problem. Callbacks can be set for different terms - linear (`xy_callback`) and quadratic (`x_callback`, `y_callback`).

:::{seealso}
- See {doc}`200_custom_cost_matrices` for an example on how to use custom matrices and pass `joint_attr`, `x_attr` and `y_attr` in the {meth}`~moscot.problems.generic.FGWProblem.prepare` method.
- See {doc}`700_barcode_distance` for an example on how to specify the cost to use barcode distance.
:::

## Imports and data loading

In [1]:
import warnings

warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)

from moscot import datasets
from moscot.problems.space import MappingProblem
from moscot.utils.tagged_array import TaggedArray

import numpy as np
import pandas as pd
from sklearn.decomposition import SparsePCA

import anndata
import scanpy as sc

In [2]:
adata_sc = datasets.drosophila(spatial=False)
adata_sp = datasets.drosophila(spatial=True)
adata_sc, adata_sp

(AnnData object with n_obs × n_vars = 1297 × 2000
     obs: 'n_counts'
     var: 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
     uns: 'hvg', 'log1p', 'pca'
     obsm: 'X_pca'
     varm: 'PCs'
     layers: 'counts',
 AnnData object with n_obs × n_vars = 3039 × 82
     obs: 'n_counts'
     var: 'n_counts'
     uns: 'log1p', 'pca'
     obsm: 'X_pca', 'spatial'
     varm: 'PCs'
     layers: 'counts')

## Spatial normalization

When `normalize_spatial=True` is passed, as it is by default, the spatial coordinates are normalized by standardizing them.

In [3]:
mp = MappingProblem(adata_sc=adata_sc, adata_sp=adata_sp)
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, normalize_spatial=True)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   


In [4]:
mp[("src", "tgt")].x.data_src.std()

1.0000000000000002

In [5]:
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, normalize_spatial=False)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [6]:
mp[("src", "tgt")].x.data_src.std()

66.97163996056013

The `normalize_spatial` argument effectively uses the `"spatial-norm"` callback.

## Passing callbacks

### PCA computation in gene space

To create a joint PCA embedding between two sets of genes, we compute the PCA embedding for pairs of distributions by passing `xy_callback="local-pca"` to run on {attr}`~anndata.AnnData.X`.

In [7]:
mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    xy_callback="local-pca",
)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   


The callback creates a point cloud that contains PCA projections of the data.

In [8]:
mp[("src", "tgt")].xy.tag

<Tag.POINT_CLOUD: 'point_cloud'>

In [9]:
mp.solve()
mp.solutions

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m3039[0m, [1;36m1297[0m[1m)[0m[1m][0m.                                          


{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.6444, converged=True]}

### Using geodesic costs

To use geodesic costs defined on a graph, we can create the underlying graph (here in gene expression space) using `xy_callback="graph-construction"`. Note that the `cost` has to be set explicitly.

In [10]:
mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
    xy_callback="graph-construction",
    cost={"xy": "geodesic", "x": "sq_euclidean", "y": "sq_euclidean"},
)

[34mINFO    [0m Computing graph construction for `xy` using `X_pca`                                                       


and verify a graph has been constructed:

In [11]:
mp[("src", "tgt")].xy.tag

<Tag.GRAPH: 'graph'>

In [12]:
mp.solve()
mp.solutions

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m3039[0m, [1;36m1297[0m[1m)[0m[1m][0m.                                          


{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.3147, converged=True]}

Or use {meth}`~moscot.base.problems.OTProblem.set_graph_xy` with a custom graph:

In [13]:
adata_concat = anndata.concat([adata_sp, adata_sc])
sc.pp.neighbors(adata_concat, use_rep="X_pca")
df_graph = pd.DataFrame(
    index=adata_concat.obs_names,
    columns=adata_concat.obs_names,
    data=adata_concat.obsp["connectivities"].A.astype("float64"),
)

First, the problem is prepared with the default (`"sq_euclidean"`) cost, and it is then overwritten by {meth}`~moscot.base.problems.OTProblem.set_graph_xy`:

In [14]:
mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
)
mp[("src", "tgt")].set_graph_xy(df_graph, cost="geodesic")

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [15]:
mp[("src", "tgt")].xy.tag

<Tag.GRAPH: 'graph'>

In [16]:
mp.solve()
mp.solutions

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m3039[0m, [1;36m1297[0m[1m)[0m[1m][0m.                                          


{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.3147, converged=True]}

### Custom callback function

A callable can also be passed to be used as a custom callback. In this example we will use the scikit-learn [`SparsePCA()`](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.SparsePCA.html#) function.

The callback function receives `term: Literal["xy", "x", "y"], problem.adata_src, problem.adata_tgt` as arguments, as well as any keyword arguments passed in `xy_callback_kwargs`. It should return a {class}`moscot.utils.tagged_array.TaggedArray`.

In [17]:
mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
    xy_callback=lambda term, src, tgt: TaggedArray(
        *np.split(
            SparsePCA().fit_transform(np.vstack([src.X.A, tgt.X.A])), [src.shape[0]]
        )
    ),
)

In [18]:
mp[("src", "tgt")].xy.tag

<Tag.POINT_CLOUD: 'point_cloud'>

In [19]:
mp.solve()
mp.solutions

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m3039[0m, [1;36m1297[0m[1m)[0m[1m][0m.                                          


{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.6591, converged=True]}