Skip to content

Commit

Permalink
Generic TopK implementation (#744)
Browse files Browse the repository at this point in the history
* move TopK to generic

* partial genericization of kernel code

* introduce TopKTypeConfig, specialize radix type and conversion for floats

* implement topk for byte tensor

* implement for char tensor

* implement for int tensor, extend test to check indices as well

* works for longs too

* make bitfield set/get a struct, add support for 64-bit types

* extend to double tensor

* implement for half tensor

* asserts; test fix
  • Loading branch information
killeent authored and soumith committed Apr 25, 2017
1 parent 181a869 commit 93a6864
Show file tree
Hide file tree
Showing 11 changed files with 756 additions and 577 deletions.
10 changes: 10 additions & 0 deletions TensorMath.lua
Expand Up @@ -885,6 +885,16 @@ for k, Tensor_ in pairs(handledTypenames) do
{name="boolean", default=0}}
)

wrap("topk",
cname("topk"),
{{name=Tensor, default=true, returned=true},
{name="CudaLongTensor", default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="long", default=1},
{name="index", default=lastdim(3)},
{name="boolean", default=0},
{name="boolean", default=0}})

wrap("mode",
cname("mode"),
{{name=Tensor, default=true, returned=true, noreadadd=true},
Expand Down
4 changes: 3 additions & 1 deletion lib/THC/CMakeLists.txt
Expand Up @@ -258,7 +258,6 @@ INSTALL(FILES
THCTensorRandom.h
THCTensorMath.h
THCTensorConv.h
THCTensorTopK.h
THCApply.cuh
THCReduce.cuh
THCReduceAll.cuh
Expand Down Expand Up @@ -295,6 +294,7 @@ INSTALL(FILES
THCTensorMathMagma.cuh
THCThrustAllocator.cuh
THCTensorMode.cuh
THCTensorTopK.cuh
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC")

INSTALL(FILES
Expand Down Expand Up @@ -341,4 +341,6 @@ INSTALL(FILES
generic/THCTensorRandom.cu
generic/THCTensorMode.h
generic/THCTensorMode.cu
generic/THCTensorTopK.h
generic/THCTensorTopK.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
1 change: 0 additions & 1 deletion lib/THC/THC.h
Expand Up @@ -15,6 +15,5 @@
#include "THCTensorRandom.h"
#include "THCTensorMath.h"
#include "THCTensorConv.h"
#include "THCTensorTopK.h"

#endif
50 changes: 37 additions & 13 deletions lib/THC/THCAsmUtils.cuh
Expand Up @@ -3,20 +3,44 @@

// Collection of direct PTX functions

__device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
}
template <typename T>
struct Bitfield {};

__device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
}
template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
}

static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
}
};

template <>
struct Bitfield<unsigned long long int> {
static __device__ __forceinline__
unsigned long long int getBitfield(unsigned long long int val, int pos, int len) {
unsigned long long int ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
}

static __device__ __forceinline__
unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
unsigned long long int ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
}
};

__device__ __forceinline__ int getLaneId() {
int laneId;
Expand Down
3 changes: 3 additions & 0 deletions lib/THC/THCTensorMath.h
Expand Up @@ -46,6 +46,9 @@
#include "generic/THCTensorMode.h"
#include "THCGenerateAllTypes.h"

#include "generic/THCTensorTopK.h"
#include "THCGenerateAllTypes.h"

THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);

Expand Down

0 comments on commit 93a6864

Please sign in to comment.