Skip to content

Commit

Permalink
Add the needed changes to support AVX512VNNI
Browse files Browse the repository at this point in the history
  • Loading branch information
amitdo authored and stweil committed Nov 25, 2022
1 parent 8d0bd68 commit 232093f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ noinst_LTLIBRARIES += libtesseract_avx512.la
endif

if HAVE_AVX512VNNI
libtesseract_avx512vnni_la_CXXFLAGS = -march=icelake-client
libtesseract_avx512vnni_la_CXXFLAGS = -mavx512vnni -mavx512vl
libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp
libtesseract_la_LIBADD += libtesseract_avx512vnni.la
Expand Down
2 changes: 1 addition & 1 deletion configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ case "${host_cpu}" in
AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions])
fi

AX_CHECK_COMPILE_FLAG([-march=icelake-client], [avx512vnni=true], [avx512vnni=false], [$WERROR])
AX_CHECK_COMPILE_FLAG([-mavx512vnni], [avx512vnni=true], [avx512vnni=false], [$WERROR])
AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni)
if $avx512vnni; then
AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions])
Expand Down
20 changes: 8 additions & 12 deletions src/arch/intsimdmatrixavx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

#include "intsimdmatrix.h"

#if !defined(__AVX2__)
#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__)
# if defined(__i686__) || defined(__x86_64__)
# error Implementation only for AVX2 capable architectures
# error Implementation only for AVX512VNNI capable architectures
# endif
#else
# include <immintrin.h>
Expand Down Expand Up @@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones,
// Normalize the signs on rep_input, weights, so weights is always +ve.
reps = _mm256_sign_epi8(rep_input, weights);
weights = _mm256_sign_epi8(weights, weights);
// Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
// with adjacent pairs added.
weights = _mm256_maddubs_epi16(weights, reps);
// Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
// with adjacent pairs added. What we really want is a horizontal add of
// 16+16=32 bit result, but there is no such instruction, so multiply by
// 16-bit ones instead. It is probably faster than all the sign-extending,
// permuting and adding that would otherwise be required.
weights = _mm256_madd_epi16(weights, ones);
result = _mm256_add_epi32(result, weights);

// VNNI instruction. It replaces 3 AVX2 instructions:
//weights = _mm256_maddubs_epi16(weights, reps);
//weights = _mm256_madd_epi16(weights, ones);
//result = _mm256_add_epi32(result, weights);
result = _mm256_dpbusd_epi32(result, weights, reps);
}

// Load 64 bits into the bottom of a 128bit register.
Expand Down

0 comments on commit 232093f

Please sign in to comment.