Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,27 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
at::Tensor w_scale, // FP32
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
const int arch = getDeviceArch();

TORCH_CHECK(
(XQ.dim() == 3 && WQ.dim() == 3),
"FP8 rowwise batched GEMM only supports 3D inputs");
int M, N;
M = XQ.size(1);
N = WQ.size(1);

const bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
if (use_e5m2) {
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn, "XQ must be FP8 e4m3fn");
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "WQ must be FP8 e4m3fn");
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
"Scale tensors must be float32.");
if (bias.has_value()) {
TORCH_CHECK(
arch == 9, "f8f8bf16_rowwise_batched only supports FP8 e5m2 on SM90");
return f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f_e5m2(
XQ, WQ, x_scale, w_scale, bias, output);
bias.value().dtype() == at::kFloat,
"Bias type must be float32 if provided.");
}

const int arch = getDeviceArch();
if (arch == 10) {
if ((M * N <= 4096 * 4096) || (N % 256 > 0 && M % 256 == 0) ||
(M % 256 > 0 && N % 256 > 0) || M >= 1024 && N >= 1024) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
128,
128,
128,
1,
2,
1,
10,
true,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 1, 2, 1, 10, true>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
128,
128,
128,
1,
2,
1,
9,
true,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 1, 2, 1, 9, true>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
128,
128,
128,
2,
1,
1,
10,
true,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 2, 1, 1, 10, true>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
128,
128,
128,
2,
1,
1,
9,
true,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 2, 1, 1, 9, true>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
64,
128,
128,
1,
2,
1,
10,
false,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 1, 2, 1, 10, false>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
64,
128,
128,
1,
2,
1,
9,
false,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 1, 2, 1, 9, false>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
64,
128,
128,
2,
1,
1,
10,
false,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 2, 1, 1, 10, false>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f(
std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> output = std::nullopt) {
// Dispatch this kernel to the correct underlying implementation.
return f8f8bf16_rowwise_batched_wrapper<
64,
128,
128,
2,
1,
1,
9,
false,
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 2, 1, 1, 9, false>(
XQ, WQ, x_scale, w_scale, bias, output);
}

} // namespace fbgemm_gpu

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ template <
int TBS_N,
int TBS_K,
int ARCH,
bool PONG,
typename INPUT_DTYPE>
bool PONG>
at::Tensor f8f8bf16_rowwise_batched_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
Expand All @@ -58,8 +57,6 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
}

TORCH_CHECK(XQ.size(-1) == K);
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

at::Tensor Y;
if (output.has_value()) {
Expand All @@ -69,7 +66,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
Y = at::empty({B, M, N}, XQ.options().dtype(at::kBFloat16));
}

using ElementInputA = INPUT_DTYPE;
using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(ElementInputA);

Expand Down Expand Up @@ -302,45 +299,6 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
return Y;
}

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
int ARCH,
bool PONG,
typename INPUT_DTYPE>
at::Tensor f8f8bf16_rowwise_batched_wrapper(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> output) {
// Check datatypes.
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
"Scale tensors must be float32.");
if (bias.has_value()) {
TORCH_CHECK(
bias.value().dtype() == at::kFloat,
"Bias type must be float32 if provided.");
}

return f8f8bf16_rowwise_batched_impl<
TB_M,
TB_N,
TB_K,
TBS_M,
TBS_N,
TBS_K,
ARCH,
PONG,
INPUT_DTYPE>(XQ, WQ, x_scale, w_scale, bias, output);
}

#else

template <
Expand Down
Loading