# Passing callbacks in {meth}`~moscot.problems.spatial.AlignmentProblem.prepare`

In this example, we show how to use different callbacks.

The `callback` argument that states which computation should be run on {attr}`~anndata.AnnData.X` to get the joint cost.

:::{seealso}
- See {doc}`200_custom_cost_matrices` for an on how to use custom matrices and pass  `joint_attr`, `x_attr` and `y_attr` in the {meth}`~moscot.problems.generic.FGWProblem.prepare` method.
:::

## Imports and data loading

In [1]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from moscot import datasets
from moscot.problems.space import AlignmentProblem

In [2]:
adata = datasets.sim_align()
adata

AnnData object with n_obs × n_vars = 1200 × 500
    obs: 'batch'
    uns: 'batch_colors'
    obsm: 'spatial'

When `normalize_spatial=True` is passed, the spatial coordinates are normalized by standardizing them.

In [3]:
ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="sequential", normalize_spatial=True)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Normalizing spatial coordinates of `x`.                                                                   
[34mINFO    [0m Normalizing spatial coordinates of `y`.                                                                   


In [4]:
ap[("1", "2")].x.data_src.std() == 1

True

In [5]:
ap = ap.prepare(batch_key="batch", policy="sequential", normalize_spatial=False)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [6]:
ap[("1", "2")].x.data_src.std() == 1

False

We can pass `callback="local-pca"` to run on {attr}`~anndata.AnnData.X` to get the joint cost.

In [7]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    normalize_spatial=False,
    callback="local-pca",
)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [8]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],
 ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}

In [9]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    normalize_spatial=False,
    callback="spatial-norm",
)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [10]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],
 ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}

To use graphs, we can either pass `callback="graph-construction"`

In [11]:
ap = ap.prepare(
    batch_key="batch",
    policy="sequential",
    normalize_spatial=False,
    callback="graph-construction",
)

[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  
[34mINFO    [0m Computing pca with `[33mn_comps[0m=[1;36m30[0m` for `xy` using `adata.X`                                                  


In [12]:
ap.solve()
ap.solutions

[34mINFO    [0m Solving `[1;36m2[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m400[0m, [1;36m400[0m[1m)[0m[1m][0m.                                            


{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0932, converged=True],
 ('1', '2'): OTTOutput[shape=(400, 400), cost=1.1171, converged=True]}