Skip to content

Commit

Permalink
migrate cuda implementation of take() from TH to ATen (#45430)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45430

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D24037297

Pulled By: bdhirsh

fbshipit-source-id: 7c5f2c08e895fb0c25eec1d68c7455e4f2b1c64e
  • Loading branch information
bdhirsh authored and facebook-github-bot committed Oct 2, 2020
1 parent a015ba8 commit 1552a92
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 170 deletions.
138 changes: 0 additions & 138 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,144 +435,6 @@ Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const
}
return self;
}
Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Bool: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaBoolTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Byte: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaByteTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Char: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaCharTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Double: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaDoubleTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Float: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Int: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaIntTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Long: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaLongTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Short: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaShortTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Half: {
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long);
THCudaHalfTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
default:
AT_ERROR("_th_take_out not supported on CUDAType for ", dispatch_scalar_type);
}
return result;
}
Tensor _th_take(const Tensor & self, const Tensor & index) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
switch (dispatch_scalar_type) {
case ScalarType::Bool: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaBoolTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Byte: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaByteTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Char: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaCharTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaDoubleTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Int: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaIntTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Long: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaLongTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Short: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaShortTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
case ScalarType::Half: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long);
THCudaHalfTensor_take(globalContext().getTHCState(), result_, self_, index_);
break;
}
default:
AT_ERROR("_th_take not supported on CUDAType for ", dispatch_scalar_type);
}
return result;
}
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
95 changes: 94 additions & 1 deletion aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,67 @@
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/Array.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/ExpandUtils.h>
#include <THC/THCTensorInfo.cuh>

