Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX-512 support to Hamming and Jaccard distance functions. #519

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 99 additions & 2 deletions src/bitvector.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "postgres.h"

#ifdef __AVX512VPOPCNTDQ__
#include <immintrin.h>
#endif

#include "bitvector.h"
#include "port/pg_bitutils.h"
#include "utils/varbit.h"
Expand All @@ -20,8 +24,8 @@
#endif

/* target_clones requires glibc */
#if defined(__x86_64__) && defined(__gnu_linux__) && defined(__has_attribute) && __has_attribute(target_clones) && !defined(__POPCNT__)
#define BIT_DISPATCH __attribute__((target_clones("default", "popcnt")))
#if defined(__x86_64__) && defined(__gnu_linux__) && defined(__has_attribute) && __has_attribute(target_clones)
#define BIT_DISPATCH __attribute__((target_clones("default", "popcnt", "popcnt,avx512f")))
#else
#define BIT_DISPATCH
#endif
Expand Down Expand Up @@ -55,9 +59,52 @@ CheckDims(VarBit *a, VarBit *b)
errmsg("different bit lengths %u and %u", VARBITLEN(a), VARBITLEN(b))));
}

#ifdef __AVX512VPOPCNTDQ__
static inline uint64
pg_popcount_xor(const char *buf1, const char *buf2, int bytes)
{
uint64 popcnt;
__m512i accum = _mm512_setzero_si512();

for (; bytes >= sizeof(__m512i); bytes -= sizeof(__m512i))
{
const __m512i val1 = _mm512_loadu_si512((const __m512i *) buf1);
const __m512i val2 = _mm512_loadu_si512((const __m512i *) buf2);
const __m512i diff = _mm512_xor_si512(val1, val2);
const __m512i count = _mm512_popcnt_epi64(diff);

accum = _mm512_add_epi64(accum, count);
buf1 += sizeof(__m512i);
buf2 += sizeof(__m512i);
}
popcnt = _mm512_reduce_add_epi64(accum);

for (; bytes >= sizeof(uint64); bytes -= sizeof(uint64))
{
const uint64 *word1 = (const uint64 *) buf1;
const uint64 *word2 = (const uint64 *) buf2;

popcnt += popcount64(*word1 ^ *word2);
buf1 += sizeof(uint64);
buf2 += sizeof(uint64);
}

for (int i = 0; i < bytes; i++)
popcnt += pg_number_of_ones[(unsigned char) (*buf1++ ^ *buf2++)];

return popcnt;
}
#endif

BIT_DISPATCH static uint64
BitHammingDistance(uint32 bytes, unsigned char *ax, unsigned char *bx)
{
#ifdef __AVX512VPOPCNTDQ__

return pg_popcount_xor((const char *) ax, (const char *) bx, bytes);

#else

uint64 distance = 0;
uint32 i;
uint32 count = (bytes / 8) * 8;
Expand All @@ -77,6 +124,8 @@ BitHammingDistance(uint32 bytes, unsigned char *ax, unsigned char *bx)
distance += pg_number_of_ones[ax[i] ^ bx[i]];

return distance;

#endif
}

/*
Expand All @@ -94,12 +143,58 @@ hamming_distance(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8((double) BitHammingDistance(VARBITBYTES(a), VARBITS(a), VARBITS(b)));
}

#ifdef __AVX512VPOPCNTDQ__
static inline uint64
pg_popcount_and(const char *buf1, const char *buf2, int bytes)
{
uint64 popcnt;
__m512i accum = _mm512_setzero_si512();

for (; bytes >= sizeof(__m512i); bytes -= sizeof(__m512i))
{
const __m512i val1 = _mm512_loadu_si512((const __m512i *) buf1);
const __m512i val2 = _mm512_loadu_si512((const __m512i *) buf2);
const __m512i diff = _mm512_and_si512(val1, val2);
const __m512i count = _mm512_popcnt_epi64(diff);

accum = _mm512_add_epi64(accum, count);
buf1 += sizeof(__m512i);
buf2 += sizeof(__m512i);
}
popcnt = _mm512_reduce_add_epi64(accum);

for (; bytes >= sizeof(uint64); bytes -= sizeof(uint64))
{
const uint64 *word1 = (const uint64 *) buf1;
const uint64 *word2 = (const uint64 *) buf2;

popcnt += popcount64(*word1 & *word2);
buf1 += sizeof(uint64);
buf2 += sizeof(uint64);
}

for (int i = 0; i < bytes; i++)
popcnt += pg_number_of_ones[(unsigned char) (*buf1++ & *buf2++)];

return popcnt;
}
#endif

BIT_DISPATCH static double
BitJaccardDistance(uint32 bytes, unsigned char *ax, unsigned char *bx)
{
uint64 ab = 0;
uint64 aa = 0;
uint64 bb = 0;

#ifdef __AVX512VPOPCNTDQ__

aa = pg_popcount((const char *) ax, bytes);
bb = pg_popcount((const char *) bx, bytes);
ab = pg_popcount_and((const char *) ax, (const char *) bx, bytes);

#else

uint32 i;
uint32 count = (bytes / 8) * 8;

Expand All @@ -123,6 +218,8 @@ BitJaccardDistance(uint32 bytes, unsigned char *ax, unsigned char *bx)
bb += pg_number_of_ones[bx[i]];
}

#endif

if (ab == 0)
return 1;
else
Expand Down