Skip to content

Commit

Permalink
Add: Changing the metric at runtime
Browse files Browse the repository at this point in the history
This feature useful for cases, when a weighted sum of metrics is used
and the weights change from query to query.
  • Loading branch information
ashvardanian committed Aug 5, 2023
1 parent 1e9dbe5 commit d7bfac7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ class index_dense_gt {

// The metric and its properties
metric_t const& metric() const { return metric_; }
void change_metric(metric_t metric) { metric_ = std::move(metric); }

scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); }
std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); }
std::size_t scalar_words() const noexcept { return metric_.scalar_words(); }
Expand Down
21 changes: 19 additions & 2 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ metric_t udf(
static dense_index_py_t make_index( //
std::size_t dimensions, //
scalar_kind_t scalar_kind, //
metric_kind_t metric_kind, //
std::size_t connectivity, //
std::size_t expansion_add, //
std::size_t expansion_search, //
metric_kind_t metric_kind, //
metric_signature_t metric_signature, //
std::uintptr_t metric_uintptr) {

Expand Down Expand Up @@ -696,10 +696,10 @@ PYBIND11_MODULE(compiled, m) {
py::kw_only(), //
py::arg("ndim") = 0, //
py::arg("dtype") = scalar_kind_t::f32_k, //
py::arg("metric_kind") = metric_kind_t::cos_k, //
py::arg("connectivity") = default_connectivity(), //
py::arg("expansion_add") = default_expansion_add(), //
py::arg("expansion_search") = default_expansion_search(), //
py::arg("metric_kind") = metric_kind_t::cos_k, //
py::arg("metric_signature") = metric_signature_t::array_array_k, //
py::arg("metric_pointer") = 0 //
);
Expand Down Expand Up @@ -781,6 +781,23 @@ PYBIND11_MODULE(compiled, m) {
i.def_property("expansion_add", &dense_index_py_t::expansion_add, &dense_index_py_t::change_expansion_add);
i.def_property("expansion_search", &dense_index_py_t::expansion_search, &dense_index_py_t::change_expansion_search);

i.def(
"change_metric",
[](dense_index_py_t& index, metric_kind_t metric_kind, metric_signature_t metric_signature,
std::uintptr_t metric_uintptr) {
scalar_kind_t scalar_kind = index.scalar_kind();
std::size_t dimensions = index.dimensions();
metric_t metric = //
metric_uintptr //
? udf(metric_kind, metric_signature, metric_uintptr, scalar_kind, dimensions)
: metric_t(dimensions, metric_kind, scalar_kind);
index.change_metric(std::move(metric));
},
py::arg("metric_kind") = metric_kind_t::cos_k, //
py::arg("metric_signature") = metric_signature_t::array_array_k, //
py::arg("metric_pointer") = 0 //
);

i.def_property_readonly("keys", &get_all_keys<dense_index_py_t>);
i.def("get_keys", &get_keys<dense_index_py_t>, py::arg("offset") = 0,
py::arg("limit") = std::numeric_limits<std::size_t>::max());
Expand Down
47 changes: 40 additions & 7 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ class Index:

def __init__(
self,
*,
ndim: int = 0,
metric: Union[str, MetricKind, CompiledMetric] = MetricKind.Cos,
dtype: Optional[Union[str, ScalarKind]] = None,
Expand Down Expand Up @@ -441,12 +442,12 @@ def __init__(
self._compiled = _CompiledIndex(
ndim=ndim,
dtype=dtype,
metric_kind=self._metric_kind,
metric_pointer=self._metric_pointer,
metric_signature=self._metric_signature,
connectivity=connectivity,
expansion_add=expansion_add,
expansion_search=expansion_search,
metric_kind=self._metric_kind,
metric_pointer=self._metric_pointer,
metric_signature=self._metric_signature,
)

self.path = path
Expand Down Expand Up @@ -643,6 +644,28 @@ def ndim(self) -> int:
def metric(self) -> Union[MetricKind, CompiledMetric]:
return self._metric_jit if self._metric_jit else self._metric_kind

@metric.setter
def metric(self, metric: Union[str, MetricKind, CompiledMetric]):
metric = _normalize_metric(metric)
if isinstance(metric, MetricKind):
metric_kind = metric
metric_pointer = 0
metric_signature = MetricSignature.ArrayArraySize
elif isinstance(metric, CompiledMetric):
metric_kind = metric.kind
metric_pointer = metric.pointer
metric_signature = metric.signature
else:
raise ValueError(
"The `metric` must be a `CompiledMetric` or a `MetricKind`"
)

return self._compiled.change_metric(
metric_kind=metric_kind,
metric_pointer=metric_pointer,
metric_signature=metric_signature,
)

@property
def dtype(self) -> ScalarKind:
return self._compiled.dtype
Expand Down Expand Up @@ -921,6 +944,13 @@ def search(
metric=metric,
dtype=dataset.dtype,
)
index.add(
None,
dataset,
threads=threads,
log=log,
batch_size=batch_size,
)
return index.search(
query,
k,
Expand All @@ -947,10 +977,13 @@ def __init__(self) -> None:

def search(self, query, k, **kwargs):
kwargs.pop("exact")
kwargs["metric_kind"] = metric_kind
kwargs["metric_pointer"] = metric_pointer
kwargs["metric_signature"] = metric_signature

kwargs.update(
dict(
metric_kind=metric_kind,
metric_pointer=metric_pointer,
metric_signature=metric_signature,
)
)
assert dataset.shape[1] == query.shape[1], "Number of dimensions differs"
if dataset.dtype != query.dtype:
query = query.astype(dataset.dtype)
Expand Down

0 comments on commit d7bfac7

Please sign in to comment.