namespace at { namespace native {

static constexpr int launch_bound2 = 4;

static constexpr int launch_size_nd = 128;

template <int Dims, typename T, typename IndexType>
__device__ __forceinline__ IndexType indexToOffset(
const cuda::detail::TensorInfo<T, IndexType>& info,
int64_t index,
IndexType size) {
IndexType linearIndex = static_cast<IndexType>(index);
CUDA_KERNEL_ASSERT(linearIndex < size && linearIndex >= -size);
if (linearIndex < 0) {
linearIndex += size;
}
return cuda::detail::IndexToOffset<T, IndexType, Dims>::get(linearIndex, info);
}

template<typename IndexType, typename T>
void dispatchTakePutImpl(const Tensor& input, Tensor& output, const Tensor& index) {
auto inputInfo = cuda::detail::getTensorInfo<T, IndexType>(input);
inputInfo.collapseDims();
auto numel = input.numel();
if (inputInfo.isContiguous()) {
cuda::CUDA_tensor_apply2<T, int64_t>(
output,
index,
[inputInfo, numel] __device__ (
T & out, const int64_t& idx) {
auto offset = indexToOffset<-2, T, IndexType>(inputInfo, idx, numel);
out = inputInfo.data[offset];
});
} else {
cuda::CUDA_tensor_apply2<T, int64_t>(
output,
index,
[inputInfo, numel] __device__ (
T & out, const int64_t& idx) {
auto offset = indexToOffset<-1, T, IndexType>(inputInfo, idx, numel);
out = inputInfo.data[offset];
});
}
}

template<typename T>
void dispatchTakePut(const Tensor& input, Tensor& output, const Tensor& index) {
if (cuda::detail::canUse32BitIndexMath(input)) {
dispatchTakePutImpl<int32_t, T>(input, output, index);
} else {
dispatchTakePutImpl<int64_t, T>(input, output, index);
}
}

template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
__global__ void index_elementwise_kernel(int N, func_t f) {
Expand Down Expand Up @@ -154,6 +205,48 @@ Tensor & masked_select_out_cuda(Tensor & result, const Tensor & self, const Tens
return masked_select_out_cuda_impl(result, self, mask);
}

void take_out_cuda_template(Tensor& output, const Tensor& input, const Tensor& index) {
TORCH_CHECK(output.device().type() == at::kCUDA, "device type of output (", output.device().type(), ") is not GPU");
TORCH_CHECK(input.device().type() == at::kCUDA, "device type of input (", input.device().type(), ") is not GPU");
TORCH_CHECK(index.device().type() == at::kCUDA, "device type of index (", index.device().type(), ") is not GPU");

TORCH_CHECK(output.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", output.layout(), " on output tensor");
TORCH_CHECK(input.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", input.layout(), " on input tensor");
TORCH_CHECK(index.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", index.layout(), " on index tensor");

TORCH_CHECK(output.scalar_type() == input.scalar_type(),
"output and input scalar type must match. but got different types: ", output.scalar_type(), " and ", input.scalar_type());
TORCH_CHECK(index.scalar_type() == kLong, "index must be an int64 tensor");

TensorArg output_arg{ output, "output", 1 };
TensorArg input_arg{ input, "input", 2 };
TensorArg index_arg{ index, "index", 3 };
checkAllSameGPU("take", {output_arg, input_arg, index_arg});

TORCH_CHECK(input.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING);
TORCH_CHECK(output.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING);
TORCH_CHECK(index.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING);

TORCH_CHECK(!(input.numel() == 0 && index.numel() != 0), "tried to take from an empty tensor");

output.resize_(index.sizes());

AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cuda", [&] {
dispatchTakePut<scalar_t>(input, output, index);
});
}

Tensor take_cuda(const Tensor& self, const Tensor& index) {
auto out = at::empty(index.sizes(), self.options());
take_out_cuda_template(out, self, index);
return out;
}

Tensor& take_out_cuda(Tensor& out, const Tensor& self, const Tensor& index) {
take_out_cuda_template(out, self, index);
return out;
}

REGISTER_DISPATCH(index_stub, &index_kernel);
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5503,14 +5503,14 @@
- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: take_out_cpu
CUDA: legacy::cuda::_th_take_out
CUDA: take_out_cuda

- func: take(Tensor self, Tensor index) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: take_cpu
CUDA: legacy::cuda::_th_take
CUDA: take_cuda

- func: take_backward(Tensor grad, Tensor input, Tensor index) -> Tensor
use_c10_dispatcher: full
Expand Down
14 changes: 0 additions & 14 deletions aten/src/THC/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,6 @@ struct WrapIndexOp {
int64_t size;
};

template <typename T, typename IndexType, int Dims>
struct TensorTakeOp {
TensorTakeOp(TensorInfo<T, IndexType> info, IndexType numel, int64_t*, int64_t*)
: info(info), numel(numel) {}

__device__ __forceinline__ void operator()(T* out, int64_t* index) {
auto offset = indexToOffset<Dims>(info, *index, numel);
*out = info.data[offset];
}

const TensorInfo<T, IndexType> info;
IndexType numel;
};

template <typename T, typename IndexType, int Dims>
struct TensorPutOp {
TensorPutOp(TensorInfo<T, IndexType> info, IndexType numel, int64_t*, int64_t*)
Expand Down
15 changes: 0 additions & 15 deletions aten/src/THC/generic/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,6 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
#undef LARGE_INDEX
}

void THCTensor_(take)(THCState *state, THCTensor *dst, THCTensor *src, THCudaLongTensor *index)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));

THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
THArgCheck(!(THCTensor_(numel)(state, src) == 0 && THCudaLongTensor_numel(state, index) != 0), 2,
"tried to take from an empty tensor");

THCTensor_(resizeNd)(state, dst, index->dim(), THTensor_getSizePtr(index), NULL);
dispatchTakePut<scalar_t, TensorTakeOp>(state, src, dst, index);
}

static void THCTensor_(sort_indices)(THCState *state, THCudaLongTensor *index, THCTensor *src) {
THCThrustAllocator thrustAlloc(state);

Expand Down

0 comments on commit 1552a92

Please sign in to comment.