Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] [feature request] Pairwise distances between all points in a set (a true pdist) #9406

Closed
vadimkantorov opened this Issue Jul 13, 2018 · 25 comments

Comments

Projects
None yet
8 participants
@vadimkantorov
Copy link

vadimkantorov commented Jul 13, 2018

Currently F.pairwise_distance and F.cosine_similarity accept two sets of vectors of the same size and compute similarity between corresponding vectors.

However, it's often useful to compute pairwise similarities or distances between all points of the set (in mini-batch metric learning scenarios), or between all possible pairs of two sets (e.g. to build a bi-partite weighted graph). It's trivial to write ad-hoc functions like this, but I believe it's also a useful and a frequent primitive (for euclidean, square euclidean, etc).

One way to introduce this would be modifying pairwise_distance to allow other shapes and an argument for metric type.

Such a function for example exists in SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html (computes all pairwise distances within one set of vectors)

Also exists in MatLab: https://www.mathworks.com/help/stats/pdist.html

@ezyang

This comment has been minimized.

Copy link
Contributor

ezyang commented Jul 13, 2018

@lena-kashtelyan has graciously agreed to implement this. Notes on how to implement this:

  1. Decide how you want the API to look. You won't be able to add extra arguments to pairwise distance to directly add support for this, since pdist should take a single tensor, not a pair of tensors. SciPy's pdist function may not be a bad idea to emulate, but many of the metrics it supports are not implemented in fused form by PyTorch, so getting support for all of the metric types is probably beyond a bootcamp task. Pairwise only supports p-norms, so it's a decent place to start.
  2. Write an implementation of pdist. You may choose to look at an existing implementation (e.g. in scipy) to get algorithmic ideas; or maybe first do a simple version that just calls the underlying norm n choose 2 times. There is plenty of opportunity for parallelism, which can give you a nontrivial speedup.
  3. Make your implementation available in PyTorch by adding it as a native function. Grep for pairwise_distance to see how it's done; also see the documentation at https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native
  4. Now do it again for CUDA!
  5. If time allows, consider generalizing your code to support other metrics.
@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Jul 13, 2018

@ezyang pdist for two tensors is also often also useful, e.g. in case of matching sets of local descriptors in two images.

If pairwise_distance is extended, can it be made sth like following?

def pairwise_distance(A, B = None):
   B = B if B is not None else A # or better treat the cases separately to not duplicate computation if only one tensor is passed
   ...
@hglaude

This comment has been minimized.

Copy link

hglaude commented Jul 30, 2018

Hi,

It would be great if this function "new_distance_func" could take a list of pairs of indices "pairs" such that

pairs = stack((indices1, indices2), dim=1)
distances = new_distance_func(matrix1, matrix2, pairs)

is equivalent to

expanded_matrix1 = matrix1.index_select(0, indices1)
expanded_matrix2 = matrix1.index_select(0, indices2)
distances = old_pairwise_distance_func(expanded_matrix1, expanded_matrix2)

But much more memory efficient. That would also suit the use case above.

Thank you

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Jul 31, 2018

One thing to keep in mind is that if pdist = lambda A, B: torch.sqrt(A.pow(2).sum(1, keepdim = True) - 2 * torch.mm(A, B.t()) + B.pow(2).sum(1, keepdim = True).t()) is used for Eucldiean distance, the values can get a tad negative, so clamping may be required (which is not free unfortunately if done from python).

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Aug 30, 2018

Worked on in #10782 and #11102

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Sep 10, 2018

I'm not sure I understand all of the discussion about what an appropriate function should look like, but I just submitted a bunch of diffs that essentially add a native pdist implementation for arbitrary p norm distances. It shouldn't be too difficult to add any other set of arbitrary distances, but we opted out of it because it was unclear how best to have a native function handle the switch. As of writing, native functions don't support enums, and the pythonic way to do this would be strings which is supported but may not be the preferred method here.

For the case of all possible combinations, between two sets of vectors, that can be achieved with pairwise_distance:

x, y = torch.randn(...), torch.randn(...)
torch.pairwise_distance(x[:, None], y)

Assuming pairwise handles broadcasting appropriately, then this should be about as efficient as one could hope.

@fmassa

This comment has been minimized.

Copy link
Member

fmassa commented Sep 10, 2018

@erikbrinkman about enums in C++ and strings in Python, have a look how this is handled in the losses

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Sep 10, 2018

@fmassa Yeah, that's possible. I personally really dislike using ints as enums in c++. It seems really hacky / illegible to have at::norm(Tensor& self, int64_t type), but I acknowledge that is an option.

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Sep 27, 2018

Even docs are in :) #12126

Closing this

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Sep 27, 2018

Some questions on torch.pdist:

  1. Does it return the results in a "condensed distance matrix" format? Maybe docs can be a little more clear about it, e.g. a "Returns: " section in the doc
  2. Should there be an argument to return a full format directly? In SciPy there is a function squareform for this
  3. Does it support batch mode? i.e. work for setting B x N x D -> B x N x N (or B x (N*(N -1)/2) for condensed)
  4. What about existing pairwise_distances and cosine_similarity. Should relation between pairwise_distances and pdist be clearer explained in the docs?

@erikbrinkman

@vadimkantorov vadimkantorov reopened this Sep 27, 2018

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Sep 28, 2018

@vadimkantorov

  1. It returns it in a condensed format. The python documentation, i.e. help(torch.pdist) states:
