Skip to content

Commit

Permalink
Added HNSW index type - #181
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 8, 2023
1 parent 19a6c81 commit 51d292c
Show file tree
Hide file tree
Showing 29 changed files with 3,927 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
@@ -1,5 +1,6 @@
## 0.5.0 (unreleased)

- Added HNSW index type
- Added support for parallel index builds
- Added `l1_distance` function
- Added element-wise multiplication for vectors
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -3,7 +3,7 @@ EXTVERSION = 0.4.4

MODULE_big = vector
DATA = $(wildcard sql/*--*.sql)
OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o

TESTS = $(wildcard test/sql/*.sql)
REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS))
Expand Down
2 changes: 1 addition & 1 deletion Makefile.win
@@ -1,7 +1,7 @@
EXTENSION = vector
EXTVERSION = 0.4.4

OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj

REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged
REGRESS_OPTS = --inputdir=test --load-extension=vector
Expand Down
23 changes: 23 additions & 0 deletions sql/vector--0.4.4--0.5.0.sql
Expand Up @@ -18,3 +18,26 @@ CREATE AGGREGATE sum(vector) (
COMBINEFUNC = vector_add,
PARALLEL = SAFE
);

CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;

CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;

COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';

CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);

CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);

CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);
25 changes: 24 additions & 1 deletion sql/vector.sql
Expand Up @@ -227,7 +227,7 @@ CREATE OPERATOR > (
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);

-- access method
-- access methods

CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
Expand All @@ -236,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler;

COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method';

CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;

CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;

COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';

-- opclasses

CREATE OPERATOR CLASS vector_ops
Expand Down Expand Up @@ -267,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops
FUNCTION 2 vector_norm(vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);

CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);

CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);

CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);
224 changes: 224 additions & 0 deletions src/hnsw.c
@@ -0,0 +1,224 @@
#include "postgres.h"

#include <float.h>
#include <math.h>

#include "access/amapi.h"
#include "commands/vacuum.h"
#include "hnsw.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"

#if PG_VERSION_NUM >= 120000
#include "commands/progress.h"
#endif

int hnsw_ef_search;
static relopt_kind hnsw_relopt_kind;

/*
* Initialize index options and variables
*/
void
HnswInit(void)
{
hnsw_relopt_kind = add_reloption_kind();
add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections",
HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction",
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);

DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search",
"Valid range is 10..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
}

/*
* Get the name of index build phase
*/
#if PG_VERSION_NUM >= 120000
static char *
hnswbuildphasename(int64 phasenum)
{
switch (phasenum)
{
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
return "initializing";
case PROGRESS_HNSW_PHASE_LOAD:
return "loading tuples";
default:
return NULL;
}
}
#endif

/*
* Estimate the cost of an index scan
*/
static void
hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation,
double *indexPages)
{
GenericCosts costs;
int m;
int entryLevel;
Relation index;
#if PG_VERSION_NUM < 120000
List *qinfos;
#endif

/* Never use index without order */
if (path->indexorderbys == NULL)
{
*indexStartupCost = DBL_MAX;
*indexTotalCost = DBL_MAX;
*indexSelectivity = 0;
*indexCorrelation = 0;
*indexPages = 0;
return;
}

MemSet(&costs, 0, sizeof(costs));

index = index_open(path->indexinfo->indexoid, NoLock);
m = HnswGetM(index);
index_close(index, NoLock);

/* Approximate entry level */
entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m);

/* TODO Improve estimate of visited tuples (currently underestimates) */
/* Account for number of tuples (or entry level), m, and ef_search */
costs.numIndexTuples = (entryLevel + 2) * m;

#if PG_VERSION_NUM >= 120000
genericcostestimate(root, path, loop_count, &costs);
#else
qinfos = deconstruct_indexquals(path);
genericcostestimate(root, path, loop_count, qinfos, &costs);
#endif

/* Use total cost since most work happens before first tuple is returned */
*indexStartupCost = costs.indexTotalCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
*indexPages = costs.numIndexPages;
}

/*
* Parse and validate the reloptions
*/
static bytea *
hnswoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
};

#if PG_VERSION_NUM >= 130000
return (bytea *) build_reloptions(reloptions, validate,
hnsw_relopt_kind,
sizeof(HnswOptions),
tab, lengthof(tab));
#else
relopt_value *options;
int numoptions;
HnswOptions *rdopts;

options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions);
rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions);
fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions,
validate, tab, lengthof(tab));

return (bytea *) rdopts;
#endif
}

/*
* Validate catalog entries for the specified operator class
*/
static bool
hnswvalidate(Oid opclassoid)
{
return true;
}

/*
* Define index handler
*
* See https://www.postgresql.org/docs/current/index-api.html
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler);
Datum
hnswhandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);

amroutine->amstrategies = 0;
amroutine->amsupport = 2;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
amroutine->amcanparallel = false;
amroutine->amcaninclude = false;
#if PG_VERSION_NUM >= 130000
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL;
#endif
amroutine->amkeytype = InvalidOid;

/* Interface functions */
amroutine->ambuild = hnswbuild;
amroutine->ambuildempty = hnswbuildempty;
amroutine->aminsert = hnswinsert;
amroutine->ambulkdelete = hnswbulkdelete;
amroutine->amvacuumcleanup = hnswvacuumcleanup;
amroutine->amcanreturn = NULL; /* tuple not included in heapsort */
amroutine->amcostestimate = hnswcostestimate;
amroutine->amoptions = hnswoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
#if PG_VERSION_NUM >= 120000
amroutine->ambuildphasename = hnswbuildphasename;
#endif
amroutine->amvalidate = hnswvalidate;
#if PG_VERSION_NUM >= 140000
amroutine->amadjustmembers = NULL;
#endif
amroutine->ambeginscan = hnswbeginscan;
amroutine->amrescan = hnswrescan;
amroutine->amgettuple = hnswgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = hnswendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;

/* Interface functions to support parallel index scans */
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;

PG_RETURN_POINTER(amroutine);
}

0 comments on commit 51d292c

Please sign in to comment.