Skip to content

Commit

Permalink
STBE GPU coalescing kernel (#2275)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2275

* this is a follow up on GPU coalescing of stbe output, we follow the old idea of sparse feature rebatching to enable this. sequence embedding is afterall jagged tensor in dim D and stbe output is well aligned with input as [sum_l(T, B), D].
* we explictly fork code for avoiding spaming all function into one reorder kernel suite.

Reviewed By: jspark1105, 842974287

Differential Revision: D52903658

fbshipit-source-id: cbb5f229c54d4701c6563e90e9a19980dcf781e8
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Jan 25, 2024
1 parent 5e9722a commit 17a0604
Show file tree
Hide file tree
Showing 6 changed files with 509 additions and 0 deletions.
96 changes: 96 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,102 @@ def reorder_batched_ad_lengths_bench(
)


@cli.command()
@click.option(
"--batch-size", default=32
) # 32 is the representative inference batch size
@click.option("--table-size", default=20)
@click.option("--length", default=512) # long sequence representative case
@click.option("--num-items", default=100)
@click.option("--dim", default=256)
@click.option("--dtype", type=click.Choice(["half", "float"]), default="half")
@click.option("--itype", type=click.Choice(["int", "long"]), default="int")
@click.option("--device", type=str, default="cpu")
def reorder_batched_sequence_embeddings_bench(
batch_size: int,
table_size: int,
length: int,
num_items: int,
dim: int,
dtype: str,
itype: str,
device: str,
) -> None:
assert (
dtype == "float" or dtype == "half"
), "Only 32/16bits floating point number are supported"
data_type = torch.half if dtype == "half" else torch.float

assert itype == "int" or itype == "long", "Only int and long are supported"
index_type = torch.int64 if itype == "long" else torch.int32

cat_sequence_embeddings = torch.random(
size=(batch_size * table_size * num_items * length * dim),
dtype=data_type,
).to(device)
cat_sequence_embeddings_lengths = (
torch.cat(
[
torch.tensor([length for _ in range(table_size * num_items)])
for _ in range(batch_size)
],
0,
)
.to(index_type)
.to(device)
)

batch_offsets = (
(torch.tensor([num_items * b for b in range(batch_size + 1)]).cuda())
.to(index_type)
.to(device)
)
num_items_in_batch = batch_size * num_items
reordered_cat_sequence_embeddings_lengths = (
torch.ops.fbgemm.reorder_batched_ad_lengths(
cat_sequence_embeddings_lengths,
batch_offsets,
num_items_in_batch,
).to(device)
)

cat_sequence_embeddings_offsets = (
torch.ops.fbgemm.asynchronous_complete_cumsum(cat_sequence_embeddings_lengths)
.to(index_type)
.to(device)
)
reordered_cat_sequence_embeddings_offsets = (
torch.ops.fbgemm.asynchronous_complete_cumsum(
reordered_cat_sequence_embeddings_lengths
)
.to(index_type)
.to(device)
)
time, _ = benchmark_torch_function(
torch.ops.fbgemm.reorder_batched_sequence_embeddings,
(
cat_sequence_embeddings_offsets,
cat_sequence_embeddings,
reordered_cat_sequence_embeddings_offsets,
batch_offsets,
num_items_in_batch,
batch_size * table_size * num_items * length,
),
num_warmups=100,
iters=1000,
)
num_bytes = (
batch_size
* table_size
* num_items
* length
* cat_sequence_embeddings.element_size()
)
logging.info(
f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
)


@cli.command()
@click.option("--num-inputs", default=1024)
@click.option("--rows", default=100)
Expand Down
15 changes: 15 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,14 @@ at::Tensor reorder_batched_ad_indices_gpu(
const bool broadcast_indices = false,
const int64_t num_indices_after_broadcast = -1);

///@ingroup sparse-data-cuda
at::Tensor reorder_batched_sequence_embeddings_gpu(
const at::Tensor& cat_sequence_embeddings_offsets,
const at::Tensor& cat_sequence_embeddings,
const at::Tensor& reordered_cat_sequence_embeddings_offsets,
const at::Tensor& batch_offsets,
const int64_t num_items_in_batch);

///@ingroup sparse-data-cpu
at::Tensor reorder_batched_ad_lengths_cpu(
const at::Tensor& cat_ad_lengths,
Expand All @@ -367,6 +375,13 @@ at::Tensor reorder_batched_ad_indices_cpu(
const bool broadcast_indices = false,
const int64_t num_indices_after_broadcast = -1);
///@ingroup sparse-data-cpu
at::Tensor reorder_batched_sequence_embeddings_cpu(
const at::Tensor& cat_sequence_embeddings_offsets,
const at::Tensor& cat_sequence_embeddings,
const at::Tensor& reordered_cat_sequence_embeddings_offsets,
const at::Tensor& batch_offsets,
const int64_t num_items_in_batch);
///@ingroup sparse-data-cpu
at::Tensor cat_reorder_batched_ad_indices_cpu(
const at::Tensor& cat_ad_offsets,
const std::vector<at::Tensor>& cat_ad_indices,
Expand Down
106 changes: 106 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,107 @@ void cat_reorder_batched_ad_indices_cpu_(
});
}

template <typename index_t, typename scalar_t>
void reorder_batched_sequence_embeddings_cpu_(
const Tensor& cat_sequence_embeddings_offsets,
const Tensor& cat_sequence_embeddings,
const Tensor& reordered_cat_sequence_embeddings_offsets,
const Tensor& batch_offsets,
const int64_t num_items_in_batch,
const int32_t dim,
Tensor& output) {
const int64_t nB = batch_offsets.numel() - 1;
const int64_t nT = (reordered_cat_sequence_embeddings_offsets.numel() - 1) /
num_items_in_batch;

const auto* batch_offsets_data = batch_offsets.data_ptr<index_t>();
const auto* cat_sequence_embeddings_offsets_data =
cat_sequence_embeddings_offsets.data_ptr<index_t>();
const auto* reordered_cat_sequence_embeddings_offsets_data =
reordered_cat_sequence_embeddings_offsets.data_ptr<index_t>();
const auto* cat_sequence_embeddings_data =
cat_sequence_embeddings.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
at::parallel_for(
0, nB * nT, FALSE_SHARING_PAD, [&](int64_t tb_begin, int64_t tb_end) {
auto b_begin = tb_begin / nT;
auto b_end = (tb_end + nT - 1) / nT;

for (const auto b : c10::irange(b_begin, b_end)) {
const auto num_ads_b =
batch_offsets_data[b + 1] - batch_offsets_data[b];
int64_t t_begin = (b == b_begin) ? tb_begin % nT : 0;
int64_t t_end =
(b == b_end - 1 && tb_end % nT != 0) ? tb_end % nT : nT;
for (const auto t : c10::irange(t_begin, t_end)) {
const auto output_segment_offset_start =
t * num_items_in_batch + batch_offsets_data[b];
const auto output_segment_start =
reordered_cat_sequence_embeddings_offsets_data
[output_segment_offset_start] *
dim;
const int32_t input_segment_offset_start =
nT * batch_offsets_data[b] + t * num_ads_b;
const int32_t input_segment_offset_end =
input_segment_offset_start + num_ads_b;
const auto input_segment_start =
cat_sequence_embeddings_offsets_data
[input_segment_offset_start] *
dim;
const auto input_segment_end =
cat_sequence_embeddings_offsets_data[input_segment_offset_end] *
dim;
const auto num_elements = (input_segment_end - input_segment_start);

for (auto i : c10::irange(num_elements)) {
// TODO memcpy once this path is heavily used?
output_data[output_segment_start + i] =
cat_sequence_embeddings_data[input_segment_start + i];
}
}
}
});
}

Tensor reorder_batched_sequence_embeddings_cpu(
const Tensor& cat_sequence_embeddings_offsets,
const Tensor& cat_sequence_embeddings,
const Tensor& reordered_cat_sequence_embeddings_offsets,
const Tensor& batch_offsets,
const int64_t num_items_in_batch) {
TENSOR_ON_CPU(cat_sequence_embeddings_offsets);
TENSOR_ON_CPU(cat_sequence_embeddings);
TENSOR_ON_CPU(reordered_cat_sequence_embeddings_offsets);
TENSOR_ON_CPU(batch_offsets);
TORCH_CHECK(cat_sequence_embeddings.dim() == 2);
// reorder embeddings from (ragged) [B x T x #num_ads_B_{i} x length_{B_{i},
// t, a})x D] to [T][B][#num_ads_b][length_{b, t, a}][D], i.e.
// [sum(length_{B_{i}, t, a}), D]
Tensor reordered_cat_ad_indices = at::empty_like(
cat_sequence_embeddings, cat_sequence_embeddings.options());

AT_DISPATCH_INDEX_TYPES(
cat_sequence_embeddings_offsets.scalar_type(),
"reorder_batched_sequence_embeddings_cpu_kernel_1",
[&] {
AT_DISPATCH_ALL_TYPES(
cat_sequence_embeddings.scalar_type(),
"reorder_eorder_batched_sequence_embeddings_cpu_kernel_2",
[&] {
reorder_batched_sequence_embeddings_cpu_<index_t, scalar_t>(
cat_sequence_embeddings_offsets,
cat_sequence_embeddings,
reordered_cat_sequence_embeddings_offsets,
batch_offsets,
num_items_in_batch,
cat_sequence_embeddings.size(1),
reordered_cat_ad_indices);
});
});

return reordered_cat_ad_indices;
}

Tensor reorder_batched_ad_indices_cpu(
const Tensor& cat_ad_offsets,
const Tensor& cat_ad_indices,
Expand Down Expand Up @@ -2752,6 +2853,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"asynchronous_complete_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"reorder_batched_sequence_embeddings(Tensor cat_sequence_embeddings_offsets, Tensor cat_sequence_embeddings, Tensor reordered_cat_sequence_embeddings_offsets, Tensor batch_offsets, SymInt num_items_in_batch) -> Tensor");
m.def(
"reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, SymInt num_ads_in_batch, bool broadcast_lengths=False) -> Tensor");
m.def(
Expand Down Expand Up @@ -2856,6 +2959,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"cat_reorder_batched_ad_indices",
fbgemm_gpu::cat_reorder_batched_ad_indices_cpu);
DISPATCH_TO_CPU(
"reorder_batched_sequence_embeddings",
fbgemm_gpu::reorder_batched_sequence_embeddings_cpu);
DISPATCH_TO_CPU("offsets_range", fbgemm_gpu::offsets_range_cpu);
DISPATCH_TO_CPU(
"batched_unary_embeddings",
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_gpu);
DISPATCH_TO_CUDA(
"reorder_batched_ad_indices", fbgemm_gpu::reorder_batched_ad_indices_gpu);
DISPATCH_TO_CUDA(
"reorder_batched_sequence_embeddings",
fbgemm_gpu::reorder_batched_sequence_embeddings_gpu);
DISPATCH_TO_CUDA(
"batched_unary_embeddings",
fbgemm_gpu::lookup_batched_unary_embedding_function);
Expand Down
113 changes: 113 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,117 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
return reordered_cat_ad_indices;
}
template <typename Dtype, typename index_t = int32_t>
__global__
__launch_bounds__(kMaxThreads) void reorder_batched_sequence_embeddings_kernel(
// reorder embeddings from (ragged) [B x T x #num_ads_B_{i} x length_{B_{i},
// t, a})x D] to [T][B][#num_ads_b][length_{b, t, a}][D], i.e.
// [sum(length_{B_{i}, t, a}), D]
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_sequence_embeddings_offsets,
const at::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
cat_sequence_embeddings,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reordered_cat_sequence_embeddings_offsets,
at::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
reordered_cat_sequence_embeddings,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
batch_offsets,
const int32_t T,
const int32_t D) {
const int32_t B = batch_offsets.size(0) - 1;
const int32_t num_items_in_batch = batch_offsets[B];
// warp-per-segment.
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t b = b_t % B;
const int32_t t = b_t / B;
if (t >= T) {
return;
}
const auto num_ads_b = batch_offsets[b + 1] - batch_offsets[b];
const auto output_segment_offset_start =
t * num_items_in_batch + batch_offsets[b];
const auto output_segment_start =
reordered_cat_sequence_embeddings_offsets[output_segment_offset_start];
const int32_t input_segment_offset_start =
T * batch_offsets[b] + t * num_ads_b;
const int32_t input_segment_offset_end =
input_segment_offset_start + num_ads_b;
const auto input_segment_start =
cat_sequence_embeddings_offsets[input_segment_offset_start];
const auto input_segment_end =
cat_sequence_embeddings_offsets[input_segment_offset_end];
const auto num_elements = input_segment_end - input_segment_start;
for (size_t i = 0; i < input_segment_end - input_segment_start; i++) {
const auto output_offset = output_segment_start + i;
const auto input_offset = input_segment_start + i;
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
reordered_cat_sequence_embeddings[output_offset][d] =
cat_sequence_embeddings[input_offset][d];
}
}
}
DLL_PUBLIC Tensor reorder_batched_sequence_embeddings_gpu(
const Tensor& cat_sequence_embeddings_offsets,
const Tensor& cat_sequence_embeddings,
const Tensor& reordered_cat_sequence_embeddings_offsets,
const Tensor& batch_offsets,
const int64_t num_items_in_batch) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
cat_sequence_embeddings_offsets,
cat_sequence_embeddings,
reordered_cat_sequence_embeddings_offsets,
batch_offsets);
const auto cat_sequence_embeddings_contig =
cat_sequence_embeddings.expect_contiguous();
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(cat_sequence_embeddings_offsets.get_device());
const int64_t B = batch_offsets.numel() - 1;
const int64_t T = (reordered_cat_sequence_embeddings_offsets.numel() - 1) /
num_items_in_batch;
const int64_t D = cat_sequence_embeddings.size(1);
Tensor reordered_cat_sequence_embeddings =
at::empty_like(cat_sequence_embeddings);
const dim3 threads(32, 32);
const dim3 blocks((B * T + 32 - 1) / 32);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
cat_sequence_embeddings.scalar_type(),
"reorder_batched_sequence_embeddings_gpu_kernel_1",
[&] {
AT_DISPATCH_INDEX_TYPES(
cat_sequence_embeddings_offsets.scalar_type(),
"reorder_batched_sequence_embeddings_gpu_kernel_2",
[&] {
reorder_batched_sequence_embeddings_kernel<scalar_t, index_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_sequence_embeddings_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
cat_sequence_embeddings_contig
->packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
reordered_cat_sequence_embeddings_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
reordered_cat_sequence_embeddings
.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
T,
D);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return reordered_cat_sequence_embeddings;
}
} // namespace fbgemm_gpu
Loading

0 comments on commit 17a0604

Please sign in to comment.