# Quadratic problems

This example shows how to solve quadratic problems, e.g., the {class}`~moscot.problems.time.LineageProblem`, the {class}`~moscot.problems.spatiotemporal.SpatioTemporalProblem`, the {class}`~moscot.problems.space.MappingProblem`, the {class}`~moscot.problems.space.AlignmentProblem`, the {class}`~moscot.problems.generic.GWProblem`, and the {class}`~moscot.problems.generic.FGWProblem`.

:::{seealso}
- See {doc}`400_quad_problems_advanced` for an advanced example on how to solve quadratic problems.
- See {doc}`100_linear_problems_basic` for an introduction on how to solve linear problems.
- See {doc}`200_linear_problems_advanced` for an advanced example on how to solve linear problems.
:::

## Imports and data loading

In [1]:
import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem, GWProblem

import numpy as np

import scanpy as sc

Simulate data using {func}`~moscot.datasets.simulate_data`.

In [2]:
adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata

AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'

## Basic parameters

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to {doc}`100_linear_problems_basic` for the role of `epsilon`, `tau_a`, and `tau_b`. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter `alpha` defining the convex combination between the quadratic and the linear term, defined by `joint_attr`. Setting ``alpha = 1`` only considers the pure quadratic problem, ignoring `joint_attr`. Setting `alpha = 0` is not possible, and hence linear problems must be chosen.

In [3]:
gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(epsilon=1e-1)

fgwp = FGWProblem(adata)
fgwp = fgwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
    joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
    np.abs(
        gwp["0", "1"].solution.transport_matrix
        - fgwp["0", "1"].solution.transport_matrix
    )
)
print(f"max difference: {max_difference:.6f}")

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m20[0m, [1;36m20[0m[1m)[0m[1m][0m.                                              


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m20[0m, [1;36m20[0m[1m)[0m[1m][0m.                                              
max difference: 0.021854


## Low-rank solutions

Whenever the dataset is very large, the computational complexity can be
reduced by setting `rank` to a positive integer {cite}`scetbon:21a`. In this
case, `epsilon` can also be set to $0$, while only the balanced case
($\text{tau}_a = \text{tau}_b = 1$) is supported. Moreover, the data has to be provided
as point clouds, i.e., no precomputed cost matrix can be passed.

In [4]:
gwp = gwp.solve(epsilon=1e-2, rank=3)

[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'solved'[0m, [33mshape[0m=[1m([0m[1;36m20[0m, [1;36m20[0m[1m)[0m[1m][0m.                                                


## Scaling the cost

`scale_cost` parameter works the same way as for linear problems, see {doc}`100_linear_problems_basic` for more information. Note that all cost terms will be scaled by the same argument.