-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Timing comparison with POT #556
Comments
Hi @zwei-beiner , there are 2 differences between the above benchmark:
Modifying some of the above code to def W_pytorch(x, y):
nsamples = x.shape[0]
cost_matrix = ot.utils.dist(x, y)
a = torch.ones(nsamples) / nsamples
b = torch.ones(nsamples) / nsamples
loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, method='sinkhorn_log', reg=1e-2, stopThr=1e-06)
return loss
@jax.jit
def W_jax(x, y):
geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn(threshold=1e-6, lse_mode=True, max_iterations=1000)
ot = solver(ot_prob)
return ot.primal_cost I get tensor(0.2506)
Time: 44.20695209503174
OTT:
0.25064818126813293
Time: 29.805460929870605
0.25064818126813293
Time: 29.549583911895752
0.25064818126813293
Time: 29.669023990631104 |
Thanks for the reply! I can confirm that OTT is faster now. I have another question: Suppose that I want to calculate |
Thanks a lot @michalk8 and thanks @zwei-beiner for the question! Comparing OTT and POT is challenging because some of the default parameters are not the same. Also, POT has multiple implementations, as mentioned above by @michalk8. I would also add that the error (used to control convergence) used by default in OTT is a 1-norm, whereas it is a 2-norm in POT. As a consequence, you may have to pass |
@zwei-beiner did you manage to reproduce the timing benchmarks with the above suggestions? |
@michalk8 With the above suggestions, I get the following timings:
However, when I set
|
As mentioned above, you should be using |
I'm trying to calculate the W2 distance with OTT, and it seems to be ~10x slower than POT. This was run on CPU.
Is there any way to speed up the calculation with OTT?
Also, please let me know if this is the correct way of calculating the W2 distance with OTT (i.e. calculating it from the transport plan since there is no simple way of accessing it as an attribute of the Sinkhorn solver output).
Output:
The text was updated successfully, but these errors were encountered: