Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking #50590

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
87d5823
[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking
kimishpatel Jan 15, 2021
4c08d84
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 26, 2021
3c6c733
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 26, 2021
a987c75
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 26, 2021
7cb56ca
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 26, 2021
8706cbc
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 27, 2021
a53bbc1
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 29, 2021
1b1b85f
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 29, 2021
9f8c6da
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 29, 2021
5fa35ac
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 29, 2021
296fb45
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 30, 2021
1dedd14
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 30, 2021
d3db2b5
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Jan 31, 2021
7f9dede
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Feb 1, 2021
d13e846
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Feb 1, 2021
5fb4300
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Feb 1, 2021
bfb2033
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Feb 2, 2021
f212e6c
Update on "[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking"
kimishpatel Feb 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ set(PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS
src/q8gemm/4x8-dq-aarch32-neon.S
src/q8gemm/4x8c2-xzp-aarch32-neon.S
src/q8gemm_sparse/8x4-packA-aarch32-neon.S
src/q8gemm_sparse/4x4-packA-aarch32-neon.S
src/q8gemm_sparse/8x4c1x4-dq-packedA-aarch32-neon.S
src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S
src/q8gemm_sparse/8x4c1x4-dq-aarch32-neon.S)

set(PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS
Expand Down
50 changes: 47 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,9 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon(
mrr,
kc(),
a(),
a() + m * kc(),
kc() * sizeof(uint8_t),
a_packed.data()
a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
);
}
}
Expand All @@ -473,7 +473,7 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon(
mrr,
nrr,
a_packed.data(),
a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
bcsr_matrix_->values.data(),
bcsr_matrix_->row_values.data(),
bcsr_matrix_->col_indices.data(),
Expand All @@ -489,6 +489,50 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon)
->Apply(SparseGEMMBenchGemmArguments);

BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon, 4, 8, 4)
(benchmark::State& state) {
for (auto _ : state) {
auto m_blocks = (mc() + mr() - 1) / mr();
auto k_blocks = (kc() + 4 - 1) / 4;
std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0, channel_offset = 0; n < nc();
n += nr(), channel_offset += nr()) {
const uint32_t nrr = min(nc() - n, nr());
pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
mrr,
kc(),
a() + m * kc(),
kc() * sizeof(uint8_t),
a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
);
}
}
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0, channel_offset = 0; n < nc();
n += nr(), channel_offset += nr()) {
const uint32_t nrr = min(nc() - n, nr());
pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon(
mrr,
nrr,
a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
bcsr_matrix_->values.data(),
bcsr_matrix_->row_values.data(),
bcsr_matrix_->col_indices.data(),
b() + n,
c() + m * nc() + n,
nc(),
channel_offset,
quantizationParams());
}
}
}
}
BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon)
->Apply(SparseGEMMBenchGemmArguments);

#endif

#if CPUINFO_ARCH_ARM64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
fully_connected->dynamic_conv_quantization_params.multipliers =
requantization_scales;

if (use_prepack_kernel) {
if (use_prepack_kernel || (pytorch_qnnp_params.q8gemm_sparse.gemm_dq == NULL)) {
fully_connected->ukernel_type =
pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq;
} else {
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ static void init(void) {
.nr = 8,
.kr = 1,
};
/*
* Current q8gemm-sparse-bench shows that for the shapes benchmarked
* 4x8 kernels are better than 8x4
* Without prepacking it always worse than dense
* all benchmarking at 70% sparsity
*/
/*
pytorch_qnnp_params.q8gemm_sparse = (struct pytorch_q8gemm_sparse_parameters){
.gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4__aarch32_neon,
.packedA_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon,
Expand All @@ -68,6 +75,16 @@ static void init(void) {
.kr = 4,
.log2_mr = 3,
};
*/
pytorch_qnnp_params.q8gemm_sparse = (struct pytorch_q8gemm_sparse_parameters){
.gemm_dq = NULL,
.packedA_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon,
.packA = pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon,
.mr = 4,
.nr = 8,
.kr = 4,
.log2_mr = 2,
};
#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
.gemm = pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <qnnpack/assembly.h>
#include <requantization/runtime-assembly.h>

# r0 mr
# r1 k
# r2 a
# r3 a_stride

.syntax unified

# Args passed via stack.
# TOS
# |----------------|
# |packed_a | 0
# |----------------|
#

# After loading w pointer in ip reg.
# And after pushing r4-r9 and d8-d15 on stack
# |----------------|
# |r4 - r11 | 0
# |packed_a | 32
# |----------------|
#

# Packed A format.
# 4kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory.
# Original A
# --------- K ----------- -- (K + 4 - 1) / 4 --
# | | | |
# | | (M + 4 - 1)/4 |
# | | Packed | |
# M | => |-------------------|
# | | Thus Packed A has (K + 4 - 1)/4 * (M + 4 -1)/4 blocks
# | |
# |---------------------|
#
# Each 4 x 4 blocks is transposed and stored.
# Each of the (K + 4 - 1)/4 blocks for a given group of 4 m blocks
# are stored adjacent in memory
# Thus, each block:
# |----4m-----|----4m-----|
# 4k | | ..... (K + 4 - 1)/4 blocks
# |-----------|-----------|
# This locality helps in loading 8kx4m blocks of activations
# Note when M is not multiple of 4, the rest can contain arbitrary
# data in packed A as we will not be writing those out.
# This wil be taken care by just copying the appropriate valid data

# void pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
# size_t mr,
# size_t K,
# const uint8_t* a,
# size_t a_stride,
# uint8_t* packed_a,
BEGIN_FUNCTION pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon
.arm
#ifndef __APPLE__
.arch armv7-a
.fpu neon
#endif

PUSH {r4, r5, r6, r7, r8, r9, r10, r11}

# r4 = a0 = a pointer
MOV r4, r2
# r2 = packed_a pointer
LDR r2, [sp, 32]

CMP r0, 2
# r5 = a1
ADD r5, r4, r3
MOVLO r5, r4

# r6 = a2
ADD r6, r5, r3
MOVLS r6, r5

CMP r0, 4
# r7 = a3
ADD r7, r6, r3
MOVNE r7, r6

# num_k_blocks = (k + (4 - 1)) / 4
ADD r1, r1, 3
LSR r1, r1, 2

SUBS r1, r1, 2
BLO 1f

.p2align 5
k_loop:
VLD1.8 {d0}, [r4]!
VLD1.8 {d1}, [r5]!
VLD1.8 {d2}, [r6]!
VLD1.8 {d3}, [r7]!

# Now we have 4x8 block of values that we will tranpose
# A matrix
# --------------------------------
# | |
# |a0-----a3 a4-----a7....|
# |b0 B00 b3 b4 B01 b7....|
# |c0 c3 c4 c7....|
# |d0-----d3 d4-----d7....|
# | |
# | |
# -------------------------------
# {va01, va23} = B00 + B01 = 2 uint8x16_t
# Sequence:
# VTRN.8 d0, d1 // low(va01), high(va01)
# VTRN.8 d2, d3 // low(va23), high(va23)
# VTRN.16 q0, q1 // va01, va23
# Now we have
# d0 = d4, c4, b4, a4 : d0, c0, b0, a0
# d1 = d5, c5, b5, a5 : d1, c1, b1, a1
# d2 = d6, c6, b6, a6 : d2, c2, b2, a2
# d3 = d7, c7, b7, a7 : d3, c3, b3, a3
# Thus 2 4x4 blocks are transposed.
# Now we have all 2 B00, B01 transposed.

VTRN.8 d0, d1
VTRN.8 d2, d3
VTRN.16 q0, q1

# Now VTRN.32 d0, d1
# Now VTRN.32 d2, d3
# Thus we have
# d0 = d1, c1, b1, a1 : d0, c0, b0, a0
# d1 = d5, c5, b5, a5 : d4, c4, b4, a4
# d2 = d3, c3, b3, a3 : d2, c2, b2, a2
# d3 = d7, c7, b7, a7 : d6, c6, b6, a6
# Then we can do
# VSWP d1, d2
# d0 = d1, c1, b1, a1 : d0, c0, b0, a0
# d1 = d3, c3, b3, a3 : d2, c2, b2, a2
# d2 = d5, c5, b5, a5 : d4, c4, b4, a4
# d3 = d7, c7, b7, a7 : d6, c6, b6, a6
# Now we can store q0 contiguously followed
VTRN.32 d0, d1
VTRN.32 d2, d3
VSWP d1, d2

# Now store the tranposed values
# d0, d1, d2, d3
VST1.8 {q0}, [r2]!
VST1.8 {q1}, [r2]!

SUBS r1, r1, 2

BHS k_loop
1:
CMP r1, -2
BEQ 2f

VLD1.32 {d0[]}, [r4]
VLD1.32 {d1[]}, [r5]
VLD1.32 {d2[]}, [r6]
VLD1.32 {d3[]}, [r7]

# Now we have 4x8 block of values that we will tranpose
# _d{0-3} are arm neon vector registers
# va0 = _d0 = a0 a1 a2 a3
# va1 = _d1 = b0 b1 b2 b3
# va2 = _d2 = c0 c1 c2 c3
# va3 = _d3 = d0 d1 d2 d3
# A matrix
# ----------------------------
# | |
# | a0-----a3|
# | b0 B00 b3|
# | last block c0 c3|
# | d0-----d3|
# | |
# | |
# ---------------------------
# Sequence:
# VTRN.8 d0, d1 // va0, va1
# VTRN.8 d2, d3 // va2, va3
# Now we have
# d0 = b2, a2, b0, a0
# d1 = b3, a3, b1, a1
# d2 = d2, c2, d0, c0
# d3 = d3, c3, d1, c1
# Sequence:
# VTRN.16 d0, d2
# VTRN.16 d1, d3
# Now we have
# d0 = d0, c0, b0, a0
# d1 = d1, c1, b1, a1
# d2 = d2, c2, b2, a2
# d3 = d3, c3, b3, a3

VTRN.8 d0, d1
VTRN.8 d2, d3
VTRN.16 d0, d2
VTRN.16 d1, d3

# Since upper half of d0 just contains duplicate values
# We dont want to store those
# So let's combine upper half of d0 to the lower part of d0
# And lower half of d1 to upper half of d0
# Same for d2, d3
VEXT.8 d0, d0, d1, #4
VEXT.8 d1, d2, d3, #4

# Now store the tranposed values
# d0, d1, d2, d3
VST1.8 {q0}, [r2]
.p2align 4
2:
POP {r4, r5, r6, r7, r8, r9, r10, r11}
BX lr

END_FUNCTION pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif