In [None]:
%matplotlib inline

Solving linear problems
=======================


This notebook elaborates on how to solve linear problems, e.g. the
`moscot.problems.time.TemporalProblem`{.interpreted-text role="class"}
and the `moscot.problems.generic.SinkhornProblem`{.interpreted-text
role="class"}.


In [None]:
from moscot.datasets import simulate_data
from moscot.problems.generic import SinkhornProblem

import numpy as np

adata = simulate_data(n_distributions=2, key="day")
adata

The `moscot.problems.time.TemporalProblem.solve`{.interpreted-text
role="meth"} has numerous arguments, a few of which will be discussed in
the following.


Basic parameters
================

[epsilon]{.title-ref} is the regularization parameter. The lower
[epsilon]{.title-ref}, the sparser the transport map. At the same time,
the algorithm takes longer to converge. [tau\_a]{.title-ref} and
[tau\_b]{.title-ref} denote the unbalancedness parameters in the source
and the target distribution, respectively. [tau\_a = 1]{.title-ref}
means the source marginals have to be fully satisfied while [0 \< tau\_a
\< 1]{.title-ref} relaxes this condition. Analogously,
[tau\_b]{.title-ref} affects the marginals of the target distribution.
We demonstrate the effect of [tau\_a]{.title-ref} and
[tau\_b]{.title-ref} with the
`moscot.problems.generic.SinkhornProblem`{.interpreted-text
role="class"}. Whenever the prior marginals [a]{.title-ref} and
[b]{.title-ref} of the source and the target distribution, respectively,
are not passed (TODO link to marginals notebook), they are set to be
uniform.


In [None]:
sp = SinkhornProblem(adata)
sp = sp.prepare(key="day")
print(sp[0, 1].a[:5], sp[0, 1].b[:5])

First, we solve the problem in a balanced manner, such that the
posterior marginals of the solution (the sum over the rows and the
columns for the source marginals and the target marginals, respectively)
are equal to the prior marginals up to small errors (which define the
convergence criterion in the balanced case).


In [None]:
sp = sp.solve(epsilon=1e-2, tau_a=1, tau_b=1)
print(sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5])

If we solve an unbalanced problem, the posterior marginals will be
different.


In [None]:
sp = sp.solve(epsilon=1e-2, tau_a=0.9, tau_b=0.99)
print(sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5])

Low-rank solutions
==================

Whenever the dataset is very large, the computational complexity can be
reduced by setting [rank]{.title-ref} to a positive integer
(`scetbon:21a`{.interpreted-text role="cite"}). In this case,
[epsilon]{.title-ref} can also be set to 0, while only the balanced case
([tau\_a = tau\_b = 1]{.title-ref}) is supported. The [rank]{.title-ref}
should be significantly smaller than the number of cells in both source
and target distribution.


In [None]:
sp = sp.solve(epsilon=0, rank=3, initializer="random")

Scaling the cost
================

[scale\_cost]{.title-ref} scales the cost matrix which often helps the
algorithm to converge. While any number can be passed, it is also
possible to scale the cost matrix by e.g. its mean, median, and maximum.
We recommend using the [mean]{.title-ref} as this is possible without
instantiating the cost matrix and hence reduces computational
complexity. Moreover, it is more stable w.r.t. outliers than for example
scaling by the maximum. Note that the solution of the Optimal Transport
is not stable across different scalings:


In [None]:
sp = sp.solve(epsilon=1e-2, scale_cost="mean")
tm_mean = sp[0, 1].solution.transport_matrix
print(tm_mean[:3, :3])

In [None]:
sp = sp.solve(epsilon=1e-2, scale_cost="max_cost")
tm_max = sp[0, 1].solution.transport_matrix
print(tm_max[:3, :3])

We can compute the correlation of the flattened transport matrix to get
an idea of the influence of different scalings.


In [None]:
correlation = np.corrcoef(tm_mean.flatten(), tm_max.flatten())[0, 1]
print(f"{correlation:.4f}")

TODO See other examples for \...
