Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed-up Sinkhorn #57

Merged
merged 3 commits into from
Jul 18, 2018
Merged

Speed-up Sinkhorn #57

merged 3 commits into from
Jul 18, 2018

Conversation

LeoGautheron
Copy link

Speed-up in 3 places:

  • the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances
  • faster computation of K = np.exp(-M / reg)
  • faster computation of the error every 10 iterations

Example with this little script:

import time
import numpy as np
import ot
rng = np.random.RandomState(0)
transport = ot.da.SinkhornTransport()
time1 = time.time()
Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100)
transport.fit(Xs=Xs, Xt=Xt)
time2 = time.time()
print("OT Computation Time {:6.2f} sec".format(time2-time1))
transport = ot.da.SinkhornLpl1Transport()
transport.fit(Xs=Xs, ys=ys, Xt=Xt)
time3 = time.time()
print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2))

Before
OT Computation Time 19.93 sec
OT LpL1 Computation Time 133.43 sec

After
OT Computation Time 7.55 sec
OT LpL1 Computation Time 82.25 sec

Speed-up in 3 places:
 - the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances
 - faster computation of K = np.exp(-M / reg)
 - faster computation of the error every 10 iterations

Example with this little script:

import time
import numpy as np
import ot
rng = np.random.RandomState(0)
transport = ot.da.SinkhornTransport()
time1 = time.time()
Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100)
transport.fit(Xs=Xs, Xt=Xt)
time2 = time.time()
print("OT Computation Time {:6.2f} sec".format(time2-time1))
transport = ot.da.SinkhornLpl1Transport()
transport.fit(Xs=Xs, ys=ys, Xt=Xt)
time3 = time.time()
print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2))

Before
OT Computation Time  19.93 sec
OT LpL1 Computation Time 133.43 sec

After
OT Computation Time   7.55 sec
OT LpL1 Computation Time  82.25 sec
@rflamary
Copy link
Collaborator

Hello @LeoGautheron ,

Thank you for all your work. Those are very nice speedup and i can confirm that I have similar gains. Still have a few comments.

  • My main concern is to add sklearn as a dependency. Up to now we managed to avoid that (and the tests fail because sklearn is not a hard dependency).

When i run the following code

ot.tic()
M1=ot.dist(x1,x2)
ot.toc()

ot.tic()
M2=euclidean_distances(x1,x2, squared=True)
ot.toc()

ot.tic()
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
M3=x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
ot.toc()

I get the following:

Elapsed time : 2.0640811920166016 s
Elapsed time : 0.37867116928100586 s
Elapsed time : 0.46784543991088867 s

with the last one pure numpy. I think the gain is sufficient with the last implementation and avoid an additional dependency for POT.

  • Next I looked at all your speedup and they all come at the cost of hard to read code. Since they all bring at least a 20% gain in performance i'm ready to keep them but please add some comments around the code like
# Next N lines equivalent to:
# K= np.exp(-M/reg)

For the next guy who wants to have a look at the function.

In any case its nice work and the computational gain is very important.

@agramfort
Copy link
Collaborator

just copy the code you need then https://github.com/scikit-learn/scikit-learn/blob/a24c8b46/sklearn/metrics/pairwise.py#L163

it's pure python code.

@LeoGautheron
Copy link
Author

I used the code from sklearn, still the same performances now :)

@rflamary rflamary merged commit 5cd6c0a into PythonOT:master Jul 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants