# Tagged Arrays


In this example, we introduce the {class}`~moscot.solvers.TaggedArray`.

{class}`moscot.solvers.TaggedArray` stores the data passed by the users in a unified way before it is passed to the backend.

In [2]:
from moscot.datasets import simulate_data
from moscot.problems.generic import GWProblem
import scanpy as sc

import numpy as np
import pandas as pd

In [9]:
rng = np.random.default_rng(seed=42)
np.set_printoptions(threshold=2)

adata = simulate_data(n_distributions=2, key="batch")
sc.pp.pca(adata)
adata

AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

We instantiate and prepare a {class}`~moscot.problems.generic.GWProblem` to demonstrate the role of the {class}`~moscot.solvers.TaggedArray`.

In [4]:
gwp = GWProblem(adata)
gwp = gwp.prepare(key="batch", joint_attr="X_pca", GW_x="X_pca", GW_y="X_pca")

The {class}`~moscot.problems.base.OTProblem` has attributes {attr}`~moscot.problems.base.OTProblem.xy`, {attr}`~moscot.problems.base.OTProblem.x`, and {attr}`~moscot.problems.base.OTProblem.y`, storing the data for the linear and quadratic term, respectively. These attributes are all {class}`TaggedArrays <moscot.solvers.TaggedArray>`.

In [13]:
gwp["0", "1"].xy

TaggedArray(data_src=ArrayView([[ 4.718115  , -0.16162483, -1.0284393 , ...,  0.5209566 ,
            -0.0388568 , -0.04343414],
           [-1.6827182 ,  0.75426227, -2.1140182 , ..., -0.348197  ,
            -0.4980026 , -0.0512526 ],
           [ 0.11928   , -1.9083179 ,  2.5255933 , ..., -0.53636247,
            -0.6098449 , -0.2735412 ],
           ...,
           [-1.7637906 ,  2.95762   , -0.06531301, ..., -0.65913224,
            -0.0113457 ,  0.06709672],
           [-1.0875723 ,  3.231092  , -0.98423856, ..., -0.5448444 ,
            -0.2618213 , -0.01479931],
           [-1.4714718 ,  0.1015937 , -0.9966435 , ...,  0.23407328,
             0.04770787, -0.02073438]], dtype=float32), data_tgt=ArrayView([[-1.0755012 , -3.5715468 , -0.31484002, ...,  0.49174863,
            -0.22571746,  0.1941065 ],
           [ 0.8515434 , -1.0328988 ,  3.8338466 , ...,  0.2522054 ,
            -0.11483867,  0.24348386],
           [-0.4111828 , -2.6891477 , -1.8634194 , ..., -0.3468407 ,
    

Each {class}`~moscot.solvers.TaggedArray` has attributes
{attr}`~moscot.solvers.TaggedArray.data_src`, {attr}`~moscot.solvers.TaggedArray.data_tgt`,
{attr}`~moscot.solvers.TaggedArray.cost`, and {attr}`~moscot.solvers.TaggedArray.tag`.

In [14]:
gwp["0", "1"].xy.tag, gwp["0", "1"].xy.cost

('point_cloud', <ott.geometry.costs.SqEuclidean at 0x28d274700>)

The :attr:`moscot.solvers.TaggedArray.tag` is of type :class:`moscot.solvers.Tag` and
defines what kind of data is stored in the :class:`moscot.solvers.TaggedArray`.
Possible tags are "cost_matrix", "kernel", and "point_cloud". Whenever the `tag` is "point_cloud",
the backend is expted to compute the cost on the fly. Note that this often reduces
memory complexity from quadratic to linear and hence is advisable.
:attr:`moscot.solvers.TaggedArray.cost` should then specify which cost to compute from the
point clouds.

If the :class:`moscot.solvers.TaggedArray` corresponds to a linear term,
:attr:`moscot.solvers.TaggedArray.data_src` and :attr:`moscot.solvers.TaggedArray.data_tgt`
contain the point clouds of the source and the target distribution, respectively.

In [16]:
print(type(gwp["0", "1"].xy.data_src))
print(type(gwp["0", "1"].xy.data_tgt))

<class 'anndata._core.views.ArrayView'>
<class 'anndata._core.views.ArrayView'>


In [None]:
# If the :class:`moscot.solvers.TaggedArray` corresponds to a quadratic term, the cost
# will be computed pairwise between points of the same distribution. Hence,
# :attr:`moscot.solvers.TaggedArray.data_tgt` will be `None`.

print(type(fgwp["0", "1"].x.data_src))
print(type(fgwp["0", "1"].x.data_tgt))

Whenever the `tag` is "cost_matrix", the backend expects an instantiated cost matrix.
There are two different cases to distinguish. First, the user might directly want to
pass custom cost matrices, see for example
:ref:`sphx_glr_auto_examples_problems_ex_passing_custom_cost_matrices.py`. In this case,
:attr:`moscot.solvers.TaggedArray.cost` must be set to `custom`. When setting custom
cost matrices, e.g. via :meth:`moscot.problems.base.OTProblem.set_xy`, the
:class:`moscot.solvers.TaggedArray` will change its :attr:`moscot.solvers.TaggedArray.tag`.
Before setting the custom cost matrix we still have "point_cloud" as a tag and
:attr:`moscot.solvers.TaggedArray.data_tgt` is not `None`, as it contains
the point cloud of the target distribution.

In [None]:
print(fgwp["0", "1"].xy.tag)
print(type(fgwp["0", "1"].xy.data_tgt))

We now construct a (random) custom cost matrix for the linear term.

In [None]:
obs_names_0 = fgwp["0", "1"].adata_src.obs_names
obs_names_1 = fgwp["0", "1"].adata_tgt.obs_names
cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))
cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)

fgwp["0", "1"].set_xy(cm_linear, tag="cost_matrix")

print(fgwp["0", "1"].xy.tag)
print(type(fgwp["0", "1"].xy.data_tgt))

If the cost matrix is to be computed via a class in :mod:`moscot.costs`, the
:attr:`moscot.solvers.TaggedArray.cost` must be set to the :obj:`str`, see for example
:ref:`sphx_glr_auto_examples_problems_ex_use_leaf_distance.py`.