From 04af15c9d6a605e9b72bf6632314f5d7fd87ce5a Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 16 Apr 2024 17:12:27 -0700 Subject: [PATCH] Added support for bit to IVFFlat --- CHANGELOG.md | 2 +- README.md | 1 + sql/vector--0.6.2--0.7.0.sql | 12 +++ sql/vector.sql | 12 +++ src/ivfbuild.c | 5 + src/ivfflat.h | 3 +- src/ivfkmeans.c | 65 ++++++++++++- src/ivfscan.c | 3 + src/ivfutils.c | 3 + test/expected/ivfflat_bit_hamming.out | 32 +++++++ test/expected/ivfflat_bit_jaccard.out | 21 ++++ test/sql/ivfflat_bit_hamming.sql | 19 ++++ test/sql/ivfflat_bit_jaccard.sql | 12 +++ test/t/035_ivfflat_bit_build_recall.pl | 128 +++++++++++++++++++++++++ 14 files changed, 315 insertions(+), 3 deletions(-) create mode 100644 test/expected/ivfflat_bit_hamming.out create mode 100644 test/expected/ivfflat_bit_jaccard.out create mode 100644 test/sql/ivfflat_bit_hamming.sql create mode 100644 test/sql/ivfflat_bit_jaccard.sql create mode 100644 test/t/035_ivfflat_bit_build_recall.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 96d7c280..4c4e8b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ - Added `halfvec` type - Added `sparsevec` type -- Added support for `bit` vectors to HNSW +- Added support for indexing `bit` type - Added `binary_quantize` function - Added `hamming_distance` function - Added `jaccard_distance` function diff --git a/README.md b/README.md index 36908ad8..2f83350e 100644 --- a/README.md +++ b/README.md @@ -353,6 +353,7 @@ Supported types are: - `vector` - up to 2,000 dimensions - `halfvec` - up to 4,000 dimensions (unreleased) +- `bit` - up to 64,000 dimensions (unreleased) ### Query Options diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 97b352fc..05677758 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -33,6 +33,18 @@ CREATE OPERATOR <%> ( COMMUTATOR = '<%>' ); +CREATE OPERATOR CLASS bit_hamming_ops + FOR TYPE bit USING ivfflat AS + OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 hamming_distance(bit, bit), + FUNCTION 3 hamming_distance(bit, bit); + +CREATE OPERATOR CLASS bit_jaccard_ops + FOR TYPE bit USING ivfflat AS + OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 jaccard_distance(bit, bit), + FUNCTION 3 jaccard_distance(bit, bit); + CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, diff --git a/sql/vector.sql b/sql/vector.sql index 96fa64c9..d2fe6873 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -326,6 +326,18 @@ CREATE OPERATOR <%> ( -- bit opclasses +CREATE OPERATOR CLASS bit_hamming_ops + FOR TYPE bit USING ivfflat AS + OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 hamming_distance(bit, bit), + FUNCTION 3 hamming_distance(bit, bit); + +CREATE OPERATOR CLASS bit_jaccard_ops + FOR TYPE bit USING ivfflat AS + OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 jaccard_distance(bit, bit), + FUNCTION 3 jaccard_distance(bit, bit); + CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 8cd54b6d..0350e9b4 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -6,6 +6,7 @@ #include "access/tableam.h" #include "access/parallel.h" #include "access/xact.h" +#include "bitvector.h" #include "catalog/index.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" @@ -324,6 +325,8 @@ GetMaxDimensions(IvfflatType type) if (type == IVFFLAT_TYPE_HALFVEC) maxDimensions *= 2; + else if (type == IVFFLAT_TYPE_BIT) + maxDimensions *= 32; return maxDimensions; } @@ -338,6 +341,8 @@ GetItemSize(IvfflatType type, int dimensions) return VECTOR_SIZE(dimensions); else if (type == IVFFLAT_TYPE_HALFVEC) return HALFVEC_SIZE(dimensions); + else if (type == IVFFLAT_TYPE_BIT) + return VARBITTOTALLEN(dimensions); else elog(ERROR, "Unsupported type"); } diff --git a/src/ivfflat.h b/src/ivfflat.h index 3ca91783..c6e70996 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -46,7 +46,8 @@ typedef enum IvfflatType { IVFFLAT_TYPE_VECTOR, - IVFFLAT_TYPE_HALFVEC + IVFFLAT_TYPE_HALFVEC, + IVFFLAT_TYPE_BIT } IvfflatType; /* Build phases */ diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 37ea549b..c0e5f361 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -3,10 +3,12 @@ #include #include +#include "bitvector.h" #include "halfutils.h" #include "halfvec.h" #include "ivfflat.h" #include "miscadmin.h" +#include "utils/builtins.h" #include "utils/datum.h" #include "utils/memutils.h" #include "vector.h" @@ -134,6 +136,15 @@ CompareHalfVectors(const void *a, const void *b) return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b); } +/* + * Compare bit vectors + */ +static int +CompareBitVectors(const void *a, const void *b) +{ + return DirectFunctionCall2(bitcmp, VarBitPGetDatum((VarBit *) a), VarBitPGetDatum((VarBit *) b)); +} + /* * Quick approach if we have little data */ @@ -151,6 +162,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy qsort(samples->items, samples->length, samples->itemsize, CompareVectors); else if (type == IVFFLAT_TYPE_HALFVEC) qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors); + else if (type == IVFFLAT_TYPE_BIT) + qsort(samples->items, samples->length, samples->itemsize, CompareBitVectors); else elog(ERROR, "Unsupported type"); @@ -191,6 +204,16 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy for (int j = 0; j < dimensions; j++) vec->x[j] = Float4ToHalfUnchecked((float) RandomDouble()); } + else if (type == IVFFLAT_TYPE_BIT) + { + VarBit *vec = DatumGetVarBitP(center); + + SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); + VARBITLEN(vec) = dimensions; + + for (int j = 0; j < dimensions; j++) + VARBITS(vec)[j / dimensions] |= (RandomDouble() > 0.5 ? 1 : 0) << (7 - (j % 8)); + } else elog(ERROR, "Unsupported type"); @@ -263,6 +286,17 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe aggCenter->x[k] += HalfToFloat4(vec->x[k]); } } + else if (type == IVFFLAT_TYPE_BIT) + { + for (int j = 0; j < numSamples; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); + VarBit *vec = (VarBit *) VectorArrayGet(samples, j); + + for (int k = 0; k < dimensions; k++) + aggCenter->x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); + } + } else elog(ERROR, "Unsupported type"); @@ -308,6 +342,21 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); } } + else if (type == IVFFLAT_TYPE_BIT) + { + for (int j = 0; j < numCenters; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + VarBit *newCenter = (VarBit *) VectorArrayGet(newCenters, j); + unsigned char *nx = VARBITS(newCenter); + + for (uint32 k = 0; k < VARBITBYTES(newCenter); k++) + nx[k] = 0; + + for (int k = 0; k < dimensions; k++) + nx[k / 8] |= (aggCenter->x[k] > 0.5) << (7 - (k % 8)); + } + } /* Normalize if needed */ if (normprocinfo != NULL) @@ -425,6 +474,18 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp vec->dim = dimensions; } } + else if (type == IVFFLAT_TYPE_BIT) + { + newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); + + for (int j = 0; j < numCenters; j++) + { + VarBit *vec = (VarBit *) VectorArrayGet(newCenters, j); + + SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); + VARBITLEN(vec) = dimensions; + } + } else elog(ERROR, "Unsupported type"); @@ -642,7 +703,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) elog(ERROR, "Infinite value detected. Please report a bug."); } } - else + else if (type != IVFFLAT_TYPE_BIT) elog(ERROR, "Unsupported type"); } @@ -652,6 +713,8 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) qsort(centers->items, centers->length, centers->itemsize, CompareVectors); else if (type == IVFFLAT_TYPE_HALFVEC) qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors); + else if (type == IVFFLAT_TYPE_BIT) + qsort(centers->items, centers->length, centers->itemsize, CompareBitVectors); else elog(ERROR, "Unsupported type"); diff --git a/src/ivfscan.c b/src/ivfscan.c index 2b847bc8..7e45ef9f 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -3,6 +3,7 @@ #include #include "access/relscan.h" +#include "bitvector.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" #include "halfvec.h" @@ -195,6 +196,8 @@ GetScanValue(IndexScanDesc scan) value = PointerGetDatum(InitVector(so->dimensions)); else if (type == IVFFLAT_TYPE_HALFVEC) value = PointerGetDatum(InitHalfVector(so->dimensions)); + else if (type == IVFFLAT_TYPE_BIT) + value = PointerGetDatum(InitBitVector(so->dimensions)); else elog(ERROR, "Unsupported type"); } diff --git a/src/ivfutils.c b/src/ivfutils.c index 6d416477..26e085d7 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -73,6 +73,9 @@ IvfflatGetType(Relation index) Form_pg_type type; IvfflatType result; + if (typid == BITOID) + return IVFFLAT_TYPE_BIT; + tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typid)); if (!HeapTupleIsValid(tuple)) elog(ERROR, "cache lookup failed for type %u", typid); diff --git a/test/expected/ivfflat_bit_hamming.out b/test/expected/ivfflat_bit_hamming.out new file mode 100644 index 00000000..cee8317e --- /dev/null +++ b/test/expected/ivfflat_bit_hamming.out @@ -0,0 +1,32 @@ +SET enable_seqscan = off; +CREATE TABLE t (val bit(3)); +INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL); +CREATE INDEX ON t USING ivfflat (val bit_hamming_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES (B'110'); +SELECT * FROM t ORDER BY val <~> B'111'; + val +----- + 111 + 110 + 100 + 000 +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; +-- TODO move +CREATE TABLE t (val varbit(3)); +CREATE INDEX ON t USING ivfflat (val bit_hamming_ops) WITH (lists = 1); +ERROR: type not supported for ivfflat index +CREATE INDEX ON t USING ivfflat ((val::bit(3)) bit_hamming_ops) WITH (lists = 1); +NOTICE: ivfflat index created with little data +DETAIL: This will cause low recall. +HINT: Drop the index until the table has more data. +CREATE INDEX ON t USING ivfflat ((val::bit(64001)) bit_hamming_ops) WITH (lists = 1); +ERROR: column cannot have more than 64000 dimensions for ivfflat index +DROP TABLE t; diff --git a/test/expected/ivfflat_bit_jaccard.out b/test/expected/ivfflat_bit_jaccard.out new file mode 100644 index 00000000..0eb6b8df --- /dev/null +++ b/test/expected/ivfflat_bit_jaccard.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val bit(4)); +INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL); +CREATE INDEX ON t USING ivfflat (val bit_jaccard_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES (B'1110'); +SELECT * FROM t ORDER BY val <%> B'1111'; + val +------ + 1111 + 1110 + 1100 + 0000 +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/sql/ivfflat_bit_hamming.sql b/test/sql/ivfflat_bit_hamming.sql new file mode 100644 index 00000000..6c697ae1 --- /dev/null +++ b/test/sql/ivfflat_bit_hamming.sql @@ -0,0 +1,19 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val bit(3)); +INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL); +CREATE INDEX ON t USING ivfflat (val bit_hamming_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES (B'110'); + +SELECT * FROM t ORDER BY val <~> B'111'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2; + +DROP TABLE t; + +-- TODO move +CREATE TABLE t (val varbit(3)); +CREATE INDEX ON t USING ivfflat (val bit_hamming_ops) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat ((val::bit(3)) bit_hamming_ops) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat ((val::bit(64001)) bit_hamming_ops) WITH (lists = 1); +DROP TABLE t; diff --git a/test/sql/ivfflat_bit_jaccard.sql b/test/sql/ivfflat_bit_jaccard.sql new file mode 100644 index 00000000..8d8b6c98 --- /dev/null +++ b/test/sql/ivfflat_bit_jaccard.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val bit(4)); +INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL); +CREATE INDEX ON t USING ivfflat (val bit_jaccard_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES (B'1110'); + +SELECT * FROM t ORDER BY val <%> B'1111'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2; + +DROP TABLE t; diff --git a/test/t/035_ivfflat_bit_build_recall.pl b/test/t/035_ivfflat_bit_build_recall.pl new file mode 100644 index 00000000..4a9607a6 --- /dev/null +++ b/test/t/035_ivfflat_bit_build_recall.pl @@ -0,0 +1,128 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 52; +my $max = 2**$dim; + +sub test_recall +{ + my ($probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator $queries[0] LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SELECT i FROM tst ORDER BY v $operator $queries[$i] LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v bit($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, (random() * $max)::bigint::bit($dim) FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r = int(rand() * $max); + push(@queries, "${r}::bigint::bit($dim)"); +} + +# Check each index type +my @operators = ("<~>", "<\%>"); +my @opclasses = ("bit_hamming_ops", "bit_jaccard_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + WITH top AS ( + SELECT v $operator $_ AS distance FROM tst ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator $_) <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + + # Test approximate results + test_recall(1, 0.10, $operator); + test_recall(10, 0.55, $operator); + + # Test probes equals lists + test_recall(100, 1.00, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall(1, 0.10, $operator); + test_recall(10, 0.55, $operator); + + # Test probes equals lists + test_recall(100, 1.00, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();