Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] small tentative for EMD 1D in torch #218

Merged
merged 8 commits into from Jan 3, 2021
Merged

[WIP] small tentative for EMD 1D in torch #218

merged 8 commits into from Jan 3, 2021

Conversation

AdrienCorenflos
Copy link
Contributor

I saw the torch branch for LP stuff.
Would you be interested in my implementation for the 1d EMD (and the sliced wasserstein with it)?

I'm not a huge fan of Pytorch so I can't vouch that what I'm doing here is the best implementation, but it feels to me like it should be fairly ok for batched inputs which is what you want for slice stuff anyway.

@codecov
Copy link

codecov bot commented Nov 16, 2020

Codecov Report

Merging #218 (bf62359) into torch (6edcb1f) will increase coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##            torch     #218      +/-   ##
==========================================
+ Coverage   92.40%   92.48%   +0.08%     
==========================================
  Files          19       19              
  Lines        3132     3169      +37     
==========================================
+ Hits         2894     2931      +37     
  Misses        238      238              

@rflamary
Copy link
Collaborator

Hello @AdrienCorenflos

thank you for the PR, as you can see i'm working on pytorch functions for POT. i welcome contributions and I agree that we need a good implementation for sliced. Note that i want the function to handle different devices sensibly (basically the computation and temporary variables have to be in the same device as the input tensors) so be careful when initializing in your functions.

I'm wondering why there is a while loop in the code. This does not sem to be necessary from the pure numpy implementation here:
https://github.com/scipy/scipy/blob/v1.5.4/scipy/stats/stats.py#L7723

also I will need you to use the new API naming I'm trying with ot_loss_1d for the function that computes the loss (i'm trying ot_solve for a solver that returns the OT matrix and ot_loss for functions that return the loss).

And of course I want the code to pass the tests. About type hints i'm OK with those even though we have straightforward function and we suppose that all array in the ot.torch module are torch arrays.

What are the computation time of your implementation and the one on cpu from POT?

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 17, 2020

Hi,
Will do all the requested things. A longer response below.

The type hinting: I only have it because it simplifies my development work (under Pycharm), it can disappear (I expect Python 3.5 is not liking it anyway).

There are three reasons for the while loop.

The simpler one is that the link you provided computes the distance between the CDFs, whereas OT is a distance between inverse CDFs, so that the scipy case only would work for p=1 (see remark 2.30 in Peyré-Cuturi https://arxiv.org/pdf/1803.00567.pdf). Now you could compute the pseudo inverse of the empirical CDF indeed and use the same technique (this is equation kind of equation (10.14) in Peyré and Cuturi's book) but it's only easy for these cases when we have the same importance weights for the source and the target distributions.

The pedantic one is that this implementation should anyway have a better theoretical complexity than the searchsorted one (O(n) instead of O(n log n)), which reflects at least on CPU.

But the real reason is that I adapted this code from my JAX implementation which doesn't support CUDA/CPP native searchsorted but rather implements it as a recursive function which tends to increase the size of the graph (and as a consequence the time of compilation) quite a bit so that I ended up preferring looping explicitely.

I'd probably tend to say that if you want something generic it's probably best to go with a while loop now and implement computational improvements down the road.

Adrien

@ncourty
Copy link
Collaborator

ncourty commented Nov 17, 2020

Hi there !
I have a torch implementation of the sliced Wasserstein distance that uses the searchsorted feature (originally written by A. Liuktus but that has been ported to pytorch few month ago). I can provide it and we might see which one of the two is better (your while loop or the search sorted implem).

@rflamary
Copy link
Collaborator

Yes please @ncourty but does your implementation work for general weights (searchsorted suggests yes)? I would like quantitative comparison (cpu+gpu) before we jump on one implementation indeed.

@ncourty
Copy link
Collaborator

ncourty commented Nov 18, 2020

Yes it is meant for general weights. There might a chance a chance that for a simple 1D distance, the while method of Adrien may be faster than the searchsorted version, but in a Sliced Wasserstein method, it will be much more efficient as it is vectorized and can work on all projections simultaneously. I'll do some tests.

@AdrienCorenflos
Copy link
Contributor Author

Hi Nicolas,
Please note that my code is (manually) vectorised. The only possible slowdown in my opinion is the while loop in python that doesn't behave well on GPU...
In the case when this would be compiled in a torch script there might be ways to tell the compiler that the memory access would be contiguous but that defeats my torch skills :)

@rflamary
Copy link
Collaborator

OK great then, let's compare them on simple examples and with large number of samples ;)

@ncourty
Copy link
Collaborator

ncourty commented Nov 18, 2020

Hello, I did some tests. You can find the google colab notebook here (with editor privileges):
https://colab.research.google.com/drive/1Bv_JPPfQuMekQjAkAHtowsVQzMm2zwWv?usp=sharing

On the CPU, simple 1D W between 20000 and 100000 random samples:
10 loops, best of 3: 42 ms per loop <-- Searchsorted
1 loop, best of 3: 18.6 s per loop <-- While loop
10 loops, best of 3: 29.2 ms per loop <-- Scikitversion

Sanity check for the values:
Torch Searchsorted EMD1D 0.01258190723442535
Torch While EMD1D 0.012581907234425298
Scipy EMD1D 0.012581907234425349

On the GPU, there was a problem with your code Adrien (see the colab, two tensors not on the same device)
The SearchSorted version took 1.19 ms (x38 faster than CPU)

@AdrienCorenflos
Copy link
Contributor Author

Hello, I did some tests. You can find the google colab notebook here (with editor privileges):
https://colab.research.google.com/drive/1Bv_JPPfQuMekQjAkAHtowsVQzMm2zwWv?usp=sharing

On the CPU, simple 1D W between 20000 and 100000 random samples:
10 loops, best of 3: 42 ms per loop <-- Searchsorted
1 loop, best of 3: 18.6 s per loop <-- While loop
10 loops, best of 3: 29.2 ms per loop <-- Scikitversion

Sanity check for the values:
Torch Searchsorted EMD1D 0.01258190723442535
Torch While EMD1D 0.012581907234425298
Scipy EMD1D 0.012581907234425349

On the GPU, there was a problem with your code Adrien (see the colab, two tensors not on the same device)
The SearchSorted version took 1.19 ms (x38 faster than CPU)

Thanks for that!
I modified the notebook a bit to make it fairer (not yet for batched inputs on your side of the code though) so that the vectorization shows.
Please note that as mentioned above your approach only works for minkowski with p=1 as it indeed reduces to a CDF distance in this special case, but it doesn't extend to p != 1, as shown in the notebook with p=2, while my approach does.
Adrien

@AdrienCorenflos
Copy link
Contributor Author

I have some old code that has this searchsorted approach on inverse CDF and not CDF, I'll iron it out over the weekend if I have some time

@ncourty
Copy link
Collaborator

ncourty commented Nov 18, 2020

Ok, I agree with that, the case of using the quantile function when the weights are not equals in the two distributions is trickier than the implementation with the CDF (never seen efficient implementation so far). I guess that searchsorted can also be efficiently used here, in the same spirit. I'll give it a try also, so we can compare in the end :)

@rflamary
Copy link
Collaborator

This is very interesting !

We will definitely converge toward a very fast implementation. @AdrienCorenflos what do you mean by batch inputs? For several 1D directions simultaneously? I agree that it is important especially for implementing sliced down the line.

Also I think it would be very interesting to try and think about the dual potentials of the OT solution. Thanks to those dual potentials, we can have subgradients for the weights in my implementation of ot_loss which is very nice if you want to optimize them (I will add a nice example about that in the torch branch). I think feasible duals can be computed form the sorted solution but it can wait for the next PR of course.

@AdrienCorenflos
Copy link
Contributor Author

That's what I mean indeed

@ncourty
Copy link
Collaborator

ncourty commented Nov 22, 2020

Ok, I managed to write a full version with quantile functions, using searchsorted, that accept any power and is fully vectorized (function emd1D_searchsorted_quantile_vector) in the previous notebook.

Comparing with your while loop @AdrienCorenflos , I get the following performances for a 100 Wasserstein distances on samples of size 2000 and 1000:

  • On the CPU:
    10 loops, best of 3: 82.4 ms per loop <-- search sorted
    1 loop, best of 3: 447 ms per loop <-- while

  • On the GPU
    100 loops, best of 3: 2.66 ms per loop <-- search sorted
    1 loop, best of 3: 1.32 s per loop <-- while

I do not exactly get why the GPU version is slower on your side, there might be room for improment. In any case, I'll prepare a PR on the upcoming days including sliced Wasserstein and also maybe sliced Gromov-Wasserstein, with test codes and also proper documentation (might take some days though).

@ncourty
Copy link
Collaborator

ncourty commented Nov 22, 2020

@rflamary I guess the cost function could/should be set as a parameter of those functions. By this, I mean going beyond Minkowsky distances and provide eventually a generic function c(x,y) that could be NN or whatever custom cost. I guess this is a design choice for the whole torch branch, that has broader impact. This is definitely something we should discuss.

@AdrienCorenflos
Copy link
Contributor Author

Looping over GPU operations is a pretty bad thing to do when you don't need to. There's a good chance the controlflow operations are actually executed on CPU so that there is undue memory exchange between CPU and GPU, though torch isn't my cup of tea so I can't be 100% sure about this.

I just checked, your implementation doesn't support non-uniform weights. I'll try and correct this during the day.

On a side note the .unique isn't a good idea, it's making the code non-future proof as XLA doesn't support dynamic shape allocation.

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 23, 2020

@ncourty
I've put my working version in the Colab nb.
There's a border effect which is not taken into account (probably an undue shift or something, I need to check with the data how I build my quantile function), but it's close. I'll get back at it later.

@rflamary
Copy link
Collaborator

@ncourty wrt the cost function since we are in 1d i think a reasonnable parameter is p for |x-y|^p with p>=1.
there is no close for for non convex cost to my knowledge and this is fine enough or at least this is what we implemented in the CPU implementation.

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 23, 2020

Finished it, see the colab.

def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights, v_weights, p=1):
    u_values_sorted, u_sorter = torch.sort(u_values)
    v_values_sorted, v_sorter = torch.sort(v_values)

    zero = torch.zeros(1, dtype=u_values.dtype, device=u_values.device)

    u_weights = u_weights[u_sorter]
    u_cdf = torch.cumsum(u_weights, 0)

    v_weights = v_weights[v_sorter]
    v_cdf = torch.cumsum(v_weights, 0)

    cdf_axis = torch.cat((zero, torch.sort(torch.cat((u_cdf, v_cdf)))[0]))

    u_index = torch.searchsorted(u_cdf, cdf_axis[1:])
    v_index = torch.searchsorted(v_cdf, cdf_axis[1:])

    u_icdf = u_values_sorted[u_index]
    v_icdf = v_values_sorted[v_index]

    delta = cdf_axis[1:] - cdf_axis[:-1]

    return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p))

Seems to be working fine on a simple example. Will need some ironing out (default arguments and such) and proper testing (with duplicates in weights and locations), but otherwise it seems pretty fine.

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 23, 2020

Seems fine even with duplicates

@AdrienCorenflos
Copy link
Contributor Author

How do you guys want to proceed re the PR?

@ncourty
Copy link
Collaborator

ncourty commented Nov 23, 2020

Thank you @AdrienCorenflos, but there are some elements I do not get: it seems that the version of the function is not adapted to handle several distribution at the same time, and is also based on the CDF rather than the quantile functions (rendering the computation for other powers p invalid ?) Besides I do not see much the difference with the emd1D_searchsortedfirst function that I wrote, apart from the fact that in the case of p=2 I take the square root of the expression (W_2) instead of (W^2_2). Also I just realized that there is some discrepancy between the notebook I'm working on and the one I referenced before (yes, colab is not the best option to share code, I need to take a closer look to see what's going on)

@AdrienCorenflos
Copy link
Contributor Author

Several distributions at the same time:

agreed, I still need to (trivially) vectorise it

based on the CDF rather than the quantile functions

That's not true, if anything I have tested for different powers. What I do is that I indeed build the CDF for u and v, then compute the pseudo inverse of the CDF (so basically the quantile function) on a common grid for u and v, so that we can substract them

Besides I do not see much the difference with the emd1D_searchsortedfirst function that I wrote

If you check what you are integrating against what, you are integrating the difference of the cdf against the locations of the particles, I am integrating the difference between the locations against the cumulative weights. To make it concrete you can look at the difference in how you define deltas vs how I do.

@ncourty
Copy link
Collaborator

ncourty commented Nov 24, 2020

ok thx @AdrienCorenflos I get the differences. Here is the (not so trivial) vectorized version:

