Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

top-k implementation + sort works for all cases now #296

Merged
merged 1 commit into from Jan 19, 2016

Conversation

wickedfoo
Copy link
Contributor

This pull request contains:
-an implementation of values, indices = tensor:topk(k, dim, dir, sorted) for CudaTensor via efficient radix selection.
-fixing sort for all possible inputs, regardless of size, stride, transpositions, holes, whatever.
-better testing for sort, and more re-factoring so that we can eventually use test tensors with holes/transpositions/multi-dimensions in all tests.

The top-k elements along slices defined by dimension dim are sorted are found and returned using radix selection (radix size of 4 is used here, it seems faster than radix-2 or radix-16 in my tests; the radix 2^k is possible where k is an even divisor of the total number of bits in the value being selected, so k=1, 2, 4, 8, 16, .., but there are 2^k buckets defined, so radix sizes of 2^1, 2^2, 2^4 are the only real feasible ones).

Radix selection is possible for floating-point numbers by bit manipulation, so that if float32 f1 < f2, then convert(f1) < convert(f2) where convert(f) produces an unsigned int and unsigned int comparison is used. Radix selection for top-k requires multiple passes over each input slice, but no scratch space (beyond registers or shared memory) are used in the process.

By default, the elements selected are returned unsorted. If sorting is desired, then a separate sort pass that performs the sorting of key/value pairs inplace is performed.

Because topk() uses the sorting mechanism, and I wanted a completely in-place implementation, I refactored the sort() code to sort on a key/value tensor pair. Even though there is an additional read and slightly more shared memory used for this sorting implementation, it is slightly faster/about the same/slightly slower than the old sort implementation for many sizes; not enough either way (+/- 3% or so) to be of real concern. The old values, indices = t:sort(...) API now uses this one (key, value) sort implementation so I don't need two implementations.

The old sort code was broken in many ways: failed on large numbers of slices of size > 2048, failed on non-contiguous index/sorted value result writes (since the result tensors can be non-contiguous as passed by torch.sort(resval, resind, ...) and other failure modes (tensor collapsing was broken, etc.). I now have a backup implementation that uses Thrust to perform a segmented vectorized sort that is the fallback case, and handles all possible sizes/dimensions/strides. In cases where the sorting dimension is not innermost contiguous or where any of the tensors (input, resval or resind) are not contiguous, memory allocations and copies are performed. What is important is that sort() now works for all possible inputs and no longer has any problems. The sort is very fast for cases where the slice to sort is <= 2048 elements, since all work is done in shared memory, and is reasonably fast for the Thrust backup when no memory is allocated.

Added new sort and top-k tests that test all manners of sizes and non-contiguous cases. The sort and top-k tests are slow, but there are lots of corner cases in the implementation (slices with size <= 2048 and > 2048 etc.) The cutorch test as a whole needs to get a lot more bulletproof and test many dimension sizes and strides, since with all of the template specialization in the cutorch code, there are many corner cases. Unfortunately to test many of them in a randomized fashion will require test.lua to take longer to run.

The sort and top-k code will work for compute capability 2.0+ (the only interesting intrinsic that I use is __ballot() for prefix sums of binary flags). I think CC 2.0+ is our minimum for cutorch anyways (1.x is ancient).

I did add a ldg feature that for 3.5+ code will use the __ldg() intrinsic, but will just use a normal deference for compute capability <3.5.

@wickedfoo
Copy link
Contributor Author

this fixes

#268
#291

default: \
assert(false); \
} \
#define HANDLE_SORT_CASE(TYPE, A) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using boo... err never mind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i know what you were going to say, but nope :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it were C++11, constexpr and some template magic can do this macro expansion crap, but as a series of successive if statements rather than a switch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but that code would be more difficult to write and less readable than the macro crap i think

@soumith
Copy link
Member

soumith commented Jan 11, 2016

@nicolasvasilache does this look good to you to merge?

@nicolasvasilache
Copy link
Contributor

@soumith yeah looks good and I like the random unit tests for corner cases.
Good to go from my POV.

@wickedfoo
Copy link
Contributor Author

ping @soumith

soumith added a commit that referenced this pull request Jan 19, 2016
top-k implementation + sort works for all cases now
@soumith soumith merged commit ba67a5b into torch:master Jan 19, 2016
@soumith
Copy link
Member

soumith commented Jan 19, 2016

Sorry for the delay.

@wickedfoo wickedfoo deleted the topk-sort branch March 14, 2016 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants