Skip to content

Commit

Permalink
[QNNPACK, Sparsity] Sparse kernel with 4x8 blocking (#50590)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50590

Larger blocking across M dim such as 8 in previous PR is likely
introducing wasted compute on the shapes being benchmarked.
Here we introduced 4x8 blocking of mrxnr. This helps 1) in packing
smaller data for small values of M and 2) for compute kernel it writes
same number of bytes but more contiguously. It is not certain but it
likely helps.

Test Plan:
q8gemm-sparse-test
fully-connected-sparse-test

Imported from OSS

Reviewed By: AshkanAliabadi

Differential Revision: D25925499

fbshipit-source-id: 01c661ceea38bd6ee8321bb85cf1d5da5de4e984
  • Loading branch information
kimishpatel authored and facebook-github-bot committed Feb 5, 2021
1 parent e8ee35a commit 70830b5
Show file tree
Hide file tree
Showing 13 changed files with 1,208 additions and 9 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ set(PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS
src/q8gemm/4x8-dq-aarch32-neon.S
src/q8gemm/4x8c2-xzp-aarch32-neon.S
src/q8gemm_sparse/8x4-packA-aarch32-neon.S
src/q8gemm_sparse/4x4-packA-aarch32-neon.S
src/q8gemm_sparse/8x4c1x4-dq-packedA-aarch32-neon.S
src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S
src/q8gemm_sparse/8x4c1x4-dq-aarch32-neon.S)

set(PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS
Expand Down
50 changes: 47 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,9 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon(
mrr,
kc(),
a(),
a() + m * kc(),
kc() * sizeof(uint8_t),
a_packed.data()
a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
);
}
}
Expand All @@ -473,7 +473,7 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon(
mrr,
nrr,
a_packed.data(),
a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
bcsr_matrix_->values.data(),
bcsr_matrix_->row_values.data(),
bcsr_matrix_->col_indices.data(),
Expand All @@ -489,6 +489,50 @@ BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon, 8,
BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x4c1x4_prepacked__aarch32_neon)
->Apply(SparseGEMMBenchGemmArguments);

BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon, 4, 8, 4)
(benchmark::State& state) {
for (auto _ : state) {
auto m_blocks = (mc() + mr() - 1) / mr();
auto k_blocks = (kc() + 4 - 1) / 4;
std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0, channel_offset = 0; n < nc();
n += nr(), channel_offset += nr()) {
const uint32_t nrr = min(nc() - n, nr());
pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
mrr,
kc(),
a() + m * kc(),
kc() * sizeof(uint8_t),
a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
);
}
}
for (uint32_t m = 0; m < mc(); m += mr()) {
const uint32_t mrr = min(mc() - m, mr());
for (uint32_t n = 0, channel_offset = 0; n < nc();
n += nr(), channel_offset += nr()) {
const uint32_t nrr = min(nc() - n, nr());
pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon(
mrr,
nrr,
a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
bcsr_matrix_->values.data(),
bcsr_matrix_->row_values.data(),
bcsr_matrix_->col_indices.data(),
b() + n,
c() + m * nc() + n,
nc(),
channel_offset,
quantizationParams());
}
}
}
}
BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon)
->Apply(SparseGEMMBenchGemmArguments);

#endif

#if CPUINFO_ARCH_ARM64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
fully_connected->dynamic_conv_quantization_params.multipliers =
requantization_scales;

if (use_prepack_kernel) {
if (use_prepack_kernel || (pytorch_qnnp_params.q8gemm_sparse.gemm_dq == NULL)) {
fully_connected->ukernel_type =
pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq;
} else {
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ static void init(void) {
.nr = 8,
.kr = 1,
};
/*
* Current q8gemm-sparse-bench shows that for the shapes benchmarked
* 4x8 kernels are better than 8x4
* Without prepacking it always worse than dense
* all benchmarking at 70% sparsity
*/
/*
pytorch_qnnp_params.q8gemm_sparse = (struct pytorch_q8gemm_sparse_parameters){
.gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4__aarch32_neon,
.packedA_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon,
Expand All @@ -68,6 +75,16 @@ static void init(void) {
.kr = 4,
.log2_mr = 3,
};
*/
pytorch_qnnp_params.q8gemm_sparse = (struct pytorch_q8gemm_sparse_parameters){
.gemm_dq = NULL,
.packedA_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon,
.packA = pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon,
.mr = 4,
.nr = 8,
.kr = 4,
.log2_mr = 2,
};
#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
.gemm = pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <qnnpack/assembly.h>
#include <requantization/runtime-assembly.h>

# r0 mr
# r1 k
# r2 a
# r3 a_stride

.syntax unified

# Args passed via stack.
# TOS
# |----------------|
# |packed_a | 0
# |----------------|
#

# After loading w pointer in ip reg.
# And after pushing r4-r9 and d8-d15 on stack
# |----------------|
# |r4 - r11 | 0
# |packed_a | 32
# |----------------|
#

# Packed A format.
# 4kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory.
# Original A
# --------- K ----------- -- (K + 4 - 1) / 4 --
# | | | |
# | | (M + 4 - 1)/4 |
# | | Packed | |
# M | => |-------------------|
# | | Thus Packed A has (K + 4 - 1)/4 * (M + 4 -1)/4 blocks
# | |
# |---------------------|
#
# Each 4 x 4 blocks is transposed and stored.
# Each of the (K + 4 - 1)/4 blocks for a given group of 4 m blocks
# are stored adjacent in memory
# Thus, each block:
# |----4m-----|----4m-----|
# 4k | | ..... (K + 4 - 1)/4 blocks
# |-----------|-----------|
# This locality helps in loading 8kx4m blocks of activations
# Note when M is not multiple of 4, the rest can contain arbitrary
# data in packed A as we will not be writing those out.
# This wil be taken care by just copying the appropriate valid data

# void pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
# size_t mr,
# size_t K,
# const uint8_t* a,
# size_t a_stride,
# uint8_t* packed_a,
BEGIN_FUNCTION pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon
.arm
#ifndef __APPLE__
.arch armv7-a
.fpu neon
#endif

PUSH {r4, r5, r6, r7, r8, r9, r10, r11}

# r4 = a0 = a pointer
MOV r4, r2
# r2 = packed_a pointer
LDR r2, [sp, 32]

CMP r0, 2
# r5 = a1
ADD r5, r4, r3
MOVLO r5, r4

# r6 = a2
ADD r6, r5, r3
MOVLS r6, r5

CMP r0, 4
# r7 = a3
ADD r7, r6, r3
MOVNE r7, r6

# num_k_blocks = (k + (4 - 1)) / 4
ADD r1, r1, 3
LSR r1, r1, 2

SUBS r1, r1, 2
BLO 1f

.p2align 5
k_loop:
VLD1.8 {d0}, [r4]!
VLD1.8 {d1}, [r5]!
VLD1.8 {d2}, [r6]!
VLD1.8 {d3}, [r7]!

# Now we have 4x8 block of values that we will tranpose
# A matrix
# --------------------------------
# | |
# |a0-----a3 a4-----a7....|
# |b0 B00 b3 b4 B01 b7....|
# |c0 c3 c4 c7....|
# |d0-----d3 d4-----d7....|
# | |
# | |
# -------------------------------
# {va01, va23} = B00 + B01 = 2 uint8x16_t
# Sequence:
# VTRN.8 d0, d1 // low(va01), high(va01)
# VTRN.8 d2, d3 // low(va23), high(va23)
# VTRN.16 q0, q1 // va01, va23
# Now we have
# d0 = d4, c4, b4, a4 : d0, c0, b0, a0
# d1 = d5, c5, b5, a5 : d1, c1, b1, a1
# d2 = d6, c6, b6, a6 : d2, c2, b2, a2
# d3 = d7, c7, b7, a7 : d3, c3, b3, a3
# Thus 2 4x4 blocks are transposed.
# Now we have all 2 B00, B01 transposed.

VTRN.8 d0, d1
VTRN.8 d2, d3
VTRN.16 q0, q1

# Now VTRN.32 d0, d1
# Now VTRN.32 d2, d3
# Thus we have
# d0 = d1, c1, b1, a1 : d0, c0, b0, a0
# d1 = d5, c5, b5, a5 : d4, c4, b4, a4
# d2 = d3, c3, b3, a3 : d2, c2, b2, a2
# d3 = d7, c7, b7, a7 : d6, c6, b6, a6
# Then we can do
# VSWP d1, d2
# d0 = d1, c1, b1, a1 : d0, c0, b0, a0
# d1 = d3, c3, b3, a3 : d2, c2, b2, a2
# d2 = d5, c5, b5, a5 : d4, c4, b4, a4
# d3 = d7, c7, b7, a7 : d6, c6, b6, a6
# Now we can store q0 contiguously followed
VTRN.32 d0, d1
VTRN.32 d2, d3
VSWP d1, d2

# Now store the tranposed values
# d0, d1, d2, d3
VST1.8 {q0}, [r2]!
VST1.8 {q1}, [r2]!

SUBS r1, r1, 2

BHS k_loop
1:
CMP r1, -2
BEQ 2f

VLD1.32 {d0[]}, [r4]
VLD1.32 {d1[]}, [r5]
VLD1.32 {d2[]}, [r6]
VLD1.32 {d3[]}, [r7]

# Now we have 4x8 block of values that we will tranpose
# _d{0-3} are arm neon vector registers
# va0 = _d0 = a0 a1 a2 a3
# va1 = _d1 = b0 b1 b2 b3
# va2 = _d2 = c0 c1 c2 c3
# va3 = _d3 = d0 d1 d2 d3
# A matrix
# ----------------------------
# | |
# | a0-----a3|
# | b0 B00 b3|
# | last block c0 c3|
# | d0-----d3|
# | |
# | |
# ---------------------------
# Sequence:
# VTRN.8 d0, d1 // va0, va1
# VTRN.8 d2, d3 // va2, va3
# Now we have
# d0 = b2, a2, b0, a0
# d1 = b3, a3, b1, a1
# d2 = d2, c2, d0, c0
# d3 = d3, c3, d1, c1
# Sequence:
# VTRN.16 d0, d2
# VTRN.16 d1, d3
# Now we have
# d0 = d0, c0, b0, a0
# d1 = d1, c1, b1, a1
# d2 = d2, c2, b2, a2
# d3 = d3, c3, b3, a3

VTRN.8 d0, d1
VTRN.8 d2, d3
VTRN.16 d0, d2
VTRN.16 d1, d3

# Since upper half of d0 just contains duplicate values
# We dont want to store those
# So let's combine upper half of d0 to the lower part of d0
# And lower half of d1 to upper half of d0
# Same for d2, d3
VEXT.8 d0, d0, d1, #4
VEXT.8 d1, d2, d3, #4

# Now store the tranposed values
# d0, d1, d2, d3
VST1.8 {q0}, [r2]
.p2align 4
2:
POP {r4, r5, r6, r7, r8, r9, r10, r11}
BX lr

END_FUNCTION pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif

0 comments on commit 70830b5

Please sign in to comment.