Skip to content

Commit

Permalink
Fix: JS keys should be bigint
Browse files Browse the repository at this point in the history
Closes #178
  • Loading branch information
ashvardanian committed Aug 4, 2023
1 parent 7da44a2 commit e1fbec4
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 54 deletions.
10 changes: 5 additions & 5 deletions javascript/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Installation

USearch is available both for Node.js backend runtime and WASM frontend runtime.
For first use the conventional `npm install`:
For the first option, use the conventional `npm install`:

```sh
npm install usearch
Expand All @@ -18,11 +18,11 @@ wasmer install unum/usearch
## Quickstart

```js
var index = new usearch.Index({ metric: 'cos', connectivity: 16, dimensions: 3 })
index.add(42, new Float32Array([0.2, 0.6, 0.4]))
var results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10)
var index = new usearch.Index({ metric: 'cos', connectivity: 16n, dimensions: 3n })
index.add(42n, new Float32Array([0.2, 0.6, 0.4]))
var results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10n)

assert.equal(index.size(), 1)
assert.equal(index.size(), 1n)
assert.deepEqual(results.keys, new BigUint64Array([42n]))
assert.deepEqual(results.distances, new Float32Array([0]))
```
Expand Down
57 changes: 38 additions & 19 deletions javascript/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,28 @@ Index::Index(Napi::CallbackInfo const& ctx) : Napi::ObjectWrap<Index>(ctx) {
return;
}

bool lossless = true;
Napi::Object params = ctx[0].As<Napi::Object>();
std::size_t dimensions = params.Has("dimensions") ? params.Get("dimensions").As<Napi::Number>().Uint32Value() : 0;
std::size_t dimensions =
params.Has("dimensions") ? params.Get("dimensions").As<Napi::BigInt>().Uint64Value(&lossless) : 0;

index_limits_t limits;
std::size_t connectivity = default_connectivity();
std::size_t expansion_add = default_expansion_add();
std::size_t expansion_search = default_expansion_search();

if (params.Has("capacity"))
limits.members = params.Get("capacity").As<Napi::Number>().Uint32Value();
limits.members = params.Get("capacity").As<Napi::BigInt>().Uint64Value(&lossless);
if (params.Has("connectivity"))
connectivity = params.Get("connectivity").As<Napi::Number>().Uint32Value();
connectivity = params.Get("connectivity").As<Napi::BigInt>().Uint64Value(&lossless);
if (params.Has("expansion_add"))
expansion_add = params.Get("expansion_add").As<Napi::Number>().Uint32Value();
expansion_add = params.Get("expansion_add").As<Napi::BigInt>().Uint64Value(&lossless);
if (params.Has("expansion_search"))
expansion_search = params.Get("expansion_search").As<Napi::Number>().Uint32Value();
expansion_search = params.Get("expansion_search").As<Napi::BigInt>().Uint64Value(&lossless);
if (!lossless) {
Napi::TypeError::New(env, "Arguments must be unsigned integers").ThrowAsJavaScriptException();
return;
}

scalar_kind_t quantization = scalar_kind_t::f32_k;
if (params.Has("quantization")) {
Expand Down Expand Up @@ -121,14 +127,16 @@ Index::Index(Napi::CallbackInfo const& ctx) : Napi::ObjectWrap<Index>(ctx) {
}

Napi::Value Index::GetDimensions(Napi::CallbackInfo const& ctx) {
return Napi::Number::New(ctx.Env(), native_->dimensions());
return Napi::BigInt::New(ctx.Env(), static_cast<std::uint64_t>(native_->dimensions()));
}
Napi::Value Index::GetSize(Napi::CallbackInfo const& ctx) {
return Napi::BigInt::New(ctx.Env(), static_cast<std::uint64_t>(native_->size()));
}
Napi::Value Index::GetSize(Napi::CallbackInfo const& ctx) { return Napi::Number::New(ctx.Env(), native_->size()); }
Napi::Value Index::GetConnectivity(Napi::CallbackInfo const& ctx) {
return Napi::Number::New(ctx.Env(), native_->connectivity());
return Napi::BigInt::New(ctx.Env(), static_cast<std::uint64_t>(native_->connectivity()));
}
Napi::Value Index::GetCapacity(Napi::CallbackInfo const& ctx) {
return Napi::Number::New(ctx.Env(), native_->capacity());
return Napi::BigInt::New(ctx.Env(), static_cast<std::uint64_t>(native_->capacity()));
}

void Index::Save(Napi::CallbackInfo const& ctx) {
Expand Down Expand Up @@ -191,8 +199,12 @@ void Index::Add(Napi::CallbackInfo const& ctx) {
using key_t = typename index_dense_t::key_t;
std::size_t index_dimensions = native_->dimensions();

auto add = [&](Napi::Number key_js, Napi::Float32Array vector_js) {
key_t key = key_js.Uint32Value();
auto add = [&](Napi::BigInt key_js, Napi::Float32Array vector_js) {
bool lossless = true;
key_t key = static_cast<key_t>(key_js.Uint64Value(&lossless));
if (!lossless)
return Napi::TypeError::New(env, "Keys must be unsigned integers").ThrowAsJavaScriptException();

float const* vector = vector_js.Data();
std::size_t dimensions = static_cast<std::size_t>(vector_js.ElementLength());

Expand Down Expand Up @@ -223,45 +235,52 @@ void Index::Add(Napi::CallbackInfo const& ctx) {
for (std::size_t i = 0; i < length; i++) {
Napi::Value key_js = keys_js[i];
Napi::Value vector_js = vectors_js[i];
add(key_js.As<Napi::Number>(), vector_js.As<Napi::Float32Array>());
add(key_js.As<Napi::BigInt>(), vector_js.As<Napi::Float32Array>());
}

} else if (ctx[0].IsNumber() && ctx[1].IsTypedArray()) {
} else if (ctx[0].IsBigInt() && ctx[1].IsTypedArray()) {
if (native_->size() + 1 >= native_->capacity())
native_->reserve(ceil2(native_->size() + 1));
add(ctx[0].As<Napi::Number>(), ctx[1].As<Napi::Float32Array>());
add(ctx[0].As<Napi::BigInt>(), ctx[1].As<Napi::Float32Array>());
} else
return Napi::TypeError::New(env, "Invalid argument type, expects integral key(s) and float vector(s)")
.ThrowAsJavaScriptException();
}

Napi::Value Index::Search(Napi::CallbackInfo const& ctx) {
Napi::Env env = ctx.Env();
if (ctx.Length() < 2 || !ctx[0].IsTypedArray() || !ctx[1].IsNumber()) {
if (ctx.Length() < 2 || !ctx[0].IsTypedArray() || !ctx[1].IsBigInt()) {
Napi::TypeError::New(env, "Expects a and the number of wanted results").ThrowAsJavaScriptException();
return {};
}

Napi::Float32Array vector_js = ctx[0].As<Napi::Float32Array>();
Napi::Number wanted_js = ctx[1].As<Napi::Number>();
Napi::BigInt wanted_js = ctx[1].As<Napi::BigInt>();

float const* vector = vector_js.Data();
std::size_t dimensions = static_cast<std::size_t>(vector_js.ElementLength());
std::uint32_t wanted = wanted_js.Uint32Value();
if (dimensions != native_->dimensions()) {
Napi::TypeError::New(env, "Wrong number of dimensions").ThrowAsJavaScriptException();
return {};
}

bool lossless = true;
std::uint64_t wanted = wanted_js.Uint64Value(&lossless);
if (!lossless) {
Napi::TypeError::New(env, "Wanted number of matches must be an unsigned integer").ThrowAsJavaScriptException();
return {};
}

using key_t = typename index_dense_t::key_t;
Napi::TypedArrayOf<key_t> matches_js = Napi::TypedArrayOf<key_t>::New(env, wanted);
static_assert(std::is_same<std::uint64_t, key_t>::value, "Matches.key interface expects BigUint64Array");
Napi::Float32Array distances_js = Napi::Float32Array::New(env, wanted);
try {
std::size_t count = native_->search(vector, wanted).dump_to(matches_js.Data(), distances_js.Data());
std::uint64_t count = native_->search(vector, wanted).dump_to(matches_js.Data(), distances_js.Data());
Napi::Object result_js = Napi::Object::New(env);
result_js.Set("keys", matches_js);
result_js.Set("distances", distances_js);
result_js.Set("count", Napi::Number::New(env, count));
result_js.Set("count", Napi::BigInt::New(env, count));
return result_js;
} catch (std::bad_alloc const&) {
Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException();
Expand Down
14 changes: 7 additions & 7 deletions javascript/test.js
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
var assert = require('assert');
var usearch = require('bindings')('usearch');

var index = new usearch.Index({ metric: 'l2sq', connectivity: 16, dimensions: 2 })
var index = new usearch.Index({ metric: 'l2sq', connectivity: 16n, dimensions: 2n })
assert.equal(index.connectivity(), 16)
assert.equal(index.dimensions(), 2)
assert.equal(index.size(), 0)

index.add(15, new Float32Array([10, 20]))
index.add(16, new Float32Array([10, 25]))
index.add(15n, new Float32Array([10, 20]))
index.add(16n, new Float32Array([10, 25]))
assert.equal(index.size(), 2)

var results = index.search(new Float32Array([13, 14]), 2)
var results = index.search(new Float32Array([13, 14]), 2n)
assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]))
assert.deepEqual(results.distances, new Float32Array([45, 130]))

// Batch

var index2 = new usearch.Index({ metric: 'l2sq', connectivity: 16, dimensions: 2 })
var index2 = new usearch.Index({ metric: 'l2sq', connectivity: 16n, dimensions: 2n })

const keys = [15, 16]
const keys = [15n, 16n]
const vectors = [new Float32Array([10, 20]), new Float32Array([10, 25])]
index2.add(keys, vectors)
assert.equal(index.size(), 2)

var results = index.search(new Float32Array([13, 14]), 2)
var results = index.search(new Float32Array([13, 14]), 2n)
assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]))
assert.deepEqual(results.distances, new Float32Array([45, 130]))

Expand Down
46 changes: 23 additions & 23 deletions javascript/usearch.d.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@

/** Search result object. */
export interface Matches {
/** The labels of the nearest neighbors found, size n*k. */
labels: BigUint64Array,
/** The disances of the nearest negihbors found, size n*k. */
/** The keys of the nearest neighbors found, size n*k. */
keys: BigUint64Array,
/** The distances of the nearest neighbors found, size n*k. */
distances: Float32Array,
/** The disances of the nearest negihbors found, size n*k. */
count: number
/** The distances of the nearest neighbors found, size n*k. */
count: bigint
}

/** K-Approximate Nearest Neighbors search index. */
Expand All @@ -15,39 +15,39 @@ export class Index {
/**
* Constructs a new index.
*
* @param {number} dimensions
* @param {bigint} dimensions
* @param {string} metric
* @param {string} quantization
* @param {number} capacity
* @param {number} connectivity
* @param {number} expansion_add
* @param {number} expansion_search
* @param {bigint} capacity
* @param {bigint} connectivity
* @param {bigint} expansion_add
* @param {bigint} expansion_search
*/
constructor(...args);

/**
* Returns the dimensionality of vectors.
* @return {number} The dimensionality of vectors.
* @return {bigint} The dimensionality of vectors.
*/
dimensions(): number;
dimensions(): bigint;

/**
* Returns the number of vectors currently indexed.
* @return {number} The number of vectors currently indexed.
* Returns the bigint of vectors currently indexed.
* @return {bigint} The bigint of vectors currently indexed.
*/
size(): number;
size(): bigint;

/**
* Returns index capacity.
* @return {numbers} The capacity of index.
* @return {bigints} The capacity of index.
*/
capacity(): number;
capacity(): bigint;

/**
* Returns connectivity.
* @return {number} The connectivity of index.
* @return {bigint} The connectivity of index.
*/
connectivity(): number;
connectivity(): bigint;

/**
* Write index to a file.
Expand All @@ -70,18 +70,18 @@ export class Index {
/**
* Add n vectors of dimension d to the index.
*
* @param {number | number[]} keys Input identifiers for every vector.
* @param {bigint | bigint[]} keys Input identifiers for every vector.
* @param {Float32Array | Float32Array[]} mat Input matrix, matrix of size n * d.
*/
add(keys: number | number[], mat: Float32Array | Float32Array[]): void;
add(keys: bigint | bigint[], mat: Float32Array | Float32Array[]): void;

/**
* Query n vectors of dimension d to the index. Return at most k vectors for each.
* If there are not enough results for a query, the result array is padded with -1s.
*
* @param {Float32Array} mat Input vectors to search, matrix of size n * d.
* @param {number} k The number of nearest neighbors to search for.
* @param {bigint} k The bigint of nearest neighbors to search for.
* @return {Matches} Output of the search result.
*/
search(mat: Float32Array, k: number): Matches;
search(mat: Float32Array, k: bigint): Matches;
}

0 comments on commit e1fbec4

Please sign in to comment.