Skip to content

Commit

Permalink
tune the tile sizes for jagged_dense_bmm (#1692)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1692

Tune the tile sizes based on the input tensor size. If M > N, then use larger tile size in M dimension, otherwise use larger tile size in N dimension.

Reviewed By: brad-mengchi

Differential Revision: D44791699

fbshipit-source-id: fe45b508f1b8a1e61bd4be231aebf3bd7b26443e
  • Loading branch information
Rengan Xu authored and facebook-github-bot committed Apr 7, 2023
1 parent 194265f commit 2d46d4e
Showing 1 changed file with 82 additions and 39 deletions.
121 changes: 82 additions & 39 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2530,48 +2530,91 @@ Tensor jagged_dense_bmm_forward(
// memory size limit and occupancy
// TODO: autotune these parameters based on max_L and input and output
// tensor sizes
constexpr int BLOCK_TILE_M = 64;
constexpr int BLOCK_TILE_N = 8;
constexpr int BLOCK_TILE_K = 8;
constexpr int THREAD_TILE_M = 4;
constexpr int THREAD_TILE_N = 4;
constexpr int THREAD_TILE_M = 2;
constexpr int THREAD_TILE_N = 2;
const dim3 block(
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
const auto grid_dim_x = div_round_up(N, BLOCK_TILE_N);
const auto grid_dim_y = div_round_up(max_L, BLOCK_TILE_M);
TORCH_CHECK(
grid_dim_y <= kMaxBlockYDim,
"max_L cannot be larger than",
grid_dim_y * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
const auto grid_dim_z = std::min(B, kMaxBlockZDim);
const dim3 grid(grid_dim_x, grid_dim_y, grid_dim_z);
if (M > N) {
constexpr int BLOCK_TILE_M = 32;
constexpr int BLOCK_TILE_N = 8;
AT_DISPATCH_INDEX_TYPES(
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_dense_bmm_kernel_2",
[&] {
jagged_dense_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
x_values.packed_accessor32<scalar_t, 2>(),
x_offsets.packed_accessor32<index_t, 1>(),
y.packed_accessor32<scalar_t, 3>(),
output.packed_accessor32<scalar_t, 2>(),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
const dim3 block(
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
const auto grid_dim_x = div_round_up(N, BLOCK_TILE_N);
const auto grid_dim_y = div_round_up(max_L, BLOCK_TILE_M);
TORCH_CHECK(
grid_dim_y <= kMaxBlockYDim,
"max_L cannot be larger than",
kMaxBlockYDim * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
const auto grid_dim_z = std::min(B, kMaxBlockZDim);
const dim3 grid(grid_dim_x, grid_dim_y, grid_dim_z);
AT_DISPATCH_INDEX_TYPES(
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_dense_bmm_kernel_2",
[&] {
jagged_dense_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
x_values.packed_accessor32<scalar_t, 2>(),
x_offsets.packed_accessor32<index_t, 1>(),
y.packed_accessor32<scalar_t, 3>(),
output.packed_accessor32<scalar_t, 2>(),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
constexpr int BLOCK_TILE_M = 8;
constexpr int BLOCK_TILE_N = 32;
const dim3 block(
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
const auto grid_dim_x = div_round_up(N, BLOCK_TILE_N);
const auto grid_dim_y = div_round_up(max_L, BLOCK_TILE_M);
TORCH_CHECK(
grid_dim_y <= kMaxBlockYDim,
"max_L cannot be larger than",
kMaxBlockYDim * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
const auto grid_dim_z = std::min(B, kMaxBlockZDim);
const dim3 grid(grid_dim_x, grid_dim_y, grid_dim_z);
AT_DISPATCH_INDEX_TYPES(
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_dense_bmm_kernel_2",
[&] {
jagged_dense_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
x_values.packed_accessor32<scalar_t, 2>(),
x_offsets.packed_accessor32<index_t, 1>(),
y.packed_accessor32<scalar_t, 3>(),
output.packed_accessor32<scalar_t, 2>(),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
}
return output;
Expand Down

0 comments on commit 2d46d4e

Please sign in to comment.