Computes the p-norm distance between every pair of row vectors in the input.
This is identical to the upper triangular portion, excluding the diagonal, of
`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
if the rows are contiguous.
  1. This is certainly up for discussion, but to me this isn't the most helpful api. As noted in the help there are pretty easy ways to calculate this with the built in operators if you want the square form, e.g. torch.pairwise_distances(X[..., None, :, :], X[..., None, :]).

  2. I have been asked about a batch mode, and the implementation should probably be modified to account for it. I should get around to it soonish. However, it should be noted that if you want B x N x D -> B x N x N, the line above will work.

  3. It should probably be explained better. At least, I should probably use pairwise_distances instead of norm in the example. However, part of these pains rest in a murky api. There hasn't really been a decision about how vector distances should be handled more generally. My guess is that the other distances will be added after ATen native functions get native support, and then there will be more parity in these functions.

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Sep 28, 2018

Thanks for clarifications :)

  1. For me it was not clear what exactly is the format of "identical to the upper triangular portion of ...". It is a row-order, flattened vecto representation I assume, but an explanation of what exactly is in [i]th element would be more clear. SciPy tries to clarify it in their docs, but also not very well IMO:
Returns a condensed distance matrix Y. For each  and  (where ),where m is the number of original observations. The metric dist(u=X[i], v=X[j]) is computed and stored in entry ij (what is entry ij? for condennsed format, is it true?). 
  1. This way wouldn't work for two matrices with different number of rows, would it?

  2. Definitely, for future, more cohesion for pairwise_distances vs pdist vs cosine_similarity would be nice

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Sep 28, 2018

  1. Fair point. What about for the batched version:

    For an input tensor X with shape B x N x K returns a condensed distance matrix Y with shape B x D where D = N * (N - 1) / 2. Each element Y[b, k] of the output is equal to paiwise_distance(X[b, i], X[b, j]) where k = n * i - i * (i + 3) / 2 + j - 1. This is equivalent to the row order of the upper triangular portion of the full B x N x N distance matrix.

  2. It will. It's using broadcasting. If the first X is B x N x K and the second X is B x M x K the resulting tensor will be B x M x N.

erikbrinkman added a commit to erikbrinkman/pytorch that referenced this issue Oct 3, 2018

Add support for batched pdist
Summary: This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406

Test Plan: Batched versions were added to old tests, and the old scipy
test switched to using the internal norm calculation, which also
illustrates alternate ways to compute pdist with norm.

Reviewers:

Subscribers:

Tasks:

Tags:

erikbrinkman added a commit to erikbrinkman/pytorch that referenced this issue Oct 3, 2018

Add support for batched pdist
Summary: This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406

Test Plan: Batched versions were added to old tests, and the old scipy
test switched to using the internal norm calculation, which also
illustrates alternate ways to compute pdist with norm.

Reviewers:

Subscribers:

Tasks:

Tags:

erikbrinkman added a commit to erikbrinkman/pytorch that referenced this issue Oct 4, 2018

Add support for batched pdist
Summary: This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406

Test Plan: Batched versions were added to old tests, and the old scipy
test switched to using the internal norm calculation, which also
illustrates alternate ways to compute pdist with norm.

Reviewers:

Subscribers:

Tasks:

Tags:

erikbrinkman added a commit to erikbrinkman/pytorch that referenced this issue Oct 4, 2018

Add support for batched pdist
Summary: This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406

Test Plan: Batched versions were added to old tests, and the old scipy
test switched to using the internal norm calculation, which also
illustrates alternate ways to compute pdist with norm.

Reviewers:

Subscribers:

Tasks:

Tags:
@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Oct 4, 2018

@erikbrinkman Full matrix format is useful if we want to apply some custom mask or weights for pairs of points (in deep metric learning scenarios). And in every case, utilities for conversion between full and condensed formats would be very useful.

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Oct 4, 2018

@vadimkantorov I'm not sure I actually understand. If you want to weight specific pairs, that should be possible in a condensed format. Either way, if you want the full format then you can use norm as specified in the help text for the batch version. I'm not clear on when you would need to convert to and from the condensed version, but I think the better way to do that is with something like triu / tril that numpy has.

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Oct 4, 2018

@erikbrinkman It is, of course, possible to make masks or weights in condensed format, but debugging in full format is easier IMO. If I understand well, both numpy and torch triu/tril still work in full format.

The scipy utility to convert to and from condensed format is scipy.spatial.distance.squareform (original matlab one)

E.g. one can have a class label vector, and weight mask in full format are computed as:

Y = torch.bernoulli(torch.ones(10) * 0.5)
W_ = torch.ger(Y, Y)
W = W_ * weight_pos_pair + (1 - W_) * weight_neg_pair

Or if one for every example has many class labels, then positive and negative pairs can be computed as torch.mm(Y, Y.t()) > 0

I mean, when working with small data or small batch sizes, it can be quite convenient to work with full format regardless of inefficiency and benefit from other linear algebra ops already implemented (such as torch.ger and torch.mm).

@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Oct 4, 2018

@vadimkantorov I'm more suggesting that if you want things in the full format, use norm and broadcasting. If you want the condensed version, use pdist. I don't see a reason to specifically offer a conversion.

More generally all of the operations we've been talking about can be easily accomplished with the appropriate index functions that don't exist but soon might / can be generated somewhat easily.

In your examples:

a, b = np.triu_indices(10, 1)
W_c = Y[a] * Y[b]
assert torch.allclose(W_c, W_[a, b]) # condensed format
W__ = torch.zeros((10, 10))
W__[a, b] = W_c
W__[b, a] = W_c
assert torch.allclose(W__, W_) # square format
@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Oct 4, 2018

@erikbrinkman Well, that's what I did. At some point got tired of carefully typing out None in the right places and filed this issue :) My issue was just about convenience, but at the end of the day it's PyTorch designers' call.

erikbrinkman added a commit to erikbrinkman/pytorch that referenced this issue Oct 8, 2018

Add support for batched pdist
Summary: This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406

Test Plan: Batched versions were added to old tests, and the old scipy
test switched to using the internal norm calculation, which also
illustrates alternate ways to compute pdist with norm.

Reviewers:

Subscribers:

Tasks:

Tags:
@jacobrgardner

This comment has been minimized.

Copy link

jacobrgardner commented Dec 19, 2018

Just to add on to the discussion here, we use pairwise distances pretty heavily for kernel computations in GPyTorch, and one of the issues with solutions like:

torch.norm(input[:, None] - input, dim=2, p=p)

or

x, y = torch.randn(...), torch.randn(...)
torch.pairwise_distance(x[:, None], y)

compared to something like pdist is that input[:, None] - input creates an n x n x d intermediary matrix, thus using O(dn^2) memory to compute an n x n matrix. Thus, if my data is 20 dimensional, I effectively store 20 copies of the distance matrix on the way to computing it, which is infeasible with limited GPU memory.

It would be great to have an upstream stable distance computation option in PyTorch that didn't use a full factor d more memory than necessary. In my testing, the built-in pdist is up to 4000x faster than a python PyTorch implementation of the (squared) distance matrix using the expanded quadratic form.

cc/ @gpleiss @Balandat

@soumith

This comment has been minimized.

Copy link
Member

soumith commented Dec 30, 2018

@jacobrgardner batched pdist is now in master. It has a known bug tracked in #15511 but is the API design sufficient?

@jacobrgardner

This comment has been minimized.

Copy link

jacobrgardner commented Dec 30, 2018

@soumith Yup! We use torch.pdist wherever possible on gpytorch master now as of
cornellius-gp/gpytorch#440. It’s much faster, more numerically stable than the quadratic expansion, and seems to be as memory efficient!

At the moment, we still fall back to using a manual quadratic expansion of the squared distance in the case where we need distances between pairs of points in two sets x1 and x2 due to the O(dn^2) memory usage of the torch.norm / torch.pairwise_distance solutions mentioned above, but at least for the large n x n training kernel matrices, this works well.

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Dec 30, 2018

@soumith There is also a related issue for cdist: #15253 and #11202 (asking for cosine similarity version of pdist/cdist).

I think it'd be nice to have a single interface for various pairwise distance/similarity computations (for instance, squared or shifted L2-distance can be useful as well). I also second @jacobrgardner on painful materialization of gigantic tensors of broacasting+norm solution. One design could be a generic map-reduce kind of function, where the user can provide two reduce functions: one elementwise and one across elements (for dot product this would be multiplication-and-addition; but max-and-addition or multiplication-and-max can be useful sometimes as well). Too bad TensorComprehensions is not in active development anymore - would be nice to have even a redued version in core.

Another point is absence of squareform. Maybe with triu_indices in master, this could be easily added as well.

If the generic map-reduce way is adopted (or more distance methods are supported), then some useful distances may be asymmetric and squareform output should be supported. A way could be adding a squareform argument to pdist / cdist for brevity.

@vadimkantorov

This comment has been minimized.

Copy link
Author

vadimkantorov commented Jan 2, 2019

squareform is being worked on in #15679

gchanan added a commit to gchanan/pytorch that referenced this issue Jan 17, 2019

Add support for batched pdist (pytorch#12302)
Summary:
This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406
Pull Request resolved: pytorch#12302

Reviewed By: ezyang

Differential Revision: D13528485

Pulled By: erikbrinkman

fbshipit-source-id: 63d93a6e1cc95b483fb58e9ff021758b341cd4de

gchanan added a commit to gchanan/pytorch that referenced this issue Jan 18, 2019

Add support for batched pdist (pytorch#12302)
Summary:
This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes pytorch#9406
Pull Request resolved: pytorch#12302

Reviewed By: ezyang

Differential Revision: D13528485

Pulled By: erikbrinkman

fbshipit-source-id: 63d93a6e1cc95b483fb58e9ff021758b341cd4de
@josauder

This comment has been minimized.

Copy link

josauder commented Mar 12, 2019

Reading the above thread has not made it clear to me what the currently best feasible solution is for batched pairwise distance. Can maybe someone who understands the details discussed summarize the thread and provide the current best way to code the function parwise_dist as used below?

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))
parwise_dist(X, Y) # Should be B x N x M
@erikbrinkman

This comment has been minimized.

Copy link
Contributor

erikbrinkman commented Mar 13, 2019

Generally what you'd do ispairwise_dist(X[:, :, None], Y[:, None], axis=3). The documentation doesn't list the ability to add a dimension to pairwise dist. If that's actually up to date then you should be able to do

pairwise_dist(
    X[:, :, None].expand(B, N, M, D).reshape((-1, D)),
    Y[:, None].expand(B, N, M, D).reshape((-1, D))
).reshape((B, N, M))

This will do some "unnecessary" memory copies, but should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.