# Tagged arrays

This examples introduced the {class}`~moscot.solvers.TaggedArray`.

{class}`~moscot.solvers.TaggedArray` stores the data passed by the users in a unified way before it is passed to the backend.

In [1]:
from moscot.datasets import simulate_data
from moscot.problems.generic import GWProblem
import scanpy as sc

import numpy as np
import pandas as pd

np.set_printoptions(threshold=1, precision=3)

In [3]:
adata = simulate_data(n_distributions=2, key="batch")
sc.pp.pca(adata)
adata

AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

We instantiate and prepare a {class}`~moscot.problems.generic.GWProblem` to demonstrate the role of the {class}`~moscot.solvers.TaggedArray`.

In [4]:
gwp = GWProblem(adata)
gwp = gwp.prepare(key="batch", joint_attr="X_pca", GW_x="X_pca", GW_y="X_pca")

The {class}`~moscot.problems.base.OTProblem` has attributes {attr}`~moscot.problems.base.OTProblem.xy`, {attr}`~moscot.problems.base.OTProblem.x`, and {attr}`~moscot.problems.base.OTProblem.y`, storing the data for the linear and quadratic term, respectively. These attributes are all {class}`TaggedArrays <moscot.solvers.TaggedArray>`.

In [5]:
gwp["0", "1"].xy

TaggedArray(data_src=ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
           [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
           [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
           ...,
           [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
           [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
           [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
          dtype=float32), data_tgt=ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
           [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
           [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
           ...,
           [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
           [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
           [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
          dtype=float32), tag='point_cloud', cost=<ott.geometry.costs.SqEuclidean object at 0x29e7bb310>)

## Attributes

Each {class}`~moscot.solvers.TaggedArray` has attributes
{attr}`~moscot.solvers.TaggedArray.data_src`, {attr}`~moscot.solvers.TaggedArray.data_tgt`,
{attr}`~moscot.solvers.TaggedArray.cost`, and {attr}`~moscot.solvers.TaggedArray.tag`.

In [6]:
gwp["0", "1"].xy.tag, gwp["0", "1"].xy.cost

('point_cloud', <ott.geometry.costs.SqEuclidean at 0x29e7bb310>)

The {attr}`~moscot.solvers.TaggedArray.tag` attribute is of type {class}`~moscot.solvers.Tag` and
defines what kind of data is stored in the {class}`~moscot.solvers.TaggedArray`.
Possible tags are {attr}`cost_matrix <moscot.solvers.Tag.COST_MATRIX>`, {attr}`kernel <moscot.solvers.Tag.KERNEL>`, and {attr}`point_cloud <moscot.solvers.Tag.POINT_CLOUD>`. Whenever `tag="point_cloud"`,
the backend is expected to compute the cost on the fly. The {attr}`~moscot.solvers.TaggedArray.cost` attribute should then specify which cost to compute from the point clouds.

If the {class}`~moscot.solvers.TaggedArray` corresponds to a linear term,
{attr}`~moscot.solvers.TaggedArray.data_src` and {attr}`~moscot.solvers.TaggedArray.data_tgt`
contain the point clouds of the source and the target distribution, respectively.

In [7]:
gwp["0", "1"].xy.data_src, gwp["0", "1"].xy.data_tgt

(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

If the {class}`~moscot.solvers.TaggedArray` corresponds to a quadratic term, the cost
will be computed pairwise between points of the same distribution. Hence,
{attr}`data_tgt is None <moscot.solvers.TaggedArray.data_tgt>`.

In [8]:
gwp["0", "1"].x.data_src, gwp["0", "1"].x.data_tgt

(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 None)

## Modifying tags

Whenever the `tag="cost_matrix"`, the backend expects an instantiated cost matrix.
There are two different cases to distinguish. First, the user might directly want to
pass custom cost matrices, see {doc}`002_custom_cost_matrices` for more information. In this case, `cost="custom"` must be set.

When setting custom cost matrices, e.g., via {meth}`~moscot.problems.base.OTProblem.set_xy`, the
{class}`~moscot.solvers.TaggedArray` will change its {attr}`~moscot.solvers.TaggedArray.tag`.
Before setting the custom cost matrix we still have `tag="point_cloud"` and
{attr}`data_tgt is not None <moscot.solvers.TaggedArray.data_tgt>`, as it contains
the point cloud of the target distribution.

In [9]:
gwp["0", "1"].xy.tag, gwp["0", "1"].xy.data_tgt

('point_cloud',
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

We now construct a random custom cost matrix for the linear term.

In [10]:
rng = np.random.default_rng(seed=42)
obs_names_0 = gwp["0", "1"].adata_src.obs_names
obs_names_1 = gwp["0", "1"].adata_tgt.obs_names

cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))
cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)

gwp["0", "1"].set_xy(cm_linear, tag="cost_matrix")
gwp["0", "1"].xy.tag, gwp["0", "1"].xy.data_tgt

('cost_matrix', None)

If the cost matrix is to be computed via a class in ``moscot.costs``, the
{attr}`~moscot.solvers.TaggedArray.cost` must be set to the corresponding {class}`str`, see, e.g., {doc}`008_leaf_distance` for more information.