-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added add/mul for nested dense [B, *, D], [B, 1, D] case (CUDA-only)
ghstack-source-id: e40878da8d04c71b787a5e19b83e3311a681dd9c Pull Request resolved: #88289
- Loading branch information
1 parent
99c0773
commit 5aa8e9e
Showing
4 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#pragma once | ||
|
||
#include <ATen/core/ATen_fwd.h> | ||
#include <ATen/native/DispatchStub.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
enum class NESTED_DENSE_OP: uint8_t {ADD, MUL}; | ||
|
||
using nested_dense_elementwise_fn = void (*)(Tensor& result, const Tensor & self, const Tensor & other, const NESTED_DENSE_OP& op); | ||
|
||
DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub); | ||
|
||
} // namespace native | ||
} // namespace at |
120 changes: 120 additions & 0 deletions
120
aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#include <ATen/native/nested/NestedTensorBinaryOps.h> | ||
|
||
#include <type_traits> | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/detail/KernelUtils.h> | ||
#include <ATen/cuda/detail/IndexUtils.cuh> | ||
#include <ATen/native/cuda/Loops.cuh> | ||
#include <ATen/native/cuda/MemoryAccess.cuh> | ||
|
||
#include <c10/cuda/CUDAMathCompat.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
|
||
|
||
#include <ATen/native/nested/NestedTensorUtils.h> | ||
|
||
#define BLOCK_DIM 256 | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
|
||
// only for nested [B, *, D], dense [B, 1, D] | ||
template <typename T, typename func_t> | ||
__global__ void op_dense_esuhm( | ||
const T* input, | ||
const T* dense, | ||
T* output, | ||
int64_t embedding_dim, | ||
const int64_t* offsets, | ||
const func_t& f) | ||
{ | ||
// each batch is handled by a block | ||
const int64_t batch_idx = blockIdx.x; | ||
const int64_t grain_size = blockDim.x; | ||
const int64_t tid = threadIdx.x; | ||
const int64_t range = offsets[batch_idx + 1] - offsets[batch_idx]; | ||
// each thread handles (embedding_dim // grain_size + (embedding_dim % grain_size <= tid)) elems | ||
// of the dense embedding | ||
for (int64_t idx = tid; idx < embedding_dim; idx += grain_size) { | ||
const T dense_elem = dense[batch_idx * embedding_dim + idx]; | ||
for (int64_t nested_idx = idx; nested_idx < range; nested_idx += embedding_dim) { | ||
output[offsets[batch_idx] + nested_idx] = f(input[offsets[batch_idx] + nested_idx], dense_elem); | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename func_t> | ||
void nested_op_dense_kernelLauncher( | ||
const T* input, // [sum(*) x embedding_dim] | ||
const T* dense, // [batch_size x embedding_dim] | ||
T* output, // [sum(*) x embedding_dim] | ||
int64_t batch_size, | ||
int64_t embedding_dim, | ||
const int64_t* input_offsets, // [batch_size] | ||
func_t f) | ||
{ | ||
dim3 grid; | ||
grid.x = batch_size; | ||
const auto stream = at::cuda::getDefaultCUDAStream(); | ||
|
||
op_dense_esuhm<<<grid, BLOCK_DIM, 0, stream>>>( | ||
input, | ||
dense, | ||
output, | ||
embedding_dim, | ||
input_offsets, | ||
f); | ||
} | ||
|
||
template <typename scalar_t, typename func_t> | ||
void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Tensor& other, func_t f) { | ||
auto self_ptr = get_nested_tensor_impl(self); | ||
auto result_ptr = get_nested_tensor_impl(result); | ||
|
||
const auto self_buffer = self_ptr->get_buffer(); | ||
const auto offsets = self_ptr->get_storage_offsets(); | ||
const auto batch_size = other.size(0); | ||
const auto embedding_size = other.size(2); | ||
|
||
auto result_buffer = result_ptr->get_buffer(); | ||
auto result_offsets = at::cat({at::tensor(offsets), at::tensor(self_ptr->numel())}); | ||
result_offsets = result_offsets.to(kCUDA); | ||
|
||
const scalar_t* self_data_ptr = self_buffer.data_ptr<scalar_t>(); | ||
const scalar_t* other_data_ptr = other.data_ptr<scalar_t>(); | ||
scalar_t* result_data_ptr = result_buffer.data_ptr<scalar_t>(); | ||
int64_t* result_offsets_ptr = result_offsets.data_ptr<int64_t>(); | ||
|
||
nested_op_dense_kernelLauncher( | ||
self_data_ptr, | ||
other_data_ptr, | ||
result_data_ptr, | ||
batch_size, | ||
embedding_size, | ||
result_offsets_ptr, | ||
f); | ||
} | ||
|
||
void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tensor& other, const NESTED_DENSE_OP& op) { | ||
AT_DISPATCH_ALL_TYPES_AND2( | ||
ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "_nested_op_dense_esuhm", [&]() { | ||
switch (op) { | ||
case NESTED_DENSE_OP::ADD : | ||
_nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a + b; }); | ||
break; | ||
case NESTED_DENSE_OP::MUL : | ||
_nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a * b; }); | ||
break; | ||
} | ||
}); | ||
} | ||
|
||
REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda); | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters