# Passing callbacks in {meth}`~moscot.problems.spatial.AlignmentProblem.prepare`

In this example, we show how to use different callbacks.

The `callback` argument states which computation should be run on {attr}`~anndata.AnnData.X` to get the joint cost when preparing the problem. Callbacks can be set for different terms - linear (`xy_callback`) and quadratic (`x_callback`, `y_callback`).

:::{seealso}
- See {doc}`200_custom_cost_matrices` for an example on how to use custom matrices and pass `joint_attr`, `x_attr` and `y_attr` in the {meth}`~moscot.problems.generic.FGWProblem.prepare` method.
- See {doc}`700_barcode_distance` for an example on how to specify the cost to use barcode distance.
:::

## Imports and data loading

In [1]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from moscot import datasets
from moscot.problems.space import AlignmentProblem

import pandas as pd

import scanpy as sc

In [2]:
adata = datasets.sim_align()
sc.pp.pca(adata)
adata

AnnData object with n_obs × n_vars = 1200 × 500
    obs: 'batch'
    uns: 'batch_colors', 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'

## Spatial normalization

When `normalize_spatial=True` is passed, as it is by default, the spatial coordinates are normalized by standardizing them.

In [3]:
ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="sequential", normalize_spatial=True)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   


In [4]:
ap[("1", "2")].x.data_src.std() == 1

True

In [5]:
ap = ap.prepare(batch_key="batch", policy="sequential", normalize_spatial=False)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [6]:
ap[("1", "2")].x.data_src.std() == 1

False

The `normalize_spatial` argument effectively uses the `"spatial-norm"` callback.

We can pass `xy_callback="local-pca"` to run on {attr}`~anndata.AnnData.X` to get the joint cost.

In [7]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    xy_callback="local-pca",
)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   


In [8]:
ap[("0", "1")].xy.tag

<Tag.POINT_CLOUD: 'point_cloud'>

In [9]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1172, converged=True],
 ('0', '1'): OTTOutput[shape=(400, 400), cost=1.0933, converged=True]}

To use graphs, the cost for the term needs to be set to `"geodesic"`, we can either pass `xy_callback="graph-construction"`

In [10]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    normalize_spatial=False,
    xy_callback="graph-construction",
    cost={"xy": "geodesic", "x": "sq_euclidean", "y": "sq_euclidean"},
)

[34mINFO    [0m Computing graph construction for `xy` using `X_pca`                                                       
[34mINFO    [0m Computing graph construction for `xy` using `X_pca`                                                       


and verify a graph has been constructed:

In [11]:
ap[("0", "1")].xy

TaggedArray(data_src=<800x800 sparse matrix of type '<class 'numpy.float64'>'
	with 12766 stored elements in Compressed Sparse Row format>, data_tgt=None, tag=<Tag.GRAPH: 'graph'>, cost='geodesic')

In [12]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1057, converged=True],
 ('0', '1'): OTTOutput[shape=(400, 400), cost=1.0803, converged=True]}

Or use `set_graph_xy()` with a custom graph:

In [13]:
adata_subset = adata[adata.obs["batch"].isin(("0", "1"))]
sc.pp.neighbors(adata_subset, use_rep="X_pca")
df_graph = pd.DataFrame(
    index=adata_subset.obs_names,
    columns=adata_subset.obs_names,
    data=adata_subset.obsp["connectivities"].A.astype("float64"),
)

In [14]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    normalize_spatial=False,
)
ap[("0", "1")].set_graph_xy(df_graph, cost="geodesic")

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [15]:
ap[("0", "1")].xy.tag

<Tag.GRAPH: 'graph'>

In [16]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True],
 ('0', '1'): OTTOutput[shape=(400, 400), cost=1.0803, converged=True]}