Skip to content

Commit

Permalink
Add smmla/ummla support in quantized Conv2d (apache#6802)
Browse files Browse the repository at this point in the history
* Add smmla/ummla support in quantized Conv2d

This introduces support for `smmla`/`ummla` instructions in TVM:
- Added `is_mmla_available` function in `arm_utils.py`
- Added the tiling node + tensorization schedule in `conv2d_gemm.py`
- Added the intrinsic support in `tensor_intrin.py`
- Added the test-case in `test_topi_conv2d_int8.py`

Change-Id: Iff48c77f16fe1e64ecb733da965a879651ce635f

* Address review comments and test failures

* Fix linting

* Rebasing
  • Loading branch information
Giuseppe Rossini authored and Trevor Morris committed Dec 4, 2020
1 parent fb0452b commit 9a801b4
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 66 deletions.
27 changes: 22 additions & 5 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ def get_arch_version(target_mattr):


def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
""" Checks whether the hardware has support for udot/sdot instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_mmla_available():
""" Checks whether the hardware has support for ummla/smmla instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.6 or (
(arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr
)


def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
Expand All @@ -63,8 +72,10 @@ def get_tiling_B_interleaved_t(interleave_A):
tile computation.
Please refer to:
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information
Parameters
Expand All @@ -77,7 +88,13 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
if is_mmla_available():
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_rows_B = 12
tile_cols_B = 8
elif is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
Expand All @@ -92,7 +109,7 @@ def get_tiling_B_interleaved_t(interleave_A):
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# If no acceleration is available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16
Expand Down
179 changes: 122 additions & 57 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
gemm_quantized_impl,
gemm_acc_4x4_int8_int8_int32,
gemm_acc_nx16_int8_int8_int32,
gemm_acc_2x2_int8_int8_int32,
)
from .arm_utils import is_aarch64_arm, is_dotprod_available
from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available


def configure_knobs(cfg, M, K):
Expand Down Expand Up @@ -130,11 +131,18 @@ def compute_conv2d_gemm_without_weight_transform(
# the tile computation.
#
# Please refer to:
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_dotprod_available() and interleave_A:
if is_mmla_available():
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_rows_A = 8
tile_cols_A = 8
elif is_dotprod_available() and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_rows_A = 8
Expand Down Expand Up @@ -177,24 +185,71 @@ def compute_conv2d_gemm_without_weight_transform(
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
name="A_interleaved",
)
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
].astype(out_dtype),
name="C",
)
if is_mmla_available():
# Execute GEMM. In the case of mmla, we need to enforce the tiling
# from the compute. This is because mmla is doing a tiled computation
# as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
# generated by mmla. In theory we could make the tile 2x2 and
# fuse and split during scheduling, but this would not work
# because of possible padding
C_interleaved = te.compute(
(
batches,
M_padded // tile_rows_A,
N_transformed,
tile_rows_A // 2,
tile_rows_B // 2,
2,
2,
),
lambda b, x, y, w, z, s, t: te.sum(
A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype(
"int32"
)
* B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype(
"int32"
),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A) // 2,
idxm(y, tile_rows_B) // 2,
idxm(idxm(x, tile_rows_A), 2),
idxm(idxm(y, tile_rows_B), 2),
].astype(out_dtype),
name="C",
)
else:
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A),
idxm(y, tile_rows_B),
].astype(out_dtype),
name="C",
)
zero = tvm.tir.const(0)
else:
# No need to pack/unpack, execute GEMM directly
Expand Down Expand Up @@ -255,7 +310,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s[data_im2col].compute_inline()

# Computation(through tensorize)
b, xo, yo, xi, yi = C_interleaved.op.axis
b, xo, yo, xi, yi = C_interleaved.op.axis[0:5]
outer_gemm, inner_gemm = cfg["reorder_gemm"].apply(s, C_interleaved, [xo, yo])

b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
Expand All @@ -271,40 +326,50 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)
if in_type in ["int8", "uint8"]:
if is_mmla_available():
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)

# Output transform
if out != final_out:
Expand Down

0 comments on commit 9a801b4

Please sign in to comment.