def emd1D_searchsorted_quantile_vector(U_values, V_values, u_weights=None, v_weights=None,p=1, require_sort=True):
    # sort the two distributions 
    if require_sort:
      U_values_sorted, U_sorter = torch.sort(U_values,dim=1)
      V_values_sorted, V_sorter = torch.sort(V_values,dim=1)
    else:
      U_values_sorted = U_values
      V_values_sorted = V_values

    # build cdfs
    if u_weights is None:
        U_cdf = (torch.arange(1,U_values.size(1)+1).double().to(U_values.device) / U_values.size(1)).repeat(U_values.size(0),1)  
    else:
      U_weights_repeat = u_weights.repeat(U_values.size(0),1)
      if require_sort:
        sorted_U_weights_repeat = U_weights_repeat.gather(1,U_sorter)
        U_cdf = torch.cumsum(sorted_U_weights_repeat,dim=1)
      else:
        U_cdf = torch.cumsum(U_weights_repeat,dim=1)

    if v_weights is None:
        V_cdf = (torch.arange(1,V_values.size(1)+1).double().to(U_values.device) / V_values.size(1)).repeat(V_values.size(0),1)  
    else:
      V_weights_repeat = v_weights.repeat(V_values.size(0),1)
      if require_sort:
        sorted_V_weights_repeat = V_weights_repeat.gather(1,V_sorter)
        V_cdf = torch.cumsum(sorted_V_weights_repeat,dim=1)
      else:
        V_cdf = torch.cumsum(V_weights_repeat,dim=1)

    all_cdf_values,_ = torch.sort(torch.cat((U_cdf, V_cdf),dim=1),dim=1)

    all_weight_values = torch.cat((torch.zeros(U_values.size(0),1).to(U_values.device),
                                   all_cdf_values),
                                   dim=1) 

    deltas = all_weight_values[:,1:] - all_weight_values[:,:-1]

    all_weight_values =all_weight_values.contiguous()

    idx_U = torch.searchsorted(U_cdf,all_weight_values[:,1:])
    idx_V = torch.searchsorted(V_cdf,all_weight_values[:,1:])

    U = U_values_sorted.gather(1,idx_U)
    V = V_values_sorted.gather(1,idx_V)

    if p==1:
      return torch.sum(torch.mul(torch.abs(U - V),deltas),dim=1)
    else:
      return torch.sum(torch.mul(torch.pow(torch.abs(U - V),p),deltas),dim=1)

Please note that I assume that all the 1D distributions share the same weight vector. I tested against various sizes of distributions with uniform and random weights. I scratched my head (for a much too longer time) against a weird bug: when the weight vector is cast as .double() instead of .float() it seems that the search sorted from pytorch has a different behaviour. We will try to make it clear from the doc.

@AdrienCorenflos
Copy link
Contributor Author

This is what I had in mind in terms of vectorisation.

Using the python ellipsis makes everything way easier. I've not checked contiguity of all the tensors but torch doesn't complain so I guess it's all good (I was worried about cdf_axis[..., 1:] which breaks the contiguity, but I guess it's fine due to the diff).

Let me know if you have worries.

def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
    
    n = u_values.shape[-1]
    m = v_values.shape[-1]

    device = u_values.device
    dtype = u_values.dtype

    if require_sort:
        u_values, u_sorter = torch.sort(u_values, -1)
        v_values, v_sorter = torch.sort(v_values, -1)
    else:
        u_sorter = v_sorter = slice(None)

    zero = torch.zeros(1, dtype=dtype, device=device)

    if u_weights is None:
        u_weights = torch.full(n, 1/n, dtype=dtype, device=device)

    if v_weights is None:
        v_weights = torch.full(m, 1/m, dtype=dtype, device=device)

    u_weights = u_weights[..., u_sorter]
    u_cdf = torch.cumsum(u_weights, -1)

    v_weights = v_weights[..., v_sorter]
    v_cdf = torch.cumsum(v_weights, -1)

    cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)
    
    u_index = torch.searchsorted(u_cdf, cdf_axis)
    v_index = torch.searchsorted(v_cdf, cdf_axis)

    u_icdf = torch.gather(u_values, -1, u_index.clip(0, n-1))
    v_icdf = torch.gather(v_values, -1, v_index.clip(0, m-1))

    cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
    delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]

    if p == 1:
        return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
    if p == 2:
        return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1)  # I actually don't think this is useful, a lot of core pytorch code uses .pow(2)
    return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1)

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 24, 2020

also please note that the slices allow for higher dimensional batching for free: for example this will works:

nb_source = 3
nb_target = 5
d = 2
 
u_s = np.random.randn(2, d, nb_source)
u_t = np.random.randn(2, d, nb_target)

It can be useful if you're trying to do some SW barycenter.

@ncourty
Copy link
Collaborator

ncourty commented Nov 24, 2020

nice, I see we have two different ways of coding. Do you know why u_index.clip(0, n-1) is necessary ? This precisely the thing that bugged me for the whole morning. Also, did you check if the two versions had similar performances ?

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 24, 2020

What exactly is tricky so that I can comment it further?
I'm going to remove the slice none, it only works in 1D

def emd1D_searchsorted_quantile_adrien(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
    
    n = u_values.shape[-1]
    m = v_values.shape[-1]

    device = u_values.device
    dtype = u_values.dtype

    if u_weights is None:
        u_weights = torch.full(n, 1/n, dtype=dtype, device=device)

    if v_weights is None:
        v_weights = torch.full(m, 1/m, dtype=dtype, device=device)

    if require_sort:
        u_values, u_sorter = torch.sort(u_values, -1)
        v_values, v_sorter = torch.sort(v_values, -1)

        u_weights = u_weights[..., u_sorter]
        v_weights = v_weights[..., v_sorter]

    zero = torch.zeros(1, dtype=dtype, device=device)
    
    u_cdf = torch.cumsum(u_weights, -1)
    v_cdf = torch.cumsum(v_weights, -1)

    cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)
    
    u_index = torch.searchsorted(u_cdf, cdf_axis)
    v_index = torch.searchsorted(v_cdf, cdf_axis)

    u_icdf = torch.gather(u_values, -1, u_index.clip(0, n-1))
    v_icdf = torch.gather(v_values, -1, v_index.clip(0, m-1))

    cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
    delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]

    if p == 1:
        return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
    if p == 2:
        return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1)  # I actually don't think this is useful, a lot of core pytorch code uses .pow(2)
    return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1)

looks like this in the end

@AdrienCorenflos
Copy link
Contributor Author

@rflamary @ncourty should I proceed?

@ncourty
Copy link
Collaborator

ncourty commented Nov 24, 2020

All good on my side ! Minor comment: on Google colab, I had to change

torch.full(n, 1/n, dtype=dtype, device=device)

by

torch.full((n,), 1/n, dtype=dtype, device=device)

I guess it is a torch version thing ?

@rflamary
Copy link
Collaborator

Yes this is great,

you can add this very nice and efficient implementation to the PR. Note that some comments in the function with reference to where the algorithm can be found sould be a plus. If nico had a hard time understanding it it means that it is not easy to understand so please help the next reader ;).

about 3.5 i agree that since it reached end of life i should probably remove its support, you can do it in the github action but the n i want to tested for 3.9 that is the recent release (it should work but that's usually when we get surprises)

@AdrienCorenflos
Copy link
Contributor Author

Yes, I'm actually going to have an auxiliary def _quantile_function for readibility

@AdrienCorenflos
Copy link
Contributor Author

torch for python 3.5 doesn't have searchsorted
Adrien

@rflamary rflamary changed the title small tentative for EMD 1D [WIP] [WIP] small tentative for EMD 1D in torch Nov 25, 2020
@rflamary
Copy link
Collaborator

As i said remove 3.5 from https://github.com/PythonOT/POT/blob/master/.github/workflows/build_tests.yml and add 3.9

@AdrienCorenflos
Copy link
Contributor Author

Ah my bad thought you were doing it in the other PR

Remove py3.5 from supported version and add 3.9 instead
@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Nov 25, 2020

cvxopt doesn't build under py3.9 I'm removing 3.9

CVXOPT doesn't support 3.9
@AdrienCorenflos AdrienCorenflos marked this pull request as ready for review November 25, 2020 09:29
@AdrienCorenflos
Copy link
Contributor Author

I'll do the SWD a bit later

@ncourty
Copy link
Collaborator

ncourty commented Nov 25, 2020

@AdrienCorenflos I can do SWD if you want, I already have it coded

@AdrienCorenflos
Copy link
Contributor Author

Sure

@ncourty
Copy link
Collaborator

ncourty commented Nov 25, 2020

ok ! I'll do it once your PR is merged in the torch branch

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great and i see you guys are eager to code new stuff from this,

here are a few comments before merging

ot/torch/lp.py Outdated Show resolved Hide resolved
ot/torch/lp.py Outdated Show resolved Hide resolved
test/test_torch.py Outdated Show resolved Hide resolved
ot/torch/lp.py Outdated Show resolved Hide resolved
…ytorch doesn't allow for a bunch of operations to be taken along a certain dimension (searchsorted, etc...)
@AdrienCorenflos
Copy link
Contributor Author

@rflamary do you have further comments?

@rflamary
Copy link
Collaborator

rflamary commented Dec 6, 2020

OI will let @ncourty do the last code review, he will be the one using you code for sliced.

Copy link
Collaborator

@ncourty ncourty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work @AdrienCorenflos !

ot/torch/lp.py Outdated Show resolved Hide resolved
ot/torch/lp.py Show resolved Hide resolved
ot/torch/lp.py Show resolved Hide resolved
@AdrienCorenflos
Copy link
Contributor Author

For some reason github doesn't send a notification about PR comments. So I just saw this, sorry about the delay, it would have taken 5mn...

@ncourty ncourty merged commit 0ef7362 into PythonOT:torch Jan 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants