Skip to content

Commit

Permalink
Fix: Report error if reserve hasn't been called
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Aug 1, 2023
1 parent 2779ffc commit f94f358
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
21 changes: 16 additions & 5 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,8 +1977,9 @@ class index_gt {
key_t key, value_at&& value, metric_at&& metric, //
index_add_config_t config = {}, callback_at&& callback = callback_at{}) usearch_noexcept_m {

usearch_assert_m(!is_immutable(), "Can't add to an immutable index");
add_result_t result;
if (is_immutable())
return result.failed("Can't add to an immutable index");

// Make sure we have enough local memory to perform this request
context_t& context = contexts_[config.thread];
Expand All @@ -2001,14 +2002,24 @@ class index_gt {
level_t max_level_copy = max_level_; // Copy under lock
std::size_t entry_idx_copy = entry_slot_; // Copy under lock
level_t target_level = choose_random_level_(context.level_generator);
if (target_level <= max_level_copy)
new_level_lock.unlock();

// Make sure we are not overflowing
std::size_t capacity = nodes_capacity_.load();
std::size_t new_slot = nodes_count_.fetch_add(1);
if (new_slot >= capacity) {
nodes_count_.fetch_sub(1);
return result.failed("Reserve capacity ahead of insertions!");
}

// Allocate the neighbors
node_t node = node_make_(key, target_level);
if (!node)
if (!node) {
nodes_count_.fetch_sub(1);
return result.failed("Out of memory!");
std::size_t new_slot = nodes_count_.fetch_add(1);
}
if (target_level <= max_level_copy)
new_level_lock.unlock();

nodes_[new_slot] = node;
result.new_size = new_slot + 1;
result.slot = new_slot;
Expand Down
3 changes: 2 additions & 1 deletion python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ static void add_typed_to_index( //
config.thread = thread_idx;
key_t key = *reinterpret_cast<key_t const*>(keys_data + task_idx * keys_info.strides[0]);
scalar_at const* vector = reinterpret_cast<scalar_at const*>(vectors_data + task_idx * vectors_info.strides[0]);
index.add(key, vector, config).error.raise();
dense_add_result_t result = index.add(key, vector, config);
result.error.raise();
if (PyErr_CheckSignals() != 0)
throw py::error_already_set();
});
Expand Down

0 comments on commit f94f358

Please sign in to comment.