From c9175680dd5b64bdb151bc48ed5330b6d57f9dce Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 17 Jan 2025 12:57:10 -0800 Subject: [PATCH 1/2] [ET-VK][ez] Fix linear weight int4 test due to change in ATen API ## Context Recently the ATen API for 4-bit quantized linear has changed, so our test must adapt to the change in API. Concretely, the changes in API were: * The `_for_cpu` suffix was added to the operator name * The `_convert_weight_to_int4pack_mm` operator now expects unpacked 4-bit weights instead of a packed scheme where 2 4-bit values are packed into a single 8-bit value. Differential Revision: [D68333687](https://our.internmc.facebook.com/intern/diff/D68333687/) [ghstack-poisoned] --- .../test/op_tests/linear_weight_int4_test.cpp | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 63ebb96cfaa..9060e8806b4 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -30,16 +30,39 @@ at::Tensor linear_weight_int4_reference_impl( const size_t ndim = original_x_size.size(); const int64_t out_features = weights_4x2.size(0); const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); - const at::Tensor packed_weights = - at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles); - at::Tensor out = at::_weight_int4pack_mm( - x_flattened, packed_weights, groupsize, scales_and_zeros); + at::Tensor out = at::_weight_int4pack_mm_for_cpu( + x_flattened, weights_4x2, groupsize, scales_and_zeros); std::vector out_shape( original_x_size.begin(), original_x_size.end()); out_shape.at(ndim - 1) = out_features; return out.reshape(out_shape); } +at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_unpacked = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt)); + + const int64_t N = weights_unpacked.size(0); + const int64_t K = weights_unpacked.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + // const int scale_idx = k_groups * n + group_idx; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + weights_unpacked[n][k] = int(first_val); + weights_unpacked[n][k + 1] = int(second_val); + } + } + + return weights_unpacked; +} + at::Tensor dequantize_and_linear( const at::Tensor& x, const at::Tensor& weights_4x2, @@ -91,13 +114,18 @@ void test_reference_linear_int4( at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); at::Tensor weights_4x2 = at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); const int k_groups = K / group_size; at::Tensor scales_and_zeros = at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); at::Tensor out = linear_weight_int4_reference_impl( - x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + x, + at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size), + group_size, + scales_and_zeros, + inner_k_tiles); at::Tensor out_ref = dequantize_and_linear( x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); From b9c38a2352779155affc8c0712d8f443cf96cc45 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 17 Jan 2025 13:00:11 -0800 Subject: [PATCH 2/2] Update on "[ET-VK][ez] Fix linear weight int4 test due to change in ATen API" ## Context Recently the ATen API for 4-bit quantized linear has changed, so our test must adapt to the change in API. Concretely, the changes in API were: * The `_for_cpu` suffix was added to the operator name * The `_convert_weight_to_int4pack_mm` operator now expects unpacked 4-bit weights instead of a packed scheme where 2 4-bit values are packed into a single 8-bit value. Differential Revision: [D68333687](https://our.internmc.facebook.com/intern/diff/D68333687/) [ghstack-poisoned] --- backends/vulkan/test/op_tests/linear_weight_int4_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 9060e8806b4..66a585844cf 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -50,7 +50,6 @@ at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { for (int n = 0; n < N; n++) { for (int k = 0; k < K; k += 2) { - // const int scale_idx = k_groups * n + group_idx; const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); const uint8_t second_val = packed_val & 0x0F; const uint8_t first_val = (packed_val & 0xF0) >> 4;