Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timing comparison with POT #556

Closed
zwei-beiner opened this issue Jul 7, 2024 · 6 comments
Closed

Timing comparison with POT #556

zwei-beiner opened this issue Jul 7, 2024 · 6 comments

Comments

@zwei-beiner
Copy link

zwei-beiner commented Jul 7, 2024

I'm trying to calculate the W2 distance with OTT, and it seems to be ~10x slower than POT. This was run on CPU.

Is there any way to speed up the calculation with OTT?

Also, please let me know if this is the correct way of calculating the W2 distance with OTT (i.e. calculating it from the transport plan since there is no simple way of accessing it as an attribute of the Sinkhorn solver output).

import time
import numpy as np

import torch
import ot

import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

torch.set_default_dtype(torch.float64)
jax.config.update("jax_enable_x64", True)


def make_samples(nsamples, ndims):
    key1, key2 = jax.random.split(jax.random.PRNGKey(0))
    samples1 = jax.random.ball(key1, ndims, shape=(nsamples,))
    samples2 = jax.random.ball(key2, ndims, shape=(nsamples,))
    return samples1, samples2
samples1, samples2 = make_samples(3000, 10)

def W_pytorch(x, y):
    nsamples = x.shape[0]
    cost_matrix = ot.utils.dist(x, y)
    a = torch.ones(nsamples) / nsamples
    b = torch.ones(nsamples) / nsamples
    loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, reg=1e-2, stopThr=1e-06)
    return loss
print("POT:")
torch_samples1 = torch.from_numpy(np.asarray(samples1))
torch_samples2 = torch.from_numpy(np.asarray(samples2))
tic = time.time()
print(W_pytorch(torch_samples1, torch_samples2))
print("Time:", time.time() - tic)


@jax.jit
def W_jax(x, y):
    geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(threshold=1e-6)
    ot = solver(ot_prob)
    return jnp.sum(ot.matrix * ot.geom.cost_matrix)
print("OTT:")
for _ in range(3): # Running multiple times because of jit compilation time
    tic = time.time()
    print(W_jax(samples1, samples2))
    print("Time:", time.time() - tic)

Output:

POT:
tensor(0.2506)
Time: 11.448308944702148
OTT:
0.2506479110728857
Time: 130.14876246452332
0.2506479110728857
Time: 133.8663191795349
0.2506479110728857
Time: 130.24367785453796
@michalk8
Copy link
Collaborator

michalk8 commented Jul 8, 2024

Hi @zwei-beiner , there are 2 differences between the above benchmark:

  1. We run our computations by default in LSE mode (lse_mode=True in sinkhorn.Sinkhorn) to have better numerical stability
  2. We use by default max 2k iterations

Modifying some of the above code to

def W_pytorch(x, y):
    nsamples = x.shape[0]
    cost_matrix = ot.utils.dist(x, y)
    a = torch.ones(nsamples) / nsamples
    b = torch.ones(nsamples) / nsamples
    loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, method='sinkhorn_log', reg=1e-2, stopThr=1e-06)
    return loss


@jax.jit
def W_jax(x, y):
    geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(threshold=1e-6, lse_mode=True, max_iterations=1000)
    ot = solver(ot_prob)
    return ot.primal_cost

I get

tensor(0.2506)
Time: 44.20695209503174
OTT:
0.25064818126813293
Time: 29.805460929870605
0.25064818126813293
Time: 29.549583911895752
0.25064818126813293
Time: 29.669023990631104

@zwei-beiner
Copy link
Author

Thanks for the reply! I can confirm that OTT is faster now.

I have another question: Suppose that I want to calculate $W_2$ for a sequence of distributions $\mu_n$ which converge to some distribution $\mu$ and I can only access the distributions through samples. I want to show that $W_2\rightarrow 0$ as $n\rightarrow \infty$. The problem is that $W_2$ converges very slowly to zero with increasing sample size, even when I calculate it on two sets of samples from the same distribution. Is there any way around this, i.e. can I force $W_2\rightarrow 0$?

@marcocuturi
Copy link
Contributor

marcocuturi commented Jul 8, 2024

Thanks a lot @michalk8 and thanks @zwei-beiner for the question!

Comparing OTT and POT is challenging because some of the default parameters are not the same. Also, POT has multiple implementations, as mentioned above by @michalk8.

I would also add that the error (used to control convergence) used by default in OTT is a 1-norm, whereas it is a 2-norm in POT. As a consequence, you may have to pass norm_error=2.0 to your Sinkhorn solver, as done in the tutorial https://ott-jax.readthedocs.io/en/latest/tutorials/OTT_%26_POT.html

@michalk8
Copy link
Collaborator

@zwei-beiner did you manage to reproduce the timing benchmarks with the above suggestions?

@zwei-beiner
Copy link
Author

@michalk8 With the above suggestions, I get the following timings:

tensor(0.2392)
Time: 128.88550853729248
OTT:
0.23915911736256118
Time: 67.72269916534424
0.23915911736256118
Time: 69.8149642944336
0.23915911736256118
Time: 69.45572185516357

However, when I set lse_mode=False and method='sinkhorn', POT gets a significant speedup but OTT is essentially unchanged:

tensor(0.2392)
Time: 10.499557495117188
OTT:
0.23915911736256149
Time: 72.12179160118103
0.23915911736256149
Time: 72.88275790214539
0.23915911736256149
Time: 74.41572165489197

@michalk8
Copy link
Collaborator

As mentioned above, you should be using method = "sinkhorn_log" when comparing to ours lse_mode=True.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants