# Tagged arrays

This example shows how to use the {class}`~moscot.utils.tagged_array.TaggedArray`.

{class}`~moscot.utils.tagged_array.TaggedArray` stores the data passed by the users in a unified way before it is passed to the backend.

:::{seealso}
- See {doc}`600_leaf_distance` on how to use lineage tree to compute leaf distance.
- See {doc}`200_custom_cost_matrices` on how to pass precomputed cost matrices.
- See {doc}`../solvers/300_quad_problems_basic` for an introduction on how to solve quadratic problems.
- See {doc}`../solvers/400_quad_problems_advanced` for an advanced example on how to solve quadratic problems.
:::

## Imports and data loading

In [1]:
import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem

import numpy as np
import pandas as pd

import scanpy as sc

np.set_printoptions(threshold=1, precision=3)

Simulate data using {func}`~moscot.datasets.simulate_data`.

In [2]:
adata = datasets.simulate_data(n_distributions=2, key="batch")
sc.pp.pca(adata)
adata

AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

## Prepare the problem

We instantiate and prepare a {class}`~moscot.problems.generic.FGWProblem` to demonstrate the role of the {class}`~moscot.utils.tagged_array.TaggedArray`.

In [3]:
fgw = FGWProblem(adata)
fgw = fgw.prepare(key="batch", x_attr="X_pca", y_attr="X_pca", joint_attr="X_pca")

The {class}`~moscot.base.problems.OTProblem` has attributes {attr}`~moscot.base.problems.OTProblem.xy`, {attr}`~moscot.base.problems.OTProblem.x`, and {attr}`~moscot.base.problems.OTProblem.y`, storing the data for the linear and quadratic term, respectively. These attributes are all {class}`TaggedArrays <moscot.utils.tagged_array.TaggedArray>`.

In [4]:
fgw[("0", "1")].xy

TaggedArray(data_src=ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
           [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
           [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
           ...,
           [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
           [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
           [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
          dtype=float32), data_tgt=ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
           [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
           [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
           ...,
           [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
           [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
           [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
          dtype=float32), tag=<Tag.POINT_CLOUD: 'point_cloud'>, cost=<ott.geometry.costs.SqEuclidean object at 0x00000

## Attributes

Each {class}`~moscot.utils.tagged_array.TaggedArray` has attributes {attr}`~moscot.utils.tagged_array.TaggedArray.data_src`, {attr}`~moscot.utils.tagged_array.TaggedArray.data_tgt`, {attr}`~moscot.utils.tagged_array.TaggedArray.cost`, and {attr}`~moscot.utils.tagged_array.TaggedArray.tag`.

In [5]:
fgw["0", "1"].xy.tag, fgw["0", "1"].xy.cost

(<Tag.POINT_CLOUD: 'point_cloud'>,
 <ott.geometry.costs.SqEuclidean at 0x2095006ed50>)

The {attr}`~moscot.utils.tagged_array.TaggedArray.tag` attribute is of type {class}`~moscot.utils.tagged_array.Tag` and defines what kind of data is stored in the {class}`~moscot.utils.tagged_array.TaggedArray`.
Possible tags are {attr}`cost_matrix <moscot.utils.tagged_array.Tag.COST_MATRIX>`, {attr}`kernel <moscot.utils.tagged_array.Tag.KERNEL>`, and {attr}`point_cloud <moscot.utils.tagged_array.Tag.POINT_CLOUD>`. Whenever `tag='point_cloud'`,
the backend is expected to compute the cost on the fly. The {attr}`~moscot.utils.tagged_array.TaggedArray.cost` attribute should then specify which cost to compute from the point clouds.

If the {class}`~moscot.utils.tagged_array.TaggedArray` corresponds to a linear term,
{attr}`~moscot.utils.tagged_array.TaggedArray.data_src` and {attr}`~moscot.utils.tagged_array.TaggedArray.data_tgt`
contain the point clouds of the source and the target distribution, respectively.

In [6]:
fgw["0", "1"].xy.data_src, fgw["0", "1"].xy.data_tgt

(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

If the {class}`~moscot.utils.tagged_array.TaggedArray` corresponds to a quadratic term, the cost
will be computed pairwise between points of the same distribution. Hence,
{attr}`data_tgt is None <moscot.utils.tagged_array.TaggedArray.data_tgt>`.

In [7]:
fgw["0", "1"].x.data_src, fgw["0", "1"].x.data_tgt

(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 None)

## Modifying the tags

Whenever the `tag='cost_matrix'`, the backend expects an instantiated cost matrix.
There are two different cases to distinguish. First, the user might directly want to
pass custom cost matrices, see {doc}`200_custom_cost_matrices` for more information. In this case, `cost='custom'` must be set.

When setting custom cost matrices, e.g., via {meth}`~moscot.base.problems.OTProblem.set_xy`, the
{class}`~moscot.utils.tagged_array.TaggedArray` will change its {attr}`~moscot.utils.tagged_array.TaggedArray.tag`.
Before setting the custom cost matrix we still have `tag='point_cloud'` and
{attr}`data_tgt is not None <moscot.utils.tagged_array.TaggedArray.data_tgt>`, as it contains
the point cloud of the target distribution.

In [8]:
fgw["0", "1"].xy.tag, fgw["0", "1"].xy.data_tgt

(<Tag.POINT_CLOUD: 'point_cloud'>,
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

We now construct a random custom cost matrix for the linear term.

In [9]:
rng = np.random.default_rng(seed=42)
obs_names_0 = fgw["0", "1"].adata_src.obs_names
obs_names_1 = fgw["0", "1"].adata_tgt.obs_names

cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))
cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)

fgw["0", "1"].set_xy(cm_linear, tag="cost_matrix")
fgw["0", "1"].xy.tag, fgw["0", "1"].xy.data_tgt

(<Tag.COST_MATRIX: 'cost_matrix'>, None)