From fe841adfe5a6d272d02b6687eb0d7b85b77035c6 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 29 Sep 2025 16:37:01 -0400 Subject: [PATCH] Revert "[CPU] Add ops for float8 linear (#3052)" This reverts commit 5e90c477f6d49a17540bd7431728bfb2457610ee. --- test/test_ops.py | 93 --- .../csrc/cpu/aten_kernels/float8_linear.cpp | 607 ------------------ torchao/csrc/cpu/aten_kernels/utils.h | 111 ---- torchao/ops.py | 75 +-- 4 files changed, 1 insertion(+), 885 deletions(-) delete mode 100644 torchao/csrc/cpu/aten_kernels/float8_linear.cpp delete mode 100644 torchao/csrc/cpu/aten_kernels/utils.h diff --git a/test/test_ops.py b/test/test_ops.py index 11ad59db27..c55c2cdb81 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -40,14 +40,7 @@ except RuntimeError: pytest.skip("torchao.ops not available") -from torchao.quantization import PerGroup, PerRow, PerTensor -from torchao.quantization.quant_primitives import ( - _choose_scale_float8, - _dequantize_affine_float8, - _quantize_affine_float8, -) from torchao.quantization.utils import ( - get_block_size, get_groupwise_affine_qparams, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -908,91 +901,5 @@ def _test_scaled_embedding_bag_cpu_helper( torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) -@pytest.mark.skipif( - "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), - reason="cpp kernels not built", -) -@pytest.mark.parametrize( - "multi_hot, batch_size, vector_size, index_type", - EMBEDINGBAG_TEST_PARAMS, - ids=str, -) -def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type): - _test_scaled_embedding_bag_cpu_helper( - multi_hot, batch_size, vector_size, index_type, torch.int8 - ) - - -@pytest.mark.skipif( - "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), - reason="cpp kernels not built", -) -@pytest.mark.parametrize( - "multi_hot, batch_size, vector_size, index_type", - EMBEDINGBAG_TEST_PARAMS, - ids=str, -) -def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type): - _test_scaled_embedding_bag_cpu_helper( - multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn - ) - - -@pytest.mark.skipif( - "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu") - or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), - reason="cpp kernels not built", -) -@pytest.mark.skipif( - not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+" -) -@pytest.mark.parametrize("shape", [(64, 64), (256, 256)]) -@pytest.mark.parametrize("bs", [1, 160]) -@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half]) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)]) -@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)]) -def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity): - in_feature, out_feature = shape - if isinstance(x_granularity, PerGroup): - if x_granularity.group_size >= in_feature: - return - if not isinstance(w_granularity, PerGroup): - return - if isinstance(w_granularity, PerGroup): - if w_granularity.group_size >= in_feature: - return - m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval() - b = m.bias - x = torch.randn(bs, in_feature) - x_block_size = get_block_size(x.shape, x_granularity) - x_scale = _choose_scale_float8( - x, - float8_dtype=torch.float8_e4m3fn, - block_size=x_block_size, - ) - x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn) - - w = m.weight.detach() - w_block_size = get_block_size(w.shape, w_granularity) - w_scale = _choose_scale_float8( - w, - float8_dtype=torch.float8_e4m3fn, - block_size=w_block_size, - ) - w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn) - - x_dq = _dequantize_affine_float8(x_fp8, x_scale) - w_dq = _dequantize_affine_float8(w_fp8, w_scale) - ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype) - - packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale) - y = torch.ops.torchao.float8_linear_cpu( - x_fp8, x_scale, packed_w, packed_scale, b, out_dtype - ) - - torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2) - - if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp deleted file mode 100644 index e9b290771d..0000000000 --- a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp +++ /dev/null @@ -1,607 +0,0 @@ -#include -#include -#include -#include -#include "utils.h" -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -namespace torchao { - -namespace { - -#define BLOCK_N 32 - -#define PER_TENSOR 1 -#define PER_ROW 2 -#define PER_GROUP 3 - -static bool cpublas_checked = false; -static bool cpublas_can_pack = false; - -bool cpublas_could_pack() { - // the could_pack check requires AMX support implicitly - if (cpublas_checked) { - return cpublas_can_pack; - } -#ifdef CPUBLAS_BRGEMM_F8F8F32 - cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn); -#else - cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); -#endif - cpublas_checked = true; - return cpublas_can_pack; -} - -/* -return: packed_weight, packed_scales -*/ -std::tuple -float8_linear_prepack_impl( - const at::Tensor& weight, - const at::Tensor& scales) { - // weight shape = [N, K] - // scales shape = [N, G] - TORCH_CHECK(weight.dim() == 2, - "Float8 linear CPU: Weight should be a 2D tensor for packing"); - int N = weight.size(0); - int K = weight.size(1); - constexpr int block_n = BLOCK_N; - // Case to fall back - if (N % block_n != 0 || K % 32 != 0) { - return std::make_tuple(weight, scales); - } - - auto new_scales = scales; - bool is_per_tensor = new_scales.numel() == 1; - if (new_scales.dim() == 1 && !is_per_tensor) { - new_scales.unsqueeze_(1); - } - new_scales = new_scales.to(at::kFloat); - int G = is_per_tensor ? 1 : new_scales.size(1); - TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); - int group_size = K / G; - int block_k = group_size > 128 ? 128 : group_size; - while (K % block_k != 0) { - block_k /= 2; - } - TORCH_CHECK(block_k > 0 && block_k <= group_size, - "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); - int Nc = N / block_n; - int Kc = K / block_k; - - // Reorder weight to [N/block_n, K/block_k, block_k, block_n] - // Reorder scales to [N/block_n, G, block_n] - auto weight_view = weight.view({Nc, block_n, Kc, block_k}); - at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); - at::Tensor blocked_weight; - at::Tensor blocked_scales = is_per_tensor ? new_scales.view({1}) : new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); - -#if defined(CPU_CAPABILITY_AVX512) - if (cpublas_could_pack()) { -#ifdef CPUBLAS_BRGEMM_F8F8F32 - constexpr int vnni_size = get_vnni_size(); // for fp8 -#else - constexpr int vnni_size = get_vnni_size(); // for bfloat16 -#endif - blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); - auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); - auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); - int64_t num_blocks = Nc * Kc; - at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { - for (const auto i : c10::irange(begin, end)) { - auto in_ptr = weight_ptr + i * block_k * block_n; - auto out_ptr = blocked_weight_ptr + i * block_k * block_n; - - // Reorder weight block to VNNI - // plain shape = [block_k, block_n] - // packed shape = [block_k / VNNI_SIZE, block_n, VNNI_SIZE] viewed as [block_k, block_n] - constexpr int n_group_size = 8; - constexpr int n_group = block_n / n_group_size; // 4 - for (int nb = 0; nb < n_group; ++nb) { - for (int k = 0; k < block_k; k += vnni_size) { - for (int ni = 0; ni < n_group_size; ++ni) { - for (int ki = 0; ki < vnni_size; ++ki) { - int src_idx = nb * n_group_size + ni + (k + ki) * block_n; - int dst_idx = (nb * n_group_size + ni) * vnni_size + k * block_n + ki; - *(out_ptr + dst_idx) = *(in_ptr + src_idx); - } - } - } - } - } - }); - } else -#endif - { - blocked_weight = weight_reordered; - } - - return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales)); -} - -#if defined(CPU_CAPABILITY_AVX512) -// this doesn't handle NaN. -inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { - const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); - - const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); - const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); - const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); - const __m512i nonsign = _mm512_or_si512(exp, mant); - - const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); - const __m512i combined = _mm512_or_si512(nonsign, sign); - - const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); - return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); -} - -static void cvt_f8e4m3_to_bf16( - const at::Float8_e4m3fn* __restrict__ in, - at::BFloat16* out, - int64_t rows, - int64_t cols, - int64_t stride) { - if (stride == cols) { - // A contiguous buffer - size_t len = rows * cols; - size_t i = 0; - for (; i < len; i += 32) { - __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[i]); - __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); - _mm512_storeu_si512((__m512i*)(out + i), (__m512i)bf16_vec); - } - for (; i < len; ++i) { - out[i] = (at::BFloat16)in[i]; - } - } else { - // Non-contiguous. Access each row with stride - TORCH_CHECK(stride > cols); - for (int r = 0; r < rows; ++r) { - size_t i = 0; - size_t vec_len = cols / 32 * 32; - for (; i < vec_len; i += 32) { - __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]); - __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); - _mm512_storeu_si512((__m512i*)(out + r * cols + i), (__m512i)bf16_vec); - } - for (; i < cols; ++i) { - out[r * cols + i] = (at::BFloat16)in[r * stride + i]; - } - } - } -} - - -// accumulate and store result to buffer -// if act/wei are per_group quantized, apply scales -template -static void _accumulate_result( - float* __restrict__ output, - const float* __restrict__ input, - const float* __restrict__ scale_a, - const float* __restrict__ scale_b, - int M, - int ldi, - int ldo, - int ldsa = 1) { - float a_scale, b_scale; - __m512 va_scale; - __m512 vb_scale; - for (int m = 0; m < M; ++m) { - if constexpr (act_quant_mode == PER_GROUP) { - a_scale = *(scale_a + m * ldsa); - va_scale = _mm512_set1_ps(a_scale); - } - constexpr int N_UNROLL = N / 16; - c10::ForcedUnroll{}([&](auto i) { - constexpr int n = i * 16; - __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); - if constexpr (act_quant_mode == PER_GROUP) { - vc_f = _mm512_mul_ps(vc_f, va_scale); - } - if constexpr (wei_quant_mode == PER_GROUP) { - vb_scale = _mm512_loadu_ps(scale_b + n); - vc_f = _mm512_mul_ps(vc_f, vb_scale); - } - if constexpr (accum) { - __m512 vo = _mm512_loadu_ps(output + m * ldo + n); - _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f)); - } else { - _mm512_storeu_ps(output + m * ldo + n, vc_f); - } - }); - constexpr int tail_start = N / 16 * 16; - for (int n = tail_start; n < N; ++n) { - float dq_val = input[m * ldi + n]; - if constexpr (act_quant_mode == PER_GROUP) { - dq_val = dq_val * a_scale; - } - if constexpr (wei_quant_mode == PER_GROUP) { - b_scale = scale_b[n]; - dq_val = dq_val * b_scale; - } - if constexpr (accum) { - output[m * ldo + n] += dq_val; - } else { - output[m * ldo + n] = dq_val; - } - } - } -} - -// Store result to output buffer with dtype conversion -// If act/wei are per_row or per_tensor quantized, apply scales -// If bias is not null, add bias -template -inline void store_out( - const float* y_buf, - out_dtype* c_ptr, - int64_t M, - int64_t lda, - const float* scales_a, - const float* scales_b, - const float* bias) { - float a_scale = 1.0, b_scale = 1.0; - __m512 va_scale, vb_scale; - if constexpr (act_quant_mode == PER_TENSOR) { - a_scale = *scales_a; - } - if constexpr (wei_quant_mode == PER_TENSOR) { - b_scale = *scales_b; - vb_scale = _mm512_set1_ps(b_scale); - } - for (int i = 0; i < M; ++i) { - if constexpr (act_quant_mode == PER_ROW) { - a_scale = *(scales_a + i); - } - if constexpr (act_quant_mode != PER_GROUP) { - va_scale = _mm512_set1_ps(a_scale); - } - constexpr int N_UNROLL = N / 16; - c10::ForcedUnroll{}([&](auto idx) { - constexpr int j = idx * 16; - __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); - __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); - if constexpr (act_quant_mode != PER_GROUP) { - y_vec = _mm512_mul_ps(y_vec, va_scale); - } - if constexpr (wei_quant_mode == PER_ROW) { - vb_scale = _mm512_loadu_ps(scales_b + j); - } - if constexpr (wei_quant_mode != PER_GROUP) { - y_vec = _mm512_mul_ps(y_vec, vb_scale); - } - y_vec = _mm512_add_ps(y_vec, bias_vec); - if constexpr (std::is_same::value) { - _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); - } else if constexpr (std::is_same::value) { - __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); - } else if constexpr (std::is_same::value) { - __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); - } else { - TORCH_CHECK(false, "Unsupported output dtype"); - } - }); - constexpr int tail_start = N / 16 * 16; - for (int j = tail_start; j < N; ++j) { - if constexpr (wei_quant_mode == PER_ROW) { - b_scale = scales_b[j]; - } - c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); - } - } // for M -} - -#else // no AVX512 - -static void cvt_f8e4m3_to_bf16( - const at::Float8_e4m3fn* __restrict__ in, - at::BFloat16* out, - int64_t rows, - int64_t cols, - int64_t stride) { - for (int r = 0; r < rows; ++r) { - for (int c = 0; c < cols; ++c) { - out[r * cols + c] = (at::BFloat16)in[r * stride + c]; - } - } -} - -// Store result to output buffer with dtype conversion -// If act/wei are per_row or per_tensor quantized, apply scales -// If bias is not null, add bias -template -inline void store_out( - const float* y_buf, - out_dtype* c_ptr, - int64_t M, - int64_t lda, - const float* scales_a, - const float* scales_b, - const float* bias) { - float a_scale = 1.0, b_scale = 1.0; - if constexpr (act_quant_mode == PER_TENSOR) { - a_scale = *scales_a; - } - if constexpr (wei_quant_mode == PER_TENSOR) { - b_scale = *scales_b; - } - for (int i = 0; i < M; ++i) { - if constexpr (act_quant_mode == PER_ROW) { - a_scale = *(scales_a + i); - } - for (int j = 0; j < N; ++j) { - if constexpr (wei_quant_mode == PER_ROW) { - b_scale = scales_b[j]; - } - c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); - } - } // for M -} - -#endif // CPU_CAPABILITY_AVX512 - -template -void _micro_gemm( - float* C, - const at::Float8_e4m3fn* A, - const float* scales_a, - const at::Float8_e4m3fn* B, - const float* scales_b, - int64_t M, - int64_t K, - int64_t lda, - int64_t ldc, - int64_t ldsa, - float* ukernel_buf, - at::BFloat16* dqA_buf, - at::BFloat16* dqB_buf) { - // If FP8 brgemm is not available, convert A/B to bf16 for computation - // Compute GEMM fp8 * fp8 -> fp32 (or bf16 * bf16 -> fp32) - // If per_group quant, apply scales. Otherwise, don't apply scales here - // Finally accumulate and store results -#if defined(CPU_CAPABILITY_AVX512) - if constexpr (cpublas_can_pack) { -#ifdef CPUBLAS_BRGEMM_F8F8F32 - at::native::cpublas::brgemm( - M, - N, - K, - lda /*lda*/, - N /*ldb*/, - N /*ldc*/, - false /* add_C */, - A, - B, - ukernel_buf, - true /* is_vnni */); -#else - cvt_f8e4m3_to_bf16(A, dqA_buf, M, K, lda); - cvt_f8e4m3_to_bf16(B, dqB_buf, K, N, N); - at::native::cpublas::brgemm( - M, - N, - K, - K /*lda*/, - N /*ldb*/, - N /*ldc*/, - false /* add_C */, - dqA_buf, - dqB_buf, - ukernel_buf, - true /* is_vnni */); -#endif - _mm_prefetch(B + N * (K + 128), _MM_HINT_T0); - _mm_prefetch(A + K + 128, _MM_HINT_T0); - _accumulate_result( - C, - ukernel_buf, - scales_a, - scales_b, - M, - N /*ldi*/, - ldc, - ldsa); - } else -#endif - { - for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - float sum = 0; - for (int64_t k = 0; k < K; ++k) { - sum += ((float)A[i * lda + k] * (float)B[k * N + j]); - } - if constexpr (act_quant_mode == PER_GROUP) { - sum *= scales_a[i * ldsa]; - } - if constexpr (wei_quant_mode == PER_GROUP) { - sum *= scales_b[j]; - } - C[i * ldc + j] += sum; - } - } - } -} - -template -void _float8_linear_impl( - const at::Tensor& input, - const at::Tensor& input_scales, - const at::Tensor& weight, - const at::Tensor& weight_scales, - const std::optional& bias, - at::Tensor& output) { - // input shape = [..., K] - // input is per token quantized - int64_t K = input.size(-1); - auto input_view = input.view({-1, K}); - int64_t M = input_view.size(0); - - // weight shape = [Nc, Kc, block_k, block_n] - // scales shape = [Nc, G, block_n] - int64_t Nc = weight.size(0); - int64_t Kc = weight.size(1); - int64_t block_k = weight.size(2); - constexpr int64_t block_n = BLOCK_N; - TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); - int64_t N = Nc * block_n; - TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); - auto [parallel_on_M, block_m, Mc, Mc_parallel] = get_m_blocking(M); - int64_t num_parallel_blocks = Mc_parallel * Nc; - - // scales shape = [Nc, G, block_n] - int64_t num_groups = wei_quant_mode == PER_TENSOR ? 1 : weight_scales.size(1); - TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); - int64_t group_size = K / num_groups; - TORCH_CHECK(group_size % block_k == 0, - "Float8 linear: group_size should be divisible by block_k"); - int64_t block_per_group = group_size / block_k; - TORCH_CHECK(!(act_quant_mode == PER_GROUP && wei_quant_mode != PER_GROUP), - "Float8 linear: if activation is per_group quantized, weight must be per_group quantized too"); - TORCH_CHECK(input_scales.numel() == 1 || input_scales.numel() == M || input_scales.numel() == M * num_groups, - "Float8 linear: unexpected input scales shape, scale shape:", input_scales.sizes(), ", M:", M, ", num_groups:", num_groups); - auto ldsa = act_quant_mode == PER_TENSOR ? 0 : act_quant_mode == PER_ROW ? 1 : num_groups; - - const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); - const float* a_scales_ptr = input_scales.data_ptr(); - const at::Float8_e4m3fn* b_ptr = weight.data_ptr(); - const float* b_scales_ptr = weight_scales.data_ptr(); - out_dtype* c_ptr = output.data_ptr(); - const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; - - int64_t block_size = block_m * block_n; - int64_t num_thread = at::get_num_threads(); - at::Tensor y_buffer = at::empty({num_thread, block_size}, output.options().dtype(at::kFloat)); - // Create buffer for brgemm output and dqA/dqB (optional) -#if defined(CPU_CAPABILITY_AVX512) - // buffer for brgemm output in float32 - int64_t buffer_size = block_size * 2; // float32 = bfloat16 * 2 -#ifndef CPUBLAS_BRGEMM_F8F8F32 - // buffers for dqA & dqB in bf16 - buffer_size += (block_k * block_n + block_m * block_k); -#endif - at::Tensor micro_gemm_buffer = at::empty({num_thread, buffer_size}, output.options().dtype(at::kBFloat16)); -#endif - - at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) { - // Get the address of pre-allocated buffers - int tid = at::get_thread_num(); - float* y_buf = y_buffer.data_ptr() + tid * block_size; - at::BFloat16 *dqA_buffer = nullptr, *dqB_buffer = nullptr; - float* ukernel_buf = nullptr; -#if defined(CPU_CAPABILITY_AVX512) - at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr() + tid * buffer_size; - ukernel_buf = reinterpret_cast(micro_gemm_buf); -#ifndef CPUBLAS_BRGEMM_F8F8F32 - dqA_buffer = micro_gemm_buf; - dqB_buffer = micro_gemm_buf + block_m * block_k; - ukernel_buf = reinterpret_cast(micro_gemm_buf + block_m * block_k + block_k * block_n); -#endif -#endif - int64_t mc = 0, nc = 0; - at::native::data_index_init(begin, mc, Mc_parallel, nc, Nc); - for (const auto i : c10::irange(begin, end)) { - (void)i; // Suppress unused variable - int64_t mc_end = parallel_on_M ? mc + 1 : Mc; - - for (int mci = mc; mci < mc_end; ++mci) { - int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; - zero_buffer(y_buf, m_size * block_n); - for (int kci = 0; kci < Kc; ++kci) { - auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; - auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; - _micro_gemm( - /* C */ y_buf, - /* A */ a_ptr + mci * block_m * K + kci * block_k, - /* scales_a */ scales_a, - /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, - /* scales_b */ scales_b, - /* M */ m_size, - /* K */ block_k, - /* lda */ K, - /* ldc */ block_n, - /* ldsa */ ldsa, - /* ukernel_buf */ ukernel_buf, - /* dqA_buf */ dqA_buffer, - /* dqB_buf */ dqB_buffer); - } - // store y_buf to output with dtype conversion - auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : - act_quant_mode == PER_ROW ? a_scales_ptr + mci * block_m : nullptr; - auto scales_b = wei_quant_mode == PER_TENSOR ? b_scales_ptr : - wei_quant_mode == PER_ROW ? b_scales_ptr + nc * block_n : nullptr; - auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; - store_out( - y_buf, - c_ptr + mci * block_m * N + nc * block_n, - m_size, - N /*lda*/, - scales_a, - scales_b, - bias_data); - } - at::native::data_index_step(mc, Mc_parallel, nc, Nc); - } - if constexpr (cpublas_can_pack) { - at::native::cpublas::brgemm_release(); - } - }); -} - -at::Tensor float8_linear_impl( - const at::Tensor& input, - const at::Tensor& input_scales, - const at::Tensor& weight, - const at::Tensor& weight_scales, - const std::optional& bias, - at::ScalarType output_dtype) { - int64_t N = weight.dim() == 4 ? weight.size(0) * weight.size(-1) : weight.size(0); - int act_quant_mode = input_scales.numel() == 1 ? PER_TENSOR : - input_scales.numel() == input.numel() / input.size(-1) ? PER_ROW : - PER_GROUP; - int wei_quant_mode = weight_scales.numel() == 1 ? PER_TENSOR : - weight_scales.numel() == N ? PER_ROW : - PER_GROUP; - // Case to fall back - if (weight.dim() == 2) { - TORCH_CHECK(act_quant_mode != PER_GROUP && wei_quant_mode != PER_GROUP, - "FP8 linear: Per-group quantization is not supported in the fallback path"); - auto y_fp32 = at::linear( - input.to(at::kFloat).mul_(input_scales), - weight.to(at::kFloat).mul_(weight_scales), - bias); - return y_fp32.to(output_dtype); - } - - static bool cpublas_can_pack = cpublas_could_pack(); - auto out_sizes = input.sizes().vec(); - out_sizes.back() = N; - auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); - - AT_DISPATCH_LINEAR_KERNEL(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode, [&](){ - _float8_linear_impl( - input, - input_scales, - weight, - weight_scales, - bias, - output); - }); - return output; -} - -} // anonymous namespace - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::float8_linear_prepack_cpu", &float8_linear_prepack_impl); - m.impl("torchao::float8_linear_cpu", &float8_linear_impl); -} - -} // namespace torchao diff --git a/torchao/csrc/cpu/aten_kernels/utils.h b/torchao/csrc/cpu/aten_kernels/utils.h deleted file mode 100644 index 0a238a316e..0000000000 --- a/torchao/csrc/cpu/aten_kernels/utils.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include -#include - -int64_t get_m_block(int64_t M) { - if (M <= 48) { - return M; - } else if (M < 64) { - return 32; - } else if (M < 96) { - return 64; - } else { - return 128; - } -} - -std::tuple -get_m_blocking(int64_t M) { - bool parallel_on_M = M > 128; - int64_t block_m = get_m_block(M); - int64_t Mc = (M + block_m - 1) / block_m; - int64_t Mc_parallel = parallel_on_M ? Mc : 1; - return std::make_tuple(parallel_on_M, block_m, Mc, Mc_parallel); -} - -#if defined(CPU_CAPABILITY_AVX512) -template -void zero_buffer(T* data, int64_t size) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto zero_vec = at::vec::Vectorized(0); - int64_t d = 0; - for (; d < size - (size % vec_size); d += vec_size) { - zero_vec.store(data + d); - } - if (d < size) { - zero_vec.store(data + d, size - d); - } -} -#else -template -void zero_buffer(T* data, int64_t size) { - memset(data, 0, sizeof(T) * size); -} -#endif - -template struct vnni_traits; -template <> struct vnni_traits { static constexpr int size = 2; }; -template <> struct vnni_traits { static constexpr int size = 2; }; -template <> struct vnni_traits { static constexpr int size = 4; }; -template <> struct vnni_traits { static constexpr int size = 4; }; - -template constexpr int get_vnni_size() { return vnni_traits::size; } - - -// Utilities for dispatch -#define AT_DISPATCH_OUT_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Float, out_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::BFloat16, out_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Half, out_t, __VA_ARGS__)) - -#define AT_DISPATCH_CASE_ENUM(VALUE, TYPE, HINT, ...) \ - case VALUE: { \ - constexpr TYPE HINT = VALUE; \ - __VA_ARGS__; \ - break; \ - } - -#define AT_DISPATCH_BOOL(VALUE, NAME, HINT, ...) \ - [&]() { \ - switch (VALUE) { \ - AT_DISPATCH_CASE_ENUM(true, bool, HINT, __VA_ARGS__) \ - AT_DISPATCH_CASE_ENUM(false, bool, HINT, __VA_ARGS__) \ - } \ - }() - -#define AT_DISPATCH_QUANT_MODE(MODE, NAME, HINT, ...) \ - [&]() { \ - switch (MODE) { \ - AT_DISPATCH_CASE_ENUM(PER_TENSOR, int, HINT, __VA_ARGS__) \ - AT_DISPATCH_CASE_ENUM(PER_ROW, int, HINT, __VA_ARGS__) \ - AT_DISPATCH_CASE_ENUM(PER_GROUP, int, HINT, __VA_ARGS__) \ - } \ - }() - -#define AT_DISPATCH_LINEAR_KERNEL(OUT_DTYPE, CAN_PACK, A_QUANT_MODE, B_QUANT_MODE, ...) \ - AT_DISPATCH_BOOL( \ - CAN_PACK, "cpublas_can_pack", can_pack, \ - AT_DISPATCH_QUANT_MODE( \ - A_QUANT_MODE, "act_quant_mode", a_quant_mode, \ - AT_DISPATCH_QUANT_MODE( \ - B_QUANT_MODE, "wei_quant_mode", b_quant_mode, \ - AT_DISPATCH_OUT_TYPES( \ - OUT_DTYPE, "out_dtype", __VA_ARGS__ \ - ) \ - ) \ - ) \ - ) diff --git a/torchao/ops.py b/torchao/ops.py index f4191d60b5..b6348f90a5 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,12 +71,6 @@ lib.define( "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" ) -lib.define( - "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" -) -lib.define( - "float8_linear_cpu(Tensor input, Tensor input_scales, Tensor weight, Tensor weight_scales, Tensor? bias, ScalarType output_dtype) -> Tensor" -) def register_custom_op(name): @@ -1122,71 +1116,4 @@ def _( # Only support include_last_offset == True assert include_last_offset == True batch_size = offsets.shape[0] - 1 - # Only support out_dtype == torch.float32 - # Next setp: support more out_dtype - out_dtype = torch.float32 - return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype) - - -def float8_linear_prepack_cpu( - weight: Tensor, - scales: Tensor, -) -> Tensor: - """ - Prepack weights for float8 linear operator on CPU. - Args: - weight: weight tensor. - scales: scales for weight tensor. - Returns: - packed weight, packed scales - """ - return torch.ops.torchao.float8_linear_prepack_cpu.default(weight, scales) - - -@register_custom_op("torchao::float8_linear_prepack_cpu") -def _(weight: Tensor, scales: Tensor) -> Tensor: - return weight, scales - - -def float8_linear_cpu( - input: Tensor, - input_scales: Tensor, - weight: Tensor, - weight_scales: Tensor, - bias: Optional[Tensor], - out_dtype: torch.dtype, -): - """ - float8 linear operator on CPU. - Args: - input: input tensor. - input_scales: scales for input tensor. - weight: weight tensor. - weight_scales: scales for weight tensor. - bias: optional bias tensor. - out_dtype: output data type. - Returns: - output tensor in out_dtype. - """ - return torch.ops.torchao.float8_linear_cpu.default( - input, - input_scales, - weight, - weight_scales, - bias, - out_dtype, - ) - - -@register_custom_op("torchao::float8_linear_cpu") -def _( - input: Tensor, - input_scales: Tensor, - weight: Tensor, - weight_scales: Tensor, - bias: Optional[Tensor], - out_dtype: torch.dtype, -) -> Tensor: - assert weight.dim() in (2, 4) - N = weight.size(0) * weight.size(3) if weight.dim() == 4 else weight.size(0) - return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) + return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)