Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

application of low-rank to single-cell data #83

Closed
zoepiran opened this issue Jun 23, 2022 · 5 comments
Closed

application of low-rank to single-cell data #83

zoepiran opened this issue Jun 23, 2022 · 5 comments
Labels
bug Something isn't working

Comments

@zoepiran
Copy link
Contributor

TL;DR in this colab we provide an example for our failure in obtaining a valid mapping using low-rank.

problem setup: In this example data set we look into mapping spatial transcriptomics at single cell resolution from mouse embryonic tissues across two time-points. so the quadratic term accounts for distances in spatial coordinates and the linear captures distances in gene-expression.

evaluation of the mapping: As an initial sanity check we look at the cell-transition table, that is the transition matrix with entries grouped by cell types (we asses the row-stochastic, forward, setting). the naive assumption is that cells of the same type, e.g. brain will be mapped mainly mapped to themselves. Evaluating the regular FGW and FGW (unbalanced) this is indeed what we observe. However, for low-rank we get a matrix with constant columns. We observed a similar phenomena at different time-points. Comparing the results we can see hints for the constant columns as they are cell-types also favored in the regular regime.

image
image
image

@zoepiran
Copy link
Contributor Author

zoepiran commented Jun 23, 2022

@LaetitiaPapaxanthos here you can also observe the performance of unbalanced with $\tau_a = \tau_b$.
As I show the problem is indeed in gw_unbalanced_correction=True (with False it works).
you can obviously play there with everything :)
image
image

@michalk8
Copy link
Collaborator

For the tau_a = tau_b = 0.9, I noticed that the total mass transported is very low (1e-5), whereas if only 1 is unbalanced, it's fairly high (0.9). Code to reproduce:

from jax.config import config
config.update("jax_enable_x64", True)
import ott
import jax

import numpy as np
from ott.geometry.pointcloud import PointCloud

np.random.seed(0)
x = np.random.normal(size=(64, 3))
y = np.random.normal(size=(128, 3))
xx = np.random.normal(size=(64, 3))
yy = np.random.normal(size=(128, 3))

o, scale_cost = True, 'max_cost'
geom_x = PointCloud(x, online=o, scale_cost=scale_cost)
geom_y = PointCloud(y, online=o, scale_cost=scale_cost)
geom_xy = PointCloud(xx, yy, online=o, scale_cost=scale_cost)

solver = ott.core.gromov_wasserstein.GromovWasserstein(jit=False, epsilon=1e-2, lse_mode=False)
prob = ott.core.quad_problems.QuadraticProblem(geom_x, geom_y,
                                               geom_xy,
                                               tau_a=0.8, tau_b=0.8,
                                               gw_unbalanced_correction=True)

iteration = 0
state = solver.init_state(prob, -1)
linear_pb = prob.update_linearization(state.linear_state, solver.epsilon, state.old_transport_mass)

out = solver.linear_ot_solver(linear_pb)
old_transport_mass = jax.lax.stop_gradient(
    state.linear_state.transport_mass()
)
state = state.update(
    iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass
)
print(state.linear_state.marginal(0).sum())  # 1.883714535238546e-05

@michalk8
Copy link
Collaborator

In the next iteration, the solution to the linearized problems contains infs; this also causes the transport mass sum to be 0 (and makes the scale between the old and the new transport mass NaN). This only happens when gw_unbalanced_correction=True.

@marcocuturi
Copy link
Contributor

maybe we can close now?

@michalk8
Copy link
Collaborator

completed via #128

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants