diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 63ebb96cfaa..66a585844cf 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl( const size_t ndim = original_x_size.size(); const int64_t out_features = weights_4x2.size(0); const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); - const at::Tensor packed_weights = - at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles); - at::Tensor out = at::_weight_int4pack_mm( - x_flattened, packed_weights, groupsize, scales_and_zeros); + at::Tensor out = at::_weight_int4pack_mm_for_cpu( + x_flattened, weights_4x2, groupsize, scales_and_zeros); std::vector out_shape( original_x_size.begin(), original_x_size.end()); out_shape.at(ndim - 1) = out_features; return out.reshape(out_shape); } +at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_unpacked = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt)); + + const int64_t N = weights_unpacked.size(0); + const int64_t K = weights_unpacked.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + weights_unpacked[n][k] = int(first_val); + weights_unpacked[n][k + 1] = int(second_val); + } + } + + return weights_unpacked; +} + at::Tensor dequantize_and_linear( const at::Tensor& x, const at::Tensor& weights_4x2, @@ -91,13 +113,18 @@ void test_reference_linear_int4( at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); at::Tensor weights_4x2 = at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); const int k_groups = K / group_size; at::Tensor scales_and_zeros = at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); at::Tensor out = linear_weight_int4_reference_impl( - x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + x, + at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size), + group_size, + scales_and_zeros, + inner_k_tiles); at::Tensor out_ref = dequantize_and_linear( x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);