-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched knn_graph #16
Comments
Yes, and the number of points of different examples can vary too! We encode this by stacking points in the node dimension and encode the example idx of a point in a batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])
pos = pos.view(B * N , D) |
Thanks, works great :) Nonetheless, it took me a while to figure out even after seeing this second example. I believe the Documentation could be improved with an end to end example for a batched version! |
For future users, the code below will be available for batched kNN. import torch_cluster
def knn(x, y, k):
"""
x: B x N x C
y: B x M x C
idx: B x N x k
"""
assert x.size(0) == y.size(0)
B, N, _ = x.size()
_, M, _ = y.size()
x_batch = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1).repeat(1, N)
y_batch = torch.arange(B, dtype=torch.long, device=y.device).view(B, 1).repeat(1, M)
idx = torch_cluster.knn(x.reshape(B * N, -1), y.reshape(B * M, -1), k, x_batch.reshape(-1), y_batch.reshape(-1))
idx = (idx[1] % M).view(B, N, k)
return idx |
I don't quite understand the documentation for
knn
orknn_graph
functions.Can these functions be used for batched computation, i.e. finding nearest neighbors between a
BxMxD
andBxNxD
tensor or constructing aknn_graph
in aBxNxD
tensor? If so, how?Thanks in advance!
The text was updated successfully, but these errors were encountered: