diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py index 997e2ccd1f..cc1ac3150c 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py @@ -289,7 +289,7 @@ def triton_quantize_mx4_unpack( stochastic_casting (bool): Whether to use stochastic casting. Returns: - torch.Tensor: [M / 2] mx4 scaled tensor packed into in8 + torch.Tensor: [M / 2] mx4 scaled tensor packed into uint8 torch.Tensor: [M / group_size] mx4 shared exponents into int8 eg. diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index f7e05758bc..f351ce0ee2 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -2385,7 +2385,7 @@ def quantize(self, x, w): def compute(self, xq, wq, x_scale, w_scale, global_scale): return torch.ops.fbgemm.f4f4bf16( - xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False + xq, wq, x_scale, w_scale, global_scale=global_scale ) def quantize_and_compute(self, x, w): @@ -2471,7 +2471,7 @@ def quantize(self, x, w): def compute(self, xq, wq, x_scale, w_scale, global_scale): return torch.ops.fbgemm.f4f4bf16( - xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False + xq, wq, x_scale, w_scale, global_scale=global_scale ) def quantize_and_compute(self, x, w): diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu index 6c3665f542..29c36f2eef 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu @@ -9,7 +9,6 @@ #include #include #include -// clang-format on #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) #include "f4f4bf16/f4f4bf16_manifest.cuh" @@ -19,147 +18,111 @@ namespace fbgemm_gpu { #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) -at::Tensor dispatch_f4f4bf16_kernel( - at::Tensor XQ, // FP4 - at::Tensor WQ, // FP4 - at::Tensor x_scale, - at::Tensor w_scale, - std::optional global_scale, - bool use_mx = true) { - auto M = XQ.size(0); - auto N = WQ.size(0); - auto K = XQ.size(1) * 2; // Since K is packed - auto BLOCK_SIZE = 16; - TORCH_CHECK( - N % BLOCK_SIZE == 0 && K % BLOCK_SIZE == 0, - "Weight dimensions N and K must be multiples of block size 16"); - - auto out_sizes = XQ.sizes().vec(); - out_sizes.back() = N; - if (M == 0 || N == 0 || K == 0) { - // Use zeros instead of empty for special case where K=0. - return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16)); - } - +Kernel_f4f4bf16 get_kernel_via_heuristics(int M, int N, int K, bool use_mx) { // MXFP4 if (use_mx) { if (M <= 128) { if (N <= 1024) { - return f4f4bf16_256_128_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } else if (N <= 2048) { - return f4f4bf16_256_192_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_t; } else { - return f4f4bf16_128_128_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_4_1_1_t; } } else if (M <= 2048) { if (M <= 256) { if (N == 896) { - return f4f4bf16_128_128_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_2_2_1_t; } else if (N == 5120) { if (K == 640 || K == 5120) { - return f4f4bf16_128_128_4_1_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_4_1_1_t; } else if ((K == 8192) || (K == 16384)) { - return f4f4bf16_256_128_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_t; } } else if (N == 5632) { - return f4f4bf16_128_192_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_t; } else if (N == 8192) { - return f4f4bf16_256_128_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_t; } } else if (M <= 512) { if (N == 896) { - return f4f4bf16_128_128_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_2_2_1_t; } else if (N == 5120) { - return f4f4bf16_256_192_4_1_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_t; } else if (N == 5632) { - return f4f4bf16_256_128_2_4_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } else if (N == 8192) { - return f4f4bf16_256_128_2_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_t; } } else if (M <= 1024) { if (N == 896) { - return f4f4bf16_256_128_2_4_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } else if (N == 5120) { if (K == 640) { - return f4f4bf16_128_128_1_4_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_1_4_1_t; } else if (K == 5120) { - return f4f4bf16_128_192_4_2_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_4_2_1_t; } else if (K == 5120 || K == 16384) { - return f4f4bf16_256_128_2_4_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } } else if (N == 5632) { - return f4f4bf16_256_128_2_4_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } else if (N == 8192) { - return f4f4bf16_256_256_4_1_1_t( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_4_1_1_t; } } if (N <= 2048) { - return f4f4bf16_256_128_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_t; } else if (N <= 8192) { - return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_t; } else { - return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_t; } } else if (M <= 4096) { if (N <= 4096) { - return f4f4bf16_256_256_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_4_1_1_t; } else if (N <= 8192) { - return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_t; } else { - return f4f4bf16_256_128_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_t; } } else if (M <= 8192) { if (N <= 4096) { - return f4f4bf16_256_256_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_2_1_t; } else if (N <= 8192) { - return f4f4bf16_256_256_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_4_1_t; } else { - return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_t; } } else if (M <= 16384) { if (N <= 2048) { - return f4f4bf16_256_256_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_4_1_t; } else if (N <= 8192) { - return f4f4bf16_128_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_t; } else { - return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_t; } } else if (M <= 32768) { if (N <= 1024) { - return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_t; } else if (N <= 4096) { - return f4f4bf16_128_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_t; } else { - return f4f4bf16_256_192_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_t; } } else if (M <= 65536) { if (N <= 2048) { - return f4f4bf16_256_192_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_4_1_t; } else if (N <= 4096) { - return f4f4bf16_256_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_2_1_t; } else { - return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_t; } } else { if (N <= 1024) { - return f4f4bf16_256_192_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_4_1_t; } else { - return f4f4bf16_256_256_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_2_1_t; } } } @@ -167,121 +130,106 @@ at::Tensor dispatch_f4f4bf16_kernel( else { if (M <= 128) { if (N <= 1024) { - return f4f4bf16_256_128_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } else if (N <= 2048) { - return f4f4bf16_256_192_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_f; } else { - return f4f4bf16_128_128_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_4_1_1_f; } } else if (M <= 2048) { if (M <= 256) { if (N == 896) { - return f4f4bf16_128_128_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_2_2_1_f; } else if (N == 5120) { if (K == 640 || K == 5120) { - return f4f4bf16_128_128_4_1_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_4_1_1_f; } else if ((K == 8192) || (K == 16384)) { - return f4f4bf16_256_128_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_f; } } else if (N == 5632) { - return f4f4bf16_128_192_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_f; } else if (N == 8192 || N == 16384) { - return f4f4bf16_256_128_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_f; } } else if (M <= 512) { if (N == 896) { - return f4f4bf16_128_128_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_2_2_1_f; } else if (N == 5120) { - return f4f4bf16_256_192_4_1_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_f; } else if (N == 5632) { - return f4f4bf16_256_128_2_4_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } else if (N == 8192) { - return f4f4bf16_256_128_2_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_f; } } else if (M <= 1024) { if (N == 896) { - return f4f4bf16_256_128_2_4_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } else if (N == 5120) { if (K == 640) { - return f4f4bf16_128_128_1_4_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_128_1_4_1_f; } else if (K == 5120) { - return f4f4bf16_128_192_4_2_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_4_2_1_f; } else if (K == 5120 || K == 16384) { - return f4f4bf16_256_128_2_4_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } } else if (N == 5632) { - return f4f4bf16_256_128_2_4_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } else if (N == 8192) { - return f4f4bf16_256_256_4_1_1_f( - XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_4_1_1_f; } } if (N <= 2048) { - return f4f4bf16_256_128_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_2_1_f; } else if (N <= 8192) { - return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_f; } else { - return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_f; } } else if (M <= 4096) { if (N <= 4096) { - return f4f4bf16_256_256_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_4_1_1_f; } else if (N <= 8192) { - return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_f; } else { - return f4f4bf16_256_128_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_128_2_4_1_f; } } else if (M <= 8192) { if (N <= 4096) { - return f4f4bf16_256_256_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_2_1_f; } else if (N <= 8192) { - return f4f4bf16_256_256_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_4_1_f; } else { - return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_f; } } else if (M <= 16384) { if (N <= 2048) { - return f4f4bf16_256_256_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_4_1_f; } else if (N <= 8192) { - return f4f4bf16_128_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_f; } else { - return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_256_2_1_1_f; } } else if (M <= 32768) { if (N <= 1024) { - return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_f; } else if (N <= 4096) { - return f4f4bf16_128_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_128_192_2_2_1_f; } else { - return f4f4bf16_256_192_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_4_1_1_f; } } else if (M <= 65536) { if (N <= 2048) { - return f4f4bf16_256_192_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_4_1_f; } else if (N <= 4096) { - return f4f4bf16_256_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_2_1_f; } else { - return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_1_1_f; } } else { if (N <= 1024) { - return f4f4bf16_256_192_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_192_2_4_1_f; } else { - return f4f4bf16_256_256_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale); + return f4f4bf16_256_256_2_2_1_f; } } } @@ -292,10 +240,32 @@ at::Tensor f4f4bf16( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, - std::optional global_scale, - bool use_mx = true) { - return dispatch_f4f4bf16_kernel( - XQ, WQ, x_scale, w_scale, global_scale, use_mx); + std::optional output, + std::optional global_scale) { + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(x_scale.is_cuda() && x_scale.is_contiguous()); + TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); + + const auto M = XQ.size(0); + const auto N = WQ.size(0); + const auto K = XQ.size(1) * 2; // Since K is packed + constexpr auto BLOCK_SIZE = 16; + TORCH_CHECK( + N % BLOCK_SIZE == 0 && K % BLOCK_SIZE == 0, + "Weight dimensions N and K must be multiples of block size 16"); + + if (M == 0 || N == 0 || K == 0) { + // Use zeros instead of empty for special case where K=0. + return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16)); + } + + at::Tensor out = output.has_value() + ? output.value() + : at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + auto kernel = get_kernel_via_heuristics(M, N, K, !global_scale.has_value()); + return kernel(XQ, WQ, x_scale, w_scale, out, global_scale); } #else @@ -305,8 +275,8 @@ at::Tensor f4f4bf16( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, - std::optional global_scale, - bool use_mx = true) { + std::optional output, + std::optional global_scale) { throw std::runtime_error( "CUDA version is older than 12.8"); // requires CUDA>=12.8 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_f.cu index 7e255d8753..1a8ae988db 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_f( 128, 1, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_t.cu index c66665beb3..5233c74e3c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_t( 128, 1, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_f.cu index aee8e2de7e..7f07dc9cce 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_f( 128, 1, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_t.cu index 103923cc8c..95438c386a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_t( 128, 1, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_f.cu index 9706826224..76c7beace3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_f( 128, 1, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_t.cu index d9d204a72d..ba8d9a235b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_1_4_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_t( 128, 1, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_f.cu index bf3b54e1ea..3214c2ce43 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_f( 128, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_t.cu index 0ee335f2f7..c44802b533 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_2_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_t( 128, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu index 2106e754bd..1e35f7b279 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_f( 128, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu index 7f7dee62b6..1c876d4b9b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_t( 128, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_f.cu index e6ebac9735..a2eda346bb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_f( 128, 4, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_t.cu index 3c7f738e70..a34aa93914 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_t( 128, 4, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_f.cu index e1196b7bef..50ebb4cf04 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_f( 192, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_t.cu index f9c4993661..17609a427b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_t( 192, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu index aded1f49f4..3bf2cf2abd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_f( 192, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_t.cu index eaacaa3cae..cdccbdace6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_t( 192, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_f.cu index bfe528341a..a9f379d0a0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_f( 192, 4, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_t.cu index 571f0d5fd3..d4974f7c53 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_4_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_t( 192, 4, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_f.cu index 424573ce1f..dbb7b39d21 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_f( 256, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_t.cu index 96a60337e6..558db85131 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_t( 256, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_f.cu index 3415c406d5..86bd90817c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_f( 128, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_t.cu index a0e005358c..e97c93b02a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_t( 128, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_f.cu index d24e9a77fb..2d79e75c79 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_f( 128, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_t.cu index ff9d4bc6f7..8fedbf289e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_t( 128, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_f.cu index 3382e537ae..a4fd9245f2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_f( 128, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_t.cu index 6d26d987f6..36395cb7af 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_t( 128, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_f.cu index 6fa5fbfb57..c32746b996 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_f( 128, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_t.cu index ee2d7773d5..9973f6861c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_4_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_t( 128, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_f.cu index 17b5139ff0..996884b364 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_f( 192, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_t.cu index 8825604164..e428244979 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_t( 192, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_f.cu index c0156ec86f..51c7b8c3bc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_f( 192, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_t.cu index 16668e5b56..401f079de9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_t( 192, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_f.cu index 6b6faad21f..cfdf4f5c7b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_f( 192, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_t.cu index 3559071604..37ac3376f3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_t( 192, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_f.cu index e47e2860f9..9c6ecb2edc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_f( 192, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_t.cu index 7e88d92850..4ddb1c9f34 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_t( 192, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_f.cu index 1b0a5b10ed..c2d1c7b617 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_f( 256, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_t.cu index ef28c41d58..1ca2c968c2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_t( 256, 2, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_f.cu index 950784da77..8bc8642972 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_f( 256, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_t.cu index da23ae2dea..ec9d010dce 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_t( 256, 2, 2, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_f.cu index 4bf7619d03..e013e79328 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_f( 256, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_t.cu index dda05249b2..b3cf594325 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_t( 256, 2, 4, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_f.cu index e428d45554..6c5641195e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_f.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_4_1_1_f( 256, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_t.cu index e87039e674..ec4262664b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_t.cu @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_256_256_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale = std::nullopt) { // Dispatch this kernel to the correct underlying implementation. return _f4f4bf16< @@ -25,7 +26,7 @@ at::Tensor f4f4bf16_256_256_4_1_1_t( 256, 4, 1, - 1>(XQ, WQ, x_scale, w_scale, global_scale); + 1>(XQ, WQ, x_scale, w_scale, output, global_scale); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh index 3ea04363ac..b40bdb30e5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh @@ -13,11 +13,11 @@ #include // clang-format off - // The fixed ordering of the headers is required for CUTLASS 3.2+ - #include - #include // @manual - #include // @manual - #include // @manual +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual // clang-format on #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) @@ -34,19 +34,13 @@ at::Tensor _f4f4bf16( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale) { c10::cuda::CUDAGuard deviceGuard(XQ.device()); - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1) * 2; // Since K is packed - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - TORCH_CHECK(x_scale.is_cuda() && x_scale.is_contiguous()); - TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); - - auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + const int M = XQ.size(0); + const int N = WQ.size(0); + const int K = XQ.size(1) * 2; // Since K is packed constexpr int TileShapeK = 128 * 8 / cutlass::sizeof_bits::value; @@ -196,9 +190,9 @@ at::Tensor _f4f4bf16( layout_SFA}, {// Epilogue arguments {1, 0}, - reinterpret_cast(Y.data_ptr()), + reinterpret_cast(output.data_ptr()), stride_output, - reinterpret_cast(Y.data_ptr()), + reinterpret_cast(output.data_ptr()), stride_output}}; if constexpr (std::is_same_v< @@ -241,7 +235,7 @@ at::Tensor _f4f4bf16( } C10_CUDA_KERNEL_LAUNCH_CHECK(); - return Y; + return output; } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_manifest.cuh index 424c5ac22a..619e6d22da 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_manifest.cuh @@ -17,6 +17,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_1_1_1_t( @@ -24,6 +25,7 @@ at::Tensor f4f4bf16_128_128_1_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_1_2_1_f( @@ -31,6 +33,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_1_2_1_t( @@ -38,6 +41,7 @@ at::Tensor f4f4bf16_128_128_1_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_1_4_1_f( @@ -45,6 +49,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_1_4_1_t( @@ -52,6 +57,7 @@ at::Tensor f4f4bf16_128_128_1_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_2_2_1_f( @@ -59,6 +65,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_2_2_1_t( @@ -66,6 +73,7 @@ at::Tensor f4f4bf16_128_128_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_4_1_1_f( @@ -73,6 +81,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_4_1_1_t( @@ -80,6 +89,7 @@ at::Tensor f4f4bf16_128_128_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_4_2_1_f( @@ -87,6 +97,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_128_4_2_1_t( @@ -94,6 +105,7 @@ at::Tensor f4f4bf16_128_128_4_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_2_1_1_f( @@ -101,6 +113,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_2_1_1_t( @@ -108,6 +121,7 @@ at::Tensor f4f4bf16_128_192_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_2_2_1_f( @@ -115,6 +129,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_2_2_1_t( @@ -122,6 +137,7 @@ at::Tensor f4f4bf16_128_192_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_4_2_1_f( @@ -129,6 +145,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_192_4_2_1_t( @@ -136,6 +153,7 @@ at::Tensor f4f4bf16_128_192_4_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_256_2_1_1_f( @@ -143,6 +161,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_128_256_2_1_1_t( @@ -150,6 +169,7 @@ at::Tensor f4f4bf16_128_256_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_1_1_f( @@ -157,6 +177,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_1_1_t( @@ -164,6 +185,7 @@ at::Tensor f4f4bf16_256_128_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_2_1_f( @@ -171,6 +193,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_2_1_t( @@ -178,6 +201,7 @@ at::Tensor f4f4bf16_256_128_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_4_1_f( @@ -185,6 +209,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_2_4_1_t( @@ -192,6 +217,7 @@ at::Tensor f4f4bf16_256_128_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_4_1_1_f( @@ -199,6 +225,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_128_4_1_1_t( @@ -206,6 +233,7 @@ at::Tensor f4f4bf16_256_128_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_1_1_f( @@ -213,6 +241,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_1_1_t( @@ -220,6 +249,7 @@ at::Tensor f4f4bf16_256_192_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_2_1_f( @@ -227,6 +257,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_2_1_t( @@ -234,6 +265,7 @@ at::Tensor f4f4bf16_256_192_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_4_1_f( @@ -241,6 +273,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_2_4_1_t( @@ -248,6 +281,7 @@ at::Tensor f4f4bf16_256_192_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_4_1_1_f( @@ -255,6 +289,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_192_4_1_1_t( @@ -262,6 +297,7 @@ at::Tensor f4f4bf16_256_192_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_1_1_f( @@ -269,6 +305,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_1_1_t( @@ -276,6 +313,7 @@ at::Tensor f4f4bf16_256_256_2_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_2_1_f( @@ -283,6 +321,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_2_1_t( @@ -290,6 +329,7 @@ at::Tensor f4f4bf16_256_256_2_2_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_4_1_f( @@ -297,6 +337,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_2_4_1_t( @@ -304,6 +345,7 @@ at::Tensor f4f4bf16_256_256_2_4_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_4_1_1_f( @@ -311,6 +353,7 @@ at::Tensor f4f4bf16_256_256_4_1_1_f( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); at::Tensor f4f4bf16_256_256_4_1_1_t( @@ -318,7 +361,16 @@ at::Tensor f4f4bf16_256_256_4_1_1_t( at::Tensor WQ, // FP4 at::Tensor x_scale, at::Tensor w_scale, + at::Tensor output, std::optional global_scale); +using Kernel_f4f4bf16 = at::Tensor (*)( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + std::optional); + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h index 587d6ffb53..b16da537ac 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h @@ -48,6 +48,14 @@ at::Tensor f4f4bf16_grouped_mm( std::optional output = std::nullopt, std::optional global_scale = std::nullopt); +at::Tensor f4f4bf16( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional output = std::nullopt, + std::optional global_scale = std::nullopt); + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 415ee37b1b..a8d78b859f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -41,13 +41,6 @@ at::Tensor i8i8bf16_dynamic( at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale); // Cutlass kernel -at::Tensor f4f4bf16( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - std::optional global_scale = std::nullopt, - bool use_mx = true); at::Tensor f4f4bf16_grouped_stacked( at::Tensor XQ, at::Tensor WQ, @@ -412,8 +405,8 @@ at::Tensor f4f4bf16_meta( at::Tensor WQ, // FP4 at::Tensor /* x_scale */, at::Tensor /* w_scale */, - std::optional /* global_scale = std::nullopt */, - bool /* use_mx */) { + std::optional /* output = std::nullopt */, + std::optional /* global_scale = std::nullopt */) { const at::SymInt M = XQ.sym_size(0); const at::SymInt N = WQ.sym_size(0); auto Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16)); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index 7ba8a701b6..2f6b9900da 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -19,7 +19,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // torch.ops.load_library, similar to below for quantize_fp8_per_tensor m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor"); m.def( - "f4f4bf16(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? global_scale=None, bool use_mx=True) -> Tensor"); + "f4f4bf16(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? output=None, Tensor? global_scale=None) -> Tensor"); m.def( "f4f4bf16_grouped_stacked(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor M_sizes, Tensor? global_scale=None, Tensor? starting_row_after_padding=None, bool use_mx=True) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 711a1271cf..c530325f9a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -2328,6 +2328,30 @@ def test_fake_quantize_nvfp4_per_tensor( y_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(fake_quant_y, y_ref, atol=0.1, rtol=0.1) + @settings(deadline=None) + @given( + M=st.sampled_from([1, 250]), + N=st.sampled_from([256, 1024]), + K=st.sampled_from([2048, 3584]), + ) + def test_gemm(self, M: int, N: int, K: int) -> None: + A = torch.randn((M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + B = torch.randn((N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + + global_scales, a_global_scales, b_global_scales = get_nvfp4_global_scales_naive( + [A], + [B], + ) + aqs, a_scales = quantize_nvfp4_naive([A], a_global_scales) + bqs, b_scales = quantize_nvfp4_naive([B], b_global_scales) + + out_nvfp4 = torch.ops.fbgemm.f4f4bf16( + aqs[0], bqs[0], a_scales[0], b_scales[0], None, global_scales[0] + ) + out_bf16 = A @ B.t() + + torch.testing.assert_close(out_nvfp4, out_bf16, atol=5.0e-2, rtol=5.0e-2) + @settings(deadline=None) @given( G=st.sampled_from([1, 4, 16]), @@ -2437,6 +2461,24 @@ class MXFP4Tests(unittest.TestCase): def setUpClass(cls): cls.device = torch.accelerator.current_accelerator() + @settings(deadline=None) + @given( + M=st.sampled_from([1, 250]), + N=st.sampled_from([256, 1024]), + K=st.sampled_from([2048, 3584]), + ) + def test_gemm(self, M: int, N: int, K: int) -> None: + A = torch.randn((M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + B = torch.randn((N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + + aq, a_scale = triton_quantize_mx4_unpack(A) + bq, b_scale = triton_quantize_mx4_unpack(B) + + out_mxfp4 = torch.ops.fbgemm.f4f4bf16(aq, bq, a_scale, b_scale) + out_bf16 = A @ B.t() + + torch.testing.assert_close(out_mxfp4, out_bf16, atol=8.0e-2, rtol=8.0e-2) + @given( G=st.sampled_from([1, 4, 16]), M=st.sampled_from([250, 500, 3500]),