diff --git a/test/test_ops.py b/test/test_ops.py index c55c2cdb81..11ad59db27 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -40,7 +40,14 @@ except RuntimeError: pytest.skip("torchao.ops not available") +from torchao.quantization import PerGroup, PerRow, PerTensor +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _dequantize_affine_float8, + _quantize_affine_float8, +) from torchao.quantization.utils import ( + get_block_size, get_groupwise_affine_qparams, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -901,5 +908,91 @@ def _test_scaled_embedding_bag_cpu_helper( torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) +@pytest.mark.skipif( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", +) +@pytest.mark.parametrize( + "multi_hot, batch_size, vector_size, index_type", + EMBEDINGBAG_TEST_PARAMS, + ids=str, +) +def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type): + _test_scaled_embedding_bag_cpu_helper( + multi_hot, batch_size, vector_size, index_type, torch.int8 + ) + + +@pytest.mark.skipif( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", +) +@pytest.mark.parametrize( + "multi_hot, batch_size, vector_size, index_type", + EMBEDINGBAG_TEST_PARAMS, + ids=str, +) +def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type): + _test_scaled_embedding_bag_cpu_helper( + multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn + ) + + +@pytest.mark.skipif( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu") + or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", +) +@pytest.mark.skipif( + not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+" +) +@pytest.mark.parametrize("shape", [(64, 64), (256, 256)]) +@pytest.mark.parametrize("bs", [1, 160]) +@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)]) +@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)]) +def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity): + in_feature, out_feature = shape + if isinstance(x_granularity, PerGroup): + if x_granularity.group_size >= in_feature: + return + if not isinstance(w_granularity, PerGroup): + return + if isinstance(w_granularity, PerGroup): + if w_granularity.group_size >= in_feature: + return + m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval() + b = m.bias + x = torch.randn(bs, in_feature) + x_block_size = get_block_size(x.shape, x_granularity) + x_scale = _choose_scale_float8( + x, + float8_dtype=torch.float8_e4m3fn, + block_size=x_block_size, + ) + x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn) + + w = m.weight.detach() + w_block_size = get_block_size(w.shape, w_granularity) + w_scale = _choose_scale_float8( + w, + float8_dtype=torch.float8_e4m3fn, + block_size=w_block_size, + ) + w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn) + + x_dq = _dequantize_affine_float8(x_fp8, x_scale) + w_dq = _dequantize_affine_float8(w_fp8, w_scale) + ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype) + + packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale) + y = torch.ops.torchao.float8_linear_cpu( + x_fp8, x_scale, packed_w, packed_scale, b, out_dtype + ) + + torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp new file mode 100644 index 0000000000..e9b290771d --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -0,0 +1,607 @@ +#include +#include +#include +#include +#include "utils.h" +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace torchao { + +namespace { + +#define BLOCK_N 32 + +#define PER_TENSOR 1 +#define PER_ROW 2 +#define PER_GROUP 3 + +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; + +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; + } +#ifdef CPUBLAS_BRGEMM_F8F8F32 + cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn); +#else + cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); +#endif + cpublas_checked = true; + return cpublas_can_pack; +} + +/* +return: packed_weight, packed_scales +*/ +std::tuple +float8_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales) { + // weight shape = [N, K] + // scales shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "Float8 linear CPU: Weight should be a 2D tensor for packing"); + int N = weight.size(0); + int K = weight.size(1); + constexpr int block_n = BLOCK_N; + // Case to fall back + if (N % block_n != 0 || K % 32 != 0) { + return std::make_tuple(weight, scales); + } + + auto new_scales = scales; + bool is_per_tensor = new_scales.numel() == 1; + if (new_scales.dim() == 1 && !is_per_tensor) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + int G = is_per_tensor ? 1 : new_scales.size(1); + TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + while (K % block_k != 0) { + block_k /= 2; + } + TORCH_CHECK(block_k > 0 && block_k <= group_size, + "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = is_per_tensor ? new_scales.view({1}) : new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + +#if defined(CPU_CAPABILITY_AVX512) + if (cpublas_could_pack()) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + constexpr int vnni_size = get_vnni_size(); // for fp8 +#else + constexpr int vnni_size = get_vnni_size(); // for bfloat16 +#endif + blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); + auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); + auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n; + + // Reorder weight block to VNNI + // plain shape = [block_k, block_n] + // packed shape = [block_k / VNNI_SIZE, block_n, VNNI_SIZE] viewed as [block_k, block_n] + constexpr int n_group_size = 8; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; ++nb) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx = nb * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb * n_group_size + ni) * vnni_size + k * block_n + ki; + *(out_ptr + dst_idx) = *(in_ptr + src_idx); + } + } + } + } + } + }); + } else +#endif + { + blocked_weight = weight_reordered; + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales)); +} + +#if defined(CPU_CAPABILITY_AVX512) +// this doesn't handle NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + if (stride == cols) { + // A contiguous buffer + size_t len = rows * cols; + size_t i = 0; + for (; i < len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + i), (__m512i)bf16_vec); + } + for (; i < len; ++i) { + out[i] = (at::BFloat16)in[i]; + } + } else { + // Non-contiguous. Access each row with stride + TORCH_CHECK(stride > cols); + for (int r = 0; r < rows; ++r) { + size_t i = 0; + size_t vec_len = cols / 32 * 32; + for (; i < vec_len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + r * cols + i), (__m512i)bf16_vec); + } + for (; i < cols; ++i) { + out[r * cols + i] = (at::BFloat16)in[r * stride + i]; + } + } + } +} + + +// accumulate and store result to buffer +// if act/wei are per_group quantized, apply scales +template +static void _accumulate_result( + float* __restrict__ output, + const float* __restrict__ input, + const float* __restrict__ scale_a, + const float* __restrict__ scale_b, + int M, + int ldi, + int ldo, + int ldsa = 1) { + float a_scale, b_scale; + __m512 va_scale; + __m512 vb_scale; + for (int m = 0; m < M; ++m) { + if constexpr (act_quant_mode == PER_GROUP) { + a_scale = *(scale_a + m * ldsa); + va_scale = _mm512_set1_ps(a_scale); + } + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto i) { + constexpr int n = i * 16; + __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); + if constexpr (act_quant_mode == PER_GROUP) { + vc_f = _mm512_mul_ps(vc_f, va_scale); + } + if constexpr (wei_quant_mode == PER_GROUP) { + vb_scale = _mm512_loadu_ps(scale_b + n); + vc_f = _mm512_mul_ps(vc_f, vb_scale); + } + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f); + } + }); + constexpr int tail_start = N / 16 * 16; + for (int n = tail_start; n < N; ++n) { + float dq_val = input[m * ldi + n]; + if constexpr (act_quant_mode == PER_GROUP) { + dq_val = dq_val * a_scale; + } + if constexpr (wei_quant_mode == PER_GROUP) { + b_scale = scale_b[n]; + dq_val = dq_val * b_scale; + } + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + __m512 va_scale, vb_scale; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + vb_scale = _mm512_set1_ps(b_scale); + } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + if constexpr (act_quant_mode != PER_GROUP) { + va_scale = _mm512_set1_ps(a_scale); + } + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto idx) { + constexpr int j = idx * 16; + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); + if constexpr (act_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, va_scale); + } + if constexpr (wei_quant_mode == PER_ROW) { + vb_scale = _mm512_loadu_ps(scales_b + j); + } + if constexpr (wei_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, vb_scale); + } + y_vec = _mm512_add_ps(y_vec, bias_vec); + if constexpr (std::is_same::value) { + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } else if constexpr (std::is_same::value) { + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } else if constexpr (std::is_same::value) { + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + }); + constexpr int tail_start = N / 16 * 16; + for (int j = tail_start; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M +} + +#else // no AVX512 + +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + out[r * cols + c] = (at::BFloat16)in[r * stride + c]; + } + } +} + +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + for (int j = 0; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M +} + +#endif // CPU_CAPABILITY_AVX512 + +template +void _micro_gemm( + float* C, + const at::Float8_e4m3fn* A, + const float* scales_a, + const at::Float8_e4m3fn* B, + const float* scales_b, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc, + int64_t ldsa, + float* ukernel_buf, + at::BFloat16* dqA_buf, + at::BFloat16* dqB_buf) { + // If FP8 brgemm is not available, convert A/B to bf16 for computation + // Compute GEMM fp8 * fp8 -> fp32 (or bf16 * bf16 -> fp32) + // If per_group quant, apply scales. Otherwise, don't apply scales here + // Finally accumulate and store results +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (cpublas_can_pack) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + at::native::cpublas::brgemm( + M, + N, + K, + lda /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + A, + B, + ukernel_buf, + true /* is_vnni */); +#else + cvt_f8e4m3_to_bf16(A, dqA_buf, M, K, lda); + cvt_f8e4m3_to_bf16(B, dqB_buf, K, N, N); + at::native::cpublas::brgemm( + M, + N, + K, + K /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + dqA_buf, + dqB_buf, + ukernel_buf, + true /* is_vnni */); +#endif + _mm_prefetch(B + N * (K + 128), _MM_HINT_T0); + _mm_prefetch(A + K + 128, _MM_HINT_T0); + _accumulate_result( + C, + ukernel_buf, + scales_a, + scales_b, + M, + N /*ldi*/, + ldc, + ldsa); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += ((float)A[i * lda + k] * (float)B[k * N + j]); + } + if constexpr (act_quant_mode == PER_GROUP) { + sum *= scales_a[i * ldsa]; + } + if constexpr (wei_quant_mode == PER_GROUP) { + sum *= scales_b[j]; + } + C[i * ldc + j] += sum; + } + } + } +} + +template +void _float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + + // weight shape = [Nc, Kc, block_k, block_n] + // scales shape = [Nc, G, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); + auto [parallel_on_M, block_m, Mc, Mc_parallel] = get_m_blocking(M); + int64_t num_parallel_blocks = Mc_parallel * Nc; + + // scales shape = [Nc, G, block_n] + int64_t num_groups = wei_quant_mode == PER_TENSOR ? 1 : weight_scales.size(1); + TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "Float8 linear: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + TORCH_CHECK(!(act_quant_mode == PER_GROUP && wei_quant_mode != PER_GROUP), + "Float8 linear: if activation is per_group quantized, weight must be per_group quantized too"); + TORCH_CHECK(input_scales.numel() == 1 || input_scales.numel() == M || input_scales.numel() == M * num_groups, + "Float8 linear: unexpected input scales shape, scale shape:", input_scales.sizes(), ", M:", M, ", num_groups:", num_groups); + auto ldsa = act_quant_mode == PER_TENSOR ? 0 : act_quant_mode == PER_ROW ? 1 : num_groups; + + const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const at::Float8_e4m3fn* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + int64_t block_size = block_m * block_n; + int64_t num_thread = at::get_num_threads(); + at::Tensor y_buffer = at::empty({num_thread, block_size}, output.options().dtype(at::kFloat)); + // Create buffer for brgemm output and dqA/dqB (optional) +#if defined(CPU_CAPABILITY_AVX512) + // buffer for brgemm output in float32 + int64_t buffer_size = block_size * 2; // float32 = bfloat16 * 2 +#ifndef CPUBLAS_BRGEMM_F8F8F32 + // buffers for dqA & dqB in bf16 + buffer_size += (block_k * block_n + block_m * block_k); +#endif + at::Tensor micro_gemm_buffer = at::empty({num_thread, buffer_size}, output.options().dtype(at::kBFloat16)); +#endif + + at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) { + // Get the address of pre-allocated buffers + int tid = at::get_thread_num(); + float* y_buf = y_buffer.data_ptr() + tid * block_size; + at::BFloat16 *dqA_buffer = nullptr, *dqB_buffer = nullptr; + float* ukernel_buf = nullptr; +#if defined(CPU_CAPABILITY_AVX512) + at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr() + tid * buffer_size; + ukernel_buf = reinterpret_cast(micro_gemm_buf); +#ifndef CPUBLAS_BRGEMM_F8F8F32 + dqA_buffer = micro_gemm_buf; + dqB_buffer = micro_gemm_buf + block_m * block_k; + ukernel_buf = reinterpret_cast(micro_gemm_buf + block_m * block_k + block_k * block_n); +#endif +#endif + int64_t mc = 0, nc = 0; + at::native::data_index_init(begin, mc, Mc_parallel, nc, Nc); + for (const auto i : c10::irange(begin, end)) { + (void)i; // Suppress unused variable + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + zero_buffer(y_buf, m_size * block_n); + for (int kci = 0; kci < Kc; ++kci) { + auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; + auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; + _micro_gemm( + /* C */ y_buf, + /* A */ a_ptr + mci * block_m * K + kci * block_k, + /* scales_a */ scales_a, + /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, + /* scales_b */ scales_b, + /* M */ m_size, + /* K */ block_k, + /* lda */ K, + /* ldc */ block_n, + /* ldsa */ ldsa, + /* ukernel_buf */ ukernel_buf, + /* dqA_buf */ dqA_buffer, + /* dqB_buf */ dqB_buffer); + } + // store y_buf to output with dtype conversion + auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : + act_quant_mode == PER_ROW ? a_scales_ptr + mci * block_m : nullptr; + auto scales_b = wei_quant_mode == PER_TENSOR ? b_scales_ptr : + wei_quant_mode == PER_ROW ? b_scales_ptr + nc * block_n : nullptr; + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + store_out( + y_buf, + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/, + scales_a, + scales_b, + bias_data); + } + at::native::data_index_step(mc, Mc_parallel, nc, Nc); + } + if constexpr (cpublas_can_pack) { + at::native::cpublas::brgemm_release(); + } + }); +} + +at::Tensor float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::ScalarType output_dtype) { + int64_t N = weight.dim() == 4 ? weight.size(0) * weight.size(-1) : weight.size(0); + int act_quant_mode = input_scales.numel() == 1 ? PER_TENSOR : + input_scales.numel() == input.numel() / input.size(-1) ? PER_ROW : + PER_GROUP; + int wei_quant_mode = weight_scales.numel() == 1 ? PER_TENSOR : + weight_scales.numel() == N ? PER_ROW : + PER_GROUP; + // Case to fall back + if (weight.dim() == 2) { + TORCH_CHECK(act_quant_mode != PER_GROUP && wei_quant_mode != PER_GROUP, + "FP8 linear: Per-group quantization is not supported in the fallback path"); + auto y_fp32 = at::linear( + input.to(at::kFloat).mul_(input_scales), + weight.to(at::kFloat).mul_(weight_scales), + bias); + return y_fp32.to(output_dtype); + } + + static bool cpublas_can_pack = cpublas_could_pack(); + auto out_sizes = input.sizes().vec(); + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + + AT_DISPATCH_LINEAR_KERNEL(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode, [&](){ + _float8_linear_impl( + input, + input_scales, + weight, + weight_scales, + bias, + output); + }); + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::float8_linear_prepack_cpu", &float8_linear_prepack_impl); + m.impl("torchao::float8_linear_cpu", &float8_linear_impl); +} + +} // namespace torchao diff --git a/torchao/csrc/cpu/aten_kernels/utils.h b/torchao/csrc/cpu/aten_kernels/utils.h new file mode 100644 index 0000000000..16a54f2abb --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/utils.h @@ -0,0 +1,114 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +int64_t get_m_block(int64_t M) { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } +} + +std::tuple +get_m_blocking(int64_t M) { + bool parallel_on_M = M > 128; + int64_t block_m = get_m_block(M); + int64_t Mc = (M + block_m - 1) / block_m; + int64_t Mc_parallel = parallel_on_M ? Mc : 1; + return std::make_tuple(parallel_on_M, block_m, Mc, Mc_parallel); +} + +#if defined(CPU_CAPABILITY_AVX512) +template +void zero_buffer(T* data, int64_t size) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto zero_vec = at::vec::Vectorized(0); + int64_t d = 0; + for (; d < size - (size % vec_size); d += vec_size) { + zero_vec.store(data + d); + } + if (d < size) { + zero_vec.store(data + d, size - d); + } +} +#else +template +void zero_buffer(T* data, int64_t size) { + memset(data, 0, sizeof(T) * size); +} +#endif + +template struct vnni_traits; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 4; }; +template <> struct vnni_traits { static constexpr int size = 4; }; + +template constexpr int get_vnni_size() { return vnni_traits::size; } + + +// Utilities for dispatch +#define AT_DISPATCH_OUT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, out_t, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ENUM(VALUE, TYPE, HINT, ...) \ + case VALUE: { \ + constexpr TYPE HINT = VALUE; \ + __VA_ARGS__; \ + break; \ + } + +#define AT_DISPATCH_BOOL(VALUE, NAME, HINT, ...) \ + [&]() { \ + if (VALUE) { \ + constexpr bool HINT = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool HINT = false; \ + __VA_ARGS__; \ + } \ + }() + +#define AT_DISPATCH_QUANT_MODE(MODE, NAME, HINT, ...) \ + [&]() { \ + switch (MODE) { \ + AT_DISPATCH_CASE_ENUM(PER_TENSOR, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_ROW, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_GROUP, int, HINT, __VA_ARGS__) \ + } \ + }() + +#define AT_DISPATCH_LINEAR_KERNEL(OUT_DTYPE, CAN_PACK, A_QUANT_MODE, B_QUANT_MODE, ...) \ + AT_DISPATCH_BOOL( \ + CAN_PACK, "cpublas_can_pack", can_pack, \ + AT_DISPATCH_QUANT_MODE( \ + A_QUANT_MODE, "act_quant_mode", a_quant_mode, \ + AT_DISPATCH_QUANT_MODE( \ + B_QUANT_MODE, "wei_quant_mode", b_quant_mode, \ + AT_DISPATCH_OUT_TYPES( \ + OUT_DTYPE, "out_dtype", __VA_ARGS__ \ + ) \ + ) \ + ) \ + ) diff --git a/torchao/ops.py b/torchao/ops.py index b6348f90a5..f4191d60b5 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,6 +71,12 @@ lib.define( "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" ) +lib.define( + "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" +) +lib.define( + "float8_linear_cpu(Tensor input, Tensor input_scales, Tensor weight, Tensor weight_scales, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -1116,4 +1122,71 @@ def _( # Only support include_last_offset == True assert include_last_offset == True batch_size = offsets.shape[0] - 1 - return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype) + # Only support out_dtype == torch.float32 + # Next setp: support more out_dtype + out_dtype = torch.float32 + return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype) + + +def float8_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, +) -> Tensor: + """ + Prepack weights for float8 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + Returns: + packed weight, packed scales + """ + return torch.ops.torchao.float8_linear_prepack_cpu.default(weight, scales) + + +@register_custom_op("torchao::float8_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor) -> Tensor: + return weight, scales + + +def float8_linear_cpu( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + float8 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.float8_linear_cpu.default( + input, + input_scales, + weight, + weight_scales, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::float8_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + assert weight.dim() in (2, 4) + N = weight.size(0) * weight.size(3) if weight.dim() == 4 else weight.size(0) + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype)