New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] small tentative for EMD 1D in torch #218
Conversation
Codecov Report
@@ Coverage Diff @@
## torch #218 +/- ##
==========================================
+ Coverage 92.40% 92.48% +0.08%
==========================================
Files 19 19
Lines 3132 3169 +37
==========================================
+ Hits 2894 2931 +37
Misses 238 238 |
Hello @AdrienCorenflos thank you for the PR, as you can see i'm working on pytorch functions for POT. i welcome contributions and I agree that we need a good implementation for sliced. Note that i want the function to handle different devices sensibly (basically the computation and temporary variables have to be in the same device as the input tensors) so be careful when initializing in your functions. I'm wondering why there is a while loop in the code. This does not sem to be necessary from the pure numpy implementation here: also I will need you to use the new API naming I'm trying with ot_loss_1d for the function that computes the loss (i'm trying ot_solve for a solver that returns the OT matrix and ot_loss for functions that return the loss). And of course I want the code to pass the tests. About type hints i'm OK with those even though we have straightforward function and we suppose that all array in the ot.torch module are torch arrays. What are the computation time of your implementation and the one on cpu from POT? |
Hi, The type hinting: I only have it because it simplifies my development work (under Pycharm), it can disappear (I expect Python 3.5 is not liking it anyway). There are three reasons for the while loop. The simpler one is that the link you provided computes the distance between the CDFs, whereas OT is a distance between inverse CDFs, so that the scipy case only would work for p=1 (see remark 2.30 in Peyré-Cuturi https://arxiv.org/pdf/1803.00567.pdf). Now you could compute the pseudo inverse of the empirical CDF indeed and use the same technique (this is equation kind of equation (10.14) in Peyré and Cuturi's book) but it's only easy for these cases when we have the same importance weights for the source and the target distributions. The pedantic one is that this implementation should anyway have a better theoretical complexity than the searchsorted one (O(n) instead of O(n log n)), which reflects at least on CPU. But the real reason is that I adapted this code from my JAX implementation which doesn't support CUDA/CPP native searchsorted but rather implements it as a recursive function which tends to increase the size of the graph (and as a consequence the time of compilation) quite a bit so that I ended up preferring looping explicitely. I'd probably tend to say that if you want something generic it's probably best to go with a while loop now and implement computational improvements down the road. Adrien |
Hi there ! |
Yes please @ncourty but does your implementation work for general weights (searchsorted suggests yes)? I would like quantitative comparison (cpu+gpu) before we jump on one implementation indeed. |
Yes it is meant for general weights. There might a chance a chance that for a simple 1D distance, the while method of Adrien may be faster than the searchsorted version, but in a Sliced Wasserstein method, it will be much more efficient as it is vectorized and can work on all projections simultaneously. I'll do some tests. |
Hi Nicolas, |
OK great then, let's compare them on simple examples and with large number of samples ;) |
Hello, I did some tests. You can find the google colab notebook here (with editor privileges): On the CPU, simple 1D W between 20000 and 100000 random samples: Sanity check for the values: On the GPU, there was a problem with your code Adrien (see the colab, two tensors not on the same device) |
Thanks for that! |
I have some old code that has this searchsorted approach on inverse CDF and not CDF, I'll iron it out over the weekend if I have some time |
Ok, I agree with that, the case of using the quantile function when the weights are not equals in the two distributions is trickier than the implementation with the CDF (never seen efficient implementation so far). I guess that searchsorted can also be efficiently used here, in the same spirit. I'll give it a try also, so we can compare in the end :) |
This is very interesting ! We will definitely converge toward a very fast implementation. @AdrienCorenflos what do you mean by batch inputs? For several 1D directions simultaneously? I agree that it is important especially for implementing sliced down the line. Also I think it would be very interesting to try and think about the dual potentials of the OT solution. Thanks to those dual potentials, we can have subgradients for the weights in my implementation of ot_loss which is very nice if you want to optimize them (I will add a nice example about that in the torch branch). I think feasible duals can be computed form the sorted solution but it can wait for the next PR of course. |
That's what I mean indeed |
Ok, I managed to write a full version with quantile functions, using searchsorted, that accept any power and is fully vectorized (function Comparing with your while loop @AdrienCorenflos , I get the following performances for a 100 Wasserstein distances on samples of size 2000 and 1000:
I do not exactly get why the GPU version is slower on your side, there might be room for improment. In any case, I'll prepare a PR on the upcoming days including sliced Wasserstein and also maybe sliced Gromov-Wasserstein, with test codes and also proper documentation (might take some days though). |
@rflamary I guess the cost function could/should be set as a parameter of those functions. By this, I mean going beyond Minkowsky distances and provide eventually a generic function c(x,y) that could be NN or whatever custom cost. I guess this is a design choice for the whole torch branch, that has broader impact. This is definitely something we should discuss. |
Looping over GPU operations is a pretty bad thing to do when you don't need to. There's a good chance the controlflow operations are actually executed on CPU so that there is undue memory exchange between CPU and GPU, though torch isn't my cup of tea so I can't be 100% sure about this. I just checked, your implementation doesn't support non-uniform weights. I'll try and correct this during the day. On a side note the .unique isn't a good idea, it's making the code non-future proof as XLA doesn't support dynamic shape allocation. |
@ncourty |
@ncourty wrt the cost function since we are in 1d i think a reasonnable parameter is p for |x-y|^p with p>=1. |
Finished it, see the colab. def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights, v_weights, p=1):
u_values_sorted, u_sorter = torch.sort(u_values)
v_values_sorted, v_sorter = torch.sort(v_values)
zero = torch.zeros(1, dtype=u_values.dtype, device=u_values.device)
u_weights = u_weights[u_sorter]
u_cdf = torch.cumsum(u_weights, 0)
v_weights = v_weights[v_sorter]
v_cdf = torch.cumsum(v_weights, 0)
cdf_axis = torch.cat((zero, torch.sort(torch.cat((u_cdf, v_cdf)))[0]))
u_index = torch.searchsorted(u_cdf, cdf_axis[1:])
v_index = torch.searchsorted(v_cdf, cdf_axis[1:])
u_icdf = u_values_sorted[u_index]
v_icdf = v_values_sorted[v_index]
delta = cdf_axis[1:] - cdf_axis[:-1]
return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p)) Seems to be working fine on a simple example. Will need some ironing out (default arguments and such) and proper testing (with duplicates in weights and locations), but otherwise it seems pretty fine. |
Seems fine even with duplicates |
How do you guys want to proceed re the PR? |
Thank you @AdrienCorenflos, but there are some elements I do not get: it seems that the version of the function is not adapted to handle several distribution at the same time, and is also based on the CDF rather than the quantile functions (rendering the computation for other powers p invalid ?) Besides I do not see much the difference with the |
agreed, I still need to (trivially) vectorise it
That's not true, if anything I have tested for different powers. What I do is that I indeed build the CDF for u and v, then compute the pseudo inverse of the CDF (so basically the quantile function) on a common grid for u and v, so that we can substract them
If you check what you are integrating against what, you are integrating the difference of the cdf against the locations of the particles, I am integrating the difference between the locations against the cumulative weights. To make it concrete you can look at the difference in how you define |
ok thx @AdrienCorenflos I get the differences. Here is the (not so trivial) vectorized version: def emd1D_searchsorted_quantile_vector(U_values, V_values, u_weights=None, v_weights=None,p=1, require_sort=True):
# sort the two distributions
if require_sort:
U_values_sorted, U_sorter = torch.sort(U_values,dim=1)
V_values_sorted, V_sorter = torch.sort(V_values,dim=1)
else:
U_values_sorted = U_values
V_values_sorted = V_values
# build cdfs
if u_weights is None:
U_cdf = (torch.arange(1,U_values.size(1)+1).double().to(U_values.device) / U_values.size(1)).repeat(U_values.size(0),1)
else:
U_weights_repeat = u_weights.repeat(U_values.size(0),1)
if require_sort:
sorted_U_weights_repeat = U_weights_repeat.gather(1,U_sorter)
U_cdf = torch.cumsum(sorted_U_weights_repeat,dim=1)
else:
U_cdf = torch.cumsum(U_weights_repeat,dim=1)
if v_weights is None:
V_cdf = (torch.arange(1,V_values.size(1)+1).double().to(U_values.device) / V_values.size(1)).repeat(V_values.size(0),1)
else:
V_weights_repeat = v_weights.repeat(V_values.size(0),1)
if require_sort:
sorted_V_weights_repeat = V_weights_repeat.gather(1,V_sorter)
V_cdf = torch.cumsum(sorted_V_weights_repeat,dim=1)
else:
V_cdf = torch.cumsum(V_weights_repeat,dim=1)
all_cdf_values,_ = torch.sort(torch.cat((U_cdf, V_cdf),dim=1),dim=1)
all_weight_values = torch.cat((torch.zeros(U_values.size(0),1).to(U_values.device),
all_cdf_values),
dim=1)
deltas = all_weight_values[:,1:] - all_weight_values[:,:-1]
all_weight_values =all_weight_values.contiguous()
idx_U = torch.searchsorted(U_cdf,all_weight_values[:,1:])
idx_V = torch.searchsorted(V_cdf,all_weight_values[:,1:])
U = U_values_sorted.gather(1,idx_U)
V = V_values_sorted.gather(1,idx_V)
if p==1:
return torch.sum(torch.mul(torch.abs(U - V),deltas),dim=1)
else:
return torch.sum(torch.mul(torch.pow(torch.abs(U - V),p),deltas),dim=1) Please note that I assume that all the 1D distributions share the same weight vector. I tested against various sizes of distributions with uniform and random weights. I scratched my head (for a much too longer time) against a weird bug: when the weight vector is cast as .double() instead of .float() it seems that the search sorted from pytorch has a different behaviour. We will try to make it clear from the doc. |
This is what I had in mind in terms of vectorisation. Using the python ellipsis makes everything way easier. I've not checked contiguity of all the tensors but torch doesn't complain so I guess it's all good (I was worried about Let me know if you have worries. def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
n = u_values.shape[-1]
m = v_values.shape[-1]
device = u_values.device
dtype = u_values.dtype
if require_sort:
u_values, u_sorter = torch.sort(u_values, -1)
v_values, v_sorter = torch.sort(v_values, -1)
else:
u_sorter = v_sorter = slice(None)
zero = torch.zeros(1, dtype=dtype, device=device)
if u_weights is None:
u_weights = torch.full(n, 1/n, dtype=dtype, device=device)
if v_weights is None:
v_weights = torch.full(m, 1/m, dtype=dtype, device=device)
u_weights = u_weights[..., u_sorter]
u_cdf = torch.cumsum(u_weights, -1)
v_weights = v_weights[..., v_sorter]
v_cdf = torch.cumsum(v_weights, -1)
cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)
u_index = torch.searchsorted(u_cdf, cdf_axis)
v_index = torch.searchsorted(v_cdf, cdf_axis)
u_icdf = torch.gather(u_values, -1, u_index.clip(0, n-1))
v_icdf = torch.gather(v_values, -1, v_index.clip(0, m-1))
cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]
if p == 1:
return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
if p == 2:
return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1) # I actually don't think this is useful, a lot of core pytorch code uses .pow(2)
return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1) |
also please note that the slices allow for higher dimensional batching for free: for example this will works: nb_source = 3
nb_target = 5
d = 2
u_s = np.random.randn(2, d, nb_source)
u_t = np.random.randn(2, d, nb_target) It can be useful if you're trying to do some SW barycenter. |
nice, I see we have two different ways of coding. Do you know why |
What exactly is tricky so that I can comment it further? def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
n = u_values.shape[-1]
m = v_values.shape[-1]
device = u_values.device
dtype = u_values.dtype
if u_weights is None:
u_weights = torch.full(n, 1/n, dtype=dtype, device=device)
if v_weights is None:
v_weights = torch.full(m, 1/m, dtype=dtype, device=device)
if require_sort:
u_values, u_sorter = torch.sort(u_values, -1)
v_values, v_sorter = torch.sort(v_values, -1)
u_weights = u_weights[..., u_sorter]
v_weights = v_weights[..., v_sorter]
zero = torch.zeros(1, dtype=dtype, device=device)
u_cdf = torch.cumsum(u_weights, -1)
v_cdf = torch.cumsum(v_weights, -1)
cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)
u_index = torch.searchsorted(u_cdf, cdf_axis)
v_index = torch.searchsorted(v_cdf, cdf_axis)
u_icdf = torch.gather(u_values, -1, u_index.clip(0, n-1))
v_icdf = torch.gather(v_values, -1, v_index.clip(0, m-1))
cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]
if p == 1:
return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
if p == 2:
return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1) # I actually don't think this is useful, a lot of core pytorch code uses .pow(2)
return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1) looks like this in the end |
All good on my side ! Minor comment: on Google colab, I had to change torch.full(n, 1/n, dtype=dtype, device=device) by torch.full((n,), 1/n, dtype=dtype, device=device) I guess it is a torch version thing ? |
Yes this is great, you can add this very nice and efficient implementation to the PR. Note that some comments in the function with reference to where the algorithm can be found sould be a plus. If nico had a hard time understanding it it means that it is not easy to understand so please help the next reader ;). about 3.5 i agree that since it reached end of life i should probably remove its support, you can do it in the github action but the n i want to tested for 3.9 that is the recent release (it should work but that's usually when we get surprises) |
Yes, I'm actually going to have an auxiliary |
torch for python 3.5 doesn't have searchsorted |
As i said remove 3.5 from https://github.com/PythonOT/POT/blob/master/.github/workflows/build_tests.yml and add 3.9 |
Ah my bad thought you were doing it in the other PR |
Remove py3.5 from supported version and add 3.9 instead
cvxopt doesn't build under py3.9 I'm removing 3.9 |
CVXOPT doesn't support 3.9
I'll do the SWD a bit later |
@AdrienCorenflos I can do SWD if you want, I already have it coded |
Sure |
ok ! I'll do it once your PR is merged in the torch branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great and i see you guys are eager to code new stuff from this,
here are a few comments before merging
…ytorch doesn't allow for a bunch of operations to be taken along a certain dimension (searchsorted, etc...)
@rflamary do you have further comments? |
OI will let @ncourty do the last code review, he will be the one using you code for sliced. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work @AdrienCorenflos !
For some reason github doesn't send a notification about PR comments. So I just saw this, sorry about the delay, it would have taken 5mn... |
I saw the torch branch for LP stuff.
Would you be interested in my implementation for the 1d EMD (and the sliced wasserstein with it)?
I'm not a huge fan of Pytorch so I can't vouch that what I'm doing here is the best implementation, but it feels to me like it should be fairly ok for batched inputs which is what you want for slice stuff anyway.