Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is the implementation wrong? #5

Closed
ranery opened this issue Jun 6, 2024 · 2 comments
Closed

Is the implementation wrong? #5

ranery opened this issue Jun 6, 2024 · 2 comments

Comments

@ranery
Copy link

ranery commented Jun 6, 2024

sorted_indices = torch.argsort(topk_indices)

should be

sorted_indices = torch.sort(topk_indices)

Right?

sramshetty added a commit that referenced this issue Jun 7, 2024
@sramshetty
Copy link
Owner

Nice catch, sorry about that!

I do in fact need the argsort() output for the gather here, but like you mentioned the gather for x and y should be the sorted indices themselves.

Let me know if this makes sense to you, and thanks for pointing it out.

@ranery
Copy link
Author

ranery commented Jun 7, 2024

That makes sense now!

I also used:

sorted_indices, index = torch.sort(topk_indices, dim=1)

so that the sorted_indices could be used for x gather and the index could be used for output gather.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants