Skip to content

Commit

Permalink
Add: Vector alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
mgevor committed Jul 24, 2023
1 parent 2c1f805 commit ea230e0
Showing 1 changed file with 58 additions and 32 deletions.
90 changes: 58 additions & 32 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2335,7 +2335,9 @@ class index_gt {
return total;
}

std::size_t memory_usage_per_node(dim_t dim, level_t level) const noexcept { return node_bytes_(dim, level); }
std::size_t memory_usage_per_node(dim_t dim, level_t level) const noexcept {
return node_capacity_bytes_(dim, level);
}

void change_metric(metric_t const& m) noexcept {
metric_ = m;
Expand Down Expand Up @@ -2424,13 +2426,11 @@ class index_gt {
// Serialize nodes one by one
for (std::size_t i = 0; i != state.size; ++i) {
node_t node = node_with_id_(i);
std::size_t node_bytes = node_bytes_(node);
std::size_t node_vector_bytes = node_vector_bytes_(node);
// Dump neighbors and vectors, as vectors may be in a disjoint location
write_chunk(node.tape(), node_bytes - node_vector_bytes);
write_chunk(node.tape(), node_head_bytes_() + node_neighbors_bytes_(node));
if (result.error)
return result;
write_chunk(node.vector(), node_vector_bytes);
write_chunk(node.vector(), node_vector_bytes_(node));
if (result.error)
return result;
progress(i, state.size);
Expand Down Expand Up @@ -2513,12 +2513,14 @@ class index_gt {
if (result.error)
return result;

std::size_t node_bytes = node_bytes_(dim, level);
node_t node = node_malloc_(dim, level);
node.label(label);
node.dim(dim);
node.level(level);
read_chunk(node.tape() + node_head_bytes_(), node_bytes - node_head_bytes_());
read_chunk(node.tape() + node_head_bytes_(), node_neighbors_bytes_(level));
if (result.error)
return result;
read_chunk(node.vector(), node_vector_bytes_(dim));
if (result.error)
return result;
nodes_[i] = node;
Expand Down Expand Up @@ -2630,10 +2632,10 @@ class index_gt {
dim_t dim = misaligned_load<dim_t>(tape + sizeof(label_t));
level_t level = misaligned_load<level_t>(tape + sizeof(label_t) + sizeof(dim_t));

std::size_t node_bytes = node_bytes_(dim, level);
std::size_t node_vector_bytes = dim * sizeof(scalar_t);
nodes_[i] = node_t{tape, (scalar_t*)(tape + node_bytes - node_vector_bytes)};
progress_bytes += node_bytes;
std::size_t node_neighbors_bytes = node_neighbors_bytes_(level);
std::size_t node_vector_bytes = node_vector_bytes_(dim);
nodes_[i] = node_t{tape, (scalar_t*)(tape + node_head_bytes_() + node_neighbors_bytes)};
progress_bytes += node_head_bytes_() + node_neighbors_bytes + node_vector_bytes;
progress(i, size);
}

Expand Down Expand Up @@ -2927,53 +2929,73 @@ class index_gt {
return pre;
}

inline std::size_t node_bytes_(node_t node) const noexcept { return node_bytes_(node.dim(), node.level()); }
inline std::size_t node_bytes_(dim_t dim, level_t level) const noexcept {
return node_head_bytes_() + pre_.neighbors_base_bytes + pre_.neighbors_bytes * level + sizeof(scalar_t) * dim;
inline std::size_t node_bytes_(node_t node) const noexcept {
return node_head_bytes_() + node_neighbors_bytes_(node.level()) + node_vector_bytes_(node.dim());
}
inline std::size_t node_capacity_bytes_(dim_t dim, level_t level) const noexcept {
std::size_t vector_space_bytes = node_vector_bytes_(dim);
vector_space_bytes += bool(vector_space_bytes) * config_.vector_alignment; // Extra space for alignment
return node_head_bytes_() + node_neighbors_bytes_(level) + vector_space_bytes;
}

using span_bytes_t = span_gt<byte_t>;
struct node_bytes_split_t {
span_bytes_t tape{};
span_bytes_t vector{};
std::size_t continuous_buffer_bytes; // Indicates the buffer size where `tape` and `vector` are stored.
// If zero, it means both are stored in different buffers

node_bytes_split_t() {}
node_bytes_split_t(span_bytes_t tape, span_bytes_t vector) noexcept : tape(tape), vector(vector) {}
node_bytes_split_t(span_bytes_t tape, span_bytes_t vector, std::size_t continuous_buffer_bytes = 0) noexcept
: tape(tape), vector(vector), continuous_buffer_bytes(continuous_buffer_bytes) {}

std::size_t memory_usage() const noexcept { return tape.size() + vector.size(); }
bool colocated() const noexcept { return tape.end() == vector.begin(); }
std::size_t memory_usage() const noexcept {
return continuous_buffer_bytes ? continuous_buffer_bytes : tape.size() + vector.size();
}
bool colocated() const noexcept { return continuous_buffer_bytes != 0; }
operator node_t() const noexcept { return node_t{tape.begin(), reinterpret_cast<scalar_t*>(vector.begin())}; }
};

inline node_bytes_split_t node_bytes_split_(node_t node) const noexcept {
std::size_t levels_bytes = pre_.neighbors_base_bytes + pre_.neighbors_bytes * node.level();
std::size_t bytes_in_tape = node_head_bytes_() + levels_bytes;
return {{node.tape(), bytes_in_tape}, {(byte_t*)node.vector(), node_vector_bytes_(node)}};
std::size_t bytes_in_tape = node_head_bytes_() + node_neighbors_bytes_(node.level());
std::size_t vector_bytes = node_vector_bytes_(node);
bool colocated =
std::size_t((byte_t*)node.vector() - (node.tape() + bytes_in_tape)) <= config_.vector_alignment;
std::size_t continuous_buffer_bytes = colocated ? bytes_in_tape + vector_bytes + config_.vector_alignment : 0;
return {{node.tape(), bytes_in_tape}, {(byte_t*)node.vector(), vector_bytes}, continuous_buffer_bytes};
}

inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); }
inline std::size_t node_neighbors_bytes_(level_t level) const noexcept {
return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level;
}

inline std::size_t node_vector_bytes_(dim_t dim) const noexcept { return dim * sizeof(scalar_t); }
inline std::size_t node_vector_bytes_(node_t node) const noexcept { return node_vector_bytes_(node.dim()); }
inline std::size_t node_vector_bytes_(dim_t dim) const noexcept { return dim * sizeof(scalar_t); }

node_bytes_split_t node_malloc_(dim_t dims_to_store, level_t level) noexcept {

std::size_t node_bytes = node_capacity_bytes_(dims_to_store, level);
std::size_t vector_bytes = node_vector_bytes_(dims_to_store);
std::size_t node_bytes = node_bytes_(dims_to_store, level);
std::size_t non_vector_bytes = node_bytes - vector_bytes;
std::size_t tape_bytes = node_bytes - vector_bytes - bool(vector_bytes) * config_.vector_alignment;

byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes);
if (!data)
return node_bytes_split_t{};
return {{data, non_vector_bytes}, {data + non_vector_bytes, vector_bytes}};

// Place vector on the memory regarding to alignment
byte_t* vector = data + tape_bytes;
vector += bool(vector_bytes) * (config_.vector_alignment - ((uintptr_t)vector % config_.vector_alignment));

return {{data, tape_bytes}, {vector, vector_bytes}, node_bytes};
}

node_t node_make_(label_t label, vector_view_t vector, level_t level, bool store_vector) noexcept {
node_bytes_split_t node_bytes = node_malloc_(vector.size() * store_vector, level);
if (store_vector) {
std::memset(node_bytes.tape.data(), 0, node_bytes.tape.size());
std::memset(node_bytes.tape.data(), 0, node_bytes.tape.size());
if (store_vector)
std::memcpy(node_bytes.vector.data(), vector.data(), node_bytes.vector.size());
} else {
std::memset(node_bytes.tape.data(), 0, node_bytes.memory_usage());
}

node_t node = node_bytes;
node.label(label);
node.dim(static_cast<dim_t>(vector.size()));
Expand All @@ -2984,8 +3006,9 @@ class index_gt {
node_t node_make_copy_(node_bytes_split_t old_bytes) noexcept {
if (old_bytes.colocated()) {
byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.memory_usage());
byte_t* vector = data + std::size_t(old_bytes.vector.begin() - old_bytes.tape.begin());
std::memcpy(data, old_bytes.tape.data(), old_bytes.memory_usage());
return node_t{data, reinterpret_cast<scalar_t*>(data + old_bytes.tape.size())};
return node_t{data, reinterpret_cast<scalar_t*>(vector)};
} else {
node_t old_node = old_bytes;
node_bytes_split_t node_bytes = node_malloc_(old_node.vector_view().size(), old_node.level());
Expand All @@ -3001,8 +3024,11 @@ class index_gt {
return;

node_t& node = nodes_[id];
std::size_t node_bytes = node_bytes_(node) - node_vector_bytes_(node) * !node_bytes_split_(node).colocated();
tape_allocator_.deallocate(node.tape(), node_bytes);
node_bytes_split_t node_bytes = node_bytes_split_(node);
if (node_bytes.colocated())
tape_allocator_.deallocate(node.tape(), node_bytes.memory_usage());
else
tape_allocator_.deallocate(node.tape(), node_bytes.tape.size());
node = node_t{};
}

Expand Down

0 comments on commit ea230e0

Please sign in to comment.