Skip to content

Commit

Permalink
Added add/mul for nested dense [B, *, D], [B, 1, D] case (CUDA-only)
Browse files Browse the repository at this point in the history
ghstack-source-id: e40878da8d04c71b787a5e19b83e3311a681dd9c
Pull Request resolved: #88289
  • Loading branch information
mikaylagawarecki committed Nov 2, 2022
1 parent 99c0773 commit 5aa8e9e
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 0 deletions.
36 changes: 36 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/native/nested/NestedTensorMath.h>
#include <ATen/native/nested/NestedTensorBinaryOps.h>

#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
Expand All @@ -18,6 +19,9 @@
namespace at {
namespace native {

DEFINE_DISPATCH(nested_dense_elementwise_stub);
REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub);

std::pair<NestedTensorImpl*, NestedTensorImpl*>
get_elementwise_nested_tensor_impl(
const Tensor& self,
Expand Down Expand Up @@ -95,6 +99,38 @@ Tensor NestedTensor_elementwise_Tensor(
self_impl->get_storage_offsets()
);
}
// special case when other is dense
if (self.is_nested() && !other.is_nested()) {
// check for the [B, *, D], [B, 1, D] esuhm case
// TODO: this if statement is ugly and hopefully we will remove this in the near future
auto self_ptr = get_nested_tensor_impl(self);
if (self_ptr->dim() == 3 &&
other.dim() == 3 &&
self_ptr->size(0) == other.size(0) &&
other.size(1) == 1 &&
self_ptr->opt_size(2).has_value() &&
self_ptr->opt_size(2).value() == other.size(2) &&
self.is_cuda() &&
other.is_cuda()) {
if (!nested_tensor_impl_is_contiguous(self_ptr)) {
self_ptr = get_nested_tensor_impl(self.contiguous());
}
const auto self_buffer = self_ptr->get_buffer();
const auto self_sizes = self_ptr->get_nested_size_tensor();
auto result_buffer = at::empty_like(self_buffer);
auto result = wrap_buffer(result_buffer, self_sizes);
if (op_name == "add") {
nested_dense_elementwise_stub(self.device().type(), result, self, other, NESTED_DENSE_OP::ADD);
} else if (op_name == "mul") {
nested_dense_elementwise_stub(self.device().type(), result, self, other, NESTED_DENSE_OP::MUL);
} else {
TORCH_CHECK(false, "Unsupported nested dense elementwise op");
}
return result;
}
TORCH_CHECK(false, "Expected both self and other to be nested, but got a nested self and non-nested other.");
}

NestedTensorImpl* self_impl = nullptr;
NestedTensorImpl* other_impl = nullptr;
std::tie(self_impl, other_impl) =
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorBinaryOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include <ATen/core/ATen_fwd.h>
#include <ATen/native/DispatchStub.h>

namespace at {
namespace native {

enum class NESTED_DENSE_OP: uint8_t {ADD, MUL};

using nested_dense_elementwise_fn = void (*)(Tensor& result, const Tensor & self, const Tensor & other, const NESTED_DENSE_OP& op);

DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub);

} // namespace native
} // namespace at
120 changes: 120 additions & 0 deletions aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <ATen/native/nested/NestedTensorBinaryOps.h>

#include <type_traits>

#include <ATen/ATen.h>
#include <ATen/Dispatch.h>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>

#include <c10/cuda/CUDAMathCompat.h>
#include <c10/cuda/CUDAStream.h>


#include <ATen/native/nested/NestedTensorUtils.h>

#define BLOCK_DIM 256

namespace at {
namespace native {


// only for nested [B, *, D], dense [B, 1, D]
template <typename T, typename func_t>
__global__ void op_dense_esuhm(
const T* input,
const T* dense,
T* output,
int64_t embedding_dim,
const int64_t* offsets,
const func_t& f)
{
// each batch is handled by a block
const int64_t batch_idx = blockIdx.x;
const int64_t grain_size = blockDim.x;
const int64_t tid = threadIdx.x;
const int64_t range = offsets[batch_idx + 1] - offsets[batch_idx];
// each thread handles (embedding_dim // grain_size + (embedding_dim % grain_size <= tid)) elems
// of the dense embedding
for (int64_t idx = tid; idx < embedding_dim; idx += grain_size) {
const T dense_elem = dense[batch_idx * embedding_dim + idx];
for (int64_t nested_idx = idx; nested_idx < range; nested_idx += embedding_dim) {
output[offsets[batch_idx] + nested_idx] = f(input[offsets[batch_idx] + nested_idx], dense_elem);
}
}
}

template <typename T, typename func_t>
void nested_op_dense_kernelLauncher(
const T* input, // [sum(*) x embedding_dim]
const T* dense, // [batch_size x embedding_dim]
T* output, // [sum(*) x embedding_dim]
int64_t batch_size,
int64_t embedding_dim,
const int64_t* input_offsets, // [batch_size]
func_t f)
{
dim3 grid;
grid.x = batch_size;
const auto stream = at::cuda::getDefaultCUDAStream();

op_dense_esuhm<<<grid, BLOCK_DIM, 0, stream>>>(
input,
dense,
output,
embedding_dim,
input_offsets,
f);
}

template <typename scalar_t, typename func_t>
void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Tensor& other, func_t f) {
auto self_ptr = get_nested_tensor_impl(self);
auto result_ptr = get_nested_tensor_impl(result);

const auto self_buffer = self_ptr->get_buffer();
const auto offsets = self_ptr->get_storage_offsets();
const auto batch_size = other.size(0);
const auto embedding_size = other.size(2);

auto result_buffer = result_ptr->get_buffer();
auto result_offsets = at::cat({at::tensor(offsets), at::tensor(self_ptr->numel())});
result_offsets = result_offsets.to(kCUDA);

const scalar_t* self_data_ptr = self_buffer.data_ptr<scalar_t>();
const scalar_t* other_data_ptr = other.data_ptr<scalar_t>();
scalar_t* result_data_ptr = result_buffer.data_ptr<scalar_t>();
int64_t* result_offsets_ptr = result_offsets.data_ptr<int64_t>();

nested_op_dense_kernelLauncher(
self_data_ptr,
other_data_ptr,
result_data_ptr,
batch_size,
embedding_size,
result_offsets_ptr,
f);
}

void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tensor& other, const NESTED_DENSE_OP& op) {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "_nested_op_dense_esuhm", [&]() {
switch (op) {
case NESTED_DENSE_OP::ADD :
_nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a + b; });
break;
case NESTED_DENSE_OP::MUL :
_nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a * b; });
break;
}
});
}

REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda);

} // namespace native
} // namespace at
15 changes: 15 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,21 @@ def test_nested_tensor_add(self, device, dtype):
out = nt1 + nt2
self.assertEqual(ref, out)

@onlyCUDA
@dtypes(torch.float, torch.float16)
@torch.inference_mode()
@parametrize("embedding_dim", [8, 128, 256, 384])
def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
batch_size = 32
seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
ref_add = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())])
ref_mul = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())])
self.assertEqual(nt.add(t), ref_add)
self.assertEqual(nt.mul(t), ref_mul)

@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
Expand Down

0 comments on commit 5aa8e9e

Please sign in to comment.