From 59a89dd3823bf892e768e326e499c9785dc8d676 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 3 Jun 2024 15:03:30 -0400 Subject: [PATCH] Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu ghstack-source-id: c9fca7bec054de558d1184560099441b7c177dc4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125946 --- aten/src/ATen/native/native_functions.yaml | 10 + .../cuda/NestedTensorTransformerFunctions.cu | 1081 +++++++++++++++++ ...asDecompTest.test_has_decomposition.expect | 2 + test/test_nestedtensor.py | 25 + 4 files changed, 1118 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a051f43e87eba..54b12a9a0b0c2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14644,6 +14644,16 @@ NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda autogen: to_padded_tensor.out +- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_jagged_to_padded_dense_forward + +- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_dense_to_jagged_forward_symint + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 56cac2a898034..c425cf504dc9e 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -11,6 +12,7 @@ #include #include +#include #include #include @@ -462,5 +464,1084 @@ template void add_padding_kernelLauncher( const int batch_size, const int output_batch_size); +// NB: The following code covers jagged <-> padded dense conversions and was lifted +// from fbgemm_gpu. For more details, see +// https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/src/jagged_tensor_ops + +// Passing lambda exp argument by value instead of by reference to avoid +// "internal compiler error: in maybe_undo_parenthesized_ref" error for specific +// compiler version. +#define JAGGED_TENSOR_DISPATCH_DIMS() \ + AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \ + switch (num_jagged_dim) { \ + case 1: \ + INVOKE_KERNEL_WITH_DIM(1); \ + break; \ + case 2: \ + INVOKE_KERNEL_WITH_DIM(2); \ + break; \ + case 3: \ + INVOKE_KERNEL_WITH_DIM(3); \ + break; \ + case 4: \ + INVOKE_KERNEL_WITH_DIM(4); \ + break; \ + case 5: \ + INVOKE_KERNEL_WITH_DIM(5); \ + break; \ + default: \ + TORCH_CHECK( \ + false, "unsupported number of jagged dim ", num_jagged_dim); \ + } \ + }); + +inline std::string torch_tensor_device_name(const at::Tensor& ten) { + return c10::DeviceTypeName(ten.device().type()); +} + +inline std::string torch_tensor_device_name( + const c10::optional& ten) { + if (ten.has_value()) { + return torch_tensor_device_name(ten.value()); + } else { + return "N/A"; + } +} + +inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) { + return ten.is_cuda(); +} + +inline bool torch_tensor_on_cuda_gpu_check( + const c10::optional& ten) { + return !ten.has_value() || torch_tensor_on_cuda_gpu_check(ten.value()); +} + +#define TENSOR_ON_CUDA_GPU(x) \ + TORCH_CHECK( \ + torch_tensor_on_cuda_gpu_check(x), \ + #x " must be a CUDA tensor; it is currently on device ", \ + torch_tensor_device_name(x)) + +// A wrapper class for passing dynamically sized dimension information (e.g. +// tensor.dims()) from the host to device. +constexpr size_t kStackArrayMaxDims = 5; + +template +struct StackArray { + T vals[kStackArrayMaxDims]; + size_t ndim; +}; + +// Warp size +#ifdef USE_ROCM +static constexpr int32_t kWarpSize = 64; +#else +static constexpr int32_t kWarpSize = 32; +#endif +// Max thread num in one thread block +static constexpr int32_t kMaxThreads = 1024; + +#define DEVICE_INLINE __device__ C10_ALWAYS_INLINE + +__host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { + return (a + b - 1) / b; +} + +__host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) { + return a / b * b; +} + +inline std::tuple> check_shape_and_partition_( + const Tensor& values, + const std::vector& offsets, + const Tensor& dense_tensor) { + const int outer_dense_size = dense_tensor.size(0); + TORCH_CHECK( + outer_dense_size == offsets[0].numel() - 1, + "outer_dense_size, ", + outer_dense_size, + " != offsets[0].numel() - 1, ", + offsets[0].numel() - 1); + const int inner_dense_size = dense_tensor.size(-1); + TORCH_CHECK( + inner_dense_size == values.size(-1), + "inner_dense_size, ", + inner_dense_size, + " != values.size(-1), ", + values.size(-1)); + const int jagged_folded_size = + dense_tensor.numel() / (outer_dense_size * inner_dense_size); + + const int threads_x = + inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size; + const int threads_y = kMaxThreads / kWarpSize; + const dim3 blocks( + div_round_up(outer_dense_size * jagged_folded_size, threads_y)); + + StackArray jagged_dims_tensor; + const int num_jagged_dim = dense_tensor.dim() - 2; + TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + jagged_dims_tensor.ndim = num_jagged_dim; + std::memcpy( + &(jagged_dims_tensor.vals[0]), + dense_tensor.sizes().data() + 1, + num_jagged_dim * sizeof(int64_t)); + return {dim3(threads_x, threads_y), blocks, jagged_dims_tensor}; +} + +template +DEVICE_INLINE bool walk_down_tensor_storage_tree_( + int& offset, + const int flattened_jagged_idx, + const StackArray& jagged_dims, + const StackArray& x_offsets) { + // compute coorindates + int jagged_coords[NUM_JAGGED_DIM]; + int j_temp = flattened_jagged_idx; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + const int jagged_size = jagged_dims.vals[d]; + jagged_coords[d] = j_temp % jagged_size; + j_temp /= jagged_size; + } + + // walk down the tree + bool is_zero = false; +#pragma unroll + for (int d = 0; d < NUM_JAGGED_DIM; ++d) { + const int begin = x_offsets.vals[d][offset]; + const int end = x_offsets.vals[d][offset + 1]; + if (jagged_coords[d] >= end - begin) { + is_zero = true; + break; + } + offset = begin + jagged_coords[d]; + } + return is_zero; +} + +// output = f(x, y) where x is jagged, y is dense, and output is dense. +// A generic elementwise operation between a jagged tensor and a dense tensor +// This kernel assumes jagged dims are clustered together, preceded by outer +// dense dimensions and followed by inner dense dimensions. +// The outer/inner dense dimensions, and jagged dimensions in between are +// assumed to be folded so physically the dense tensor is 3D and the value of +// jagged tensor is 2D. +// To support arbitrary number of jagged dimensions, we pass a vector of +// pointers to offset tensors (this is ugly and probably we can use nested +// tensor here). +// This kernel parallelizes the (folded) inner dense dimension across +// blockDim.x so the inner dense dimension should be similar to or bigger than +// warp size. +// We rely on compiler unrolling the compiler time constant NUM_JAGGED_DIM. +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + const at::PackedTensorAccessor32 y, + at::PackedTensorAccessor32 output, + StackArray jagged_dims, + F f, + const scalar_t padding_value) { + const int outer_dense_size = y.size(0); + const int jagged_folded_size = y.size(1); + const int inner_dense_size = y.size(2); + + const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int outer_stride = gridDim.x * blockDim.y; + for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size; + outer += outer_stride) { + const int oidx = outer / jagged_folded_size; + const int jidx = outer % jagged_folded_size; + + int offset = oidx; + const bool is_zero = walk_down_tensor_storage_tree_( + offset, jidx, jagged_dims, x_offsets); + + if (is_zero) { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(padding_value, y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + } + } + } +} + +template +void jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + F f, + const scalar_t padding_value = static_cast(0)) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim ", + num_jagged_dim); + + if (y.numel() == 0) { + return; + } + + dim3 threads, blocks; + StackArray jagged_dims_tensor; + std::tie(threads, blocks, jagged_dims_tensor) = + check_shape_and_partition_(x_values, x_offsets, y); + + // Canonicalize y and output to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + Tensor output_reshaped = output.view(y_reshaped.sizes()); + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + } \ + jagged_dense_elementwise_dense_output_kernel_ \ + <<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + y_reshaped \ + .packed_accessor32(), \ + output_reshaped \ + .packed_accessor32(), \ + jagged_dims_tensor, \ + f, \ + padding_value); \ + } + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#undef INVOKE_KERNEL_WITH_DIM +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + StackArray x_offsets_sizes, + const at::PackedTensorAccessor32 y_0, + const at::PackedTensorAccessor32 y_1, + at::PackedTensorAccessor32 + output_values, + StackArray jagged_dims, + F f) { + const int outer_dense_size = y_0.size(0); + const int inner_dense_size = y_0.size(2); + const int nnz = x_values.size(0); + + const int offset_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int offset_stride = gridDim.x * blockDim.y; + for (int offset = offset_begin; offset < nnz; offset += offset_stride) { + int offset_temp = offset; + int jidx = 0; + bool truncated = false; + int dim_prod = 1; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + // Binary search the first that is bigger than offset + int count = x_offsets_sizes.vals[d] - 1; + int first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (x_offsets.vals[d][idx] <= offset_temp) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + + --first; + int coord = offset_temp - x_offsets.vals[d][first]; + if (coord >= jagged_dims.vals[d]) { + truncated = true; + break; + } + jidx += coord * dim_prod; + dim_prod *= jagged_dims.vals[d]; + offset_temp = first; + } + + if (offset_temp >= outer_dense_size) { + // This can happen when values have more elements than the last element of + // offset + truncated = true; + } + if (!truncated) { + const int oidx = offset_temp; + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], + y_0[oidx][jidx][2 * iidx + 1], + y_1[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], 0, 0); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + } + } + } +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#undef INVOKE_KERNEL_WITH_DIM + +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ int64_t* getPointer() { + extern __shared__ int64_t s_int64_t[]; + return s_int64_t; + } +}; + +template <> +struct SharedMemory { + __device__ int32_t* getPointer() { + extern __shared__ int32_t s_int32_t[]; + return s_int32_t; + } +}; + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_( + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 rows, + at::PackedTensorAccessor32 cols, + int nnz, + int B) { + struct SharedMemory smem; + index_t* offsets_sh = smem.getPointer(); + + for (int i = threadIdx.x; i < B + 1; i += blockDim.x) { + offsets_sh[i] = offsets[i]; + } + __syncthreads(); + int row = threadIdx.x + blockIdx.x * blockDim.x; + if (row >= nnz) + return; + int first = -1; + int count = B - 1; + first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (offsets_sh[idx] <= row) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + --first; + + int dense_row = first; + int offset = offsets_sh[dense_row]; + int dense_col = row - offset; + rows[row] = dense_row; + cols[row] = dense_col; +} + +struct VecType128 { + typedef float4 TType; // Transaction Type + typedef struct __align__(16) { + __half a, b, c, d, w, x, y, z; + } + half8; + + union Data { + half8 val; + TType mask; + } data; + + __device__ VecType128() { + data.mask = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } +}; + +struct VecType64 { + typedef float2 TType; // Transaction Type + typedef struct __align__(8) { + __half a, b, c, d; + } + half4; + + union Data { + half4 val; + TType mask; + } data; + + __device__ VecType64() { + data.mask = make_float2(0.0f, 0.0f); + } +}; + +struct VecType32 { + typedef float TType; // Transaction Type + + union Data { + __half2 val; + TType mask; + } data; + + __device__ VecType32() { + data.mask = 0.0f; + } +}; + +template +__device__ void f128( + VecType128& v_out, + const VecType128& x, + const VecType128& y0, + const VecType128& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); + v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w); + v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x); + v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y); + v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z); +} + +template +__device__ void f64( + VecType64& v_out, + const VecType64& x, + const VecType64& y0, + const VecType64& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); +} + +template +__device__ void f32( + VecType32& v_out, + const VecType32& x, + const VecType32& y0, + const VecType32& y1, + F f) { + v_out.data.val = __halves2half2( + f(__low2half(x.data.val), + __low2half(y0.data.val), + __low2half(y1.data.val)), + f(__high2half(x.data.val), + __high2half(y0.data.val), + __high2half(y1.data.val))); +} + +template +__device__ void +fh(__half& v_out, const __half& x, const __half& y0, const __half& y1, F f) { + v_out = f(x, y0, y1); +} + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( + at::PackedTensorAccessor32 values, + const at::PackedTensorAccessor32 + x_values, + const at::PackedTensorAccessor32 y0, + const at::PackedTensorAccessor32 y1, + const at::PackedTensorAccessor32 rows, + const at::PackedTensorAccessor32 cols, + const int nnz, + const int E, + F f) { + int values_row = threadIdx.y + blockIdx.y * blockDim.y; + if (values_row >= nnz) + return; + for (int real_row = values_row; real_row < nnz; + real_row += blockDim.y * gridDim.y) { + int dense_row = rows[real_row]; + int dense_col = cols[real_row]; + __half* values_ptr = reinterpret_cast<__half*>(&values[real_row][0]); + const __half* x_ptr = + reinterpret_cast(&x_values[real_row][0]); + const __half* y0_ptr = + reinterpret_cast(&y0[dense_row][dense_col][0]); + const __half* y1_ptr = + reinterpret_cast(&y1[dense_row][dense_col][0]); + if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) && + (dense_col < y1.size(1)) && (dense_row < y1.size(0)) && + (dense_col >= 0) && (dense_row >= 0)) { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + v_y0 = static_cast<__half>(y0_ptr[tid]); + v_y1 = static_cast<__half>(y1_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } else { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } + } +} + +// Check to see if the inputs to the op are amenable to the fast path +inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( + const int& num_jagged_dim, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y_0_reshaped, + const Tensor& y_1_reshaped, + const Tensor& output_values) { + bool matches = true; + matches &= (num_jagged_dim == 1); + + // Unit stride embedding dim + matches &= (x_values.stride(-1) == 1); + matches &= (output_values.stride(-1) == 1); + matches &= (y_0_reshaped.stride(-1) == 1); + matches &= (y_1_reshaped.stride(-1) == 1); + + // Each row is aligned to 128-bit + matches &= (x_values.stride(-2) % 8 == 0); + matches &= (output_values.stride(-2) % 8 == 0); + matches &= (y_0_reshaped.stride(-2) % 8 == 0); + matches &= (y_1_reshaped.stride(-2) % 8 == 0); + + // Base addresses aligned to 128-bit + matches &= (reinterpret_cast(x_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(output_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_0_reshaped.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_1_reshaped.data_ptr()) % 16 == 0); + + // Rows and col fit into int32_t + matches &= (y_0_reshaped.size(0) < INT_MAX); + matches &= (y_0_reshaped.size(1) < INT_MAX); + + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_0_reshaped.get_device())); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "check_shared_memory", [&] { + auto B = y_0_reshaped.size(0); + // the default shared memory on V100/A100/H100 is 48 KB from + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x + if ((B + 1) * sizeof(index_t) >= used_shared_bytes) { + matches = false; + } + }); + return matches; +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +inline int calc_used_shared_bytes(const int device) { + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device)); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + return used_shared_bytes; +} + +template +inline void set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes) { +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB +#endif +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_opt_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + if (jagged_dense_dense_elementwise_jagged_output_matches_opt( + num_jagged_dim, + x_values, + x_offsets, + y_reshaped, + y_reshaped, + output_values)) { + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] { + auto nnz = output_values.size(0); + auto B = y_reshaped.size(0); + auto E = y_reshaped.size(2); + Tensor t_rows_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + Tensor t_cols_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + + // Binary search + size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto cur_max_shared_bytes = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + if (dynamic_smem_size > cur_max_shared_bytes) { + int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device()); + set_max_dynamic_shared_mem_size_for_opt_search_kernel(used_shared_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + } + dim3 threads_bs = dim3(1024, 1, 1); + dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t> + <<>>( + x_offsets[0] + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + B); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Gather kernel + dim3 threads = dim3(16, 16, 1); + dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1); + if (blocks.y > 65535) { + blocks.y = 65535; + } + jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< + index_t> + <<>>( + output_values + .packed_accessor32(), + x_values + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + E, + [f] __device__(__half x, __half y0, __half) -> __half { + // NB: added the static_casts here + return static_cast<__half>( + f(static_cast(x), static_cast(y0)) + ); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); // AT_DISPATCH + } else { + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +at::Tensor _fbgemm_jagged_to_padded_dense_forward( + const Tensor& values, + TensorList offsets, + c10::IntArrayRef max_lengths, + const double padding_value) { + const size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + + // Canonicalize padded_values by unsqueeze the last dim if the inner dense + // dimension is 1 and folded. + const bool D_folded = values.dim() == 1; + if (!D_folded) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = + at::empty_symint(padded_values_shape, values.options()); + Tensor padded_values_view = + D_folded ? padded_values.unsqueeze(-1) : padded_values; + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + values.scalar_type(), + "jagged_to_padded_dense", + [&] { + jagged_dense_elementwise_dense_output_( + values_canonicalized, + offsets.vec(), + padded_values_view, // dummy not used in the lambda function + padded_values_view, + [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t { + return x; + }, + static_cast(padding_value)); + }); + + return padded_values; +} + +#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \ + AT_DISPATCH_CASE(TYPE, [&] { \ + jagged_dense_elementwise_jagged_output_opt_( \ + values, \ + offsets.vec(), \ + dense, \ + output, \ + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \ + return y; \ + }); \ + }) + +Tensor _fbgemm_dense_to_jagged_forward_symint( + const Tensor& dense, + TensorList offsets, + c10::optional total_L) { + // D is the embedding dimension + auto D = dense.size(-1); + + // If total_L is not given then compute it + at::SymInt total_L_computed; + if (total_L.has_value()) { + total_L_computed = total_L.value(); + } else { + total_L_computed = (int64_t)offsets.back().max().item(); + } + auto values = at::empty_symint({total_L_computed, D}, dense.options()); + auto output = at::empty_like(values); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(dense.get_device()); + + // clang-format off + AT_DISPATCH_SWITCH( + values.scalar_type(), + "dense_to_jagged_gpu_op_forward", + DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half) + // NB: removed this to build + // DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int) + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( + at::ScalarType::Long, + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + values, + offsets.vec(), + dense, + output, + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { + return y; + }); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND + ); // SWITCH + // clang-format on + +#undef DISPATCH_DENSE_TO_JAGGED_CASE + + return output; +} + } // namespace native } // namespace at diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index f9bf58c5f4746..ad9cf07d75503 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -388,6 +388,7 @@ aten::_int_mm aten::_int_mm.out aten::_is_all_true aten::_is_any_true +aten::_jagged_to_padded_dense_forward aten::_lazy_clone aten::_linalg_check_errors aten::_linalg_det @@ -477,6 +478,7 @@ aten::_nnpack_spatial_convolution.out aten::_nnz aten::_pack_padded_sequence aten::_pack_padded_sequence.out +aten::_padded_dense_to_jagged_forward aten::_pdist_backward aten::_pdist_backward.out aten::_pdist_forward diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index d369135a6e526..a468aac7160ce 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4414,6 +4414,31 @@ def forward(self, query, value, offsets): self.assertTrue(torch.allclose(attn_output_eager, attn_output)) self.assertTrue(torch.allclose(value_grad, value.grad)) + @dtypes(torch.float64, torch.float32, torch.half) + @onlyCUDA + def test_fbgemm_jagged_to_padded_dense_kernels(self, device, dtype): + values = torch.randn(10, 5, device=device, dtype=dtype) + offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) + max_length = offsets.diff().max().item() + padding_value = 1.3 + + # convert jagged -> padded dense + padded = torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets], [max_length], padding_value + ) + + batch_size = offsets.shape[0] - 1 + expected_padded_shape = (batch_size, max_length, values.shape[-1]) + self.assertEqual(padded.shape, expected_padded_shape) + + # convert padded dense -> jagged + total_L = values.shape[0] + output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], total_L + ) + + # should be equivalent to the original values + self.assertEqual(values, output_jagged) instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals())