diff --git a/.clang-format b/.clang-format index 90a245f4c..21f3a956f 100644 --- a/.clang-format +++ b/.clang-format @@ -34,7 +34,7 @@ PenaltyBreakFirstLessLess: 20 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Right +PointerAlignment: Left SpaceAfterControlStatementKeyword: true SpaceBeforeAssignmentOperators: true SpaceInEmptyParentheses: false diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 000000000..834809eb1 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,17 @@ +HeaderFilterRegex: '.*include/pisa.*\.hpp' +Checks: | + *, + -clang-diagnostic-c++17-extensions, + -llvm-header-guard, + -cppcoreguidelines-pro-type-reinterpret-cast, + -google-runtime-references, + -fuchsia-*, + -google-readability-namespace-comments, + -llvm-namespace-comment, + -clang-diagnostic-error, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-avoid-magic-numbers, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -modernize-use-trailing-return-type, + -misc-non-private-member-variables-in-classes, + -readability-magic-numbers diff --git a/.gitignore b/.gitignore index 077c235de..1f0bdf6f1 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ docs/_build/ node_modules .clangd/ +compile_commands.json diff --git a/.gitmodules b/.gitmodules index c9f3cb426..51a1937f0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,6 +67,21 @@ [submodule "external/wapopp"] path = external/wapopp url = https://github.com/pisa-engine/wapopp.git +[submodule "external/optional"] + path = external/optional + url = https://github.com/TartanLlama/optional.git +[submodule "external/expected"] + path = external/expected + url = https://github.com/TartanLlama/expected.git +[submodule "external/yaml-cpp"] + path = external/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git [submodule "external/rapidcheck"] path = external/rapidcheck url = https://github.com/emil-e/rapidcheck.git +[submodule "external/json"] + path = external/json + url = https://github.com/nlohmann/json.git +[submodule "external/cereal"] + path = external/cereal + url = https://github.com/USCiLab/cereal.git diff --git a/CMakeLists.txt b/CMakeLists.txt index ca1b47d1d..d1418f7ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,9 @@ set(CMAKE_CXX_EXTENSIONS OFF) option(PISA_BUILD_TOOLS "Build command line tools." ON) option(PISA_ENABLE_TESTING "Enable testing of the library." ON) option(PISA_ENABLE_BENCHMARKING "Enable benchmarking of the library." ON) +option(PISA_COMPILE_TOOLS "Compile CLI tools." ON) +option(FORCE_COLORED_OUTPUT "Always produce ANSI-colored output (GNU/Clang only)." ON) +option(PISA_LIBCXX "Use libc++ standard library." OFF) configure_file( ${PISA_SOURCE_DIR}/include/pisa/pisa_config.hpp.in @@ -30,7 +33,7 @@ ExternalProject_Add(gumbo-external SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/gumbo-parser BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/gumbo-parser CONFIGURE_COMMAND ./autogen.sh && ./configure --prefix=${CMAKE_BINARY_DIR}/gumbo-parser - BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/gumbo-parser/lib/libgumbo.a + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/gumbo-parser/lib/libgumbo.a BUILD_COMMAND ${MAKE}) add_library(gumbo::gumbo STATIC IMPORTED) set_property(TARGET gumbo::gumbo APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES @@ -47,20 +50,38 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/external/CMake-codecov/cmake" find_package(codecov) list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") - +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-strict-aliasing") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DGSL_UNENFORCED_ON_CONTRACT_VIOLATION -flto") if (UNIX) # For hardware popcount and other special instructions - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wno-odr") # Extensive warnings - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-missing-braces") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") if (USE_SANITIZERS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") endif () - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb") # Add debug info anyway + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb -gdwarf") # Add debug info anyway + + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") + + if (${FORCE_COLORED_OUTPUT}) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always") + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics") + endif () + endif () + +endif() +if (PISA_LIBCXX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi") endif() find_package(OpenMP) @@ -100,16 +121,21 @@ target_link_libraries(pisa PUBLIC # TODO(michal): are there any of these we can spdlog fmt::fmt range-v3 + optional + yaml-cpp + nlohmann_json ) target_include_directories(pisa PUBLIC external) -if (PISA_BUILD_TOOLS) +if (PISA_COMPILE_TOOLS) + add_subdirectory(v1) add_subdirectory(tools) endif() if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() - add_subdirectory(test) + #add_subdirectory(test) + add_subdirectory(test/v1) endif() if (PISA_ENABLE_BENCHMARKING) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 580bd9af4..464a348ad 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -106,8 +106,9 @@ set(TRECPP_BUILD_TOOL OFF CACHE BOOL "skip trecpp testing") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/trecpp) # Add trecpp +set(JSON_MultipleHeaders ON CACHE BOOL "") set(WAPOPP_ENABLE_TESTING OFF CACHE BOOL "skip wapopp testing") -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/wapopp) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/wapopp EXCLUDE_FROM_ALL) # Add fmt add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/fmt) @@ -118,6 +119,22 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/spdlog) # Add range-v3 add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/range-v3) +# Add tl::optional +set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip tl::optional testing") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional EXCLUDE_FROM_ALL) + +# Add yaml-cpp +set(YAML_CPP_BUILD_TOOLS OFF CACHE BOOL "skip building YAML tools") +set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "skip building YAML tests") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/yaml-cpp EXCLUDE_FROM_ALL) + # Add RapidCheck add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/rapidcheck) target_compile_options(rapidcheck PRIVATE -Wno-error=all) + +# Add json +# TODO(michal): I had to comment this out because `wapocpp` already adds this target. +# How should we deal with it? +#set(JSON_MultipleHeaders ON CACHE BOOL "") +#set(JSON_BuildTests OFF CACHE BOOL "skip building JSON tests") +#add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/json) diff --git a/external/cereal b/external/cereal new file mode 160000 index 000000000..a5a309531 --- /dev/null +++ b/external/cereal @@ -0,0 +1 @@ +Subproject commit a5a30953125e70b115a28dd76b64adf3c97cc883 diff --git a/external/expected b/external/expected new file mode 160000 index 000000000..3d741708b --- /dev/null +++ b/external/expected @@ -0,0 +1 @@ +Subproject commit 3d741708b967b83ca1e2888239196c4a67f9f9b0 diff --git a/external/json b/external/json new file mode 160000 index 000000000..e7b3b40b5 --- /dev/null +++ b/external/json @@ -0,0 +1 @@ +Subproject commit e7b3b40b5a95bc74b9a7f662830a27c49ffc01b4 diff --git a/external/optional b/external/optional new file mode 160000 index 000000000..5c4876059 --- /dev/null +++ b/external/optional @@ -0,0 +1 @@ +Subproject commit 5c4876059c1168d5fa3c945bd8dd05ebbe61b6fe diff --git a/external/yaml-cpp b/external/yaml-cpp new file mode 160000 index 000000000..72f699f5c --- /dev/null +++ b/external/yaml-cpp @@ -0,0 +1 @@ +Subproject commit 72f699f5ce2d22a8f70cd59d73ce2fbb42dd960e diff --git a/include/pisa/codec/integer_codes.hpp b/include/pisa/codec/integer_codes.hpp index 425857e96..7a6ee3a29 100644 --- a/include/pisa/codec/integer_codes.hpp +++ b/include/pisa/codec/integer_codes.hpp @@ -5,7 +5,7 @@ namespace pisa { // note: n can be 0 -inline void write_gamma(bit_vector_builder &bvb, uint64_t n) +inline void write_gamma(bit_vector_builder& bvb, uint64_t n) { uint64_t nn = n + 1; uint64_t l = broadword::msb(nn); @@ -14,22 +14,27 @@ inline void write_gamma(bit_vector_builder &bvb, uint64_t n) bvb.append_bits(nn ^ hb, l); } -inline void write_gamma_nonzero(bit_vector_builder &bvb, uint64_t n) +inline void write_gamma_nonzero(bit_vector_builder& bvb, uint64_t n) { assert(n > 0); write_gamma(bvb, n - 1); } -inline uint64_t read_gamma(bit_vector::enumerator &it) +template +inline uint64_t read_gamma(BitVectorEnumerator& it) { uint64_t l = it.skip_zeros(); assert(l < 64); return (it.take(l) | (uint64_t(1) << l)) - 1; } -inline uint64_t read_gamma_nonzero(bit_vector::enumerator &it) { return read_gamma(it) + 1; } +template +inline uint64_t read_gamma_nonzero(BitVectorEnumerator& it) +{ + return read_gamma(it) + 1; +} -inline void write_delta(bit_vector_builder &bvb, uint64_t n) +inline void write_delta(bit_vector_builder& bvb, uint64_t n) { uint64_t nn = n + 1; uint64_t l = broadword::msb(nn); @@ -38,9 +43,11 @@ inline void write_delta(bit_vector_builder &bvb, uint64_t n) bvb.append_bits(nn ^ hb, l); } -inline uint64_t read_delta(bit_vector::enumerator &it) +template +inline uint64_t read_delta(BitVectorEnumerator& it) { uint64_t l = read_gamma(it); return (it.take(l) | (uint64_t(1) << l)) - 1; } + } // namespace pisa diff --git a/include/pisa/codec/simdbp.hpp b/include/pisa/codec/simdbp.hpp index 23a694657..fa446eebe 100644 --- a/include/pisa/codec/simdbp.hpp +++ b/include/pisa/codec/simdbp.hpp @@ -1,8 +1,8 @@ #pragma once -#include -#include "util/util.hpp" #include "codec/block_codecs.hpp" +#include "util/util.hpp" +#include extern "C" { #include "simdcomp/include/simdbitpacking.h" @@ -14,7 +14,8 @@ struct simdbp_block { static void encode(uint32_t const *in, uint32_t sum_of_values, size_t n, - std::vector &out) { + std::vector &out) + { assert(n <= block_size); uint32_t *src = const_cast(in); @@ -23,23 +24,22 @@ struct simdbp_block { return; } uint32_t b = maxbits(in); - thread_local std::vector buf(8*n); - uint8_t * buf_ptr = buf.data(); + thread_local std::vector buf(8 * n); + uint8_t *buf_ptr = buf.data(); *buf_ptr++ = b; simdpackwithoutmask(src, (__m128i *)buf_ptr, b); out.insert(out.end(), buf.data(), buf.data() + b * sizeof(__m128i) + 1); } - static uint8_t const *decode(uint8_t const *in, - uint32_t *out, - uint32_t sum_of_values, - size_t n) { + static uint8_t const *decode(uint8_t const *in, uint32_t *out, uint32_t sum_of_values, size_t n) + { assert(n <= block_size); if (PISA_UNLIKELY(n < block_size)) { return interpolative_block::decode(in, out, sum_of_values, n); } uint32_t b = *in++; simdunpack((const __m128i *)in, out, b); - return in + b * sizeof(__m128i); + return in + b * sizeof(__m128i); } }; -} // namespace pisa \ No newline at end of file + +} // namespace pisa diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index 640e53cee..d1d023eb2 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -1,9 +1,11 @@ #pragma once +#include + +#include "query/queries.hpp" +#include "scorer/bm25.hpp" #include "scorer/index_scorer.hpp" #include "wand_data.hpp" -#include "query/queries.hpp" -#include namespace pisa { @@ -20,9 +22,9 @@ struct block_max_scored_cursor { }; template -[[nodiscard]] auto make_block_max_scored_cursors(Index const &index, - WandType const &wdata, - Scorer const &scorer, +[[nodiscard]] auto make_block_max_scored_cursors(Index const& index, + WandType const& wdata, + Scorer const& scorer, Query query) { auto terms = query.terms; @@ -34,7 +36,7 @@ template query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), - [&](auto &&term) { + [&](auto&& term) { auto list = index[term.first]; auto w_enum = wdata.getenum(term.first); float q_weight = term.second; @@ -45,4 +47,4 @@ template return cursors; } -} // namespace pisa \ No newline at end of file +} // namespace pisa diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index b4c81a5cc..09cd42ff3 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -1,9 +1,11 @@ #pragma once +#include + +#include "query/queries.hpp" +#include "scorer/bm25.hpp" #include "scorer/index_scorer.hpp" #include "wand_data.hpp" -#include "query/queries.hpp" -#include namespace pisa { @@ -17,9 +19,9 @@ struct max_scored_cursor { }; template -[[nodiscard]] auto make_max_scored_cursors(Index const &index, - WandType const &wdata, - Scorer const &scorer, +[[nodiscard]] auto make_max_scored_cursors(Index const& index, + WandType const& wdata, + Scorer const& scorer, Query query) { auto terms = query.terms; @@ -30,7 +32,7 @@ template std::transform(query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), - [&](auto &&term) { + [&](auto&& term) { auto list = index[term.first]; float q_weight = term.second; auto max_weight = q_weight * wdata.max_term_weight(term.first); diff --git a/include/pisa/dec_time_prediction.hpp b/include/pisa/dec_time_prediction.hpp index 73868ef00..54e5edae4 100644 --- a/include/pisa/dec_time_prediction.hpp +++ b/include/pisa/dec_time_prediction.hpp @@ -20,7 +20,7 @@ namespace time_prediction { enum class feature_type { BOOST_PP_SEQ_ENUM(PISA_FEATURE_TYPES), end }; - inline feature_type parse_feature_type(std::string const &name) + inline feature_type parse_feature_type(std::string const& name) { if (false) { #define LOOP_BODY(R, DATA, T) \ @@ -55,10 +55,10 @@ namespace time_prediction { public: feature_vector() { std::fill(m_features.begin(), m_features.end(), 0); } - float &operator[](feature_type f) { return m_features[(size_t)f]; } - float const &operator[](feature_type f) const { return m_features[(size_t)f]; } + float& operator[](feature_type f) { return m_features[(size_t)f]; } + float const& operator[](feature_type f) const { return m_features[(size_t)f]; } - stats_line &dump(stats_line &sl) const + stats_line& dump(stats_line& sl) const { for (size_t i = 0; i < num_features; ++i) { feature_type ft = (feature_type)i; @@ -75,9 +75,9 @@ namespace time_prediction { public: predictor() : m_bias(0) {} - predictor(std::vector> const &values) + predictor(std::vector> const& values) { - for (auto const &kv : values) { + for (auto const& kv : values) { if (kv.first == "bias") { bias() = kv.second; } else { @@ -86,10 +86,10 @@ namespace time_prediction { } } - float &bias() { return m_bias; } - float const &bias() const { return m_bias; } + float& bias() { return m_bias; } + float const& bias() const { return m_bias; } - float operator()(feature_vector const &f) const + float operator()(feature_vector const& f) const { float result = bias(); for (size_t i = 0; i < num_features; ++i) { @@ -103,7 +103,7 @@ namespace time_prediction { float m_bias; }; - inline void values_statistics(std::vector values, feature_vector &f) + inline void values_statistics(std::vector values, feature_vector& f) { std::sort(values.begin(), values.end()); f[feature_type::n] = values.size(); @@ -141,9 +141,9 @@ namespace time_prediction { f[feature_type::max_b] = max_b; } - inline bool read_block_stats(std::istream &is, - uint32_t &list_id, - std::vector &block_counts) + inline bool read_block_stats(std::istream& is, + uint32_t& list_id, + std::vector& block_counts) { thread_local std::string line; uint32_t count; diff --git a/include/pisa/intersection.hpp b/include/pisa/intersection.hpp index 187f40ede..e149b4022 100644 --- a/include/pisa/intersection.hpp +++ b/include/pisa/intersection.hpp @@ -65,7 +65,7 @@ inline auto Intersection::compute(Index const &index, { auto filtered_query = term_mask ? intersection::filter(query, *term_mask) : query; scored_and_query retrieve{}; - auto scorer = scorer::from_name("bm25", wand); + auto scorer = ::pisa::scorer::from_name("bm25", wand); auto results = retrieve(make_scored_cursors(index, *scorer, filtered_query), index.num_docs()); auto max_element = [&](auto const &vec) -> float { auto order = [](auto const &lhs, auto const &rhs) { return lhs.second < rhs.second; }; diff --git a/include/pisa/io.hpp b/include/pisa/io.hpp index e7c8bdc45..29446785d 100644 --- a/include/pisa/io.hpp +++ b/include/pisa/io.hpp @@ -61,4 +61,18 @@ void for_each_line(std::istream &is, Function fn) return data; } +[[nodiscard]] inline auto load_bytes(std::string const &data_file) +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + data.resize(size); + if (not in.read(data.data(), size)) { + throw std::runtime_error("Failed reading " + data_file); + } + return data; +} + } // namespace pisa::io diff --git a/include/pisa/mixed_block.hpp b/include/pisa/mixed_block.hpp index 9c4470ae2..707da2249 100644 --- a/include/pisa/mixed_block.hpp +++ b/include/pisa/mixed_block.hpp @@ -25,17 +25,17 @@ struct mixed_block { static const size_t block_types = 3; static const uint64_t block_size = 128; - static void encode(uint32_t const *, uint32_t, size_t, std::vector &) + static void encode(uint32_t const*, uint32_t, size_t, std::vector&) { throw std::runtime_error("Mixed block indexes can only be created by transformation"); } static void encode_type(block_type type, compr_param_type param, - uint32_t const *in, + uint32_t const* in, uint32_t sum_of_values, size_t n, - std::vector &out) + std::vector& out) { assert(n <= block_size); if (n < block_size) { @@ -65,11 +65,11 @@ struct mixed_block { static bool compression_stats(block_type type, compr_param_type param, - uint32_t const *in, + uint32_t const* in, uint32_t sum_of_values, size_t n, - std::vector &buf, - time_prediction::feature_vector &fv) + std::vector& buf, + time_prediction::feature_vector& fv) { assert(buf.empty()); using namespace time_prediction; @@ -83,7 +83,7 @@ struct mixed_block { // codec-specific stats if (type == block_type::pfor) { - auto const &possLogs = optpfor_block::codec_type::possLogs; + auto const& possLogs = optpfor_block::codec_type::possLogs; uint32_t b = possLogs[param]; uint32_t max_b = (uint32_t)fv[feature_type::max_b]; // float is exact up to 2^24 if (b > max_b && possLogs[param - 1] >= max_b) @@ -112,16 +112,16 @@ struct mixed_block { block_type type; compr_param_type param; - bool operator<(space_time_point const &other) const + bool operator<(space_time_point const& other) const { return std::make_pair(space, time) < std::make_pair(other.space, other.time); } }; static std::vector compute_space_time( - std::vector const &values, + std::vector const& values, uint32_t sum_of_values, - std::vector const &predictors, + std::vector const& predictors, uint32_t access_count) { using namespace time_prediction; @@ -175,14 +175,14 @@ struct mixed_block { uint32_t size; uint32_t doc_gaps_universe; - void append_docs_block(std::vector &out) const + void append_docs_block(std::vector& out) const { thread_local std::vector buf; m_input_block.decode_doc_gaps(buf); encode_type(m_docs_type, m_docs_param, buf.data(), doc_gaps_universe, size, out); } - void append_freqs_block(std::vector &out) const + void append_freqs_block(std::vector& out) const { thread_local std::vector buf; m_input_block.decode_freqs(buf); @@ -195,7 +195,7 @@ struct mixed_block { compr_param_type m_docs_param, m_freqs_param; }; - static uint8_t const *decode(uint8_t const *in, uint32_t *out, uint32_t sum_of_values, size_t n) + static uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) { block_type type = block_type::interpolative; if (PISA_LIKELY(n == block_size)) { @@ -218,7 +218,7 @@ struct mixed_block { using predictors_vec_type = std::vector; -inline predictors_vec_type load_predictors(const char *predictors_filename) +inline predictors_vec_type load_predictors(const char* predictors_filename) { std::vector predictors(mixed_block::block_types); diff --git a/include/pisa/payload_vector.hpp b/include/pisa/payload_vector.hpp index 2dc1418fa..c18342072 100644 --- a/include/pisa/payload_vector.hpp +++ b/include/pisa/payload_vector.hpp @@ -26,7 +26,7 @@ namespace detail { typename gsl::span::iterator offset_iter; typename gsl::span::iterator payload_iter; - constexpr auto operator++() -> Payload_Vector_Iterator & + constexpr auto operator++() -> Payload_Vector_Iterator& { ++offset_iter; std::advance(payload_iter, *offset_iter - *std::prev(offset_iter)); @@ -40,7 +40,7 @@ namespace detail { return next_iter; } - constexpr auto operator--() -> Payload_Vector_Iterator & { return *this -= 1; } + constexpr auto operator--() -> Payload_Vector_Iterator& { return *this -= 1; } [[nodiscard]] constexpr auto operator--(int) -> Payload_Vector_Iterator { @@ -55,7 +55,7 @@ namespace detail { std::next(payload_iter, *std::next(offset_iter, n) - *offset_iter)}; } - [[nodiscard]] constexpr auto operator+=(size_type n) -> Payload_Vector_Iterator & + [[nodiscard]] constexpr auto operator+=(size_type n) -> Payload_Vector_Iterator& { std::advance(payload_iter, *std::next(offset_iter, n) - *offset_iter); std::advance(offset_iter, n); @@ -68,12 +68,12 @@ namespace detail { std::prev(payload_iter, *offset_iter - *std::prev(offset_iter, n))}; } - [[nodiscard]] constexpr auto operator-=(size_type n) -> Payload_Vector_Iterator & + [[nodiscard]] constexpr auto operator-=(size_type n) -> Payload_Vector_Iterator& { return *this = *this - n; } - [[nodiscard]] constexpr auto operator-(Payload_Vector_Iterator const &other) + [[nodiscard]] constexpr auto operator-(Payload_Vector_Iterator const& other) -> difference_type { return offset_iter - other.offset_iter; @@ -81,16 +81,22 @@ namespace detail { [[nodiscard]] constexpr auto operator*() -> value_type { - return value_type(reinterpret_cast(&*payload_iter), - *std::next(offset_iter) - *offset_iter); + if constexpr (std::is_same_v) { + return value_type(reinterpret_cast(&*payload_iter), + *std::next(offset_iter) - *offset_iter); + } else { + value_type value; + std::memcpy(&value, reinterpret_cast(&*payload_iter), sizeof(value)); + return value; + } } - [[nodiscard]] constexpr auto operator==(Payload_Vector_Iterator const &other) const -> bool + [[nodiscard]] constexpr auto operator==(Payload_Vector_Iterator const& other) const -> bool { return offset_iter == other.offset_iter; } - [[nodiscard]] constexpr auto operator!=(Payload_Vector_Iterator const &other) const -> bool + [[nodiscard]] constexpr auto operator!=(Payload_Vector_Iterator const& other) const -> bool { return offset_iter != other.offset_iter; } @@ -129,12 +135,12 @@ namespace detail { }; template - [[nodiscard]] static constexpr auto unpack(std::byte const *ptr) -> std::tuple + [[nodiscard]] static constexpr auto unpack(std::byte const* ptr) -> std::tuple { - if constexpr (sizeof...(Ts) == 0u) { - return std::tuple(*reinterpret_cast(ptr)); + if constexpr (sizeof...(Ts) == 0U) { + return std::tuple(*reinterpret_cast(ptr)); } else { - return std::tuple_cat(std::tuple(*reinterpret_cast(ptr)), + return std::tuple_cat(std::tuple(*reinterpret_cast(ptr)), unpack(ptr + sizeof(T))); } } @@ -147,39 +153,38 @@ struct Payload_Vector_Buffer { std::vector const offsets; std::vector const payloads; - [[nodiscard]] static auto from_file(std::string const &filename) -> Payload_Vector_Buffer - { - boost::system::error_code ec; - auto file_size = boost::filesystem::file_size(boost::filesystem::path(filename)); - std::ifstream is(filename); + //[[nodiscard]] static auto from_file(std::string const& filename) -> Payload_Vector_Buffer + //{ + // auto file_size = boost::filesystem::file_size(boost::filesystem::path(filename)); + // std::ifstream is(filename); - size_type len; - is.read(reinterpret_cast(&len), sizeof(size_type)); + // size_type len; + // is.read(reinterpret_cast(&len), sizeof(size_type)); - auto offsets_bytes = (len + 1) * sizeof(size_type); - std::vector offsets(len + 1); - is.read(reinterpret_cast(offsets.data()), offsets_bytes); + // auto offsets_bytes = (len + 1) * sizeof(size_type); + // std::vector offsets(len + 1); + // is.read(reinterpret_cast(offsets.data()), offsets_bytes); - auto payloads_bytes = file_size - offsets_bytes - sizeof(size_type); - std::vector payloads(payloads_bytes); - is.read(reinterpret_cast(payloads.data()), payloads_bytes); + // auto payloads_bytes = file_size - offsets_bytes - sizeof(size_type); + // std::vector payloads(payloads_bytes); + // is.read(reinterpret_cast(payloads.data()), payloads_bytes); - return Payload_Vector_Buffer{std::move(offsets), std::move(payloads)}; - } + // return Payload_Vector_Buffer{std::move(offsets), std::move(payloads)}; + //} - void to_file(std::string const &filename) const + void to_file(std::string const& filename) const { std::ofstream is(filename); to_stream(is); } - void to_stream(std::ostream &is) const + void to_stream(std::ostream& is) const { - size_type length = offsets.size() - 1u; - is.write(reinterpret_cast(&length), sizeof(length)); - is.write(reinterpret_cast(offsets.data()), + size_type length = offsets.size() - 1U; + is.write(reinterpret_cast(&length), sizeof(length)); + is.write(reinterpret_cast(offsets.data()), offsets.size() * sizeof(offsets[0])); - is.write(reinterpret_cast(payloads.data()), payloads.size()); + is.write(reinterpret_cast(payloads.data()), payloads.size()); } template @@ -188,7 +193,7 @@ struct Payload_Vector_Buffer { PayloadEncodingFn encoding_fn) -> Payload_Vector_Buffer { std::vector offsets; - offsets.push_back(0u); + offsets.push_back(0U); std::vector payloads; for (; first != last; ++first) { encoding_fn(*first, std::back_inserter(payloads)); @@ -256,7 +261,7 @@ template throw std::runtime_error( fmt::format("Failed to cast byte-span to span of T of size {}", type_size)); } - return gsl::make_span(reinterpret_cast(mem.data()), mem.size() / type_size); + return gsl::make_span(reinterpret_cast(mem.data()), mem.size() / type_size); } template @@ -267,20 +272,20 @@ class Payload_Vector { using payload_type = Payload_View; using iterator = detail::Payload_Vector_Iterator; - Payload_Vector(Payload_Vector_Buffer const &container) + explicit Payload_Vector(Payload_Vector_Buffer const& container) : offsets_(container.offsets), payloads_(container.payloads) { } Payload_Vector(gsl::span offsets, gsl::span payloads) - : offsets_(std::move(offsets)), payloads_(std::move(payloads)) + : offsets_(offsets), payloads_(payloads) { } template - [[nodiscard]] constexpr static auto from(ContiguousContainer &&mem) -> Payload_Vector + [[nodiscard]] constexpr static auto from(ContiguousContainer&& mem) -> Payload_Vector { - return from(gsl::make_span(reinterpret_cast(mem.data()), mem.size())); + return from(gsl::make_span(reinterpret_cast(mem.data()), mem.size())); } [[nodiscard]] static auto from(gsl::span mem) -> Payload_Vector @@ -289,15 +294,16 @@ class Payload_Vector { gsl::span tail; try { std::tie(length, tail) = unpack_head(mem); - } catch (std::runtime_error const &err) { + } catch (std::runtime_error const& err) { throw std::runtime_error(std::string("Failed to parse payload vector length: ") + err.what()); } - gsl::span offsets, payloads; + gsl::span offsets; + gsl::span payloads; try { - std::tie(offsets, payloads) = split(tail, (length + 1u) * sizeof(size_type)); - } catch (std::runtime_error const &err) { + std::tie(offsets, payloads) = split(tail, (length + 1U) * sizeof(size_type)); + } catch (std::runtime_error const& err) { throw std::runtime_error(std::string("Failed to parse payload vector offset table: ") + err.what()); } diff --git a/include/pisa/query/algorithm/block_max_maxscore_query.hpp b/include/pisa/query/algorithm/block_max_maxscore_query.hpp index 46f6520f2..186c7aaa0 100644 --- a/include/pisa/query/algorithm/block_max_maxscore_query.hpp +++ b/include/pisa/query/algorithm/block_max_maxscore_query.hpp @@ -1,6 +1,7 @@ #pragma once #include + #include "query/queries.hpp" #include "topk_queue.hpp" diff --git a/include/pisa/query/algorithm/wand_query.hpp b/include/pisa/query/algorithm/wand_query.hpp index 3a04086f4..22f7f2ddc 100644 --- a/include/pisa/query/algorithm/wand_query.hpp +++ b/include/pisa/query/algorithm/wand_query.hpp @@ -94,4 +94,4 @@ struct wand_query { topk_queue &m_topk; }; -} // namespace pisa \ No newline at end of file +} // namespace pisa diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index 419704f73..b42c0be79 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -22,23 +22,23 @@ struct Query { std::vector term_weights; }; -[[nodiscard]] auto split_query_at_colon(std::string const &query_string) +[[nodiscard]] auto split_query_at_colon(std::string const& query_string) -> std::pair, std::string_view>; -[[nodiscard]] auto parse_query_terms(std::string const &query_string, TermProcessor term_processor) +[[nodiscard]] auto parse_query_terms(std::string const& query_string, TermProcessor term_processor) -> Query; -[[nodiscard]] auto parse_query_ids(std::string const &query_string) -> Query; +[[nodiscard]] auto parse_query_ids(std::string const& query_string) -> Query; [[nodiscard]] std::function resolve_query_parser( - std::vector &queries, - std::optional const &terms_file, - std::optional const &stopwords_filename, - std::optional const &stemmer_type); + std::vector& queries, + std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type); -bool read_query(term_id_vec &ret, std::istream &is = std::cin); +bool read_query(term_id_vec& ret, std::istream& is = std::cin); -void remove_duplicate_terms(term_id_vec &terms); +void remove_duplicate_terms(term_id_vec& terms); term_freq_vec query_freqs(term_id_vec terms); diff --git a/include/pisa/query/query.hpp b/include/pisa/query/query.hpp new file mode 100644 index 000000000..12559f957 --- /dev/null +++ b/include/pisa/query/query.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include + +namespace pisa { + +using term_id_type = std::uint32_t; +using term_id_vec = std::vector; + +struct Query { + std::optional id; + std::vector terms; + std::vector term_weights; +}; + +} // namespace pisa diff --git a/include/pisa/query/term_processor.hpp b/include/pisa/query/term_processor.hpp index 3f4d0b231..5933e217d 100644 --- a/include/pisa/query/term_processor.hpp +++ b/include/pisa/query/term_processor.hpp @@ -24,9 +24,9 @@ class TermProcessor { std::function(std::string)> _to_id; public: - TermProcessor(std::optional const &terms_file, - std::optional const &stopwords_filename, - std::optional const &stemmer_type) + TermProcessor(std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) { auto source = std::make_shared(terms_file->c_str()); auto terms = Payload_Vector<>::from(*source); @@ -66,7 +66,7 @@ class TermProcessor { // Loads stopwords. if (stopwords_filename) { std::ifstream is(*stopwords_filename); - io::for_each_line(is, [&](auto &&word) { + io::for_each_line(is, [&](auto&& word) { if (auto processed_term = _to_id(std::move(word)); processed_term.has_value()) { stopwords.insert(*processed_term); } diff --git a/include/pisa/sequence/partitioned_sequence.hpp b/include/pisa/sequence/partitioned_sequence.hpp index 4ebc5c3e0..132b70dbd 100644 --- a/include/pisa/sequence/partitioned_sequence.hpp +++ b/include/pisa/sequence/partitioned_sequence.hpp @@ -231,6 +231,8 @@ namespace pisa { return m_partitions; } + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + friend class partitioned_sequence_test; private: diff --git a/include/pisa/topk_queue.hpp b/include/pisa/topk_queue.hpp index 6dfa5f516..b737d2de8 100644 --- a/include/pisa/topk_queue.hpp +++ b/include/pisa/topk_queue.hpp @@ -1,8 +1,8 @@ #pragma once -#include -#include "util/util.hpp" #include "util/likely.hpp" +#include "util/util.hpp" +#include namespace pisa { @@ -10,12 +10,16 @@ using Threshold = float; struct topk_queue { using entry_type = std::pair; - explicit topk_queue(uint64_t k) : m_threshold(0), m_k(k) { m_q.reserve(m_k + 1); } - topk_queue(topk_queue const &q) = default; - topk_queue &operator=(topk_queue const &q) = default; + explicit topk_queue(uint64_t k) : m_threshold(std::numeric_limits::lowest()), m_k(k) + { + m_q.reserve(m_k + 1); + } + topk_queue(topk_queue const& q) = default; + topk_queue& operator=(topk_queue const& q) = default; - [[nodiscard]] constexpr static auto min_heap_order(entry_type const &lhs, - entry_type const &rhs) noexcept -> bool { + [[nodiscard]] constexpr static auto min_heap_order(entry_type const& lhs, + entry_type const& rhs) noexcept -> bool + { return lhs.first > rhs.first; } @@ -28,7 +32,7 @@ struct topk_queue { m_q.emplace_back(score, docid); if (PISA_UNLIKELY(m_q.size() <= m_k)) { std::push_heap(m_q.begin(), m_q.end(), min_heap_order); - if(PISA_UNLIKELY(m_q.size() == m_k)) { + if (PISA_UNLIKELY(m_q.size() == m_k)) { m_threshold = m_q.front().first; } } else { @@ -39,35 +43,36 @@ struct topk_queue { return true; } - bool would_enter(float score) const { return score >= m_threshold; } + [[nodiscard]] bool would_enter(float score) const { return score >= m_threshold; } - void finalize() { + void finalize() + { std::sort_heap(m_q.begin(), m_q.end(), min_heap_order); size_t size = std::lower_bound(m_q.begin(), m_q.end(), 0, - [](std::pair l, float r) { return l.first > r; }) - - m_q.begin(); + [](std::pair l, float r) { return l.first > r; }) + - m_q.begin(); m_q.resize(size); } - [[nodiscard]] std::vector const &topk() const noexcept { return m_q; } - - void set_threshold(Threshold t) noexcept { - m_threshold = t; - } + [[nodiscard]] std::vector const& topk() const noexcept { return m_q; } - void clear() noexcept { + void clear() noexcept + { m_q.clear(); m_threshold = 0; } [[nodiscard]] uint64_t size() const noexcept { return m_k; } + [[nodiscard]] auto full() const noexcept -> bool { return m_q.size() == m_k; } + + void set_threshold(float threshold) noexcept { m_threshold = threshold; } private: - float m_threshold; - uint64_t m_k; + float m_threshold; + uint64_t m_k; std::vector m_q; }; diff --git a/include/pisa/v1/README.md b/include/pisa/v1/README.md new file mode 100644 index 000000000..aa30112d7 --- /dev/null +++ b/include/pisa/v1/README.md @@ -0,0 +1,79 @@ +> This document is a **work in progress**. + +# Introduction + +In our efforts to come up with the v1.0 of both PISA and our index format, +we should start a discussion about the shape of things from the point of view +of both the binary format and how we can use it in our library. + +## Index Format specification + +This document mainly discusses the binary file format of each index component, +as well as how these components come together to form a cohesive structure. + +## Reference Implementation + +Along with format description and discussion, this directory includes some +reference implementation of the discussed structures and some algorithms working on them. + +The goal of this is to show how things work on certain examples, +and find out what works and what doesn't and still needs to be thought through. + +> Look in `test/test_v1.cpp` for code examples. + +# Posting Files + +> Example: `v1/raw_cursor.hpp`. + +Each _posting file_ contains a list of blocks of data, each related to a single term, +preceded by a header encoding information about the type of payload. + +> Do we need the header? I would say "yes" because even if we store the information +> somewhere else, then we might want to (1) verify that we are reading what we think +> we are reading, and (2) verify format version compatibility. +> The latter should be further discussed. + +``` +Posting File := Header, [Posting Block] +``` + +Each posting block encodes a list of homogeneous values, called _postings_. +Encoding is not fixed. + +> Note that _block_ here means the entire posting list area. +> We can work on the terminology. + +## Header + +> Example: `v1/posting_format_header.hpp`. + +We should store the type of the postings in the file, as well as encoding used. +**This might be tricky because we'd like it to be an open set of values/encodings.** + +``` +Header := Version, Type, Encoding +Version := Major, Minor, Path +Type := ValueId, Count +``` + +## Posting Types + +I think supporting these types will be sufficient to express about anything we +would want to, including single-value lists, document-frequency (or score) lists, +positional indexes, etc. + +``` +Type := Primitive | List[Type] | Tuple[Type] +Primitive := int32 | float32 +``` + +## Encodings + +We can identify encodings by either a name or ID/hash, or both. +I can imagine that an index reader could **register** new encodings, +and default to whatever we define in PISA. +We should then also verify that this encoding implement a `Encoding` "concept". +This is not the same as our "codecs". +This would be more like posting list reader. + +> Example: `IndexRunner` in `v1/index.hpp`. diff --git a/include/pisa/v1/algorithm.hpp b/include/pisa/v1/algorithm.hpp new file mode 100644 index 000000000..7d8f5481a --- /dev/null +++ b/include/pisa/v1/algorithm.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +#include + +namespace pisa::v1 { + +template +[[nodiscard]] auto min_value(Container&& cursors) +{ + auto pos = + std::min_element(cursors.begin(), cursors.end(), [](auto const& lhs, auto const& rhs) { + return lhs.value() < rhs.value(); + }); + return pos->value(); +} + +template +[[nodiscard]] auto min_sentinel(Container&& cursors) +{ + auto pos = + std::min_element(cursors.begin(), cursors.end(), [](auto const& lhs, auto const& rhs) { + return lhs.sentinel() < rhs.sentinel(); + }); + return pos->sentinel(); +} + +template +void partition_by_index(gsl::span range, gsl::span right_indices) +{ + if (right_indices.empty()) { + return; + } + std::sort(right_indices.begin(), right_indices.end()); + if (right_indices[right_indices.size() - 1] >= range.size()) { + throw std::logic_error("Essential index too large"); + } + std::vector essential; + essential.reserve(right_indices.size()); + std::vector non_essential; + non_essential.reserve(range.size() - right_indices.size()); + + auto cidx = 0; + auto eidx = 0; + while (eidx < right_indices.size()) { + if (cidx < right_indices[eidx]) { + non_essential.push_back(std::move(range[cidx])); + cidx += 1; + } else { + essential.push_back(std::move(range[cidx])); + eidx += 1; + cidx += 1; + } + } + std::move(std::next(range.begin(), cidx), range.end(), std::back_inserter(non_essential)); + auto pos = std::move(non_essential.begin(), non_essential.end(), range.begin()); + std::move(essential.begin(), essential.end(), pos); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/base_index.hpp b/include/pisa/v1/base_index.hpp new file mode 100644 index 000000000..c41bace75 --- /dev/null +++ b/include/pisa/v1/base_index.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/posting_builder.hpp" +#include "v1/source.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +using OffsetSpan = gsl::span; +using BinarySpan = gsl::span; + +[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float; +[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector; + +/// Lexicographically compares bigrams. +/// Used for looking up bigram mappings. +[[nodiscard]] auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool; + +struct PostingData { + BinarySpan postings; + OffsetSpan offsets; +}; + +struct UnigramData { + PostingData documents; + PostingData payloads; +}; + +struct BigramData { + PostingData documents; + std::array payloads; + gsl::span const> mapping; +}; + +/// Parts of the index independent of the template parameters. +struct BaseIndex { + + template + BaseIndex(PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source) + : m_documents(documents), + m_payloads(payloads), + m_bigrams(bigrams), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length.map_or_else( + [](auto&& self) { return self; }, + [&]() { return calc_avg_length(m_document_lengths); })), + m_max_scores(std::move(max_scores)), + m_block_max_scores(std::move(block_max_scores)), + m_quantized_max_scores(quantized_max_scores), + m_source(std::move(source)) + { + } + + [[nodiscard]] auto num_terms() const -> std::size_t; + [[nodiscard]] auto num_documents() const -> std::size_t; + [[nodiscard]] auto num_pairs() const -> std::size_t; + [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t; + [[nodiscard]] auto avg_document_length() const -> float; + [[nodiscard]] auto normalized_document_length(DocId docid) const -> float; + [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional; + [[nodiscard]] auto pairs() const -> tl::optional const>>; + + protected: + void assert_term_in_bounds(TermId term) const; + [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_documents(TermId bigram) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const + -> std::array, 2>; + template + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const -> gsl::span; + + [[nodiscard]] auto max_score(std::size_t scorer_hash, TermId term) const -> float; + [[nodiscard]] auto block_max_scores(std::size_t scorer_hash) const -> UnigramData const&; + [[nodiscard]] auto quantized_max_score(TermId term) const -> std::uint8_t; + + private: + PostingData m_documents; + PostingData m_payloads; + tl::optional m_bigrams; + + gsl::span m_document_lengths; + float m_avg_document_length; + std::unordered_map> m_max_scores; + std::unordered_map m_block_max_scores; + gsl::span m_quantized_max_scores; + std::any m_source; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/bit_cast.hpp b/include/pisa/v1/bit_cast.hpp new file mode 100644 index 000000000..dd4fa6413 --- /dev/null +++ b/include/pisa/v1/bit_cast.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +constexpr auto bit_cast(gsl::span mem) -> std::remove_const_t +{ + std::remove_const_t dst{}; + std::memcpy(&dst, mem.data(), sizeof(T)); + return dst; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/bit_sequence_cursor.hpp b/include/pisa/v1/bit_sequence_cursor.hpp new file mode 100644 index 000000000..d9895a3e5 --- /dev/null +++ b/include/pisa/v1/bit_sequence_cursor.hpp @@ -0,0 +1,250 @@ +#pragma once + +#include +#include +#include + +#include + +#include "codec/block_codecs.hpp" +#include "codec/integer_codes.hpp" +#include "global_parameters.hpp" +#include "util/compiler_attribute.hpp" +#include "v1/base_index.hpp" +#include "v1/bit_cast.hpp" +#include "v1/bit_vector.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + + BitSequenceCursor(std::shared_ptr bits, sequence_enumerator_type sequence_enumerator) + : m_sequence_enumerator(std::move(sequence_enumerator)), m_bits(std::move(bits)) + { + reset(); + } + + void reset() + { + m_position = 0; + m_current_value = m_sequence_enumerator.move(0).second; + } + + [[nodiscard]] constexpr auto operator*() const -> value_type + { + if (PISA_UNLIKELY(empty())) { + return sentinel(); + } + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } + + constexpr void advance() + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.next(); + } + + PISA_FLATTEN_FUNC void advance_to_position(std::size_t position) + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.move(position); + } + + constexpr void advance_to_geq(value_type value) + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.next_geq(value); + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_position == size(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } + [[nodiscard]] constexpr auto size() const -> std::size_t + { + return m_sequence_enumerator.size(); + } + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return m_sequence_enumerator.universe(); + } + + private: + std::uint64_t m_position = 0; + std::uint64_t m_current_value{}; + sequence_enumerator_type m_sequence_enumerator; + std::shared_ptr m_bits; +}; + +template +struct DocumentBitSequenceCursor : public BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + explicit DocumentBitSequenceCursor(std::shared_ptr bits, + sequence_enumerator_type sequence_enumerator) + : BitSequenceCursor(std::move(bits), std::move(sequence_enumerator)) + { + } +}; + +template +struct PayloadBitSequenceCursor : public BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + explicit PayloadBitSequenceCursor(std::shared_ptr bits, + sequence_enumerator_type sequence_enumerator) + : BitSequenceCursor(std::move(bits), std::move(sequence_enumerator)) + { + } +}; + +template +struct BitSequenceReader { + using value_type = std::uint32_t; + + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor + { + runtime_assert(bytes.size() % sizeof(BitVector::storage_type) == 0).or_throw([&]() { + return fmt::format( + "Attempted to read no. bytes ({}) not aligned with the storage type of size {}", + bytes.size(), + sizeof(typename BitVector::storage_type)); + }); + + auto true_bit_length = bit_cast( + bytes.first(sizeof(typename BitVector::storage_type))); + bytes = bytes.subspan(sizeof(typename BitVector::storage_type)); + auto bits = std::make_shared( + gsl::span( + reinterpret_cast(bytes.data()), + bytes.size() / sizeof(BitVector::storage_type)), + true_bit_length); + BitVector::enumerator enumerator(*bits, 0); + std::uint64_t universe = read_gamma_nonzero(enumerator); + std::uint64_t n = 1; + if (universe > 1) { + n = enumerator.take(ceil_log2(universe + 1)); + } + return Cursor(bits, + typename BitSequence::enumerator( + *bits, enumerator.position(), universe + 1, n, global_parameters())); + } + + void init([[maybe_unused]] BaseIndex const& index) {} + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::BitSequence | encoding_traits::encoding_tag::encoding(); + } +}; + +template +struct DocumentBitSequenceReader + : public BitSequenceReader> { +}; +template +struct PayloadBitSequenceReader + : public BitSequenceReader> { +}; + +template +struct BitSequenceWriter { + using value_type = std::uint32_t; + BitSequenceWriter() = default; + explicit BitSequenceWriter(std::size_t num_documents) : m_num_documents(num_documents) {} + BitSequenceWriter(BitSequenceWriter const&) = default; + BitSequenceWriter(BitSequenceWriter&&) noexcept = default; + BitSequenceWriter& operator=(BitSequenceWriter const&) = default; + BitSequenceWriter& operator=(BitSequenceWriter&&) noexcept = default; + ~BitSequenceWriter() = default; + + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::BitSequence | encoding_traits::encoding_tag::encoding(); + } + + void init(pisa::binary_freq_collection const& collection) + { + m_num_documents = collection.num_docs(); + } + void push(value_type const& posting) + { + m_sum += posting; + m_postings.push_back(posting); + } + void push(value_type&& posting) + { + m_sum += posting; + m_postings.push_back(posting); + } + + template + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t + { + runtime_assert(m_num_documents.has_value()) + .or_throw("Uninitialized writer. Must call `init()` before writing."); + runtime_assert(!m_postings.empty()).or_throw("Tried to write an empty posting list"); + bit_vector_builder builder; + auto universe = [&]() { + if constexpr (DocumentWriter) { + return m_num_documents.value() - 1; + } else { + return m_sum; + } + }(); + write_gamma_nonzero(builder, universe); + if (universe > 1) { + builder.append_bits(m_postings.size(), ceil_log2(universe + 1)); + } + BitSequence::write( + builder, m_postings.begin(), universe + 1, m_postings.size(), global_parameters()); + typename BitVector::storage_type true_bit_length = builder.size(); + auto data = builder.move_bits(); + os.write(reinterpret_cast(&true_bit_length), sizeof(true_bit_length)); + auto memory = gsl::as_bytes(gsl::make_span(data.data(), data.size())); + os.write(reinterpret_cast(memory.data()), memory.size()); + auto bytes_written = sizeof(true_bit_length) + memory.size(); + runtime_assert(bytes_written % sizeof(typename BitVector::storage_type) == 0) + .or_throw([&]() { + return fmt::format( + "Bytes written ({}) are not aligned with the storage type of size {}", + bytes_written, + sizeof(typename BitVector::storage_type)); + }); + return bytes_written; + } + + void reset() + { + m_postings.clear(); + m_sum = 0; + } + + private: + std::vector m_postings{}; + value_type m_sum = 0; + tl::optional m_num_documents{}; +}; + +template +using DocumentBitSequenceWriter = BitSequenceWriter; + +template +using PayloadBitSequenceWriter = BitSequenceWriter; + +template +struct CursorTraits> { + using Value = std::uint32_t; + using Writer = DocumentBitSequenceWriter; + using Reader = DocumentBitSequenceReader; +}; + +template +struct CursorTraits> { + using Value = std::uint32_t; + using Writer = PayloadBitSequenceWriter; + using Reader = PayloadBitSequenceReader; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/bit_vector.hpp b/include/pisa/v1/bit_vector.hpp new file mode 100644 index 000000000..d7ed29922 --- /dev/null +++ b/include/pisa/v1/bit_vector.hpp @@ -0,0 +1,772 @@ +#pragma once + +#include +#include + +#include "util/broadword.hpp" +#include "util/util.hpp" + +namespace pisa::v1 { + +namespace detail { + inline size_t words_for(uint64_t n) { return ceil_div(n, 64); } +} // namespace detail + +class BitVectorBuilder { + public: + using bits_type = std::vector; + + BitVectorBuilder(uint64_t size = 0, bool init = 0) : m_size(size) + { + m_bits.resize(detail::words_for(size), uint64_t(-init)); + if (size) { + m_cur_word = &m_bits.back(); + // clear padding bits + if (init && size % 64) { + *m_cur_word >>= 64 - (size % 64); + } + } + } + BitVectorBuilder(const BitVectorBuilder&) = delete; + BitVectorBuilder& operator=(const BitVectorBuilder&) = delete; + + void reserve(uint64_t size) { m_bits.reserve(detail::words_for(size)); } + + inline void push_back(bool b) + { + uint64_t pos_in_word = m_size % 64; + if (pos_in_word == 0) { + m_bits.push_back(0); + m_cur_word = &m_bits.back(); + } + *m_cur_word |= (uint64_t)b << pos_in_word; + ++m_size; + } + + inline void set(uint64_t pos, bool b) + { + uint64_t word = pos / 64; + uint64_t pos_in_word = pos % 64; + + m_bits[word] &= ~(uint64_t(1) << pos_in_word); + m_bits[word] |= uint64_t(b) << pos_in_word; + } + + inline void set_bits(uint64_t pos, uint64_t bits, size_t len) + { + assert(pos + len <= size()); + // check there are no spurious bits + assert(len == 64 || (bits >> len) == 0); + if (!len) + return; + uint64_t mask = (len == 64) ? uint64_t(-1) : ((uint64_t(1) << len) - 1); + uint64_t word = pos / 64; + uint64_t pos_in_word = pos % 64; + + m_bits[word] &= ~(mask << pos_in_word); + m_bits[word] |= bits << pos_in_word; + + uint64_t stored = 64 - pos_in_word; + if (stored < len) { + m_bits[word + 1] &= ~(mask >> stored); + m_bits[word + 1] |= bits >> stored; + } + } + + inline void append_bits(uint64_t bits, size_t len) + { + // check there are no spurious bits + assert(len == 64 || (bits >> len) == 0); + if (!len) + return; + uint64_t pos_in_word = m_size % 64; + m_size += len; + if (pos_in_word == 0) { + m_bits.push_back(bits); + } else { + *m_cur_word |= bits << pos_in_word; + if (len > 64 - pos_in_word) { + m_bits.push_back(bits >> (64 - pos_in_word)); + } + } + m_cur_word = &m_bits.back(); + } + + inline void zero_extend(uint64_t n) + { + m_size += n; + uint64_t needed = detail::words_for(m_size) - m_bits.size(); + if (needed) { + m_bits.insert(m_bits.end(), needed, 0); + m_cur_word = &m_bits.back(); + } + } + + inline void one_extend(uint64_t n) + { + while (n >= 64) { + append_bits(uint64_t(-1), 64); + n -= 64; + } + if (n) { + append_bits(uint64_t(-1) >> (64 - n), n); + } + } + + void append(BitVectorBuilder const& rhs) + { + if (!rhs.size()) + return; + + uint64_t pos = m_bits.size(); + uint64_t shift = size() % 64; + m_size = size() + rhs.size(); + m_bits.resize(detail::words_for(m_size)); + + if (shift == 0) { // word-aligned, easy case + std::copy(rhs.m_bits.begin(), rhs.m_bits.end(), m_bits.begin() + ptrdiff_t(pos)); + } else { + uint64_t* cur_word = &m_bits.front() + pos - 1; + for (size_t i = 0; i < rhs.m_bits.size() - 1; ++i) { + uint64_t w = rhs.m_bits[i]; + *cur_word |= w << shift; + *++cur_word = w >> (64 - shift); + } + *cur_word |= rhs.m_bits.back() << shift; + if (cur_word < &m_bits.back()) { + *++cur_word = rhs.m_bits.back() >> (64 - shift); + } + } + m_cur_word = &m_bits.back(); + } + + // reverse in place + void reverse() + { + uint64_t shift = 64 - (size() % 64); + + uint64_t remainder = 0; + for (size_t i = 0; i < m_bits.size(); ++i) { + uint64_t cur_word; + if (shift != 64) { // this should be hoisted out + cur_word = remainder | (m_bits[i] << shift); + remainder = m_bits[i] >> (64 - shift); + } else { + cur_word = m_bits[i]; + } + m_bits[i] = broadword::reverse_bits(cur_word); + } + assert(remainder == 0); + std::reverse(m_bits.begin(), m_bits.end()); + } + + bits_type& move_bits() + { + assert(detail::words_for(m_size) == m_bits.size()); + return m_bits; + } + + uint64_t size() const { return m_size; } + + void swap(BitVectorBuilder& other) + { + m_bits.swap(other.m_bits); + std::swap(m_size, other.m_size); + std::swap(m_cur_word, other.m_cur_word); + } + + private: + bits_type m_bits; + uint64_t m_size; + uint64_t* m_cur_word; +}; + +class BitVector { + public: + using storage_type = std::uint64_t; + BitVector() = default; + BitVector(gsl::span bits, std::size_t size) : m_bits(bits), m_size(size) {} + BitVector(BitVector const&) = default; + BitVector(BitVector&&) noexcept = default; + BitVector& operator=(BitVector const&) = default; + BitVector& operator=(BitVector&&) noexcept = default; + ~BitVector() = default; + + // template + // void map(Visitor& visit) + //{ + // visit(m_size, "m_size")(m_bits, "m_bits"); + //} + + // void swap(BitVector &other) + //{ + // std::swap(other.m_size, m_size); + // other.m_bits.swap(m_bits); + //} + + inline size_t size() const { return m_size; } + + inline bool operator[](uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + assert(block < m_bits.size()); + uint64_t shift = pos % 64; + return (m_bits[block] >> shift) & 1; + } + + inline uint64_t get_bits(uint64_t pos, uint64_t len) const + { + assert(pos + len <= size()); + assert(len <= 64); + if (len == 0U) { + return 0; + } + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t mask = std::numeric_limits::max() + >> (std::numeric_limits::digits - len); + if (shift + len <= 64) { + return m_bits[block] >> shift & mask; + } else { + return (m_bits[block] >> shift) | (m_bits[block + 1] << (64 - shift) & mask); + } + } + + // same as get_bits(pos, 64) but it can extend further size(), padding with zeros + inline uint64_t get_word(uint64_t pos) const + { + assert(pos < size()); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = m_bits[block] >> shift; + if (shift && block + 1 < m_bits.size()) { + word |= m_bits[block + 1] << (64 - shift); + } + return word; + } + + // unsafe and fast version of get_word, it retrieves at least 56 bits + inline uint64_t get_word56(uint64_t pos) const + { + // XXX check endianness? + const char* ptr = reinterpret_cast(m_bits.data()); + return *(reinterpret_cast(ptr + pos / 8)) >> (pos % 8); + } + + inline uint64_t predecessor0(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = 64 - pos % 64 - 1; + uint64_t word = ~m_bits[block]; + word = (word << shift) >> shift; + + unsigned long ret; + while (!broadword::msb(word, ret)) { + assert(block); + word = ~m_bits[--block]; + }; + return block * 64 + ret; + } + + inline uint64_t successor0(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = (~m_bits[block] >> shift) << shift; + + unsigned long ret; + while (!broadword::lsb(word, ret)) { + ++block; + assert(block < m_bits.size()); + word = ~m_bits[block]; + }; + return block * 64 + ret; + } + + inline uint64_t predecessor1(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = 64 - pos % 64 - 1; + uint64_t word = m_bits[block]; + word = (word << shift) >> shift; + + unsigned long ret; + while (!broadword::msb(word, ret)) { + assert(block); + word = m_bits[--block]; + }; + return block * 64 + ret; + } + + inline uint64_t successor1(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = (m_bits[block] >> shift) << shift; + + unsigned long ret; + while (!broadword::lsb(word, ret)) { + ++block; + assert(block < m_bits.size()); + word = m_bits[block]; + }; + return block * 64 + ret; + } + + std::uint64_t const* data() const { return m_bits.data(); } + + struct enumerator { + enumerator() : m_bv(0), m_pos(uint64_t(-1)) {} + + enumerator(BitVector const& bv, size_t pos) : m_bv(&bv), m_pos(pos), m_buf(0), m_avail(0) + { + // m_bv->data().prefetch(m_pos / 64); + } + + inline bool next() + { + if (!m_avail) + fill_buf(); + bool b = m_buf & 1; + m_buf >>= 1; + m_avail -= 1; + m_pos += 1; + return b; + } + + inline uint64_t take(size_t l) + { + if (m_avail < l) + fill_buf(); + uint64_t val; + if (l != 64) { + val = m_buf & ((uint64_t(1) << l) - 1); + m_buf >>= l; + } else { + val = m_buf; + } + m_avail -= l; + m_pos += l; + return val; + } + + inline uint64_t skip_zeros() + { + uint64_t zs = 0; + // XXX the loop may be optimized by aligning access + while (!m_buf) { + m_pos += m_avail; + zs += m_avail; + m_avail = 0; + fill_buf(); + } + + uint64_t l = broadword::lsb(m_buf); + m_buf >>= l; + m_buf >>= 1; + m_avail -= l + 1; + m_pos += l + 1; + return zs + l; + } + + inline uint64_t position() const { return m_pos; } + + private: + inline void fill_buf() + { + m_buf = m_bv->get_word(m_pos); + m_avail = 64; + } + + BitVector const* m_bv; + size_t m_pos; + uint64_t m_buf; + size_t m_avail; + }; + + struct unary_enumerator { + unary_enumerator() : m_data(0), m_position(0), m_buf(0) {} + + unary_enumerator(BitVector const& bv, uint64_t pos) + { + m_data = bv.data(); + m_position = pos; + m_buf = m_data[pos / 64]; + // clear low bits + m_buf &= uint64_t(-1) << (pos % 64); + } + + uint64_t position() const { return m_position; } + + uint64_t next() + { + unsigned long pos_in_word; + uint64_t buf = m_buf; + while (!broadword::lsb(buf, pos_in_word)) { + m_position += 64; + buf = m_data[m_position / 64]; + } + + m_buf = buf & (buf - 1); // clear LSB + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + return m_position; + } + + // skip to the k-th one after the current position + void skip(uint64_t k) + { + uint64_t skipped = 0; + uint64_t buf = m_buf; + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + m_position += 64; + buf = m_data[m_position / 64]; + } + assert(buf); + uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); + m_buf = buf & (uint64_t(-1) << pos_in_word); + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + } + + // return the position of the k-th one after the current position. + uint64_t skip_no_move(uint64_t k) + { + uint64_t position = m_position; + uint64_t skipped = 0; + uint64_t buf = m_buf; + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + position += 64; + buf = m_data[position / 64]; + } + assert(buf); + uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); + position = (position & ~uint64_t(63)) + pos_in_word; + return position; + } + + // skip to the k-th zero after the current position + void skip0(uint64_t k) + { + uint64_t skipped = 0; + uint64_t pos_in_word = m_position % 64; + uint64_t buf = ~m_buf & (uint64_t(-1) << pos_in_word); + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + m_position += 64; + buf = ~m_data[m_position / 64]; + } + assert(buf); + pos_in_word = broadword::select_in_word(buf, k - skipped); + m_buf = ~buf & (uint64_t(-1) << pos_in_word); + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + } + + private: + uint64_t const* m_data; + uint64_t m_position; + uint64_t m_buf; + }; + + protected: + gsl::span m_bits{}; + std::size_t m_size = 0; +}; + +// struct BitVector { +// using storage_type = std::uint64_t; +// +// BitVector() = default; +// BitVector(gsl::span bits, std::size_t size) : m_bits(bits), m_size(size) +// {} BitVector(BitVector const&) = default; BitVector(BitVector&&) noexcept = default; +// BitVector& operator=(BitVector const&) = default; +// BitVector& operator=(BitVector&&) noexcept = default; +// ~BitVector() = default; +// +// [[nodiscard]] inline auto size() const -> std::size_t { return m_size; } +// +// [[nodiscard]] inline auto operator[](std::size_t pos) const -> bool +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// Expects(block < m_bits.size()); +// std::uint64_t shift = pos % 64; +// return ((m_bits[block] >> shift) & 1U) != 0U; +// } +// +// [[nodiscard]] inline auto get_bits(uint64_t pos, uint64_t len) const -> std::uint64_t +// { +// Expects(pos + len <= size()); +// Expects(len <= 64); +// if (len == 0U) { +// return 0; +// } +// uint64_t block = pos / 64; +// uint64_t shift = pos % 64; +// uint64_t mask = std::numeric_limits::max() +// >> (std::numeric_limits::digits - len); +// if (shift + len <= 64) { +// return m_bits[block] >> shift & mask; +// } +// return (m_bits[block] >> shift) | (m_bits[block + 1] << (64 - shift) & mask); +// } +// +// // same as get_bits(pos, 64) but it can extend further size(), padding with zeros +// [[nodiscard]] inline auto get_word(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < size()); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = m_bits[block] >> shift; +// if (shift > 0U && block + 1U < m_bits.size()) { +// word |= m_bits[block + 1] << (64 - shift); +// } +// return word; +// } +// +// // unsafe and fast version of get_word, it retrieves at least 56 bits +// [[nodiscard]] inline auto get_word56(std::uint64_t pos) const -> std::uint64_t +// { +// // XXX check endianness? +// const char* ptr = reinterpret_cast(m_bits.data()); +// return *(reinterpret_cast(ptr + pos / 8)) >> (pos % 8); +// } +// +// [[nodiscard]] inline auto predecessor0(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = 64 - pos % 64 - 1; +// std::uint64_t word = ~m_bits[block]; +// word = (word << shift) >> shift; +// +// std::uint64_t ret; +// while (broadword::msb(word, ret) == 0U) { +// Expects(block); +// word = ~m_bits[--block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto successor0(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = (~m_bits[block] >> shift) << shift; +// +// std::uint64_t ret; +// while (broadword::lsb(word, ret) == 0U) { +// ++block; +// Expects(block < m_bits.size()); +// word = ~m_bits[block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto predecessor1(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = 64 - pos % 64 - 1; +// std::uint64_t word = m_bits[block]; +// word = (word << shift) >> shift; +// +// std::uint64_t ret; +// while (broadword::msb(word, ret) == 0) { +// Expects(block); +// word = m_bits[--block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto successor1(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = (m_bits[block] >> shift) << shift; +// +// std::uint64_t ret; +// while (broadword::lsb(word, ret) == 0U) { +// ++block; +// Expects(block < m_bits.size()); +// word = m_bits[block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto data() const -> std::uint64_t const* { return m_bits.data(); } +// +// struct enumerator { +// enumerator() : m_pos(uint64_t(-1)) {} +// +// enumerator(BitVector const& bv, size_t pos) : m_bv(&bv), m_pos(pos) +// { +// intrinsics::prefetch(std::next(m_bv->data(), m_pos / 64)); +// fill_buf(); +// } +// +// inline auto next() -> bool +// { +// if (m_avail == 0) { +// fill_buf(); +// } +// bool b = (m_buf & 1U) > 0; +// m_buf >>= 1U; +// m_avail -= 1; +// m_pos += 1; +// return b; +// } +// +// inline auto take(size_t l) -> std::uint64_t +// { +// if (m_avail < l) { +// fill_buf(); +// } +// std::uint64_t val; +// if (l != 64) { +// val = m_buf & ((std::uint64_t(1) << l) - 1); +// m_buf >>= l; +// } else { +// val = m_buf; +// } +// m_avail -= l; +// m_pos += l; +// return val; +// } +// +// inline uint64_t skip_zeros() +// { +// uint64_t zs = 0; +// // XXX the loop may be optimized by aligning access +// while (m_buf == 0) { +// m_pos += m_avail; +// zs += m_avail; +// m_avail = 0; +// fill_buf(); +// } +// +// uint64_t l = broadword::lsb(m_buf); +// m_buf >>= l; +// m_buf >>= 1U; +// m_avail -= l + 1; +// m_pos += l + 1; +// return zs + l; +// } +// +// [[nodiscard]] inline auto position() const -> std::uint64_t { return m_pos; } +// +// private: +// inline void fill_buf() +// { +// m_buf = m_bv->get_word(m_pos); +// m_avail = 64; +// } +// +// BitVector const* m_bv = nullptr; +// std::size_t m_pos; +// std::uint64_t m_buf = 0; +// std::size_t m_avail = 0; +// }; +// +// struct unary_enumerator { +// unary_enumerator() = default; +// unary_enumerator(BitVector const& bv, uint64_t pos) +// { +// m_data = bv.data(); +// m_position = pos; +// m_buf = m_data[pos / 64]; +// // clear low bits +// m_buf &= uint64_t(-1) << (pos % 64); +// } +// +// [[nodiscard]] auto position() const -> std::uint64_t { return m_position; } +// +// uint64_t next() +// { +// std::uint64_t pos_in_word; +// std::uint64_t buf = m_buf; +// while (broadword::lsb(buf, pos_in_word) == 0) { +// m_position += 64; +// buf = m_data[m_position / 64]; +// } +// +// m_buf = buf & (buf - 1); // clear LSB +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// return m_position; +// } +// +// // skip to the k-th one after the current position +// void skip(uint64_t k) +// { +// uint64_t skipped = 0; +// uint64_t buf = m_buf; +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// m_position += 64; +// buf = m_data[m_position / 64]; +// } +// Expects(buf); +// uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); +// m_buf = buf & (uint64_t(-1) << pos_in_word); +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// } +// +// // return the position of the k-th one after the current position. +// uint64_t skip_no_move(uint64_t k) +// { +// uint64_t position = m_position; +// uint64_t skipped = 0; +// uint64_t buf = m_buf; +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// position += 64; +// buf = m_data[position / 64]; +// } +// Expects(buf); +// uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); +// position = (position & ~uint64_t(63)) + pos_in_word; +// return position; +// } +// +// // skip to the k-th zero after the current position +// void skip0(uint64_t k) +// { +// uint64_t skipped = 0; +// uint64_t pos_in_word = m_position % 64; +// uint64_t buf = ~m_buf & (uint64_t(-1) << pos_in_word); +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// m_position += 64; +// buf = ~m_data[m_position / 64]; +// } +// Expects(buf); +// pos_in_word = broadword::select_in_word(buf, k - skipped); +// m_buf = ~buf & (uint64_t(-1) << pos_in_word); +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// } +// +// private: +// std::uint64_t const* m_data; +// std::uint64_t m_position; +// std::uint64_t m_buf; +// }; +// +// private: +// gsl::span m_bits{}; +// std::size_t m_size = 0; +//}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp new file mode 100644 index 000000000..37b60d75c --- /dev/null +++ b/include/pisa/v1/blocked_cursor.hpp @@ -0,0 +1,414 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "util/likely.hpp" +#include "v1/base_index.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/encoding_traits.hpp" +#include "v1/types.hpp" +#include "v1/unaligned_span.hpp" + +namespace pisa::v1 { + +/// Non-template base of blocked cursors. +struct BaseBlockedCursor { + using value_type = std::uint32_t; + using offset_type = std::uint32_t; + using size_type = std::uint32_t; + + /// Creates a cursor from the encoded bytes. + BaseBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + size_type length, + size_type num_blocks, + size_type block_length) + : m_encoded_blocks(encoded_blocks), + m_block_endpoints(block_endpoints), + m_decoded_block(block_length), + m_length(length), + m_num_blocks(num_blocks), + m_block_length(block_length), + m_current_block({.number = 0, + .offset = 0, + .length = std::min(length, static_cast(m_block_length))}) + { + } + + BaseBlockedCursor(BaseBlockedCursor const&) = default; + BaseBlockedCursor(BaseBlockedCursor&&) noexcept = default; + BaseBlockedCursor& operator=(BaseBlockedCursor const&) = default; + BaseBlockedCursor& operator=(BaseBlockedCursor&&) noexcept = default; + ~BaseBlockedCursor() = default; + + [[nodiscard]] auto operator*() const -> value_type; + [[nodiscard]] auto value() const noexcept -> value_type; + [[nodiscard]] auto empty() const noexcept -> bool; + [[nodiscard]] auto position() const noexcept -> std::size_t; + [[nodiscard]] auto size() const -> std::size_t; + [[nodiscard]] auto sentinel() const -> value_type; + + protected: + struct Block { + std::uint32_t number = 0; + std::uint32_t offset = 0; + std::uint32_t length = 0; + }; + + [[nodiscard]] auto block_offset(size_type block) const -> offset_type; + [[nodiscard]] auto decoded_block() -> value_type*; + [[nodiscard]] auto decoded_value(size_type n) -> value_type; + [[nodiscard]] auto encoded_block(offset_type offset) -> uint8_t const*; + [[nodiscard]] auto length() const -> size_type; + [[nodiscard]] auto num_blocks() const -> size_type; + [[nodiscard]] auto current_block() -> Block&; + void update_current_value(value_type val); + void increase_current_value(value_type val); + + private: + gsl::span m_encoded_blocks; + UnalignedSpan m_block_endpoints; + std::vector m_decoded_block{}; + size_type m_length; + size_type m_num_blocks; + size_type m_block_length; + Block m_current_block; + + value_type m_current_value{}; +}; + +template +struct GenericBlockedCursor : public BaseBlockedCursor { + using BaseBlockedCursor::offset_type; + using BaseBlockedCursor::size_type; + using BaseBlockedCursor::value_type; + + GenericBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : BaseBlockedCursor(encoded_blocks, block_endpoints, length, num_blocks, Codec::block_size), + m_block_last_values(block_last_values), + m_current_block_last_value(m_block_last_values.empty() ? value_type{} + : m_block_last_values[0]) + { + reset(); + } + + void reset() { decode_and_update_block(0); } + + /// Advances the cursor to the next position. + void advance() + { + auto& current_block = this->current_block(); + current_block.offset += 1; + if (PISA_UNLIKELY(current_block.offset == current_block.length)) { + if (current_block.number + 1 == num_blocks()) { + update_current_value(sentinel()); + return; + } + decode_and_update_block(current_block.number + 1); + } else { + if constexpr (DeltaEncoded) { + increase_current_value(decoded_value(current_block.offset)); + } else { + update_current_value(decoded_value(current_block.offset)); + } + } + } + + /// Moves the cursor to the position `pos`. + void advance_to_position(std::uint32_t pos) + { + Expects(pos >= position()); + auto& current_block = this->current_block(); + auto block = pos / Codec::block_size; + if (PISA_UNLIKELY(block != current_block.number)) { + decode_and_update_block(block); + } + while (position() < pos) { + current_block.offset += 1; + if constexpr (DeltaEncoded) { + increase_current_value(decoded_value(current_block.offset)); + } else { + update_current_value(decoded_value(current_block.offset)); + } + } + } + + protected: + [[nodiscard]] auto& block_last_values() { return m_block_last_values; } + [[nodiscard]] auto& current_block_last_value() { return m_current_block_last_value; } + + void decode_and_update_block(size_type block) + { + auto block_size = Codec::block_size; + auto const* block_data = encoded_block(block_offset(block)); + auto& current_block = this->current_block(); + current_block.length = + ((block + 1) * block_size <= size()) ? block_size : (size() % block_size); + + if constexpr (DeltaEncoded) { + std::uint32_t first_value = block > 0U ? m_block_last_values[block - 1] + 1U : 0U; + m_current_block_last_value = m_block_last_values[block]; + Codec::decode(block_data, + decoded_block(), + m_current_block_last_value - first_value - (current_block.length - 1), + current_block.length); + decoded_block()[0] += first_value; + } else { + Codec::decode(block_data, + decoded_block(), + std::numeric_limits::max(), + current_block.length); + decoded_block()[0] += 1; + } + + current_block.number = block; + current_block.offset = 0U; + update_current_value(decoded_block()[0]); + } + + private: + UnalignedSpan m_block_last_values{}; + value_type m_current_block_last_value{}; +}; + +template +struct DocumentBlockedCursor : public GenericBlockedCursor { + using offset_type = typename GenericBlockedCursor::offset_type; + using size_type = typename GenericBlockedCursor::size_type; + using value_type = typename GenericBlockedCursor::value_type; + + DocumentBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : GenericBlockedCursor( + encoded_blocks, block_endpoints, block_last_values, length, num_blocks) + { + } + + /// Moves the cursor to the next value equal or greater than `value`. + void advance_to_geq(value_type value) + { + auto& current_block = this->current_block(); + if (PISA_UNLIKELY(value > this->current_block_last_value())) { + if (value > this->block_last_values().back()) { + this->update_current_value(this->sentinel()); + return; + } + auto block = current_block.number + 1U; + while (this->block_last_values()[block] < value) { + ++block; + } + this->decode_and_update_block(block); + } + + while (this->value() < value) { + this->increase_current_value(this->decoded_value(++current_block.offset)); + Ensures(current_block.offset < current_block.length); + } + } +}; + +template +struct PayloadBlockedCursor : public GenericBlockedCursor { + using offset_type = typename GenericBlockedCursor::offset_type; + using size_type = typename GenericBlockedCursor::size_type; + using value_type = typename GenericBlockedCursor::value_type; + + PayloadBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + std::uint32_t length, + std::uint32_t num_blocks) + : GenericBlockedCursor( + encoded_blocks, block_endpoints, {}, length, num_blocks) + { + } +}; + +template +constexpr auto block_encoding_type() -> std::uint32_t +{ + if constexpr (DeltaEncoded) { + return EncodingId::BlockDelta; + } else { + return EncodingId::Block; + } +} + +template +struct GenericBlockedReader { + using value_type = std::uint32_t; + + void init(BaseIndex const& index) {} + [[nodiscard]] auto read(gsl::span bytes) const + { + std::uint32_t length; + auto begin = reinterpret_cast(bytes.data()); + auto after_length_ptr = pisa::TightVariableByte::decode(begin, &length, 1); + auto length_byte_size = std::distance(begin, after_length_ptr); + auto num_blocks = ceil_div(length, Codec::block_size); + UnalignedSpan block_last_values; + if constexpr (DeltaEncoded) { + block_last_values = UnalignedSpan( + bytes.subspan(length_byte_size, num_blocks * sizeof(value_type))); + } + UnalignedSpan block_endpoints( + bytes.subspan(length_byte_size + block_last_values.byte_size(), + (num_blocks - 1) * sizeof(value_type))); + auto encoded_blocks = bytes.subspan(length_byte_size + block_last_values.byte_size() + + block_endpoints.byte_size()); + if constexpr (DeltaEncoded) { + return DocumentBlockedCursor( + encoded_blocks, block_endpoints, block_last_values, length, num_blocks); + } else { + return PayloadBlockedCursor(encoded_blocks, block_endpoints, length, num_blocks); + } + } + + constexpr static auto encoding() -> std::uint32_t + { + return block_encoding_type() + | encoding_traits::encoding_tag::encoding(); + } +}; + +template +using DocumentBlockedReader = GenericBlockedReader; +template +using PayloadBlockedReader = GenericBlockedReader; + +template +struct GenericBlockedWriter { + using value_type = std::uint32_t; + + GenericBlockedWriter() = default; + explicit GenericBlockedWriter([[maybe_unused]] std::size_t num_documents) {} + + constexpr static auto encoding() -> std::uint32_t + { + return block_encoding_type() + | encoding_traits::encoding_tag::encoding(); + } + + void init([[maybe_unused]] pisa::binary_freq_collection const& collection) {} + void push(value_type const& posting) + { + if constexpr (DeltaEncoded) { + if (posting < m_last_value) { + throw std::runtime_error( + fmt::format("Delta-encoded sequences must be monotonic, but {} < {}", + posting, + m_last_value)); + } + } + m_postings.push_back(posting); + m_last_value = posting; + } + + template + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t + { + std::vector buffer; + std::uint32_t length = m_postings.size(); + TightVariableByte::encode_single(length, buffer); + auto block_size = Codec::block_size; + auto num_blocks = ceil_div(length, block_size); + auto begin_block_maxs = buffer.size(); + auto begin_block_endpoints = [&]() { + if constexpr (DeltaEncoded) { + return begin_block_maxs + 4U * num_blocks; + } else { + return begin_block_maxs; + } + }(); + auto begin_blocks = begin_block_endpoints + 4U * (num_blocks - 1); + buffer.resize(begin_blocks); + + auto iter = m_postings.begin(); + std::vector block_buffer(block_size); + std::uint32_t last_value(-1); + std::uint32_t block_base = 0; + for (auto block = 0; block < num_blocks; ++block) { + auto current_block_size = + ((block + 1) * block_size <= length) ? block_size : (length % block_size); + + std::for_each(block_buffer.begin(), + std::next(block_buffer.begin(), current_block_size), + [&](auto&& elem) { + if constexpr (DeltaEncoded) { + auto value = *iter++; + elem = value - (last_value + 1); + last_value = value; + } else { + elem = *iter++ - 1; + } + }); + + if constexpr (DeltaEncoded) { + std::memcpy( + &buffer[begin_block_maxs + 4U * block], &last_value, sizeof(last_value)); + auto size = buffer.size(); + Codec::encode(block_buffer.data(), + last_value - block_base - (current_block_size - 1), + current_block_size, + buffer); + } else { + Codec::encode(block_buffer.data(), std::uint32_t(-1), current_block_size, buffer); + } + if (block != num_blocks - 1) { + std::size_t endpoint = buffer.size() - begin_blocks; + std::memcpy( + &buffer[begin_block_endpoints + 4U * block], &endpoint, sizeof(last_value)); + } + block_base = last_value + 1; + } + os.write(reinterpret_cast(buffer.data()), buffer.size()); + return buffer.size(); + } + + void reset() + { + m_postings.clear(); + m_last_value = 0; + } + + private: + std::vector m_postings{}; + value_type m_last_value = 0U; +}; + +template +using DocumentBlockedWriter = GenericBlockedWriter; +template +using PayloadBlockedWriter = GenericBlockedWriter; + +template +struct CursorTraits> { + using Value = std::uint32_t; + using Writer = DocumentBlockedWriter; + using Reader = DocumentBlockedReader; +}; + +template +struct CursorTraits> { + using Value = std::uint32_t; + using Writer = PayloadBlockedWriter; + using Reader = PayloadBlockedReader; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/accumulate.hpp b/include/pisa/v1/cursor/accumulate.hpp new file mode 100644 index 000000000..9a5f58680 --- /dev/null +++ b/include/pisa/v1/cursor/accumulate.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "v1/cursor/for_each.hpp" + +namespace pisa::v1 { + +template +[[nodiscard]] constexpr inline auto accumulate(Cursor cursor, Payload init, AccumulateFn accumulate) +{ + for_each(cursor, [&](auto&& cursor) { init = accumulate(init, cursor); }); + return init; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/collect.hpp b/include/pisa/v1/cursor/collect.hpp new file mode 100644 index 000000000..f2d68f66b --- /dev/null +++ b/include/pisa/v1/cursor/collect.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +namespace pisa::v1 { + +template +auto collect(Cursor &&cursor, Transform transform) +{ + std::vector> vec; + while (not cursor.empty()) { + vec.push_back(transform(cursor)); + cursor.advance(); + } + return vec; +} + +template +auto collect(Cursor &&cursor) +{ + return collect(std::forward(cursor), [](auto &&cursor) { return *cursor; }); +} + +template +auto collect_with_payload(Cursor &&cursor) +{ + return collect(std::forward(cursor), + [](auto &&cursor) { return std::make_pair(*cursor, cursor.payload()); }); +} + +template +auto collect_payloads(Cursor &&cursor) +{ + return collect(std::forward(cursor), [](auto &&cursor) { return cursor.payload(); }); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/compact_elias_fano.hpp b/include/pisa/v1/cursor/compact_elias_fano.hpp new file mode 100644 index 000000000..22e050cc2 --- /dev/null +++ b/include/pisa/v1/cursor/compact_elias_fano.hpp @@ -0,0 +1,67 @@ +#include +#include + +#include "util/likely.hpp" + +namespace pisa::v1 { + +struct CompactEliasFanoCursor { + using value_type = std::uint32_t; + + /// Dereferences the current value. + //[[nodiscard]] auto operator*() const -> value_type + //{ + // if (PISA_UNLIKELY(empty())) { + // return sentinel(); + // } + // return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + //} + + ///// Alias for `operator*()`. + //[[nodiscard]] auto value() const noexcept -> value_type { return *(*this); } + + ///// Advances the cursor to the next position. + // constexpr void advance() { m_current += sizeof(T); } + + ///// Moves the cursor to the position `pos`. + // constexpr void advance_to_position(std::size_t pos) { m_current = pos * sizeof(T); } + + ///// Moves the cursor to the next value equal or greater than `value`. + // constexpr void advance_to_geq(T value) + //{ + // while (this->value() < value) { + // advance(); + // } + //} + + ///// Returns `true` if there is no elements left. + //[[nodiscard]] constexpr auto empty() const noexcept -> bool + //{ + // return m_current == m_bytes.size(); + //} + + ///// Returns the current position. + //[[nodiscard]] constexpr auto position() const noexcept -> std::size_t + //{ + // return m_current / sizeof(T); + //} + + ///// Returns the number of elements in the list. + //[[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); + //} + + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return std::numeric_limits::max(); + } + + private: + // BitVector const* m_bv; + // offsets m_of; + + std::size_t m_position; + value_type m_value; + // BitVector::unary_enumerator m_high_enumerator; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/for_each.hpp b/include/pisa/v1/cursor/for_each.hpp new file mode 100644 index 000000000..170b9c2f0 --- /dev/null +++ b/include/pisa/v1/cursor/for_each.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace pisa::v1 { + +template +void for_each(Cursor &&cursor, UnaryOp op) +{ + while (not cursor.empty()) { + op(std::forward(cursor)); + cursor.advance(); + } +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/labeled_cursor.hpp b/include/pisa/v1/cursor/labeled_cursor.hpp new file mode 100644 index 000000000..3f44e0fff --- /dev/null +++ b/include/pisa/v1/cursor/labeled_cursor.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include "v1/cursor_traits.hpp" + +namespace pisa::v1 { + +template +struct LabeledCursor { + using Value = typename CursorTraits::Value; + + explicit constexpr LabeledCursor(Cursor cursor, Label label) + : m_cursor(std::move(cursor)), m_label(std::move(label)) + { + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_cursor.value(); } + [[nodiscard]] constexpr auto payload() noexcept { return m_cursor.payload(); } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Value value) { m_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + // TODO: Support not sized + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Value { return m_cursor.sentinel(); } + [[nodiscard]] constexpr auto label() const -> Label const& { return m_label; } + [[nodiscard]] constexpr auto max_score() const { return m_cursor.max_score(); } + + private: + Cursor m_cursor; + Label m_label; +}; + +template +auto label(Cursor cursor, Label label) +{ + return LabeledCursor(std::move(cursor), std::move(label)); +} + +template +auto label(std::vector cursors, LabelFn&& label_fn) +{ + using label_type = std::decay_t()))>; + std::vector> labeled; + std::transform(cursors.begin(), + cursors.end(), + std::back_inserter(labeled), + [&](Cursor&& cursor) { return label(cursor, label_fn(cursor)); }); + return labeled; +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/lookup_transform.hpp b/include/pisa/v1/cursor/lookup_transform.hpp new file mode 100644 index 000000000..33ff24cbb --- /dev/null +++ b/include/pisa/v1/cursor/lookup_transform.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include + +namespace pisa::v1 { + +/// **Note**: This currently works only for pair cursor with single-term lookup cursors. +/// This callable transforms a cursor by performing lookups to the current document +/// in the given lookup cursors, and then adding the scores that were found. +/// It uses the same short-circuiting rules before each lookup as `UnionLookupJoin`. +template +struct LookupTransform { + + LookupTransform(std::vector lookup_cursors, + float lookup_cursors_upper_bound, + AboveThresholdFn above_threshold, + Inspector* inspect = nullptr) + : m_lookup_cursors(std::move(lookup_cursors)), + m_lookup_cursors_upper_bound(lookup_cursors_upper_bound), + m_above_threshold(std::move(above_threshold)), + m_inspect(inspect) + { + } + + template + auto operator()(Cursor& cursor) + { + auto docid = cursor.value(); + auto scores = cursor.payload(); + float score = std::get<0>(scores) + std::get<1>(scores); + auto upper_bound = score + m_lookup_cursors_upper_bound; + for (auto&& lookup_cursor : m_lookup_cursors) { + //if (docid == 2288) { + // std::cout << fmt::format("[checking] doc: {}\tbound: {}\n", docid, upper_bound); + //} + if (not m_above_threshold(upper_bound)) { + return score; + } + lookup_cursor.advance_to_geq(docid); + //std::cout << fmt::format("[b] doc: {}\tbound: {}\n", docid, upper_bound); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + upper_bound -= lookup_cursor.max_score(); + } + return score; + } + + private: + std::vector m_lookup_cursors; + float m_lookup_cursors_upper_bound; + AboveThresholdFn m_above_threshold; + Inspector* m_inspect; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/reference.hpp b/include/pisa/v1/cursor/reference.hpp new file mode 100644 index 000000000..398e8e387 --- /dev/null +++ b/include/pisa/v1/cursor/reference.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include + +#include "v1/cursor_traits.hpp" + +namespace pisa::v1 { + +template +struct CursorRef { + using Value = typename CursorTraits>::Value; + + constexpr CursorRef(Cursor&& cursor) : m_cursor(std::ref(cursor)) {} + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_cursor.get().value(); } + [[nodiscard]] constexpr auto payload() { return m_cursor.get().payload(); } + constexpr void advance() { m_cursor.get().advance(); } + constexpr void advance_to_geq(Value value) { m_cursor.get().advance_to_geq(value); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.get().advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.get().empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.get().position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.get().size(); } + [[nodiscard]] constexpr auto max_score() const { return m_cursor.get().max_score(); } + + private: + std::reference_wrapper> m_cursor; +}; + +template +auto ref(Cursor&& cursor) +{ + return CursorRef(std::forward(cursor)); +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp new file mode 100644 index 000000000..1fe0d929e --- /dev/null +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -0,0 +1,175 @@ +#pragma once + +#include + +#include + +#include "v1/cursor_traits.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct ScoringCursorTag { +}; +struct MaxScoreCursorTag { +}; +struct BlockMaxScoreCursorTag { +}; + +template +struct ScoringCursor { + using Document = decltype(*std::declval()); + using Payload = decltype((std::declval())(std::declval(), + std::declval().payload())); + using Tag = ScoringCursorTag; + + explicit constexpr ScoringCursor(BaseCursor base_cursor, TermScorer scorer) + : m_base_cursor(std::move(base_cursor)), m_scorer(std::move(scorer)) + { + } + constexpr ScoringCursor(ScoringCursor const&) = default; + constexpr ScoringCursor(ScoringCursor&&) noexcept = default; + constexpr ScoringCursor& operator=(ScoringCursor const&) = default; + constexpr ScoringCursor& operator=(ScoringCursor&&) noexcept = default; + ~ScoringCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() noexcept + { + return m_scorer(m_base_cursor.value(), m_base_cursor.payload()); + } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + + private: + BaseCursor m_base_cursor; + TermScorer m_scorer; +}; + +template +struct MaxScoreCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(std::declval().payload()); + using Tag = MaxScoreCursorTag; + + constexpr MaxScoreCursor(BaseCursor base_cursor, ScoreT max_score) + : m_base_cursor(std::move(base_cursor)), m_max_score(max_score) + { + } + constexpr MaxScoreCursor(MaxScoreCursor const&) = default; + constexpr MaxScoreCursor(MaxScoreCursor&&) noexcept = default; + constexpr MaxScoreCursor& operator=(MaxScoreCursor const&) = default; + constexpr MaxScoreCursor& operator=(MaxScoreCursor&&) noexcept = default; + ~MaxScoreCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Document { return m_base_cursor.value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() noexcept { return m_base_cursor.payload(); } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + [[nodiscard]] constexpr auto max_score() const -> ScoreT { return m_max_score; } + + private: + BaseCursor m_base_cursor; + float m_max_score; +}; + +template +struct BlockMaxScoreCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(std::declval().payload()); + using Tag = BlockMaxScoreCursorTag; + + constexpr BlockMaxScoreCursor(BaseScoredCursor base_cursor, + BlockMaxCursor block_max_cursor, + ScoreT max_score) + : m_base_cursor(std::move(base_cursor)), + m_block_max_cursor(block_max_cursor), + m_max_score(max_score) + { + } + constexpr BlockMaxScoreCursor(BlockMaxScoreCursor const&) = default; + constexpr BlockMaxScoreCursor(BlockMaxScoreCursor&&) noexcept = default; + constexpr BlockMaxScoreCursor& operator=(BlockMaxScoreCursor const&) = default; + constexpr BlockMaxScoreCursor& operator=(BlockMaxScoreCursor&&) noexcept = default; + ~BlockMaxScoreCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Document { return m_base_cursor.value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() noexcept { return m_base_cursor.payload(); } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + [[nodiscard]] constexpr auto max_score() const -> float { return m_max_score; } + [[nodiscard]] constexpr auto block_max_docid() { return m_block_max_cursor.value(); } + [[nodiscard]] constexpr auto block_max_score() { return m_block_max_cursor.payload(); } + [[nodiscard]] constexpr auto block_max_score(DocId docid) + { + m_block_max_cursor.advance_to_geq(docid); + return m_block_max_cursor.payload(); + } + + private: + BaseScoredCursor m_base_cursor; + BlockMaxCursor m_block_max_cursor; + float m_max_score; +}; + +template +[[nodiscard]] auto block_max_score_cursor(BaseScoredCursor base_cursor, + BlockMaxCursor block_max_cursor, + ScoreT max_score) +{ + return BlockMaxScoreCursor( + std::move(base_cursor), std::move(block_max_cursor), max_score); +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/transform.hpp b/include/pisa/v1/cursor/transform.hpp new file mode 100644 index 000000000..f19688981 --- /dev/null +++ b/include/pisa/v1/cursor/transform.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +struct TransformCursor { + using Value = + decltype(std::declval(std::declval>())); + + constexpr TransformCursor(Cursor cursor, TransformFn transform) + : m_cursor(std::move(cursor)), m_transform(std::move(transform)) + { + } + TransformCursor(TransformCursor&&) noexcept = default; + TransformCursor(TransformCursor const&) = default; + TransformCursor& operator=(TransformCursor&&) noexcept = default; + TransformCursor& operator=(TransformCursor const&) = default; + ~TransformCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value + { + return m_transform(m_cursor.value()); + } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + + private: + Cursor m_cursor; + TransformFn m_transform; +}; + +template +auto transform(Cursor cursor, TransformFn transform) +{ + return TransformCursor(std::move(cursor), std::move(transform)); +} + +template +struct TransformPayloadCursor { + constexpr TransformPayloadCursor(Cursor cursor, TransformFn transform) + : m_cursor(std::move(cursor)), m_transform(std::move(transform)) + { + } + TransformPayloadCursor(TransformPayloadCursor&&) noexcept = default; + TransformPayloadCursor(TransformPayloadCursor const&) = default; + TransformPayloadCursor& operator=(TransformPayloadCursor&&) noexcept = default; + TransformPayloadCursor& operator=(TransformPayloadCursor const&) = default; + ~TransformPayloadCursor() = default; + + [[nodiscard]] constexpr auto operator*() const { return value(); } + [[nodiscard]] constexpr auto value() const noexcept { return m_cursor.value(); } + [[nodiscard]] constexpr auto payload() { return m_transform(m_cursor); } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const { return m_cursor.sentinel(); } + + private: + Cursor m_cursor; + TransformFn m_transform; +}; + +template +auto transform_payload(Cursor cursor, TransformFn transform) +{ + return TransformPayloadCursor(std::move(cursor), std::move(transform)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_accumulator.hpp b/include/pisa/v1/cursor_accumulator.hpp new file mode 100644 index 000000000..f12b76374 --- /dev/null +++ b/include/pisa/v1/cursor_accumulator.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace pisa::v1::accumulators { + +struct Add { + template + auto operator()(Score&& score, Cursor&& cursor) + { + score += cursor.payload(); + return score; + } +}; + +template +struct InspectAdd { + constexpr explicit InspectAdd(Inspect* inspect) : m_inspect(inspect) {} + + template + auto operator()(Score&& score, Cursor&& cursor, std::size_t /* term_idx */) + { + if constexpr (not std::is_void_v) { + m_inspect->posting(); + } + score += cursor.payload(); + return score; + } + + private: + Inspect* m_inspect; +}; + +} // namespace pisa::v1::accumulate diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp new file mode 100644 index 000000000..ccb8812c7 --- /dev/null +++ b/include/pisa/v1/cursor_intersection.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include "util/likely.hpp" + +namespace pisa::v1 { + +/// Transforms a list of cursors into one cursor by lazily merging them together +/// into an intersection. +template +struct CursorIntersection { + using Cursor = typename CursorContainer::value_type; + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + using Value = std::decay_t())>; + + constexpr CursorIntersection(CursorContainer cursors, Payload init, AccumulateFn accumulate) + : m_unordered_cursors(std::move(cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_cursor_mapping(m_unordered_cursors.size()) + { + Expects(not m_unordered_cursors.empty()); + std::iota(m_cursor_mapping.begin(), m_cursor_mapping.end(), 0); + auto order = [&](auto lhs, auto rhs) { + return m_unordered_cursors[lhs].size() < m_unordered_cursors[rhs].size(); + }; + std::sort(m_cursor_mapping.begin(), m_cursor_mapping.end(), order); + std::transform(m_cursor_mapping.begin(), + m_cursor_mapping.end(), + std::back_inserter(m_cursors), + [&](auto idx) { return std::ref(m_unordered_cursors[idx]); }); + m_sentinel = std::min_element(m_unordered_cursors.begin(), + m_unordered_cursors.end(), + [](auto const &lhs, auto const &rhs) { + return lhs.sentinel() < rhs.sentinel(); + }) + ->sentinel(); + m_candidate = *m_cursors[0].get(); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } + + constexpr void advance() + { + while (PISA_LIKELY(m_candidate < sentinel())) { + for (; m_next_cursor < m_cursors.size(); ++m_next_cursor) { + Cursor &cursor = m_cursors[m_next_cursor]; + cursor.advance_to_geq(m_candidate); + if (*cursor != m_candidate) { + m_candidate = *cursor; + m_next_cursor = 0; + break; + } + } + if (m_next_cursor == m_cursors.size()) { + m_current_payload = m_init; + for (auto idx = 0; idx < m_cursors.size(); ++idx) { + m_current_payload = m_accumulate( + m_current_payload, m_cursors[idx].get(), m_cursor_mapping[idx]); + } + m_cursors[0].get().advance(); + m_current_value = std::exchange(m_candidate, *m_cursors[0].get()); + m_next_cursor = 1; + return; + } + } + m_current_value = sentinel(); + m_current_payload = m_init; + } + + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const & + { + return m_current_payload; + } + + constexpr void advance_to_position(std::size_t pos); // TODO(michal) + constexpr void advance_to_geq(Value value); // TODO(michal) + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::size_t { return m_sentinel; } + [[nodiscard]] constexpr auto size() const -> std::size_t = delete; + + private: + CursorContainer m_unordered_cursors; + Payload m_init; + AccumulateFn m_accumulate; + std::vector m_cursor_mapping; + + std::vector> m_cursors; + Value m_current_value{}; + Value m_candidate{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_cursor = 1; +}; + +template +[[nodiscard]] constexpr inline auto intersect(CursorContainer cursors, + Payload init, + AccumulateFn accumulate) +{ + return CursorIntersection( + std::move(cursors), std::move(init), std::move(accumulate)); +} + +template +[[nodiscard]] constexpr inline auto intersect(std::initializer_list cursors, + Payload init, + AccumulateFn accumulate) +{ + std::vector cursor_container(cursors); + return CursorIntersection, Payload, AccumulateFn>( + std::move(cursor_container), std::move(init), std::move(accumulate)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_traits.hpp b/include/pisa/v1/cursor_traits.hpp new file mode 100644 index 000000000..186962256 --- /dev/null +++ b/include/pisa/v1/cursor_traits.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct CursorTraits; + +template +struct EncodingTraits; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp new file mode 100644 index 000000000..81fd9d303 --- /dev/null +++ b/include/pisa/v1/cursor_union.hpp @@ -0,0 +1,375 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "util/likely.hpp" +#include "v1/algorithm.hpp" + +namespace pisa::v1 { + +template +void init_payload(T& payload, T const& initial_value) +{ + payload = initial_value; +} + +template <> +inline void init_payload(std::vector& payload, std::vector const& initial_value) +{ + std::copy(initial_value.begin(), initial_value.end(), payload.begin()); +} + +/// Transforms a list of cursors into one cursor by lazily merging them together. +template +struct CursorUnion { + using Cursor = typename CursorContainer::value_type; + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + using Value = std::decay_t())>; + + constexpr CursorUnion(CursorContainer cursors, Payload init, AccumulateFn accumulate) + : m_cursors(std::move(cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_size(std::nullopt) + { + m_current_payload = m_init; + if (m_cursors.empty()) { + m_current_value = std::numeric_limits::max(); + } else { + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { + m_current_value = m_sentinel; + ::pisa::v1::init_payload(m_current_payload, m_init); + } else { + ::pisa::v1::init_payload(m_current_payload, m_init); + m_current_value = m_next_docid; + m_next_docid = m_sentinel; + std::size_t cursor_idx = 0; + for (auto& cursor : m_cursors) { + if (cursor.value() == m_current_value) { + m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); + cursor.advance(); + } + if (auto value = cursor.value(); value < m_next_docid) { + m_next_docid = value; + } + ++cursor_idx; + } + } + } + + constexpr void advance_to_position(std::size_t pos); // TODO(michal) + constexpr void advance_to_geq(Value value); // TODO(michal) + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorContainer m_cursors; + Payload m_init; + AccumulateFn m_accumulate; + std::optional m_size; + + Value m_current_value{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_docid{}; +}; + +/// Transforms a list of cursors into one cursor by lazily merging them together. +// template +// struct CursorFlatUnion { +// using Cursor = typename CursorContainer::value_type; +// using iterator_category = +// typename std::iterator_traits::iterator_category; +// static_assert(std::is_base_of(), +// "cursors must be stored in a random access container"); +// using Value = std::decay_t())>; +// +// explicit constexpr CursorFlatUnion(CursorContainer cursors) : m_cursors(std::move(cursors)) +// { +// m_current_payload = m_init; +// if (m_cursors.empty()) { +// m_current_value = std::numeric_limits::max(); +// } else { +// m_next_docid = min_value(m_cursors); +// m_sentinel = min_sentinel(m_cursors); +// advance(); +// } +// } +// +// [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& +// { +// return m_current_payload; +// } +// [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } +// +// constexpr void advance() +// { +// //if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { +// // m_current_value = m_sentinel; +// //} else { +// // m_current_value = m_next_docid; +// // m_next_docid = m_sentinel; +// // std::size_t cursor_idx = 0; +// // for (auto& cursor : m_cursors) { +// // if (cursor.value() == m_current_value) { +// // m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); +// // cursor.advance(); +// // } +// // if (cursor.value() < m_next_docid) { +// // m_next_docid = cursor.value(); +// // } +// // ++cursor_idx; +// // } +// //} +// } +// +// constexpr void advance_to_position(std::size_t pos); // TODO(michal) +// constexpr void advance_to_geq(Value value); // TODO(michal) +// [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) +// +// [[nodiscard]] constexpr auto empty() const noexcept -> bool +// { +// return m_current_value >= sentinel(); +// } +// +// private: +// CursorContainer m_cursors; +// +// std::size_t m_cursor_idx = 0; +// Value m_current_value{}; +// Value m_sentinel{}; +// Payload m_current_payload{}; +// std::uint32_t m_next_docid{}; +//}; + +/// Transforms a list of cursors into one cursor by lazily merging them together. +template +struct VariadicCursorUnion { + using Value = std::decay_t(std::declval()))>; + + constexpr VariadicCursorUnion(Payload init, + CursorsTuple cursors, + std::tuple accumulate) + : m_cursors(std::move(cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_size(std::nullopt) + { + m_next_docid = std::numeric_limits::max(); + m_sentinel = std::numeric_limits::min(); + for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { + if (cursor.value() < m_next_docid) { + m_next_docid = cursor.value(); + } + }); + for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { + if (cursor.sentinel() > m_sentinel) { + m_sentinel = cursor.sentinel(); + } + }); + advance(); + } + + template + void for_each_cursor(Fn&& fn) + { + std::apply( + [&](auto&&... cursor) { + std::apply([&](auto&&... accumulate) { (fn(cursor, accumulate), ...); }, + m_accumulate); + }, + m_cursors); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { + m_current_value = m_sentinel; + m_current_payload = m_init; + } else { + m_current_payload = m_init; + m_current_value = m_next_docid; + m_next_docid = m_sentinel; + std::size_t cursor_idx = 0; + for_each_cursor([&](auto&& cursor, auto&& accumulate) { + if (cursor.value() == m_current_value) { + m_current_payload = accumulate(m_current_payload, cursor, cursor_idx); + cursor.advance(); + } + if (cursor.value() < m_next_docid) { + m_next_docid = cursor.value(); + } + ++cursor_idx; + }); + } + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorsTuple m_cursors; + Payload m_init; + std::tuple m_accumulate; + std::optional m_size; + + Value m_current_value{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_docid{}; +}; + +///// Transforms a list of cursors into one cursor by lazily merging them together. +// template +// struct PairCursorUnion { +// using Value = std::pair())>, +// std::decay_t())>>; +// +// constexpr PairCursorUnion(Payload init, +// Cursor1 cursor1, +// Cursor2 cursor2, +// AccumulateFn1 accumulate1, +// AccumulateFn2 accumulate2) +// : m_cursors(std::move(cursors)), +// m_init(std::move(init)), +// m_accumulate(std::move(accumulate)), +// m_size(std::nullopt) +// { +// m_next_docid = std::numeric_limits::max(); +// m_sentinel = std::numeric_limits::max(); +// for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { +// if (cursor.value() < m_next_docid) { +// m_next_docid = cursor.value(); +// } +// }); +// for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { +// std::cerr << fmt::format("Sentinel: {} v. current {}\n", cursor.sentinel(), +// m_sentinel); if (cursor.sentinel() < m_sentinel) { +// m_sentinel = cursor.sentinel(); +// } +// }); +// advance(); +// } +// +// template +// void for_each_cursor(Fn&& fn) +// { +// std::apply( +// [&](auto&&... cursor) { +// std::apply([&](auto&&... accumulate) { (fn(cursor, accumulate), ...); }, +// m_accumulate); +// }, +// m_cursors); +// } +// +// [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& +// { +// return m_current_payload; +// } +// [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } +// +// constexpr void advance() +// { +// if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { +// m_current_value = m_sentinel; +// m_current_payload = m_init; +// } else { +// m_current_payload = m_init; +// m_current_value = m_next_docid; +// m_next_docid = m_sentinel; +// std::size_t cursor_idx = 0; +// for_each_cursor([&](auto&& cursor, auto&& accumulate) { +// if (cursor.value() == m_current_value) { +// m_current_payload = accumulate(m_current_payload, cursor, cursor_idx); +// cursor.advance(); +// } +// if (cursor.value() < m_next_docid) { +// m_next_docid = cursor.value(); +// } +// ++cursor_idx; +// }); +// } +// } +// +// [[nodiscard]] constexpr auto empty() const noexcept -> bool +// { +// return m_current_value >= sentinel(); +// } +// +// private: +// CursorsTuple m_cursors; +// Payload m_init; +// std::tuple m_accumulate; +// std::optional m_size; +// +// Value m_current_value{}; +// Value m_sentinel{}; +// Payload m_current_payload{}; +// std::uint32_t m_next_docid{}; +//}; + +template +[[nodiscard]] constexpr inline auto union_merge(CursorContainer cursors, + Payload init, + AccumulateFn accumulate) +{ + return CursorUnion( + std::move(cursors), std::move(init), std::move(accumulate)); +} + +template +[[nodiscard]] constexpr inline auto variadic_union_merge(Payload init, + std::tuple cursors, + std::tuple accumulate) +{ + return VariadicCursorUnion, AccumulateFn...>( + std::move(init), std::move(cursors), std::move(accumulate)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/daat_and.hpp b/include/pisa/v1/daat_and.hpp new file mode 100644 index 000000000..0f3bc0343 --- /dev/null +++ b/include/pisa/v1/daat_and.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto daat_and(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + std::vector cursors; + std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(cursors), [&](auto term) { + return index.scored_cursor(term, scorer); + }); + auto intersection = + v1::intersect(std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(intersection, [&](auto& cursor) { topk.insert(cursor.payload(), *cursor); }); + return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/daat_or.hpp b/include/pisa/v1/daat_or.hpp new file mode 100644 index 000000000..e9e919b16 --- /dev/null +++ b/include/pisa/v1/daat_or.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto daat_or(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + std::vector cursors; + std::transform(query.get_term_ids().begin(), + query.get_term_ids().end(), + std::back_inserter(cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + score += cursor.payload(); + return score; + }); + v1::for_each(cunion, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectDaatOr : Inspect { + + InspectDaatOr(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + daat_or(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/default_index_runner.hpp b/include/pisa/v1/default_index_runner.hpp new file mode 100644 index 000000000..cec99a48a --- /dev/null +++ b/include/pisa/v1/default_index_runner.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "index_types.hpp" +#include "v1/bit_sequence_cursor.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/index.hpp" +#include "v1/index_metadata.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" + +namespace pisa::v1 { + +[[nodiscard]] inline auto index_runner(IndexMetadata metadata) +{ + return index_runner(std::move(metadata), + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}, + DocumentBitSequenceReader{}, + DocumentBitSequenceReader>{}), + std::make_tuple(RawReader{}, + PayloadBlockedReader<::pisa::simdbp_block>{}, + PayloadBitSequenceReader>{})); +} + +[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata) +{ + return scored_index_runner(std::move(metadata), + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}, + DocumentBitSequenceReader{}, + DocumentBitSequenceReader>{}), + std::make_tuple(RawReader{})); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp new file mode 100644 index 000000000..a137873fb --- /dev/null +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +#include "v1/cursor_traits.hpp" + +namespace pisa::v1 { + +template +struct DocumentPayloadCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(*std::declval()); + + constexpr DocumentPayloadCursor(DocumentCursor key_cursor, PayloadCursor payload_cursor) + : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) + { + } + constexpr DocumentPayloadCursor(DocumentPayloadCursor const&) = default; + constexpr DocumentPayloadCursor(DocumentPayloadCursor&&) noexcept = default; + constexpr DocumentPayloadCursor& operator=(DocumentPayloadCursor const&) = default; + constexpr DocumentPayloadCursor& operator=(DocumentPayloadCursor&&) noexcept = default; + ~DocumentPayloadCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document { return m_key_cursor.value(); } + [[nodiscard]] constexpr auto payload() noexcept -> Payload + { + if (auto pos = m_key_cursor.position(); pos != m_payload_cursor.position()) { + m_payload_cursor.advance_to_position(m_key_cursor.position()); + } + return m_payload_cursor.value(); + } + constexpr void advance() { m_key_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_key_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_key_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_key_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_key_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_key_cursor.sentinel(); } + + private: + DocumentCursor m_key_cursor; + PayloadCursor m_payload_cursor; +}; + +template +[[nodiscard]] auto document_payload_cursor(DocumentCursor key_cursor, PayloadCursor payload_cursor) +{ + return DocumentPayloadCursor(std::move(key_cursor), + std::move(payload_cursor)); +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/encoding_traits.hpp b/include/pisa/v1/encoding_traits.hpp new file mode 100644 index 000000000..983fcf2d2 --- /dev/null +++ b/include/pisa/v1/encoding_traits.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include "codec/simdbp.hpp" +#include "v1/sequence/indexed_sequence.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct SimdBPTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::SimdBP; } +}; + +template <> +struct encoding_traits<::pisa::simdbp_block> { + using encoding_tag = SimdBPTag; +}; + +struct PartitionedSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::PEF; } +}; + +template <> +struct encoding_traits> { + using encoding_tag = PartitionedSequenceTag; +}; + +struct IndexedSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return 17U; } +}; + +template <> +struct encoding_traits { + using encoding_tag = IndexedSequenceTag; +}; + +struct PositiveSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::PositiveSeq; } +}; + +template <> +struct encoding_traits> { + using encoding_tag = PositiveSequenceTag; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp new file mode 100644 index 000000000..41c20472d --- /dev/null +++ b/include/pisa/v1/index.hpp @@ -0,0 +1,454 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "v1/base_index.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/cursor/scoring_cursor.hpp" +#include "v1/document_payload_cursor.hpp" +#include "v1/posting_builder.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/source.hpp" +#include "v1/types.hpp" +#include "v1/zip_cursor.hpp" + +namespace pisa::v1 { + +/// A generic type for an inverted index. +/// +/// \tparam DocumentReader Type of an object that reads document posting lists from bytes +/// It must read lists containing `DocId` objects. +/// \tparam PayloadReader Type of an object that reads payload posting lists from bytes. +/// It can read lists of arbitrary types, such as `Frequency`, +/// `Score`, or `std::pair` for a bigram scored index. +template +struct Index : public BaseIndex { + + using document_cursor_type = DocumentCursor; + using payload_cursor_type = PayloadCursor; + + /// Constructs the index. + /// + /// \param document_reader Reads document posting lists from bytes. + /// \param payload_reader Reads payload posting lists from bytes. + /// TODO(michal)... + /// \param source This object (optionally) owns the raw data pointed at by + /// `documents` and `payloads` to ensure it is valid throughout + /// the lifetime of the index. It should release any resources + /// in its destructor. + template + Index(DocumentReader document_reader, + PayloadReader payload_reader, + PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source) + : BaseIndex(documents, + payloads, + bigrams, + document_lengths, + avg_document_length, + std::move(max_scores), + std::move(block_max_scores), + quantized_max_scores, + source), + m_document_reader(std::move(document_reader)), + m_payload_reader(std::move(payload_reader)) + { + } + + /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). + [[nodiscard]] auto cursor(TermId term) const + { + return DocumentPayloadCursor(documents(term), + payloads(term)); + } + + [[nodiscard]] auto cursors(gsl::span terms) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return cursor(term); + }); + return cursors; + } + + [[nodiscard]] auto bigram_payloads_0(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return m_payload_reader.read(fetch_bigram_payloads<0>(bid)); + }); + } + + [[nodiscard]] auto bigram_payloads_1(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return m_payload_reader.read(fetch_bigram_payloads<1>(bid)); + }); + } + + [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return document_payload_cursor( + m_document_reader.read(fetch_bigram_documents(bid)), + zip(m_payload_reader.read(fetch_bigram_payloads<0>(bid)), + m_payload_reader.read(fetch_bigram_payloads<1>(bid)))); + }); + } + + /// Constructs a new document-score cursor. + template + [[nodiscard]] auto scoring_cursor(TermId term, Scorer&& scorer) const + { + return ScoringCursor(cursor(term), std::forward(scorer).term_scorer(term)); + } + + template + [[nodiscard]] auto scoring_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return scoring_cursor(term, scorer); + }); + return cursors; + } + + /// This is equivalent to the `scoring_cursor` unless the scorer is of type `VoidScorer`, + /// in which case index payloads are treated as scores. + template + [[nodiscard]] auto scored_cursor(TermId term, Scorer&& scorer) const + { + if constexpr (std::is_convertible_v) { + return cursor(term); + } else { + return scoring_cursor(term, std::forward(scorer)); + } + } + + template + [[nodiscard]] auto cursors(gsl::span terms, Fn&& fn) const + { + return ranges::views::transform(terms, [this, &fn](auto term) { return fn(*this, term); }) + | ranges::to_vector; + } + + template + [[nodiscard]] auto scored_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return scored_cursor(term, scorer); + }); + return cursors; + } + + template + [[nodiscard]] auto max_scored_cursor(TermId term, Scorer&& scorer) const + { + using cursor_type = + std::decay_t(scorer)))>; + if constexpr (std::is_convertible_v) { + return MaxScoreCursor( + scored_cursor(term, std::forward(scorer)), quantized_max_score(term)); + } else { + return MaxScoreCursor( + scored_cursor(term, std::forward(scorer)), + max_score(std::hash>{}(scorer), term)); + } + } + + template + [[nodiscard]] auto max_scored_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return max_scored_cursor(term, scorer); + }); + return cursors; + } + + template + [[nodiscard]] auto block_max_scored_cursor(TermId term, Scorer&& scorer) const + { + auto const& document_reader = block_max_document_reader(); + auto const& score_reader = block_max_score_reader(); + // Expects(term + 1 < m_payloads.offsets.size()); + if constexpr (std::is_convertible_v) { + if (false) { // TODO(michal): Workaround for now to avoid explicitly defining return + // type. + return block_max_score_cursor( + scored_cursor(term, std::forward(scorer)), + document_payload_cursor(document_reader.read({}), score_reader.read({})), + 0.0F); + } + throw std::logic_error("Quantized block max scores uimplemented"); + } else { + auto const& data = block_max_scores(std::hash>{}(scorer)); + auto block_max_document_subspan = data.documents.postings.subspan( + data.documents.offsets[term], + data.documents.offsets[term + 1] - data.documents.offsets[term]); + auto block_max_score_subspan = data.payloads.postings.subspan( + data.payloads.offsets[term], + data.payloads.offsets[term + 1] - data.payloads.offsets[term]); + return block_max_score_cursor( + scored_cursor(term, std::forward(scorer)), + document_payload_cursor(document_reader.read(block_max_document_subspan), + score_reader.read(block_max_score_subspan)), + max_score(std::hash>{}(scorer), term)); + } + } + + template + [[nodiscard]] auto block_max_scored_cursors(gsl::span terms, + Scorer&& scorer) const + { + std::vector(0, std::forward(scorer)))> + cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return block_max_scored_cursor(term, scorer); + }); + return cursors; + } + + /// Constructs a new document-score cursor. + template + [[nodiscard]] auto scoring_bigram_cursor(TermId left_term, + TermId right_term, + Scorer&& scorer) const + { + return bigram_cursor(left_term, right_term) + .take() + .map([scorer = std::forward(scorer), left_term, right_term](auto cursor) { + return ScoringCursor(cursor, + [scorers = std::make_tuple(scorer.term_scorer(left_term), + scorer.term_scorer(right_term))]( + auto&& docid, auto&& payload) { + return std::array{ + std::get<0>(scorers)(docid, std::get<0>(payload)), + std::get<1>(scorers)(docid, std::get<1>(payload))}; + }); + }); + } + + template + [[nodiscard]] auto scored_bigram_cursor(TermId left_term, + TermId right_term, + Scorer&& scorer) const + { + if constexpr (std::is_convertible_v) { + return bigram_cursor(left_term, right_term); + } else { + return scoring_bigram_cursor(left_term, right_term, std::forward(scorer)); + } + } + + /// Constructs a new document cursor. + [[nodiscard]] auto documents(TermId term) const + { + assert_term_in_bounds(term); + return m_document_reader.read(fetch_documents(term)); + } + + /// Constructs a new payload cursor. + [[nodiscard]] auto payloads(TermId term) const + { + assert_term_in_bounds(term); + return m_payload_reader.read(fetch_payloads(term)); + } + + [[nodiscard]] auto term_posting_count(TermId term) const -> std::uint32_t + { + // TODO(michal): Should be done more efficiently. + return documents(term).size(); + } + + [[nodiscard]] auto block_max_document_reader() const -> Reader> const& + { + return m_block_max_document_reader; + } + + [[nodiscard]] auto block_max_score_reader() const -> Reader> const& + { + return m_block_max_score_reader; + } + + private: + Reader m_document_reader; + Reader m_payload_reader; + + Reader> m_block_max_document_reader = + Reader>(RawReader{}); + Reader> m_block_max_score_reader = + Reader>(RawReader{}); +}; + +template +auto make_index(DocumentReader document_reader, + PayloadReader payload_reader, + PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source) +{ + using DocumentCursor = + decltype(document_reader.read(std::declval>())); + using PayloadCursor = decltype(payload_reader.read(std::declval>())); + return Index(std::move(document_reader), + std::move(payload_reader), + documents, + payloads, + bigrams, + document_lengths, + avg_document_length, + std::move(max_scores), + std::move(block_max_scores), + quantized_max_scores, + std::move(source)); +} + +template +struct IndexRunner { + template + IndexRunner(PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source, + DocumentReaders document_readers, + PayloadReaders payload_readers) + : m_documents(documents), + m_payloads(payloads), + m_bigrams(bigrams), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length), + m_max_scores(std::move(max_scores)), + m_block_max_scores(std::move(block_max_scores)), + m_max_quantized_scores(quantized_max_scores), + m_source(std::move(source)), + m_document_readers(std::move(document_readers)), + m_payload_readers(std::move(payload_readers)) + { + } + + template + auto operator()(Fn fn) + { + auto dheader = PostingFormatHeader::parse(m_documents.postings.first(8)); + auto pheader = PostingFormatHeader::parse(m_payloads.postings.first(8)); + auto run = [&](auto&& dreader, auto&& preader) { + if (std::decay_t::encoding() == dheader.encoding + && std::decay_t::encoding() == pheader.encoding + && is_type::value_type>(dheader.type) + && is_type::value_type>(pheader.type)) { + auto block_max_scores = m_block_max_scores; + for (auto& [key, data] : block_max_scores) { + data.documents.postings = data.documents.postings.subspan(8); + data.payloads.postings = data.payloads.postings.subspan(8); + } + fn(make_index( + std::forward(dreader), + std::forward(preader), + PostingData{.postings = m_documents.postings.subspan(8), + .offsets = m_documents.offsets}, + PostingData{.postings = m_payloads.postings.subspan(8), + .offsets = m_payloads.offsets}, + m_bigrams.map([](auto&& bigram_data) { + return BigramData{ + .documents = + PostingData{.postings = bigram_data.documents.postings.subspan(8), + .offsets = bigram_data.documents.offsets}, + .payloads = + std::array{ + PostingData{ + .postings = + std::get<0>(bigram_data.payloads).postings.subspan(8), + .offsets = std::get<0>(bigram_data.payloads).offsets}, + PostingData{ + .postings = + std::get<1>(bigram_data.payloads).postings.subspan(8), + .offsets = std::get<1>(bigram_data.payloads).offsets}}, + .mapping = bigram_data.mapping}; + }), + m_document_lengths, + m_avg_document_length, + m_max_scores, + block_max_scores, + m_max_quantized_scores, + false)); + return true; + } + return false; + }; + auto result = std::apply( + [&](auto... dreaders) { + auto with_document_reader = [&](auto dreader) { + return std::apply( + [&](auto... preaders) { return (run(dreader, preaders) || ...); }, + m_payload_readers); + }; + return (with_document_reader(dreaders) || ...); + }, + m_document_readers); + if (not result) { + std::ostringstream os; + os << fmt::format( + "Unknown posting encoding. Requested document: " + "{:x} ({:b}), payload: {:x} ({:b})\n", + dheader.encoding, + static_cast(to_byte(dheader.type)), + pheader.encoding, + static_cast(to_byte(pheader.type))); + auto print_reader = [&](auto&& reader) { + os << fmt::format( + "\t{:x} ({:b})\n", + reader.encoding(), + static_cast(to_byte( + value_type::value_type>()))); + }; + os << "Available document readers: \n"; + std::apply([&](auto... readers) { (print_reader(readers), ...); }, m_document_readers); + os << "Available payload readers: \n"; + std::apply([&](auto... readers) { (print_reader(readers), ...); }, m_payload_readers); + throw std::domain_error(os.str()); + } + } + + private: + PostingData m_documents; + PostingData m_payloads; + tl::optional m_bigrams; + + gsl::span m_document_lengths; + tl::optional m_avg_document_length; + std::unordered_map> m_max_scores; + std::unordered_map m_block_max_scores; + gsl::span m_max_quantized_scores; + tl::optional const>> m_bigram_mapping; + std::any m_source; + DocumentReaders m_document_readers; + PayloadReaders m_payload_readers; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp new file mode 100644 index 000000000..8746b6306 --- /dev/null +++ b/include/pisa/v1/index_builder.hpp @@ -0,0 +1,287 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "binary_freq_collection.hpp" +#include "v1/cursor/accumulate.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/index.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/scorer/bm25.hpp" + +namespace pisa::v1 { + +template +struct IndexBuilder { + explicit IndexBuilder(DocumentWriters document_writers, PayloadWriters payload_writers) + : m_document_writers(std::move(document_writers)), + m_payload_writers(std::move(payload_writers)) + { + } + + template + void operator()(Encoding document_encoding, Encoding payload_encoding, Fn fn) + { + auto run = [&](auto&& dwriter, auto&& pwriter) -> bool { + if (std::decay_t::encoding() == document_encoding + && std::decay_t::encoding() == payload_encoding) { + fn(dwriter, pwriter); + return true; + } + return false; + }; + bool success = std::apply( + [&](auto... dwriters) { + auto with_document_writer = [&](auto dwriter) { + return std::apply( + [&](auto... pwriters) { return (run(dwriter, pwriters) || ...); }, + m_payload_writers); + }; + return (with_document_writer(dwriters) || ...); + }, + m_document_writers); + if (not success) { + throw std::domain_error("Unknown posting encoding"); + } + } + + private: + DocumentWriters m_document_writers; + PayloadWriters m_payload_writers; +}; + +template +auto make_index_builder(DocumentWriters document_writers, PayloadWriters payload_writers) +{ + return IndexBuilder(std::move(document_writers), + std::move(payload_writers)); +} + +template +auto compress_batch(CollectionIterator first, + CollectionIterator last, + std::ofstream& dout, + std::ofstream& fout, + Writer document_writer, + Writer frequency_writer, + tl::optional bar) + -> std::tuple, std::vector> +{ + PostingBuilder document_builder(std::move(document_writer)); + PostingBuilder frequency_builder(std::move(frequency_writer)); + for (auto pos = first; pos != last; ++pos) { + auto dseq = pos->docs; + auto fseq = pos->freqs; + for (auto doc : dseq) { + document_builder.accumulate(doc); + } + for (auto freq : fseq) { + frequency_builder.accumulate(freq); + } + document_builder.flush_segment(dout); + frequency_builder.flush_segment(fout); + *bar += 1; + } + return std::make_tuple(std::move(document_builder.offsets()), + std::move(frequency_builder.offsets())); +} + +template +void write_span(gsl::span offsets, std::ofstream& os) +{ + auto bytes = gsl::as_bytes(offsets); + os.write(reinterpret_cast(bytes.data()), bytes.size()); +} + +template +void write_span(gsl::span offsets, std::string const& file) +{ + std::ofstream os(file); + write_span(offsets, os); +} + +inline void compress_binary_collection(std::string const& input, + std::string_view fwd, + std::string_view output, + std::size_t const threads, + Writer document_writer, + Writer frequency_writer) +{ + pisa::binary_freq_collection const collection(input.c_str()); + document_writer.init(collection); + frequency_writer.init(collection); + ProgressStatus status(collection.size(), + DefaultProgressCallback("Compressing in parallel"), + std::chrono::milliseconds(100)); + tbb::task_group group; + auto const num_terms = collection.size(); + std::vector> document_offsets(threads); + std::vector> frequency_offsets(threads); + std::vector document_paths; + std::vector frequency_paths; + std::vector document_streams; + std::vector frequency_streams; + auto for_each_batch = [threads](auto fn) { + for (auto thread_idx = 0; thread_idx < threads; thread_idx += 1) { + fn(thread_idx); + } + }; + for_each_batch([&](auto thread_idx) { + auto document_batch = + document_paths.emplace_back(fmt::format("{}.doc.batch.{}", output, thread_idx)); + auto frequency_batch = + frequency_paths.emplace_back(fmt::format("{}.freq.batch.{}", output, thread_idx)); + document_streams.emplace_back(document_batch); + frequency_streams.emplace_back(frequency_batch); + }); + auto batch_size = num_terms / threads; + for_each_batch([&](auto thread_idx) { + group.run([thread_idx, + batch_size, + threads, + &collection, + &document_streams, + &frequency_streams, + &document_offsets, + &frequency_offsets, + &status, + &document_writer, + &frequency_writer]() { + auto first = std::next(collection.begin(), thread_idx * batch_size); + auto last = [&]() { + if (thread_idx == threads - 1) { + return collection.end(); + } + return std::next(collection.begin(), (thread_idx + 1) * batch_size); + }(); + auto& dout = document_streams[thread_idx]; + auto& fout = frequency_streams[thread_idx]; + std::tie(document_offsets[thread_idx], frequency_offsets[thread_idx]) = + compress_batch(first, + last, + dout, + fout, + document_writer, + frequency_writer, + tl::make_optional(status)); + }); + }); + group.wait(); + document_streams.clear(); + frequency_streams.clear(); + + std::vector all_document_offsets; + std::vector all_frequency_offsets; + all_document_offsets.reserve(num_terms + 1); + all_frequency_offsets.reserve(num_terms + 1); + all_document_offsets.push_back(0); + all_frequency_offsets.push_back(0); + auto documents_file = fmt::format("{}.documents", output); + auto frequencies_file = fmt::format("{}.frequencies", output); + std::ofstream document_out(documents_file); + std::ofstream frequency_out(frequencies_file); + + PostingBuilder(document_writer).write_header(document_out); + PostingBuilder(frequency_writer).write_header(frequency_out); + + { + ProgressStatus merge_status( + threads, DefaultProgressCallback("Merging files"), std::chrono::milliseconds(500)); + for_each_batch([&](auto thread_idx) { + std::transform( + std::next(document_offsets[thread_idx].begin()), + document_offsets[thread_idx].end(), + std::back_inserter(all_document_offsets), + [base = all_document_offsets.back()](auto offset) { return base + offset; }); + std::transform( + std::next(frequency_offsets[thread_idx].begin()), + frequency_offsets[thread_idx].end(), + std::back_inserter(all_frequency_offsets), + [base = all_frequency_offsets.back()](auto offset) { return base + offset; }); + std::ifstream docbatch(document_paths[thread_idx]); + std::ifstream freqbatch(frequency_paths[thread_idx]); + document_out << docbatch.rdbuf(); + frequency_out << freqbatch.rdbuf(); + merge_status += 1; + }); + } + + std::cerr << "Writing offsets..."; + auto doc_offset_file = fmt::format("{}.document_offsets", output); + auto freq_offset_file = fmt::format("{}.frequency_offsets", output); + write_span(gsl::span(all_document_offsets), doc_offset_file); + write_span(gsl::span(all_frequency_offsets), freq_offset_file); + std::cerr << " Done.\n"; + + std::cerr << "Writing sizes..."; + auto lengths = read_sizes(input); + auto document_lengths_file = fmt::format("{}.document_lengths", output); + write_span(gsl::span(lengths), document_lengths_file); + float avg_len = calc_avg_length(gsl::span(lengths)); + std::cerr << " Done.\n"; + + IndexMetadata{ + .documents = PostingFilePaths{.postings = documents_file, .offsets = doc_offset_file}, + .frequencies = PostingFilePaths{.postings = frequencies_file, .offsets = freq_offset_file}, + .scores = {}, + .document_lengths_path = document_lengths_file, + .avg_document_length = avg_len, + .term_lexicon = tl::make_optional(fmt::format("{}.termlex", fwd)), + .document_lexicon = tl::make_optional(fmt::format("{}.doclex", fwd)), + .stemmer = tl::make_optional("porter2")} + .write(fmt::format("{}.yml", output)); +} + +template +auto bigram_gain(Index&& index, Query const& bigram) -> float +{ + auto&& term_ids = bigram.get_term_ids(); + runtime_assert(term_ids.size() == 2).or_throw("Queries must be of exactly two unique terms"); + auto cursors = index.scored_cursors(term_ids, make_bm25(index)); + auto union_length = cursors[0].size() + cursors[1].size(); + auto intersection_length = + accumulate(intersect(std::move(cursors), + false, + []([[maybe_unused]] auto count, + [[maybe_unused]] auto&& cursor, + [[maybe_unused]] auto idx) { return true; }), + std::size_t{0}, + [](auto count, [[maybe_unused]] auto&& cursor) { return count + 1; }); + if (intersection_length == 0) { + return 0.0; + } + return static_cast(bigram.get_probability()) * static_cast(union_length) + / static_cast(intersection_length); +} + +auto verify_compressed_index(std::string const& input, std::string_view output) + -> std::vector; + +auto collect_unique_bigrams(std::vector const& queries, + std::function const& callback) + -> std::vector>; + +[[nodiscard]] auto select_best_bigrams(IndexMetadata const& meta, + std::vector const& queries, + std::size_t num_bigrams_to_select) + -> std::vector>; + +auto build_bigram_index(IndexMetadata meta, + std::vector> const& bigrams, + tl::optional const& clone_path) -> IndexMetadata; + +auto build_pair_index(IndexMetadata meta, + std::vector> const& pairs, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp new file mode 100644 index 000000000..6f8651cce --- /dev/null +++ b/include/pisa/v1/index_metadata.hpp @@ -0,0 +1,286 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "v1/index.hpp" +#include "v1/query.hpp" +#include "v1/source.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto append_extension(std::string file_path) -> std::string; + +/// Return the passed file path if is not `nullopt`. +/// Otherwise, look for an `.yml` file in the current directory. +/// It will throw if no `.yml` file is found or there are multiple `.yml` files. +[[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string; + +template +[[nodiscard]] auto convert_optional(Optional opt) +{ + if (opt) { + return tl::make_optional(*opt); + } + return tl::optional>(); +} + +template +[[nodiscard]] auto to_std(tl::optional opt) -> std::optional +{ + if (opt) { + return std::make_optional(opt.take()); + } + return std::optional>(); +} + +struct PostingFilePaths { + std::string postings; + std::string offsets; +}; + +struct UnigramFilePaths { + PostingFilePaths documents; + PostingFilePaths payloads; +}; + +struct BigramMetadata { + PostingFilePaths documents; + std::pair frequencies; + std::vector> scores{}; + std::string mapping; + std::size_t count; +}; + +struct IndexMetadata final { + tl::optional basename{}; + PostingFilePaths documents; + PostingFilePaths frequencies; + std::vector scores{}; + std::string document_lengths_path; + float avg_document_length; + tl::optional term_lexicon{}; + tl::optional document_lexicon{}; + tl::optional stemmer{}; + tl::optional bigrams{}; + std::map max_scores{}; + std::map block_max_scores{}; + std::map quantized_max_scores{}; + + void write(std::string const& file) const; + void update() const; + [[nodiscard]] auto query_parser(tl::optional const& stop_words = tl::nullopt) const + -> std::function; + [[nodiscard]] auto get_basename() const -> std::string const&; + [[nodiscard]] static auto from_file(std::string const& file) -> IndexMetadata; +}; + +template +[[nodiscard]] auto to_span(mio::mmap_source const* mmap) +{ + static_assert(std::is_trivially_constructible_v); + return gsl::span(reinterpret_cast(mmap->data()), mmap->size() / sizeof(T)); +}; + +template +[[nodiscard]] auto source_span(MMapSource& source, std::string const& file) +{ + return to_span( + source.file_sources.emplace_back(std::make_shared(file)).get()); +}; + +template +[[nodiscard]] inline auto index_runner(IndexMetadata metadata, + DocumentReaders document_readers, + PayloadReaders payload_readers) +{ + MMapSource source; + auto documents = source_span(source, metadata.documents.postings); + auto frequencies = source_span(source, metadata.frequencies.postings); + auto document_offsets = source_span(source, metadata.documents.offsets); + auto frequency_offsets = source_span(source, metadata.frequencies.offsets); + auto document_lengths = source_span(source, metadata.document_lengths_path); + auto bigrams = [&]() -> tl::optional { + gsl::span bigram_document_offsets{}; + std::array, 2> bigram_frequency_offsets{}; + gsl::span bigram_documents{}; + std::array, 2> bigram_frequencies{}; + gsl::span const> bigram_mapping{}; + if (metadata.bigrams) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_frequency_offsets = { + source_span(source, metadata.bigrams->frequencies.first.offsets), + source_span(source, metadata.bigrams->frequencies.second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_frequencies = { + source_span(source, metadata.bigrams->frequencies.first.postings), + source_span(source, metadata.bigrams->frequencies.second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + return BigramData{ + .documents = {.postings = bigram_documents, .offsets = bigram_document_offsets}, + .payloads = + std::array{ + PostingData{.postings = std::get<0>(bigram_frequencies), + .offsets = std::get<0>(bigram_frequency_offsets)}, + PostingData{.postings = std::get<1>(bigram_frequencies), + .offsets = std::get<1>(bigram_frequency_offsets)}}, + .mapping = bigram_mapping}; + } + return tl::nullopt; + }(); + std::unordered_map> max_scores; + if (not metadata.max_scores.empty()) { + for (auto [name, file] : metadata.max_scores) { + auto bytes = source_span(source, file); + max_scores[std::hash{}(name)] = gsl::span( + reinterpret_cast(bytes.data()), bytes.size() / (sizeof(float))); + } + } + std::unordered_map block_max_scores; + if (not metadata.block_max_scores.empty()) { + for (auto [name, files] : metadata.block_max_scores) { + auto document_bytes = source_span(source, files.documents.postings); + auto document_offsets = source_span(source, files.documents.offsets); + auto payload_bytes = source_span(source, files.payloads.postings); + auto payload_offsets = source_span(source, files.payloads.offsets); + block_max_scores[std::hash{}(name)] = + UnigramData{.documents = {.postings = document_bytes, .offsets = document_offsets}, + .payloads = {.postings = payload_bytes, .offsets = payload_offsets}}; + } + } + return IndexRunner( + {.postings = documents, .offsets = document_offsets}, + {.postings = frequencies, .offsets = frequency_offsets}, + bigrams, + document_lengths, + tl::make_optional(metadata.avg_document_length), + std::move(max_scores), + std::move(block_max_scores), + {}, + std::move(source), + std::move(document_readers), + std::move(payload_readers)); +} + +template +[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, + DocumentReaders document_readers, + PayloadReaders payload_readers) +{ + MMapSource source; + auto documents = source_span(source, metadata.documents.postings); + // TODO(michal): support many precomputed scores + auto scores = source_span(source, metadata.scores.front().postings); + auto document_offsets = source_span(source, metadata.documents.offsets); + auto score_offsets = source_span(source, metadata.scores.front().offsets); + auto document_lengths = source_span(source, metadata.document_lengths_path); + auto bigrams = [&]() -> tl::optional { + gsl::span bigram_document_offsets{}; + std::array, 2> bigram_score_offsets{}; + gsl::span bigram_documents{}; + std::array, 2> bigram_scores{}; + gsl::span const> bigram_mapping{}; + if (metadata.bigrams && not metadata.bigrams->scores.empty()) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_score_offsets = { + source_span(source, metadata.bigrams->scores[0].first.offsets), + source_span(source, metadata.bigrams->scores[0].second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_scores = { + source_span(source, metadata.bigrams->scores[0].first.postings), + source_span(source, metadata.bigrams->scores[0].second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + return BigramData{ + .documents = {.postings = bigram_documents, .offsets = bigram_document_offsets}, + .payloads = + std::array{ + PostingData{.postings = std::get<0>(bigram_scores), + .offsets = std::get<0>(bigram_score_offsets)}, + PostingData{.postings = std::get<1>(bigram_scores), + .offsets = std::get<0>(bigram_score_offsets)}}, + .mapping = bigram_mapping}; + } + return tl::nullopt; + }(); + gsl::span quantized_max_scores; + if (not metadata.quantized_max_scores.empty()) { + // TODO(michal): support many precomputed scores + for (auto [name, file] : metadata.quantized_max_scores) { + quantized_max_scores = source_span(source, file); + } + } + return IndexRunner( + {.postings = documents, .offsets = document_offsets}, + {.postings = scores, .offsets = score_offsets}, + bigrams, + document_lengths, + tl::make_optional(metadata.avg_document_length), + {}, + {}, + quantized_max_scores, + std::move(source), + std::move(document_readers), + std::move(payload_readers)); +} + +} // namespace pisa::v1 + +namespace YAML { +template <> +struct convert<::pisa::v1::PostingFilePaths> { + static Node encode(const ::pisa::v1::PostingFilePaths& rhs) + { + Node node; + node["file"] = rhs.postings; + node["offsets"] = rhs.offsets; + return node; + } + + static bool decode(const Node& node, ::pisa::v1::PostingFilePaths& rhs) + { + if (!node.IsMap()) { + return false; + } + + rhs.postings = node["file"].as(); + rhs.offsets = node["offsets"].as(); + return true; + } +}; + +template <> +struct convert<::pisa::v1::UnigramFilePaths> { + static Node encode(const ::pisa::v1::UnigramFilePaths& rhs) + { + Node node; + node["documents"] = convert<::pisa::v1::PostingFilePaths>::encode(rhs.documents); + node["payloads"] = convert<::pisa::v1::PostingFilePaths>::encode(rhs.payloads); + return node; + } + + static bool decode(const Node& node, ::pisa::v1::UnigramFilePaths& rhs) + { + if (!node.IsMap()) { + return false; + } + + rhs.documents = node["documents"].as<::pisa::v1::PostingFilePaths>(); + rhs.payloads = node["payloads"].as<::pisa::v1::PostingFilePaths>(); + return true; + } +}; + +} // namespace YAML diff --git a/include/pisa/v1/inspect_query.hpp b/include/pisa/v1/inspect_query.hpp new file mode 100644 index 000000000..31e3a4c5c --- /dev/null +++ b/include/pisa/v1/inspect_query.hpp @@ -0,0 +1,401 @@ +#pragma once + +#include +#include + +#include +#include + +#include "topk_queue.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto write_delimited(std::ostream& os, + std::string_view sep, + FirstValue first_value, + Values&&... values) -> std::ostream&; + +template +auto write_delimited(std::ostream& os, [[maybe_unused]] std::string_view sep, Value value) + -> std::ostream& +{ + return os << value; +} + +template +auto write_delimited(std::ostream& os, std::string_view sep, std::tuple t) -> std::ostream& +{ + std::apply([&](auto... vals) { write_delimited(os, sep, vals...); }, t); + return os; +} + +template +auto write_delimited(std::ostream& os, std::string_view sep, std::pair p) -> std::ostream& +{ + write_delimited(os, sep, p.first); + os << sep; + return write_delimited(os, sep, p.second); +} + +/// Writes a list of values into a stream separated by `sep`. +template +auto write_delimited(std::ostream& os, + std::string_view sep, + FirstValue first_value, + Values&&... values) -> std::ostream& +{ + write_delimited(os, sep, first_value); + ( + [&] { + os << sep; + write_delimited(os, sep, values); + }(), + ...); + return os; +} + +struct InspectCount { + using value_type = std::size_t; + void reset() { m_current_count = 0; } + void inc(std::size_t n = 1) + { + m_current_count += n; + m_total_count += n; + } + [[nodiscard]] auto get() const -> std::size_t { return m_current_count; } + [[nodiscard]] auto mean(std::size_t n) const -> float + { + return static_cast(m_total_count) / n; + } + + struct Result { + explicit Result(std::size_t value) : m_value(value) {} + [[nodiscard]] auto get() const { return m_value; } + [[nodiscard]] auto operator+(Result const& other) const { return m_value + other.m_value; } + + private: + std::size_t m_value; + }; + + private: + std::size_t m_current_count = 0; + std::size_t m_total_count = 0; +}; + +struct InspectPostings : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto postings() const { return get(); } + }; + void posting() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("postings{}", suffix); } +}; + +struct InspectDocuments : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto documents() const { return get(); } + }; + void document() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("documents{}", suffix); } +}; + +struct InspectLookups : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto lookups() const { return get(); } + }; + void lookup() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("lookups{}", suffix); } +}; + +struct InspectInserts : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto inserts() const { return get(); } + }; + void insert() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("inserts{}", suffix); } +}; + +struct InspectEssential : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto essentials() const { return get(); } + }; + void essential(std::size_t n) { inc(n); } + static auto header(std::string_view suffix) { return fmt::format("essential-terms{}", suffix); } +}; + +template +struct InspectPartitioned { + struct Result { + explicit Result(typename Inspect::Result first, typename Inspect::Result second) + : first(first), second(second), sum(first + second) + { + } + [[nodiscard]] auto get() const + { + return std::make_tuple(sum.get(), first.get(), second.get()); + } + + typename Inspect::Result first; + typename Inspect::Result second; + typename Inspect::Result sum; + }; + using value_type = Result; + + void reset() + { + m_components.first.reset(); + m_components.second.reset(); + } + void inc(std::size_t n = 1) + { + m_components.first.inc(n); + m_components.second.inc(n); + } + [[nodiscard]] auto get() const + { + return Result(m_components.first.get(), m_components.second.get()); + } + [[nodiscard]] auto mean(std::size_t n) const + { + return Result(m_components.first.mean(n), m_components.second.mean(n)); + } + [[nodiscard]] auto first() -> Inspect* { return &m_components.first; } + [[nodiscard]] auto second() -> Inspect* { return &m_components.second; } + static auto header(std::string_view suffix) + { + return std::make_tuple(Inspect::header(fmt::format("{}", suffix)), + Inspect::header(fmt::format("{}_1", suffix)), + Inspect::header(fmt::format("{}_2", suffix))); + } + + private: + std::pair m_components; +}; + +template +struct InspectPair { + struct Result { + explicit Result(typename First::Result first, typename Second::Result second) + : first(first), second(second) + { + } + [[nodiscard]] auto get() const { return std::make_pair(first.get(), second.get()); } + + typename First::Result first; + typename Second::Result second; + }; + using value_type = Result; + + void reset() + { + m_components.first.reset(); + m_components.second.reset(); + } + void inc(std::size_t n = 1) + { + m_components.first.inc(n); + m_components.second.inc(n); + } + [[nodiscard]] auto get() const + { + return Result(m_components.first.get(), m_components.second.get()); + } + [[nodiscard]] auto mean(std::size_t n) const + { + return Result(m_components.first.mean(n), m_components.second.mean(n)); + } + [[nodiscard]] auto first() -> First* { return &m_components.first; } + [[nodiscard]] auto second() -> Second* { return &m_components.second; } + static auto header(std::string_view suffix) + { + return std::make_pair(First::header(fmt::format("{}_1", suffix)), + Second::header(fmt::format("{}_2", suffix))); + } + + private: + std::pair m_components; +}; + +template +struct InspectMany : Stat... { + struct Result : Stat::Result... { + explicit Result(typename Stat::value_type... values) : Stat::Result(values)... {} + Result(Result const& result) = default; + Result(Result&& result) noexcept = default; + Result& operator=(Result const& result) = default; + Result& operator=(Result&& result) noexcept = default; + ~Result() = default; + [[nodiscard]] auto get() const { return std::make_tuple(Stat::Result::get()...); } + [[nodiscard]] auto operator+(Result const& other) const + { + return Result((Stat::Result::get() + other.Stat::Result::get())...); + } + }; + using value_type = Result; + void reset() { (Stat::reset(), ...); } + void inc(std::size_t n = 1) { (Stat::inc(n), ...); } + [[nodiscard]] auto get() const { return Result(Stat::get()...); } + [[nodiscard]] auto mean(std::size_t n) const { return Result(Stat::mean(n)...); } + static auto header(std::string_view suffix) { return std::make_tuple(Stat::header(suffix)...); } +}; + +template +struct Inspect : Stat... { + Inspect(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) {} + + virtual void run(Query const& query, + Index const& index, + Scorer const& scorer, + topk_queue topk) = 0; + + [[nodiscard]] auto mean() const { return InspectResult(Stat::mean(m_count)...); } + + struct InspectResult : Stat::Result... { + explicit InspectResult(typename Stat::value_type... values) : Stat::Result(values)... {} + std::ostream& write(std::ostream& os, std::string_view sep = "\t") + { + return write_delimited(os, sep, Stat::Result::get()...); + } + }; + + [[nodiscard]] auto operator()(Query const& query) -> InspectResult + { + (Stat::reset(), ...); + run(query, m_index, m_scorer, topk_queue(query.k())); + m_count += 1; + return InspectResult(Stat::get()...); + } + static std::ostream& header(std::ostream& os, std::string_view sep = "\t") + { + return write_delimited(os, sep, Stat::header("")...); + } + + private: + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + +struct InspectResult { + + template + explicit constexpr InspectResult(R result) : m_inner(std::make_unique>(result)) + { + } + InspectResult() = default; + InspectResult(InspectResult const& other) : m_inner(other.m_inner->clone()) {} + InspectResult(InspectResult&& other) noexcept = default; + InspectResult& operator=(InspectResult const& other) = delete; + InspectResult& operator=(InspectResult&& other) noexcept = default; + ~InspectResult() = default; + + std::ostream& write(std::ostream& os, std::string_view sep = "\t") + { + return m_inner->write(os, sep); + } + + struct ResultInterface { + ResultInterface() = default; + ResultInterface(ResultInterface const&) = default; + ResultInterface(ResultInterface&&) noexcept = default; + ResultInterface& operator=(ResultInterface const&) = default; + ResultInterface& operator=(ResultInterface&&) noexcept = default; + virtual ~ResultInterface() = default; + virtual std::ostream& write(std::ostream& os, std::string_view sep) = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct ResultImpl : ResultInterface { + explicit ResultImpl(R result) : m_result(std::move(result)) {} + ResultImpl() = default; + ResultImpl(ResultImpl const&) = default; + ResultImpl(ResultImpl&&) noexcept = default; + ResultImpl& operator=(ResultImpl const&) = default; + ResultImpl& operator=(ResultImpl&&) noexcept = default; + ~ResultImpl() override = default; + std::ostream& write(std::ostream& os, std::string_view sep) override + { + return m_result.write(os, sep); + } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_result; + }; + + private: + std::unique_ptr m_inner; +}; + +struct QueryInspector { + + template + explicit constexpr QueryInspector(R writer) + : m_inner(std::make_unique>(writer)) + { + } + QueryInspector() = default; + QueryInspector(QueryInspector const& other) : m_inner(other.m_inner->clone()) {} + QueryInspector(QueryInspector&& other) noexcept = default; + QueryInspector& operator=(QueryInspector const& other) = delete; + QueryInspector& operator=(QueryInspector&& other) noexcept = default; + ~QueryInspector() = default; + + InspectResult operator()(Query const& query) { return m_inner->operator()(query); } + InspectResult mean() { return m_inner->mean(); } + std::ostream& header(std::ostream& os) { return m_inner->header(os); } + + struct InspectorInterface { + InspectorInterface() = default; + InspectorInterface(InspectorInterface const&) = default; + InspectorInterface(InspectorInterface&&) noexcept = default; + InspectorInterface& operator=(InspectorInterface const&) = default; + InspectorInterface& operator=(InspectorInterface&&) noexcept = default; + virtual ~InspectorInterface() = default; + virtual InspectResult operator()(Query const& query) = 0; + virtual InspectResult mean() = 0; + virtual std::ostream& header(std::ostream& os) = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct InspectorImpl : InspectorInterface { + explicit InspectorImpl(R inspect) : m_inspect(std::move(inspect)) {} + InspectorImpl() = default; + InspectorImpl(InspectorImpl const&) = default; + InspectorImpl(InspectorImpl&&) noexcept = default; + InspectorImpl& operator=(InspectorImpl const&) = default; + InspectorImpl& operator=(InspectorImpl&&) noexcept = default; + ~InspectorImpl() override = default; + InspectResult operator()(Query const& query) override + { + return InspectResult(m_inspect(query)); + } + InspectResult mean() override { return InspectResult(m_inspect.mean()); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + std::ostream& header(std::ostream& os) override { return m_inspect.header(os); } + + private: + R m_inspect; + }; + + private: + std::unique_ptr m_inner; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/intersection.hpp b/include/pisa/v1/intersection.hpp new file mode 100644 index 000000000..c44a57989 --- /dev/null +++ b/include/pisa/v1/intersection.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include + +namespace pisa::v1 { + +/// Read a list of intersections. +/// +/// Each line in the format relates to one query, and each space-separated value +/// is an integer intersection representation. These numbers are converted to +/// bitsets, and each 1 at position `i` means that the `i`-th term in the query +/// is present in the intersection. +/// +/// # Example +/// +/// Let `q = a b c d e` be our query. The following line: +/// ``` +/// 1 2 5 16 +/// ``` +/// can be represented as bitsets: +/// ``` +/// 00001 00010 00101 10000 +/// ``` +/// which in turn represent four intersection: a, b, ac, e. +[[nodiscard]] auto read_intersections(std::string const& filename) + -> std::vector>>; +[[nodiscard]] auto read_intersections(std::istream& is) + -> std::vector>>; + +/// Converts a bitset to a vector of positions set to 1. +[[nodiscard]] auto to_vector(std::bitset<64> const& bits) -> std::vector; + +/// Returns a lambda taking a bitset and returning `true` if it has `n` set bits. +[[nodiscard]] inline auto is_n_gram(std::size_t n) +{ + return [n](std::bitset<64> const& bits) -> bool { return bits.count() == n; }; +} + +/// Returns only positions of terms in unigrams. +[[nodiscard]] auto filter_unigrams(std::vector>> const& intersections) + -> std::vector>; + +/// Returns only positions of terms in bigrams. +[[nodiscard]] auto filter_bigrams(std::vector>> const& intersections) + -> std::vector>>; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/io.hpp b/include/pisa/v1/io.hpp new file mode 100644 index 000000000..06e792083 --- /dev/null +++ b/include/pisa/v1/io.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include + +namespace pisa::v1 { + +[[nodiscard]] auto load_bytes(std::string const& data_file) -> std::vector; + +template +[[nodiscard]] auto load_vector(std::string const& data_file) -> std::vector +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + + runtime_assert(size % sizeof(T) == 0).or_exit([&] { + return fmt::format("Tried loading a vector of elements of size {} but size of file is {}", + sizeof(T), + size); + }); + data.resize(size / sizeof(T)); + + runtime_assert(in.read(reinterpret_cast(data.data()), size).good()).or_exit([&] { + return fmt::format("Failed reading ", data_file); + }); + return data; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp new file mode 100644 index 000000000..5902832bd --- /dev/null +++ b/include/pisa/v1/maxscore.hpp @@ -0,0 +1,260 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +struct MaxScoreJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr MaxScoreJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_upper_bounds(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + initialize(); + } + + constexpr MaxScoreJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) + : m_cursors(std::move(cursors)), + m_upper_bounds(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_inspect(inspect) + { + initialize(); + } + + void initialize() + { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } + std::sort(m_cursors.begin(), m_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() < rhs.max_score(); + }); + + m_upper_bounds[0] = m_cursors[0].max_score(); + for (size_t i = 1; i < m_cursors.size(); ++i) { + m_upper_bounds[i] = m_upper_bounds[i - 1] + m_cursors[i].max_score(); + } + + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + while (m_non_essential_count < m_cursors.size() + && not m_above_threshold(m_upper_bounds[m_non_essential_count])) { + m_non_essential_count += 1; + } + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() + || m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + if constexpr (not std::is_void_v) { + m_inspect->document(); + } + + for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); + sorted_position += 1) { + + auto& cursor = m_cursors[sorted_position]; + if (cursor.value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_inspect->posting(); + } + m_current_payload = m_accumulate(m_current_payload, cursor); + cursor.advance(); + } + if (auto docid = cursor.value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; + sorted_position -= 1) { + if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { + exit = false; + break; + } + auto& cursor = m_cursors[sorted_position]; + cursor.advance_to_geq(m_current_value); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (cursor.value() == m_current_value) { + m_current_payload = m_accumulate(m_current_payload, cursor); + } + } + } + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorContainer m_cursors; + std::vector m_upper_bounds; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + std::size_t m_non_essential_count = 0; + payload_type m_previous_threshold{}; + + Inspect* m_inspect; +}; + +template +auto join_maxscore(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return MaxScoreJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto join_maxscore(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect) +{ + return MaxScoreJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); +} + +template +auto maxscore(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using value_type = decltype(index.max_scored_cursor(0, scorer).value()); + + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); + } + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectMaxScore : Inspect { + + InspectMaxScore(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + maxscore(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore_union_lookup.hpp b/include/pisa/v1/maxscore_union_lookup.hpp new file mode 100644 index 000000000..a4539ba99 --- /dev/null +++ b/include/pisa/v1/maxscore_union_lookup.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/reference.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/union_lookup_join.hpp" + +namespace pisa::v1 { + +/// This is a special case of Union-Lookup algorithm that does not use user-defined selections, +/// but rather uses the same way of determining essential list as Maxscore does. +/// The difference is that this algorithm will never update the threshold whereas Maxscore will +/// try to improve the estimate after each accumulated document. +template +auto maxscore_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + [[maybe_unused]] Inspect* inspect = nullptr) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); + + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + auto [non_essential, essential] = maxscore_partition(gsl::make_span(cursors), threshold); + + std::vector essential_cursors; + std::move(essential.begin(), essential.end(), std::back_inserter(essential_cursors)); + std::vector lookup_cursors; + std::move(non_essential.begin(), non_essential.end(), std::back_inserter(lookup_cursors)); + std::reverse(lookup_cursors.begin(), lookup_cursors.end()); + + auto joined = join_union_lookup( + std::move(essential_cursors), + std::move(lookup_cursors), + payload_type{}, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectMaxScoreUnionLookup : Inspect { + + InspectMaxScoreUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/posting_builder.hpp b/include/pisa/v1/posting_builder.hpp new file mode 100644 index 000000000..a381aa568 --- /dev/null +++ b/include/pisa/v1/posting_builder.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include +#include + +#include + +#include "v1/cursor_traits.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +/// Builds a "posting file" from passed values. +/// +/// TODO: Probably the offsets should be part of the file along with the size. +template +struct PostingBuilder { + template + explicit PostingBuilder(WriterImpl writer) : m_writer(Writer(std::move(writer))) + { + m_offsets.push_back(0); + } + + template + void write_header(std::basic_ostream& os) const + { + std::array header{}; + PostingFormatHeader{.version = FormatVersion::current(), + .type = value_type(), + .encoding = m_writer.encoding()} + .write(gsl::make_span(header)); + os.write(reinterpret_cast(header.data()), header.size()); + } + + template + auto write_segment(std::basic_ostream& os, ValueIterator first, ValueIterator last) + -> std::basic_ostream& + { + std::for_each(first, last, [&](auto&& value) { m_writer.push(value); }); + return flush_segment(os); + } + + void accumulate(Value value) { m_writer.push(value); } + + template + auto flush_segment(std::basic_ostream& os) -> std::basic_ostream& + { + m_offsets.push_back(m_offsets.back() + m_writer.write(os)); + m_writer.reset(); + return os; + } + + [[nodiscard]] auto offsets() const -> gsl::span + { + return gsl::make_span(m_offsets); + } + + [[nodiscard]] auto offsets() -> std::vector&& { return std::move(m_offsets); } + + private: + Writer m_writer; + std::vector m_offsets{}; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/posting_format_header.hpp b/include/pisa/v1/posting_format_header.hpp new file mode 100644 index 000000000..64f5678aa --- /dev/null +++ b/include/pisa/v1/posting_format_header.hpp @@ -0,0 +1,225 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "v1/bit_cast.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +auto write_little_endian(Int number, gsl::span bytes) +{ + static_assert(std::is_integral_v); + Expects(bytes.size() == sizeof(Int)); + + Int mask{0xFF}; + for (unsigned int byte_num = 0; byte_num < sizeof(Int); byte_num += 1) { + auto byte_value = static_cast((number & mask) >> (8U * byte_num)); + bytes[byte_num] = byte_value; + mask <<= 8U; + } +} + +struct FormatVersion { + std::uint8_t major = 0; + std::uint8_t minor = 0; + std::uint8_t patch = 0; + + constexpr static auto parse(gsl::span bytes) -> FormatVersion + { + Expects(bytes.size() == 3); + return FormatVersion{ + bit_cast(bytes.first(1)), + bit_cast(bytes.subspan(1, 1)), + bit_cast(bytes.subspan(2, 1)), + }; + }; + + constexpr auto write(gsl::span bytes) -> void + { + bytes[0] = std::byte{major}; + bytes[1] = std::byte{minor}; + bytes[2] = std::byte{patch}; + }; + + constexpr static auto current() -> FormatVersion { return FormatVersion{0, 1, 0}; }; +}; + +enum class Primitive { Int = 0, Float = 1 }; + +struct Array { + Primitive type; +}; + +struct Tuple { + Primitive type; + std::uint8_t size; +}; + +[[nodiscard]] inline auto operator==(Tuple const &lhs, Tuple const &rhs) +{ + return lhs.type == rhs.type && lhs.size == rhs.size; +} + +using ValueType = std::variant; + +template +struct is_array : std::false_type { +}; + +template +struct is_array> : std::true_type { +}; + +template +struct array_length : public std::integral_constant { +}; + +template +struct array_length> : public std::integral_constant { +}; + +template +constexpr static auto value_type() -> ValueType +{ + if constexpr (std::is_integral_v) { + return Primitive::Int; + } else if constexpr (std::is_floating_point_v) { + return Primitive::Float; + } else if constexpr (is_array::value) { + auto len = array_length::value; + if constexpr (std::is_integral_v) { + return Tuple{Primitive::Int, len}; + } else if constexpr (std::is_floating_point_v) { + return Tuple{Primitive::Float, len}; + } else { + throw std::domain_error("Unsupported type"); + } + } else { + // TODO(michal): array + throw std::domain_error("Unsupported type"); + } +} + +template +constexpr static auto is_type(ValueType type) +{ + if constexpr (std::is_integral_v) { + return std::holds_alternative(type) + && std::get(type) == Primitive::Int; + } else if constexpr (std::is_floating_point_v) { + return std::holds_alternative(type) + && std::get(type) == Primitive::Float; + } else if constexpr (is_array::value) { + auto len = array_length::value; + if constexpr (std::is_integral_v) { + return std::holds_alternative(type) + && std::get(type) == Tuple{Primitive::Int, len}; + } else if constexpr (std::is_floating_point_v) { + return std::holds_alternative(type) + && std::get(type) == Tuple{Primitive::Float, len}; + } else { + throw std::domain_error("Unsupported type"); + } + } else { + // TODO(michal): array + throw std::domain_error("Unsupported type"); + } +} + +constexpr auto parse_type(std::byte const byte) -> ValueType +{ + auto element_type = [byte]() { + switch (std::to_integer((byte & std::byte{0b00000100}) >> 2)) { + case 0U: + return Primitive::Int; + case 1U: + return Primitive::Float; + } + Unreachable(); + }; + switch (std::to_integer(byte & std::byte{0b00000011})) { + case 0U: + return Primitive::Int; + case 1U: + return Primitive::Float; + case 2U: + return Array{element_type()}; + case 3U: + return Tuple{element_type(), + std::to_integer((byte & std::byte{0b11111000}) >> 3)}; + } + Unreachable(); +}; + +constexpr auto to_byte(ValueType type) -> std::byte +{ + std::byte byte{}; + std::visit(overloaded{[&byte](Primitive primitive) { + switch (primitive) { + case Primitive::Int: + byte = std::byte{0b00000000}; + break; + case Primitive::Float: + byte = std::byte{0b00000001}; + break; + } + }, + [&byte](Array arr) { + switch (arr.type) { + case Primitive::Int: + byte = std::byte{0b00000010}; + break; + case Primitive::Float: + byte = std::byte{0b00000110}; + break; + } + }, + [&byte](Tuple tup) { + switch (tup.type) { + case Primitive::Int: + byte = std::byte{0b00000011}; + break; + case Primitive::Float: + byte = std::byte{0b00000111}; + break; + } + byte |= (std::byte{tup.size} << 3); + }}, + type); + return byte; +}; + +using Encoding = std::uint32_t; + +struct PostingFormatHeader { + FormatVersion version; + ValueType type; + Encoding encoding; + + static auto parse(gsl::span bytes) -> PostingFormatHeader + { + Expects(bytes.size() == 8); + auto version = FormatVersion::parse(bytes.first(3)); + auto type = parse_type(bytes[3]); + auto encoding = bit_cast(bytes.subspan(4)); + return {version, type, encoding}; + }; + + void write(gsl::span bytes) + { + Expects(bytes.size() == 8); + FormatVersion::current().write(bytes.first(3)); + bytes[3] = to_byte(type); + write_little_endian(encoding, bytes.subspan(4)); + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/progress_status.hpp b/include/pisa/v1/progress_status.hpp new file mode 100644 index 000000000..84e8ca477 --- /dev/null +++ b/include/pisa/v1/progress_status.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "runtime_assert.hpp" +#include "type_safe.hpp" + +namespace pisa::v1 { + +/// Represents progress of a certain operation. +struct Progress { + /// A number between 0 and `target` to indicate the current progress. + std::size_t count; + /// The target value `count` reaches at completion. + std::size_t target; +}; + +/// An alias of the type of callback function used in a progress status. +using CallbackFunction = + std::function)>; + +/// This thread-safe object is responsible for keeping the current progress of an operation. +/// At a defined interval, it invokes a callback function with the current progress, and +/// the starting time of the operation (see `CallbackFunction` type alias). +/// +/// In order to ensure that terminal updates are not iterfered with, there should be no +/// writing to stdin or stderr outside of the callback function between `ProgressStatus` +/// construction and either its destruction (e.g., by going out of scope) or closing it +/// explicitly by calling `close()` member function. +struct ProgressStatus { + + /// Constructs a new progress status. + /// + /// \tparam Callback A callable that conforms to `CallbackFunction` signature. + /// \tparam Duration A duration type that will be called to pause the status thread. + /// + /// \param target Usually a number of processed elements that will be incremented + /// throughout an operation. + /// \param callback A function that prints out the current status. See + /// `DefaultProgressCallback`. + /// \param interval A time interval between printing progress. + template + explicit ProgressStatus(std::size_t target, Callback&& callback, Duration interval) + : m_target(target), m_callback(std::forward(callback)) + { + m_loop = std::thread([this, interval]() { + this->m_callback(Progress{this->m_count.load(), this->m_target}, this->m_start); + while (this->m_count.load() < this->m_target) { + std::this_thread::sleep_for(interval); + this->m_callback(Progress{this->m_count.load(), this->m_target}, this->m_start); + } + }); + } + ProgressStatus(ProgressStatus const&) = delete; + ProgressStatus(ProgressStatus&&) = delete; + ProgressStatus& operator=(ProgressStatus const&) = delete; + ProgressStatus& operator=(ProgressStatus&&) = delete; + ~ProgressStatus(); + + /// Increments the counter by `inc`. + void operator+=(std::size_t inc) { m_count += inc; } + /// Increments the counter by 1. + void operator++() { m_count += 1; } + /// Increments the counter by 1. + void operator++(int) { m_count += 1; } + /// Sets the progress to 100% and terminates the progress thread. + /// This function should be only called if it is known that the operation is finished. + /// An example is a situation when it is one wants to print a message after finishing a task + /// but before the progress status goes out of the current scope. + /// However, otherwise it is better to just let it go out of scope, which will clean up + /// automatically to enable further writing to the standard output. + void close(); + + private: + std::size_t const m_target; + CallbackFunction m_callback; + std::atomic_size_t m_count = 0; + std::chrono::time_point m_start = std::chrono::steady_clock::now(); + std::thread m_loop; + bool m_open = true; +}; + +/// This is the default callback that prints status in the following format: +/// `Building bigram index: 1% [1m 29s] [<1h 46m 49s]` or +/// `% [] [<]`. +struct DefaultProgressCallback { + DefaultProgressCallback() = default; + explicit DefaultProgressCallback(std::string caption); + DefaultProgressCallback(DefaultProgressCallback const&) = default; + DefaultProgressCallback(DefaultProgressCallback&&) noexcept = default; + DefaultProgressCallback& operator=(DefaultProgressCallback const&) = default; + DefaultProgressCallback& operator=(DefaultProgressCallback&&) noexcept = default; + ~DefaultProgressCallback() = default; + void operator()(Progress progress, std::chrono::time_point start); + + private: + std::size_t m_previous = 0; + std::string m_caption; + std::size_t m_prev_msg_len = 0; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp new file mode 100644 index 000000000..c1db38fd7 --- /dev/null +++ b/include/pisa/v1/query.hpp @@ -0,0 +1,192 @@ +#pragma once + +#include +#include +//#include +#include + +#include +#include +#include +#include + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/cursor_union.hpp" +#include "v1/intersection.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct ListSelection { + std::vector unigrams{}; + std::vector> bigrams{}; +}; + +struct TermIdSet { + explicit TermIdSet(std::vector terms) : m_term_list(std::move(terms)) + { + m_term_set = m_term_list; + ranges::sort(m_term_set); + ranges::actions::unique(m_term_set); + std::size_t pos = 0; + for (auto term_id : m_term_set) { + m_sorted_positions[term_id] = pos++; + } + } + + [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t + { + return m_sorted_positions.at(term); + } + + [[nodiscard]] auto term_at_pos(std::size_t pos) const -> TermId + { + if (pos >= m_term_list.size()) { + throw std::out_of_range("Invalid intersections: term position out of bounds"); + } + return m_term_list[pos]; + } + + [[nodiscard]] auto get() const -> std::vector const& { return m_term_set; } + + private: + friend std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids); + std::vector m_term_list; + std::vector m_term_set{}; + std::unordered_map m_sorted_positions{}; +}; + +struct Query { + Query() = default; + + explicit Query(std::string query, tl::optional id = tl::nullopt); + explicit Query(std::vector term_ids, tl::optional id = tl::nullopt); + + template + static auto from_ids(Ids... ids) -> Query + { + std::vector id_vec; + (id_vec.push_back(ids), ...); + return Query(std::move(id_vec)); + } + + /// Setters for optional values (or ones with default value). + auto term_ids(std::vector term_ids) -> Query&; + auto id(std::string) -> Query&; + auto k(int k) -> Query&; + auto selections(gsl::span const> selections) -> Query&; + auto selections(ListSelection selections) -> Query&; + auto threshold(float threshold) -> Query&; + auto probability(float probability) -> Query&; + + /// Consuming setters. + auto with_term_ids(std::vector term_ids) && -> Query; + auto with_id(std::string) && -> Query; + auto with_k(int k) && -> Query; + auto with_selections(gsl::span const> selections) && -> Query; + auto with_selections(ListSelection selections) && -> Query; + auto with_threshold(float threshold) && -> Query; + auto with_probability(float probability) && -> Query; + + /// Non-throwing getters + [[nodiscard]] auto term_ids() const -> tl::optional const&>; + [[nodiscard]] auto id() const -> tl::optional const&; + [[nodiscard]] auto k() const -> int; + [[nodiscard]] auto selections() const -> tl::optional; + [[nodiscard]] auto threshold() const -> tl::optional; + [[nodiscard]] auto probability() const -> tl::optional; + [[nodiscard]] auto raw() const -> tl::optional; + + /// Throwing getters + [[nodiscard]] auto get_term_ids() const -> std::vector const&; + [[nodiscard]] auto get_id() const -> std::string const&; + [[nodiscard]] auto get_selections() const -> ListSelection const&; + [[nodiscard]] auto get_threshold() const -> float; + [[nodiscard]] auto get_probability() const -> float; + [[nodiscard]] auto get_raw() const -> std::string const&; + + [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t; + [[nodiscard]] auto term_at_pos(std::size_t pos) const -> TermId; + + template + [[nodiscard]] auto parse(Parser&& parser) + { + parser(*this); + } + + void add_selections(gsl::span const> selections); + + [[nodiscard]] auto filtered_terms(std::bitset<64> selection) const -> std::vector; + [[nodiscard]] auto to_json() const -> std::unique_ptr; + [[nodiscard]] static auto from_json(std::string_view) -> Query; + [[nodiscard]] static auto from_plain(std::string_view) -> Query; + + private: + friend std::ostream& operator<<(std::ostream& os, Query const& query); + auto resolve_term(std::size_t pos) -> TermId; + + tl::optional m_term_ids{}; + tl::optional m_selections{}; + tl::optional m_threshold{}; + tl::optional m_id{}; + tl::optional m_raw_string; + tl::optional m_probability; + int m_k = 1000; +}; + +template +std::ostream& operator<<(std::ostream& os, tl::optional const& value) +{ + if (not value) { + os << "None"; + } else { + os << "Some(" << *value << ")"; + } + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::pair const& p) +{ + return os << '(' << p.first << ", " << p.second << ')'; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector const& vec) +{ + auto pos = vec.begin(); + os << '['; + if (pos != vec.end()) { + os << *pos++; + } + for (; pos != vec.end(); ++pos) { + os << ' ' << *pos; + } + os << ']'; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, ListSelection const& selection) +{ + return os << "ListSelection { unigrams: " << selection.unigrams + << ", bigrams: " << selection.bigrams << " }"; +} + +inline std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids) +{ + return os << "TermIdSet { original: " << term_ids.m_term_list + << ", unique: " << term_ids.m_term_set << " }"; +} + +inline std::ostream& operator<<(std::ostream& os, Query const& query) +{ + return os << "Query { term_ids: " << query.m_term_ids << ", selections: " << query.m_selections + << " }"; +} + +/// Returns only unique terms, in sorted order. +[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp new file mode 100644 index 000000000..0d7fefa39 --- /dev/null +++ b/include/pisa/v1/raw_cursor.hpp @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "util/likely.hpp" +#include "v1/base_index.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +[[nodiscard]] auto next(Cursor&& cursor) -> tl::optional::value_type> +{ + cursor.advance(); + if (cursor.empty()) { + return tl::nullopt; + } + return tl::make_optional(cursor.value()); +} + +template +inline void contract(bool condition, std::string const& message, Args&&... args) +{ + if (not condition) { + throw std::logic_error(fmt::format(message, std::forward(args)...)); + } +} + +/// Uncompressed example of implementation of a single value cursor. +template +struct RawCursor { + static_assert(std::is_trivially_copyable_v); + using value_type = T; + + /// Creates a cursor from the encoded bytes. + explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) + { + contract(m_bytes.size() % sizeof(T) == 0, + "Raw cursor memory size must be multiplier of element size ({}) but is {}", + sizeof(T), + m_bytes.size()); + contract(not m_bytes.empty(), "Raw cursor memory must not be empty"); + } + constexpr RawCursor(RawCursor const&) = default; + constexpr RawCursor(RawCursor&&) noexcept = default; + constexpr RawCursor& operator=(RawCursor const&) = default; + constexpr RawCursor& operator=(RawCursor&&) noexcept = default; + ~RawCursor() = default; + + /// Dereferences the current value. + [[nodiscard]] constexpr auto operator*() const -> T + { + if (PISA_UNLIKELY(empty())) { + return sentinel(); + } + return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + } + + /// Alias for `operator*()`. + [[nodiscard]] constexpr auto value() const noexcept -> T { return *(*this); } + + /// Advances the cursor to the next position. + constexpr void advance() { m_current += sizeof(T); } + + /// Moves the cursor to the position `pos`. + constexpr void advance_to_position(std::size_t pos) { m_current = pos * sizeof(T); } + + /// Moves the cursor to the next value equal or greater than `value`. + constexpr void advance_to_geq(T value) + { + while (this->value() < value) { + advance(); + } + } + + /// Returns `true` if there is no elements left. + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current == m_bytes.size(); + } + + /// Returns the current position. + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_current / sizeof(T); + } + + /// Returns the number of elements in the list. + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } + + /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. + [[nodiscard]] constexpr auto sentinel() const -> T { return std::numeric_limits::max(); } + + private: + std::size_t m_current = 0; + gsl::span m_bytes; +}; + +template +struct RawReader { + static_assert(std::is_trivially_copyable::value); + using value_type = T; + + [[nodiscard]] auto read(gsl::span bytes) const -> RawCursor + { + return RawCursor(bytes); + } + + void init(BaseIndex const& index) {} + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } +}; + +template +struct RawWriter { + static_assert(std::is_trivially_copyable::value); + using value_type = T; + + RawWriter() = default; + explicit RawWriter([[maybe_unused]] std::size_t num_documents) {} + + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } + + void init([[maybe_unused]] pisa::binary_freq_collection const& collection) {} + void push(T const& posting) { m_postings.push_back(posting); } + void push(T&& posting) { m_postings.push_back(posting); } + + template + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t + { + assert(!m_postings.empty()); + std::uint32_t length = m_postings.size(); + os.write(reinterpret_cast(&length), sizeof(length)); + auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); + os.write(reinterpret_cast(memory.data()), memory.size()); + return sizeof(length) + memory.size(); + } + + void reset() { m_postings.clear(); } + + private: + std::vector m_postings{}; +}; + +template +struct CursorTraits> { + using Value = T; + using Writer = RawWriter; + using Reader = RawReader; + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } +}; + +extern template struct RawCursor; +extern template struct RawCursor; +extern template struct RawCursor; +extern template struct RawReader; +extern template struct RawReader; +extern template struct RawReader; +extern template struct RawWriter; +extern template struct RawWriter; +extern template struct RawWriter; +extern template struct CursorTraits>; +extern template struct CursorTraits>; +extern template struct CursorTraits>; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/runtime_assert.hpp b/include/pisa/v1/runtime_assert.hpp new file mode 100644 index 000000000..736342334 --- /dev/null +++ b/include/pisa/v1/runtime_assert.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include + +#include +#include + +namespace pisa::v1 { + +struct RuntimeAssert { + explicit RuntimeAssert(bool condition) : passes_(condition) {} + + template + void or_exit(Message&& message) + { + if (not passes_) { + if constexpr (std::is_invocable_r_v) { + spdlog::error(message()); + } else { + spdlog::error("{}", message); + } + } + } + + template + void or_throw(Message&& message) + { + if (not passes_) { + if constexpr (std::is_invocable_r_v) { + throw std::runtime_error(message()); + } else { + throw std::runtime_error(std::forward(message)); + } + } + } + + private: + bool passes_; +}; + +[[nodiscard]] inline auto runtime_assert(bool condition) -> RuntimeAssert +{ + return RuntimeAssert(condition); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp new file mode 100644 index 000000000..e19aa2b85 --- /dev/null +++ b/include/pisa/v1/score_index.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include + +#include + +#include "v1/index_metadata.hpp" + +namespace pisa::v1 { + +struct FixedBlock { + std::size_t size; +}; + +struct VariableBlock { + float lambda; +}; + +using BlockType = std::variant; + +template +auto score_index(Index const& index, + std::basic_ostream& os, + Writer writer, + Scorer scorer, + Quantizer&& quantizer, + Callback&& callback) -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), [&](auto&& cursor) { + score_builder.accumulate(quantizer(cursor.payload())); + }); + score_builder.flush_segment(os); + callback(); + }); + return std::move(score_builder.offsets()); +} + +template +auto score_index(Index const& index, std::basic_ostream& os, Writer writer, Scorer scorer) + -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), + [&](auto& cursor) { score_builder.accumulate(cursor.payload()); }); + score_builder.flush_segment(os); + }); + return std::move(score_builder.offsets()); +} + +auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata; +auto bm_score_index(IndexMetadata meta, + BlockType block_type, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/scorer/bm25.hpp b/include/pisa/v1/scorer/bm25.hpp new file mode 100644 index 000000000..31bfe7921 --- /dev/null +++ b/include/pisa/v1/scorer/bm25.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct BM25 { + static constexpr float b = 0.4; + static constexpr float k1 = 0.9; + + explicit BM25(Index const& index) : m_index(index) {} + + struct TermScorer { + TermScorer(Index const& index, float term_weight) + : m_index(&index), m_term_weight(term_weight) + { + } + + auto operator()(std::uint32_t docid, std::uint32_t frequency) const + { + return m_term_weight + * doc_term_weight(frequency, m_index->normalized_document_length(docid)); + } + + private: + Index const* m_index; + float m_term_weight; + }; + + [[nodiscard]] static float doc_term_weight(uint64_t freq, float norm_len) + { + auto f = static_cast(freq); + return f / (f + k1 * (1.0F - b + b * norm_len)); + } + + [[nodiscard]] static float query_term_weight(uint64_t df, uint64_t num_docs) + { + auto fdf = static_cast(df); + float idf = std::log((float(num_docs) - fdf + 0.5F) / (fdf + 0.5F)); + static const float epsilon_score = 1.0E-6; + return std::max(epsilon_score, idf) * (1.0F + k1); + } + + [[nodiscard]] auto term_scorer(TermId term_id) const + { + auto term_weight = + query_term_weight(m_index.term_posting_count(term_id), m_index.num_documents()); + return TermScorer(m_index, term_weight); + } + + private: + Index const& m_index; +}; + +template +auto make_bm25(Index const& index) +{ + return BM25(index); +} + +} // namespace pisa::v1 + +namespace std { +template +struct hash<::pisa::v1::BM25> { + std::size_t operator()(::pisa::v1::BM25 const& /* bm25 */) const noexcept + { + return std::hash{}("bm25"); + } +}; +} // namespace std diff --git a/include/pisa/v1/scorer/runner.hpp b/include/pisa/v1/scorer/runner.hpp new file mode 100644 index 000000000..c042ec001 --- /dev/null +++ b/include/pisa/v1/scorer/runner.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +#include + +namespace pisa::v1 { + +/// This is similar to the `IndexRunner` and it's used for running tasks +/// that require on-the-fly scoring. +template +struct ScorerRunner { + explicit ScorerRunner(Index const& index, Scorers... scorers) + : m_index(index), m_scorers(std::move(scorers...)) + { + } + + template + void operator()(std::string_view scorer_name, Fn fn) + { + auto run = [&](auto scorer) -> bool { + if (std::hash>{}(scorer) + == std::hash{}(scorer_name)) { + fn(std::move(scorer)); + return true; + } + return false; + }; + bool success = std::apply( + [&](Scorers... scorers) { return (run(std::move(scorers)) || ...); }, m_scorers); + if (not success) { + throw std::domain_error(fmt::format("Unknown scorer: {}", scorer_name)); + } + } + + private: + Index const& m_index; + std::tuple m_scorers; +}; + +template +auto scorer_runner(Index const& index, Scorers... scorers) +{ + return ScorerRunner(index, std::move(scorers...)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/indexed_sequence.hpp b/include/pisa/v1/sequence/indexed_sequence.hpp new file mode 100644 index 000000000..1f44fe7dc --- /dev/null +++ b/include/pisa/v1/sequence/indexed_sequence.hpp @@ -0,0 +1,939 @@ +#pragma once + +#include + +#include + +#include "bit_vector.hpp" +#include "global_parameters.hpp" +#include "util/compiler_attribute.hpp" +#include "util/likely.hpp" +#include "v1/bit_vector.hpp" + +namespace pisa::v1 { + +[[nodiscard]] constexpr auto positive(std::uint64_t n) -> std::uint64_t +{ + if (n == 0) { + throw std::logic_error("argument must be positive"); + } + return n; +} + +struct CompactEliasFano { + + struct offsets { + offsets() {} + + offsets(uint64_t base_offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : universe(universe), + n(positive(n)), + log_sampling0(params.ef_log_sampling0), + log_sampling1(params.ef_log_sampling1), + lower_bits(universe > n ? broadword::msb(universe / n) : 0), + mask((uint64_t(1) << lower_bits) - 1), + // pad with a zero on both sides as sentinels + higher_bits_length(n + (universe >> lower_bits) + 2), + pointer_size(ceil_log2(higher_bits_length)), + pointers0((higher_bits_length - n) >> log_sampling0), // XXX + pointers1(n >> log_sampling1), + pointers0_offset(base_offset), + pointers1_offset(pointers0_offset + pointers0 * pointer_size), + higher_bits_offset(pointers1_offset + pointers1 * pointer_size), + lower_bits_offset(higher_bits_offset + higher_bits_length), + end(lower_bits_offset + n * lower_bits) + { + } + + uint64_t universe; + uint64_t n; + uint64_t log_sampling0; + uint64_t log_sampling1; + + uint64_t lower_bits; + uint64_t mask; + uint64_t higher_bits_length; + uint64_t pointer_size; + uint64_t pointers0; + uint64_t pointers1; + + uint64_t pointers0_offset; + uint64_t pointers1_offset; + uint64_t higher_bits_offset; + uint64_t lower_bits_offset; + uint64_t end; + }; + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + return offsets(0, universe, n, params).end; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t base_offset = bvb.size(); + offsets of(base_offset, universe, n, params); + // initialize all the bits to 0 + bvb.zero_extend(of.end - base_offset); + + uint64_t sample1_mask = (uint64_t(1) << of.log_sampling1) - 1; + uint64_t offset; + + // utility function to set 0 pointers + auto set_ptr0s = [&](uint64_t begin, uint64_t end, uint64_t rank_end) { + uint64_t begin_zeros = begin - rank_end; + uint64_t end_zeros = end - rank_end; + + for (uint64_t ptr0 = ceil_div(begin_zeros, uint64_t(1) << of.log_sampling0); + (ptr0 << of.log_sampling0) < end_zeros; + ++ptr0) { + if (!ptr0) + continue; + offset = of.pointers0_offset + (ptr0 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.pointers1_offset); + bvb.set_bits(offset, (ptr0 << of.log_sampling0) + rank_end, of.pointer_size); + } + }; + + uint64_t last = 0; + uint64_t last_high = 0; + Iterator it = begin; + for (size_t i = 0; i < n; ++i) { + uint64_t v = *it++; + + if (i && v < last) { + throw std::runtime_error("Sequence is not sorted"); + } + + assert(v < universe); + + uint64_t high = (v >> of.lower_bits) + i + 1; + uint64_t low = v & of.mask; + + bvb.set(of.higher_bits_offset + high, 1); + + offset = of.lower_bits_offset + i * of.lower_bits; + assert(offset + of.lower_bits <= of.end); + bvb.set_bits(offset, low, of.lower_bits); + + if (i && (i & sample1_mask) == 0) { + uint64_t ptr1 = i >> of.log_sampling1; + assert(ptr1 > 0); + offset = of.pointers1_offset + (ptr1 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.higher_bits_offset); + bvb.set_bits(offset, high, of.pointer_size); + } + + // write pointers for the run of zeros in [last_high, high) + set_ptr0s(last_high + 1, high, i); + last_high = high; + last = v; + } + + // pointers to zeros after the last 1 + set_ptr0s(last_high + 1, of.higher_bits_length, n); // XXX + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_bv(&bv), + m_of(offset, universe, n, params), + m_position(size()), + m_value(m_of.universe) + { + } + + value_type move(uint64_t position) + { + assert(position <= m_of.n); + + if (position == m_position) { + return value(); + } + + uint64_t skip = position - m_position; + // optimize small forward skips + if (PISA_LIKELY(position > m_position && skip <= linear_scan_threshold)) { + m_position = position; + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + } else { + BitVector::unary_enumerator he = m_high_enumerator; + for (size_t i = 0; i < skip; ++i) { + he.next(); + } + m_value = ((he.position() - m_of.higher_bits_offset - m_position - 1) + << m_of.lower_bits) + | read_low(); + m_high_enumerator = he; + } + return value(); + } + + return slow_move(position); + } + + value_type next_geq(uint64_t lower_bound) + { + if (lower_bound == m_value) { + return value(); + } + + uint64_t high_lower_bound = lower_bound >> m_of.lower_bits; + uint64_t cur_high = m_value >> m_of.lower_bits; + uint64_t high_diff = high_lower_bound - cur_high; + + if (PISA_LIKELY(lower_bound > m_value && high_diff <= linear_scan_threshold)) { + // optimize small skips + next_reader next_value(*this, m_position + 1); + uint64_t val; + do { + m_position += 1; + if (PISA_LIKELY(m_position < size())) { + val = next_value(); + } else { + m_position = size(); + val = m_of.universe; + break; + } + } while (val < lower_bound); + + m_value = val; + return value(); + } else { + return slow_next_geq(lower_bound); + } + } + + uint64_t size() const { return m_of.n; } + + value_type next() + { + m_position += 1; + assert(m_position <= size()); + + if (PISA_LIKELY(m_position < size())) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + return value(); + } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + + uint64_t prev_high = 0; + if (PISA_LIKELY(m_position < size())) { + prev_high = m_bv->predecessor1(m_high_enumerator.position() - 1); + } else { + prev_high = m_bv->predecessor1(m_of.lower_bits_offset - 1); + } + prev_high -= m_of.higher_bits_offset; + + uint64_t prev_pos = m_position - 1; + uint64_t prev_low = + m_bv->get_word56(m_of.lower_bits_offset + prev_pos * m_of.lower_bits) & m_of.mask; + return ((prev_high - prev_pos - 1) << m_of.lower_bits) | prev_low; + } + + uint64_t position() const { return m_position; } + + inline value_type value() const { return value_type(m_position, m_value); } + + private: + value_type PISA_NOINLINE slow_move(uint64_t position) + { + if (PISA_UNLIKELY(position == size())) { + m_position = position; + m_value = m_of.universe; + return value(); + } + + uint64_t skip = position - m_position; + uint64_t to_skip; + if (position > m_position && (skip >> m_of.log_sampling1) == 0) { + to_skip = skip - 1; + } else { + uint64_t ptr = position >> m_of.log_sampling1; + uint64_t high_pos = pointer1(ptr); + uint64_t high_rank = ptr << m_of.log_sampling1; + m_high_enumerator = + BitVector::unary_enumerator(*m_bv, m_of.higher_bits_offset + high_pos); + to_skip = position - high_rank; + } + + m_high_enumerator.skip(to_skip); + m_position = position; + m_value = read_next(); + return value(); + } + + value_type PISA_NOINLINE slow_next_geq(uint64_t lower_bound) + { + if (PISA_UNLIKELY(lower_bound >= m_of.universe)) { + return move(size()); + } + + uint64_t high_lower_bound = lower_bound >> m_of.lower_bits; + uint64_t cur_high = m_value >> m_of.lower_bits; + uint64_t high_diff = high_lower_bound - cur_high; + + // XXX bounds checking! + uint64_t to_skip; + if (lower_bound > m_value && (high_diff >> m_of.log_sampling0) == 0) { + // note: at the current position in the bitvector there + // should be a 1, but since we already consumed it, it + // is 0 in the enumerator, so we need to skip it + to_skip = high_diff; + } else { + uint64_t ptr = high_lower_bound >> m_of.log_sampling0; + uint64_t high_pos = pointer0(ptr); + uint64_t high_rank0 = ptr << m_of.log_sampling0; + + m_high_enumerator = + BitVector::unary_enumerator(*m_bv, m_of.higher_bits_offset + high_pos); + to_skip = high_lower_bound - high_rank0; + } + + m_high_enumerator.skip0(to_skip); + m_position = m_high_enumerator.position() - m_of.higher_bits_offset - high_lower_bound; + + next_reader read_value(*this, m_position); + while (true) { + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + return value(); + } + auto val = read_value(); + if (val >= lower_bound) { + m_value = val; + return value(); + } + m_position++; + } + } + + static const uint64_t linear_scan_threshold = 8; + + inline uint64_t read_low() + { + return m_bv->get_word56(m_of.lower_bits_offset + m_position * m_of.lower_bits) + & m_of.mask; + } + + inline uint64_t read_next() + { + assert(m_position < size()); + uint64_t high = m_high_enumerator.next() - m_of.higher_bits_offset; + return ((high - m_position - 1) << m_of.lower_bits) | read_low(); + } + + struct next_reader { + next_reader(enumerator& e, uint64_t position) + : e(e), + high_enumerator(e.m_high_enumerator), + high_base(e.m_of.higher_bits_offset + position + 1), + lower_bits(e.m_of.lower_bits), + lower_base(e.m_of.lower_bits_offset + position * lower_bits), + mask(e.m_of.mask), + bv(*e.m_bv) + { + } + + ~next_reader() { e.m_high_enumerator = high_enumerator; } + + uint64_t operator()() + { + uint64_t high = high_enumerator.next() - high_base; + uint64_t low = bv.get_word56(lower_base) & mask; + high_base += 1; + lower_base += lower_bits; + return (high << lower_bits) | low; + } + + enumerator& e; + BitVector::unary_enumerator high_enumerator; + uint64_t high_base, lower_bits, lower_base, mask; + BitVector const& bv; + }; + + inline uint64_t pointer(uint64_t offset, uint64_t i) const + { + if (i == 0) { + return 0; + } else { + return m_bv->get_word56(offset + (i - 1) * m_of.pointer_size) + & ((uint64_t(1) << m_of.pointer_size) - 1); + } + } + + inline uint64_t pointer0(uint64_t i) const { return pointer(m_of.pointers0_offset, i); } + + inline uint64_t pointer1(uint64_t i) const { return pointer(m_of.pointers1_offset, i); } + + BitVector const* m_bv; + offsets m_of; + + uint64_t m_position; + uint64_t m_value; + BitVector::unary_enumerator m_high_enumerator; + }; +}; + +struct CompactRankedBitvector { + + struct offsets { + offsets(uint64_t base_offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : universe(universe), + n(n), + log_rank1_sampling(params.rb_log_rank1_sampling), + log_sampling1(params.rb_log_sampling1) + + , + rank1_sample_size(ceil_log2(n + 1)), + pointer_size(ceil_log2(universe)), + rank1_samples(universe >> params.rb_log_rank1_sampling), + pointers1(n >> params.rb_log_sampling1) + + , + rank1_samples_offset(base_offset), + pointers1_offset(rank1_samples_offset + rank1_samples * rank1_sample_size), + bits_offset(pointers1_offset + pointers1 * pointer_size), + end(bits_offset + universe) + { + } + + uint64_t universe; + uint64_t n; + uint64_t log_rank1_sampling; + uint64_t log_sampling1; + + uint64_t rank1_sample_size; + uint64_t pointer_size; + + uint64_t rank1_samples; + uint64_t pointers1; + + uint64_t rank1_samples_offset; + uint64_t pointers1_offset; + uint64_t bits_offset; + uint64_t end; + }; + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + return offsets(0, universe, n, params).end; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t base_offset = bvb.size(); + offsets of(base_offset, universe, n, params); + // initialize all the bits to 0 + bvb.zero_extend(of.end - base_offset); + + uint64_t offset; + + auto set_rank1_samples = [&](uint64_t begin, uint64_t end, uint64_t rank) { + for (uint64_t sample = ceil_div(begin, uint64_t(1) << of.log_rank1_sampling); + (sample << of.log_rank1_sampling) < end; + ++sample) { + if (!sample) + continue; + offset = of.rank1_samples_offset + (sample - 1) * of.rank1_sample_size; + assert(offset + of.rank1_sample_size <= of.pointers1_offset); + bvb.set_bits(offset, rank, of.rank1_sample_size); + } + }; + + uint64_t sample1_mask = (uint64_t(1) << of.log_sampling1) - 1; + uint64_t last = 0; + Iterator it = begin; + for (size_t i = 0; i < n; ++i) { + uint64_t v = *it++; + if (i && v == last) { + throw std::runtime_error("Duplicate element"); + } + if (i && v < last) { + throw std::runtime_error("Sequence is not sorted"); + } + + assert(!i || v > last); + assert(v <= universe); + + bvb.set(of.bits_offset + v, 1); + + if (i && (i & sample1_mask) == 0) { + uint64_t ptr1 = i >> of.log_sampling1; + assert(ptr1 > 0); + offset = of.pointers1_offset + (ptr1 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.bits_offset); + bvb.set_bits(offset, v, of.pointer_size); + } + + set_rank1_samples(last + 1, v + 1, i); + last = v; + } + + set_rank1_samples(last + 1, universe, n); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_bv(&bv), + m_of(offset, universe, n, params), + m_position(size()), + m_value(m_of.universe) + { + } + + value_type move(uint64_t position) + { + assert(position <= size()); + + if (position == m_position) { + return value(); + } + + // optimize small forward skips + uint64_t skip = position - m_position; + if (PISA_LIKELY(position > m_position && skip <= linear_scan_threshold)) { + m_position = position; + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + } else { + BitVector::unary_enumerator he = m_enumerator; + for (size_t i = 0; i < skip; ++i) { + he.next(); + } + m_value = he.position() - m_of.bits_offset; + m_enumerator = he; + } + + return value(); + } + + return slow_move(position); + } + + value_type next_geq(uint64_t lower_bound) + { + if (lower_bound == m_value) { + return value(); + } + + uint64_t diff = lower_bound - m_value; + if (PISA_LIKELY(lower_bound > m_value && diff <= linear_scan_threshold)) { + // optimize small skips + BitVector::unary_enumerator he = m_enumerator; + uint64_t val; + do { + m_position += 1; + if (PISA_LIKELY(m_position < size())) { + val = he.next() - m_of.bits_offset; + } else { + m_position = size(); + val = m_of.universe; + break; + } + } while (val < lower_bound); + + m_value = val; + m_enumerator = he; + return value(); + } else { + return slow_next_geq(lower_bound); + } + } + + value_type next() + { + m_position += 1; + assert(m_position <= size()); + + if (PISA_LIKELY(m_position < size())) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + return value(); + } + + uint64_t size() const { return m_of.n; } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + + uint64_t pos = 0; + if (PISA_LIKELY(m_position < size())) { + pos = m_bv->predecessor1(m_enumerator.position() - 1); + } else { + pos = m_bv->predecessor1(m_of.end - 1); + } + + return pos - m_of.bits_offset; + } + + private: + value_type PISA_NOINLINE slow_move(uint64_t position) + { + uint64_t skip = position - m_position; + if (PISA_UNLIKELY(position == size())) { + m_position = position; + m_value = m_of.universe; + return value(); + } + + uint64_t to_skip; + if (position > m_position && (skip >> m_of.log_sampling1) == 0) { + to_skip = skip - 1; + } else { + uint64_t ptr = position >> m_of.log_sampling1; + uint64_t ptr_pos = pointer1(ptr); + + m_enumerator = BitVector::unary_enumerator(*m_bv, m_of.bits_offset + ptr_pos); + to_skip = position - (ptr << m_of.log_sampling1); + } + + m_enumerator.skip(to_skip); + m_position = position; + m_value = read_next(); + + return value(); + } + + value_type PISA_NOINLINE slow_next_geq(uint64_t lower_bound) + { + using broadword::popcount; + + if (PISA_UNLIKELY(lower_bound >= m_of.universe)) { + return move(size()); + } + + uint64_t skip = lower_bound - m_value; + m_enumerator = BitVector::unary_enumerator(*m_bv, m_of.bits_offset + lower_bound); + + uint64_t begin; + if (lower_bound > m_value && (skip >> m_of.log_rank1_sampling) == 0) { + begin = m_of.bits_offset + m_value; + } else { + uint64_t block = lower_bound >> m_of.log_rank1_sampling; + m_position = rank1_sample(block); + + begin = m_of.bits_offset + (block << m_of.log_rank1_sampling); + } + + uint64_t end = m_of.bits_offset + lower_bound; + uint64_t begin_word = begin / 64; + uint64_t begin_shift = begin % 64; + uint64_t end_word = end / 64; + uint64_t end_shift = end % 64; + uint64_t word = (m_bv->data()[begin_word] >> begin_shift) << begin_shift; + + while (begin_word < end_word) { + m_position += popcount(word); + word = m_bv->data()[++begin_word]; + } + if (end_shift) { + m_position += popcount(word << (64 - end_shift)); + } + + if (m_position < size()) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + + return value(); + } + + static const uint64_t linear_scan_threshold = 8; + + inline value_type value() const { return value_type(m_position, m_value); } + + inline uint64_t read_next() { return m_enumerator.next() - m_of.bits_offset; } + + inline uint64_t pointer(uint64_t offset, uint64_t i, uint64_t size) const + { + if (i == 0) { + return 0; + } else { + return m_bv->get_word56(offset + (i - 1) * size) & ((uint64_t(1) << size) - 1); + } + } + + inline uint64_t pointer1(uint64_t i) const + { + return pointer(m_of.pointers1_offset, i, m_of.pointer_size); + } + + inline uint64_t rank1_sample(uint64_t i) const + { + return pointer(m_of.rank1_samples_offset, i, m_of.rank1_sample_size); + } + + BitVector const* m_bv; + offsets m_of; + + uint64_t m_position; + uint64_t m_value; + BitVector::unary_enumerator m_enumerator; + }; +}; + +struct AllOnesSequence { + + inline static uint64_t bitsize(global_parameters const& /* params */, + uint64_t universe, + uint64_t n) + { + return (universe == n) ? 0 : uint64_t(-1); + } + + template + static void write( + bit_vector_builder&, Iterator, uint64_t universe, uint64_t n, global_parameters const&) + { + assert(universe == n); + (void)universe; + (void)n; + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator( + BitVector const&, uint64_t, uint64_t universe, uint64_t n, global_parameters const&) + : m_universe(universe), m_position(size()) + { + assert(universe == n); + (void)n; + } + + value_type move(uint64_t position) + { + assert(position <= size()); + m_position = position; + return value_type(m_position, m_position); + } + + value_type next_geq(uint64_t lower_bound) + { + assert(lower_bound <= size()); + m_position = lower_bound; + return value_type(m_position, m_position); + } + + value_type next() + { + m_position += 1; + return value_type(m_position, m_position); + } + + uint64_t size() const { return m_universe; } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + return m_position - 1; + } + + private: + uint64_t m_universe; + uint64_t m_position; + }; +}; + +struct IndexedSequence { + + enum index_type { + elias_fano = 0, + ranked_bitvector = 1, + all_ones = 2, + + index_types = 3 + }; + + static const uint64_t type_bits = 1; // all_ones is implicit + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + + uint64_t ef_cost = CompactEliasFano::bitsize(params, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(params, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + } + + return best_cost; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + int best_type = all_ones; + + if (best_cost) { + uint64_t ef_cost = CompactEliasFano::bitsize(params, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + best_type = elias_fano; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(params, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + best_type = ranked_bitvector; + } + + bvb.append_bits(best_type, type_bits); + } + + switch (best_type) { + case elias_fano: + CompactEliasFano::write(bvb, begin, universe, n, params); + break; + case ranked_bitvector: + CompactRankedBitvector::write(bvb, begin, universe, n, params); + break; + case all_ones: + AllOnesSequence::write(bvb, begin, universe, n, params); + break; + default: + assert(false); + } + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_universe(universe) + { + if (AllOnesSequence::bitsize(params, universe, n) == 0) { + m_type = all_ones; + } else { + m_type = index_type(bv.get_word56(offset) & ((uint64_t(1) << type_bits) - 1)); + } + + switch (m_type) { + case elias_fano: + m_enumerator = + CompactEliasFano::enumerator(bv, offset + type_bits, universe, n, params); + break; + case ranked_bitvector: + m_enumerator = + CompactRankedBitvector::enumerator(bv, offset + type_bits, universe, n, params); + break; + case all_ones: + m_enumerator = + AllOnesSequence::enumerator(bv, offset + type_bits, universe, n, params); + break; + default: + throw std::invalid_argument("Unsupported type"); + } + } + + value_type move(uint64_t position) + { + return boost::apply_visitor([&position](auto&& e) { return e.move(position); }, + m_enumerator); + } + value_type next_geq(uint64_t lower_bound) + { + return boost::apply_visitor( + [&lower_bound](auto&& e) { return e.next_geq(lower_bound); }, m_enumerator); + } + value_type next() + { + return boost::apply_visitor([](auto&& e) { return e.next(); }, m_enumerator); + } + + uint64_t size() const + { + return boost::apply_visitor([](auto&& e) { return e.size(); }, m_enumerator); + } + + uint64_t prev_value() const + { + return boost::apply_visitor([](auto&& e) { return e.prev_value(); }, m_enumerator); + } + + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + + private: + index_type m_type; + boost::variant + m_enumerator; + std::uint32_t m_universe; + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/partitioned_sequence.hpp b/include/pisa/v1/sequence/partitioned_sequence.hpp new file mode 100644 index 000000000..6dce91c36 --- /dev/null +++ b/include/pisa/v1/sequence/partitioned_sequence.hpp @@ -0,0 +1,403 @@ +#pragma once + +#include "tbb/task_group.h" +#include + +#include "codec/integer_codes.hpp" +#include "configuration.hpp" +#include "global_parameters.hpp" +#include "optimal_partition.hpp" +#include "util/util.hpp" +#include "v1/bit_vector.hpp" +#include "v1/sequence/indexed_sequence.hpp" + +namespace pisa::v1 { + +template +struct PartitionedSequence { + + using base_sequence_type = BaseSequence; + using base_sequence_enumerator = typename base_sequence_type::enumerator; + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + { + assert(n > 0); + auto partition = compute_partition(begin, universe, n, params); + + size_t partitions = partition.size(); + assert(partitions > 0); + assert(partition.front() != 0); + assert(partition.back() == n); + write_gamma_nonzero(bvb, partitions); + + std::vector cur_partition; + std::uint64_t cur_base = 0; + if (partitions == 1) { + cur_base = *begin; + Iterator it = begin; + + for (size_t i = 0; i < n; ++i, ++it) { + cur_partition.push_back(*it - cur_base); + } + + std::uint64_t universe_bits = ceil_log2(universe); + bvb.append_bits(cur_base, universe_bits); + + // write universe only if non-singleton and not tight + if (n > 1) { + if (cur_base + cur_partition.back() + 1 == universe) { + // tight universe + write_delta(bvb, 0); + } else { + write_delta(bvb, cur_partition.back()); + } + } + + base_sequence_type::write( + bvb, cur_partition.begin(), cur_partition.back() + 1, cur_partition.size(), params); + } else { + bit_vector_builder bv_sequences; + std::vector endpoints; + std::vector upper_bounds; + + std::uint64_t cur_i = 0; + Iterator it = begin; + cur_base = *begin; + upper_bounds.push_back(cur_base); + + for (size_t p = 0; p < partition.size(); ++p) { + cur_partition.clear(); + std::uint64_t value = 0; + for (; cur_i < partition[p]; ++cur_i, ++it) { + value = *it; + cur_partition.push_back(value - cur_base); + } + + std::uint64_t upper_bound = value; + assert(cur_partition.size() > 0); + base_sequence_type::write(bv_sequences, + cur_partition.begin(), + cur_partition.back() + 1, + cur_partition.size(), // XXX skip last one? + params); + endpoints.push_back(bv_sequences.size()); + upper_bounds.push_back(upper_bound); + cur_base = upper_bound + 1; + } + + bit_vector_builder bv_sizes; + CompactEliasFano::write(bv_sizes, partition.begin(), n, partitions - 1, params); + + bit_vector_builder bv_upper_bounds; + CompactEliasFano::write( + bv_upper_bounds, upper_bounds.begin(), universe, partitions + 1, params); + + std::uint64_t endpoint_bits = ceil_log2(bv_sequences.size() + 1); + write_gamma(bvb, endpoint_bits); + + bvb.append(bv_sizes); + bvb.append(bv_upper_bounds); + + for (std::uint64_t p = 0; p < endpoints.size() - 1; ++p) { + bvb.append_bits(endpoints[p], endpoint_bits); + } + + bvb.append(bv_sequences); + } + } + + class enumerator { + public: + using value_type = std::pair; // (position, value) + + enumerator(BitVector const& bv, + std::uint64_t offset, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + : m_params(params), m_size(n), m_universe(universe), m_bv(&bv) + { + BitVector::enumerator it(bv, offset); + m_partitions = read_gamma_nonzero(it); + if (m_partitions == 1) { + m_cur_partition = 0; + m_cur_begin = 0; + m_cur_end = n; + + std::uint64_t universe_bits = ceil_log2(universe); + m_cur_base = it.take(universe_bits); + auto ub = 0; + if (n > 1) { + std::uint64_t universe_delta = read_delta(it); + ub = universe_delta > 0U ? universe_delta : (universe - m_cur_base - 1); + } + + m_partition_enum = + base_sequence_enumerator(*m_bv, it.position(), ub + 1, n, m_params); + + m_cur_upper_bound = m_cur_base + ub; + } else { + m_endpoint_bits = read_gamma(it); + + std::uint64_t cur_offset = it.position(); + m_sizes = + CompactEliasFano::enumerator(bv, cur_offset, n, m_partitions - 1, params); + cur_offset += CompactEliasFano::bitsize(params, n, m_partitions - 1); + + m_upper_bounds = CompactEliasFano::enumerator( + bv, cur_offset, universe, m_partitions + 1, params); + cur_offset += CompactEliasFano::bitsize(params, universe, m_partitions + 1); + + m_endpoints_offset = cur_offset; + std::uint64_t endpoints_size = m_endpoint_bits * (m_partitions - 1); + cur_offset += endpoints_size; + + m_sequences_offset = cur_offset; + } + + m_position = size(); + slow_move(); + } + + value_type PISA_ALWAYSINLINE move(std::uint64_t position) + { + assert(position <= size()); + m_position = position; + + if (m_position >= m_cur_begin && m_position < m_cur_end) { + std::uint64_t val = + m_cur_base + m_partition_enum.move(m_position - m_cur_begin).second; + return value_type(m_position, val); + } + + return slow_move(); + } + + // note: this is instantiated oly if BaseSequence has next_geq + template > + value_type PISA_ALWAYSINLINE next_geq(std::uint64_t lower_bound) + { + if (PISA_LIKELY(lower_bound >= m_cur_base && lower_bound <= m_cur_upper_bound)) { + auto val = m_partition_enum.next_geq(lower_bound - m_cur_base); + m_position = m_cur_begin + val.first; + return value_type(m_position, m_cur_base + val.second); + } + return slow_next_geq(lower_bound); + } + + value_type PISA_ALWAYSINLINE next() + { + ++m_position; + + if (PISA_LIKELY(m_position < m_cur_end)) { + std::uint64_t val = m_cur_base + m_partition_enum.next().second; + return value_type(m_position, val); + } + return slow_next(); + } + + [[nodiscard]] auto size() const -> std::uint64_t { return m_size; } + + [[nodiscard]] auto prev_value() const -> std::uint64_t + { + if (PISA_UNLIKELY(m_position == m_cur_begin)) { + return m_cur_partition ? m_cur_base - 1 : 0; + } + return m_cur_base + m_partition_enum.prev_value(); + } + + [[nodiscard]] auto num_partitions() const -> std::uint64_t { return m_partitions; } + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + + friend class partitioned_sequence_test; + + private: + // the compiler does not seem smart enough to figure out that this + // is a very unlikely condition, and inlines the move(0) inside the + // next(), causing the code to grow. Since next is called in very + // tight loops, on microbenchmarks this causes an improvement of + // about 3ns on my i7 3Ghz + value_type PISA_NOINLINE slow_next() + { + if (PISA_UNLIKELY(m_position == m_size)) { + assert(m_cur_partition == m_partitions - 1); + auto val = m_partition_enum.next(); + assert(val.first == m_partition_enum.size()); + (void)val; + return value_type(m_position, m_universe); + } + + switch_partition(m_cur_partition + 1); + std::uint64_t val = m_cur_base + m_partition_enum.move(0).second; + return value_type(m_position, val); + } + + value_type PISA_NOINLINE slow_move() + { + if (m_position == size()) { + if (m_partitions > 1) { + switch_partition(m_partitions - 1); + } + m_partition_enum.move(m_partition_enum.size()); + return value_type(m_position, m_universe); + } + auto size_it = m_sizes.next_geq(m_position + 1); // need endpoint strictly > m_position + switch_partition(size_it.first); + std::uint64_t val = m_cur_base + m_partition_enum.move(m_position - m_cur_begin).second; + return value_type(m_position, val); + } + + value_type PISA_NOINLINE slow_next_geq(std::uint64_t lower_bound) + { + if (m_partitions == 1) { + if (lower_bound < m_cur_base) { + return move(0); + } + return move(size()); + } + + auto ub_it = m_upper_bounds.next_geq(lower_bound); + if (ub_it.first == 0) { + return move(0); + } + + if (ub_it.first == m_upper_bounds.size()) { + return move(size()); + } + + switch_partition(ub_it.first - 1); + return next_geq(lower_bound); + } + + void switch_partition(std::uint64_t partition) + { + assert(m_partitions > 1); + + std::uint64_t endpoint = + partition > 0U + ? (m_bv->get_word56(m_endpoints_offset + (partition - 1) * m_endpoint_bits) + & ((std::uint64_t(1) << m_endpoint_bits) - 1)) + : 0; + + std::uint64_t partition_begin = m_sequences_offset + endpoint; + intrinsics::prefetch(std::next(m_bv->data(), partition_begin / 64)); + + m_cur_partition = partition; + auto size_it = m_sizes.move(partition); + m_cur_end = size_it.second; + m_cur_begin = m_sizes.prev_value(); + + auto ub_it = m_upper_bounds.move(partition + 1); + m_cur_upper_bound = ub_it.second; + m_cur_base = m_upper_bounds.prev_value() + (partition > 0 ? 1 : 0); + + m_partition_enum = base_sequence_enumerator(*m_bv, + partition_begin, + m_cur_upper_bound - m_cur_base + 1, + m_cur_end - m_cur_begin, + m_params); + } + + global_parameters m_params; + std::uint64_t m_partitions; + std::uint64_t m_endpoints_offset; + std::uint64_t m_endpoint_bits; + std::uint64_t m_sequences_offset; + std::uint64_t m_size; + std::uint64_t m_universe; + + std::uint64_t m_position; + std::uint64_t m_cur_partition; + std::uint64_t m_cur_begin; + std::uint64_t m_cur_end; + std::uint64_t m_cur_base; + std::uint64_t m_cur_upper_bound; + + BitVector const* m_bv; + CompactEliasFano::enumerator m_sizes; + CompactEliasFano::enumerator m_upper_bounds; + base_sequence_enumerator m_partition_enum; + }; + + private: + template + static std::vector compute_partition(Iterator begin, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + { + std::vector partition; + + auto const& conf = configuration::get(); + + if (base_sequence_type::bitsize(params, universe, n) < 2 * conf.fix_cost) { + partition.push_back(n); + return partition; + } + + auto cost_fun = [&](std::uint64_t universe, std::uint64_t n) { + return base_sequence_type::bitsize(params, universe, n) + conf.fix_cost; + }; + + const size_t superblock_bound = conf.eps3 != 0 ? size_t(conf.fix_cost / conf.eps3) : n; + + std::deque> superblock_partitions; + tbb::task_group tg; + + size_t superblock_pos = 0; + auto superblock_begin = begin; + auto superblock_base = *begin; + + while (superblock_pos < n) { + size_t superblock_size = std::min(superblock_bound, n - superblock_pos); + // If the remainder is smaller than the bound (possibly + // empty), merge it to the current (now last) superblock. + if (n - (superblock_pos + superblock_size) < superblock_bound) { + superblock_size = n - superblock_pos; + } + auto superblock_last = std::next(superblock_begin, superblock_size - 1); + auto superblock_end = std::next(superblock_last); + + // If this is the last superblock, its universe is the + // list universe. + size_t superblock_universe = + superblock_pos + superblock_size == n ? universe : *superblock_last + 1; + + superblock_partitions.emplace_back(); + auto& superblock_partition = superblock_partitions.back(); + + tg.run([=, &cost_fun, &conf, &superblock_partition] { + optimal_partition opt(superblock_begin, + superblock_base, + superblock_universe, + superblock_size, + cost_fun, + conf.eps1, + conf.eps2); + + superblock_partition.reserve(opt.partition.size()); + for (auto& endpoint : opt.partition) { + superblock_partition.push_back(superblock_pos + endpoint); + } + }); + + superblock_pos += superblock_size; + superblock_begin = superblock_end; + superblock_base = superblock_universe; + } + tg.wait(); + + for (const auto& superblock_partition : superblock_partitions) { + partition.insert( + partition.end(), superblock_partition.begin(), superblock_partition.end()); + } + + return partition; + } +}; +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/positive_sequence.hpp b/include/pisa/v1/sequence/positive_sequence.hpp new file mode 100644 index 000000000..c2cc3ddd2 --- /dev/null +++ b/include/pisa/v1/sequence/positive_sequence.hpp @@ -0,0 +1,311 @@ +#pragma once + +#include + +#include + +#include "global_parameters.hpp" +#include "util/util.hpp" +#include "v1/bit_vector.hpp" +#include "v1/sequence/indexed_sequence.hpp" + +namespace pisa::v1 { + +struct StrictEliasFano { + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + assert(universe >= n); + return CompactEliasFano::bitsize(params, universe - n + 1, n); + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t new_universe = universe - n + 1; + typedef typename std::iterator_traits::value_type value_type; + auto new_begin = make_function_iterator( + std::make_pair(value_type(0), begin), + [](std::pair& state) { + ++state.first; + ++state.second; + }, + [](std::pair const& state) { + return *state.second - state.first; + }); + CompactEliasFano::write(bvb, new_begin, new_universe, n, params); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_ef_enum(bv, offset, universe - n + 1, n, params) + { + } + + value_type move(uint64_t position) + { + auto val = m_ef_enum.move(position); + return value_type(val.first, val.second + val.first); + } + + value_type next() + { + auto val = m_ef_enum.next(); + return value_type(val.first, val.second + val.first); + } + + uint64_t size() const { return m_ef_enum.size(); } + + uint64_t prev_value() const + { + if (m_ef_enum.position()) { + return m_ef_enum.prev_value() + m_ef_enum.position() - 1; + } else { + return 0; + } + } + + private: + CompactEliasFano::enumerator m_ef_enum; + }; +}; + +struct StrictSequence { + + enum index_type { + elias_fano = 0, + ranked_bitvector = 1, + all_ones = 2, + + index_types = 3 + }; + + static const uint64_t type_bits = 1; // all_ones is implicit + + static global_parameters strict_params(global_parameters params) + { + // we do not need to index the zeros + params.ef_log_sampling0 = 63; + params.rb_log_rank1_sampling = 63; + return params; + } + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + auto sparams = strict_params(params); + + uint64_t ef_cost = StrictEliasFano::bitsize(sparams, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(sparams, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + } + + return best_cost; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + auto sparams = strict_params(params); + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + int best_type = all_ones; + + if (best_cost) { + uint64_t ef_cost = StrictEliasFano::bitsize(sparams, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + best_type = elias_fano; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(sparams, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + best_type = ranked_bitvector; + } + + bvb.append_bits(best_type, type_bits); + } + + switch (best_type) { + case elias_fano: + StrictEliasFano::write(bvb, begin, universe, n, sparams); + break; + case ranked_bitvector: + CompactRankedBitvector::write(bvb, begin, universe, n, sparams); + break; + case all_ones: + AllOnesSequence::write(bvb, begin, universe, n, sparams); + break; + default: + assert(false); + } + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + + auto sparams = strict_params(params); + + if (AllOnesSequence::bitsize(params, universe, n) == 0) { + m_type = all_ones; + } else { + m_type = index_type(bv.get_word56(offset) & ((uint64_t(1) << type_bits) - 1)); + } + + switch (m_type) { + case elias_fano: + m_enumerator = + StrictEliasFano::enumerator(bv, offset + type_bits, universe, n, sparams); + break; + case ranked_bitvector: + m_enumerator = CompactRankedBitvector::enumerator( + bv, offset + type_bits, universe, n, sparams); + break; + case all_ones: + m_enumerator = + AllOnesSequence::enumerator(bv, offset + type_bits, universe, n, sparams); + break; + default: + throw std::invalid_argument("Unsupported type"); + } + } + + value_type move(uint64_t position) + { + return boost::apply_visitor([&position](auto&& e) { return e.move(position); }, + m_enumerator); + } + + value_type next() + { + return boost::apply_visitor([](auto&& e) { return e.next(); }, m_enumerator); + } + + uint64_t size() const + { + return boost::apply_visitor([](auto&& e) { return e.size(); }, m_enumerator); + } + + uint64_t prev_value() const + { + return boost::apply_visitor([](auto&& e) { return e.prev_value(); }, m_enumerator); + } + + private: + index_type m_type; + boost::variant + m_enumerator; + }; +}; + +template +struct PositiveSequence { + + typedef BaseSequence base_sequence_type; + typedef typename base_sequence_type::enumerator base_sequence_enumerator; + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + assert(n > 0); + auto cumulative_begin = make_function_iterator( + std::make_pair(uint64_t(0), begin), + [](std::pair& state) { state.first += *state.second++; }, + [](std::pair const& state) { return state.first + *state.second; }); + base_sequence_type::write(bvb, cumulative_begin, universe, n, params); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() = delete; + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_base_enum(bv, offset, universe, n, params), + m_position(m_base_enum.size()), + m_universe(universe) + { + } + + value_type next() { return move(m_position + 1); } + auto size() const { return m_base_enum.size(); } + auto universe() const { return m_universe; } + + value_type move(uint64_t position) + { + // we cache m_position and m_cur to avoid the call overhead in + // the most common cases + uint64_t prev = m_cur; + if (position != m_position + 1) { + if (PISA_UNLIKELY(position == 0)) { + // we need to special-case position 0 + m_cur = m_base_enum.move(0).second; + m_position = 0; + return value_type(m_position, m_cur); + } + prev = m_base_enum.move(position - 1).second; + } + + m_cur = m_base_enum.next().second; + m_position = position; + return value_type(position, m_cur - prev); + } + + base_sequence_enumerator const& base() const { return m_base_enum; } + + private: + base_sequence_enumerator m_base_enum; + uint64_t m_position; + uint64_t m_cur{}; + uint64_t m_universe{0}; + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/source.hpp b/include/pisa/v1/source.hpp new file mode 100644 index 000000000..1335ae6a2 --- /dev/null +++ b/include/pisa/v1/source.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include +#include + +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct VectorSource { + std::vector> bytes{}; + std::vector> offsets{}; + std::vector> sizes{}; +}; + +struct MMapSource { + MMapSource() = default; + MMapSource(MMapSource &&) = default; + MMapSource(MMapSource const &) = default; + MMapSource &operator=(MMapSource &&) = default; + MMapSource &operator=(MMapSource const &) = default; + ~MMapSource() = default; + std::vector> file_sources{}; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/taat_or.hpp b/include/pisa/v1/taat_or.hpp new file mode 100644 index 000000000..d1e7e1f46 --- /dev/null +++ b/include/pisa/v1/taat_or.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto taat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + std::vector accumulator(index.num_documents(), 0.0F); + for (auto term : query.get_term_ids()) { + v1::for_each(index.scored_cursor(term, scorer), + [&accumulator](auto&& cursor) { accumulator[*cursor] += cursor.payload(); }); + } + for (auto document = 0; document < accumulator.size(); document += 1) { + topk.insert(accumulator[document], document); + } + return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp new file mode 100644 index 000000000..7e81b87b5 --- /dev/null +++ b/include/pisa/v1/types.hpp @@ -0,0 +1,203 @@ +#pragma once + +#include +#include +#include + +#include + +#include "binary_freq_collection.hpp" + +#define Unreachable() std::abort(); + +namespace pisa::v1 { + +using TermId = std::uint32_t; +using DocId = std::uint32_t; +using Frequency = std::uint32_t; +using Score = float; +using Result = std::pair; +using ByteOStream = std::basic_ostream; + +enum EncodingId { + Raw = 0xDA43, + BlockDelta = 0xEF00, + Block = 0xFF00, + BitSequence = 0xDF00, + SimdBP = 0x0001, + Varbyte = 0x0002, + PEF = 0x0003, + PositiveSeq = 0x0004 +}; + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; + +template +overloaded(Ts...)->overloaded; + +struct BaseIndex; + +template +struct Reader { + using Value = std::decay_t())>; + + template + explicit constexpr Reader(R reader) : m_internal_reader(std::make_unique>(reader)) + { + } + Reader() = default; + Reader(Reader const& other) : m_internal_reader(other.m_internal_reader->clone()) {} + Reader(Reader&& other) noexcept = default; + Reader& operator=(Reader const& other) = delete; + Reader& operator=(Reader&& other) noexcept = default; + ~Reader() = default; + + void init(BaseIndex const& index) { m_internal_reader->init(index); } + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor + { + return m_internal_reader->read(bytes); + } + [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_reader->encoding(); } + + struct ReaderInterface { + ReaderInterface() = default; + ReaderInterface(ReaderInterface const&) = default; + ReaderInterface(ReaderInterface&&) noexcept = default; + ReaderInterface& operator=(ReaderInterface const&) = default; + ReaderInterface& operator=(ReaderInterface&&) noexcept = default; + virtual ~ReaderInterface() = default; + virtual void init(BaseIndex const& index) = 0; + [[nodiscard]] virtual auto read(gsl::span bytes) const -> Cursor = 0; + [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct ReaderImpl : ReaderInterface { + explicit ReaderImpl(R reader) : m_reader(std::move(reader)) {} + ReaderImpl() = default; + ReaderImpl(ReaderImpl const&) = default; + ReaderImpl(ReaderImpl&&) noexcept = default; + ReaderImpl& operator=(ReaderImpl const&) = default; + ReaderImpl& operator=(ReaderImpl&&) noexcept = default; + ~ReaderImpl() = default; + void init(BaseIndex const& index) override { m_reader.init(index); } + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor override + { + return m_reader.read(bytes); + } + [[nodiscard]] auto encoding() const -> std::uint32_t override { return R::encoding(); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_reader; + }; + + private: + std::unique_ptr m_internal_reader; +}; + +template +struct Writer { + using Value = T; + + template + explicit constexpr Writer(W writer) : m_internal_writer(std::make_unique>(writer)) + { + } + Writer() = default; + Writer(Writer const& other) : m_internal_writer(other.m_internal_writer->clone()) {} + Writer(Writer&& other) noexcept = default; + Writer& operator=(Writer const& other) = delete; + Writer& operator=(Writer&& other) noexcept = default; + ~Writer() = default; + + void init(pisa::binary_freq_collection const& collection) + { + m_internal_writer->init(collection); + } + void push(T const& posting) { m_internal_writer->push(posting); } + void push(T&& posting) { m_internal_writer->push(posting); } + auto write(ByteOStream& os) const -> std::size_t { return m_internal_writer->write(os); } + auto write(std::ostream& os) const -> std::size_t { return m_internal_writer->write(os); } + [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_writer->encoding(); } + void reset() { return m_internal_writer->reset(); } + + struct WriterInterface { + WriterInterface() = default; + WriterInterface(WriterInterface const&) = default; + WriterInterface(WriterInterface&&) noexcept = default; + WriterInterface& operator=(WriterInterface const&) = default; + WriterInterface& operator=(WriterInterface&&) noexcept = default; + virtual ~WriterInterface() = default; + virtual void init(pisa::binary_freq_collection const& collection) = 0; + virtual void push(T const& posting) = 0; + virtual void push(T&& posting) = 0; + virtual auto write(ByteOStream& os) const -> std::size_t = 0; + virtual auto write(std::ostream& os) const -> std::size_t = 0; + virtual void reset() = 0; + [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct WriterImpl : WriterInterface { + explicit WriterImpl(W writer) : m_writer(std::move(writer)) {} + WriterImpl() = default; + WriterImpl(WriterImpl const&) = default; + WriterImpl(WriterImpl&&) noexcept = default; + WriterImpl& operator=(WriterImpl const&) = default; + WriterImpl& operator=(WriterImpl&&) noexcept = default; + ~WriterImpl() = default; + void init(pisa::binary_freq_collection const& collection) override + { + m_writer.init(collection); + } + void push(T const& posting) override { m_writer.push(posting); } + void push(T&& posting) override { m_writer.push(posting); } + auto write(ByteOStream& os) const -> std::size_t override { return m_writer.write(os); } + auto write(std::ostream& os) const -> std::size_t override { return m_writer.write(os); } + void reset() override { return m_writer.reset(); } + [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + W m_writer; + }; + + private: + std::unique_ptr m_internal_writer; +}; + +template +[[nodiscard]] inline auto make_writer(W&& writer) +{ + return Writer(std::forward(writer)); +} + +template +[[nodiscard]] inline auto make_writer() +{ + return Writer(W{}); +} + +/// Indicates that payloads should be treated as scores. +/// To be used with pre-computed scores, be it floats or quantized ints. +struct VoidScorer { +}; + +template +struct encoding_traits; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/unaligned_span.hpp b/include/pisa/v1/unaligned_span.hpp new file mode 100644 index 000000000..2f57b3c97 --- /dev/null +++ b/include/pisa/v1/unaligned_span.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include +#include + +#include +#include + +#include "v1/bit_cast.hpp" + +namespace pisa::v1 { + +template +struct UnalignedSpan; + +template +struct UnalignedSpanIterator { + UnalignedSpanIterator(std::uint32_t index, UnalignedSpan const& span) + : m_index(index), m_span(span) + { + } + UnalignedSpanIterator(UnalignedSpanIterator const&) = default; + UnalignedSpanIterator(UnalignedSpanIterator&&) noexcept = default; + UnalignedSpanIterator& operator=(UnalignedSpanIterator const&) = default; + UnalignedSpanIterator& operator=(UnalignedSpanIterator&&) noexcept = default; + ~UnalignedSpanIterator() = default; + [[nodiscard]] auto operator==(UnalignedSpanIterator const& other) const + { + return m_span.bytes().data() == other.m_span.bytes().data() && m_index == other.m_index; + } + [[nodiscard]] auto operator!=(UnalignedSpanIterator const& other) const + { + return m_index != other.m_index || m_span.bytes().data() != other.m_span.bytes().data(); + } + [[nodiscard]] auto operator*() const { return m_span[m_index]; } + auto operator++() -> UnalignedSpanIterator& + { + m_index++; + return *this; + } + auto operator++(int) -> UnalignedSpanIterator + { + auto copy = *this; + m_index++; + return copy; + } + [[nodiscard]] auto operator+=(std::uint32_t n) -> UnalignedSpanIterator& + { + m_index += n; + return *this; + } + [[nodiscard]] auto operator+(std::uint32_t n) const -> UnalignedSpanIterator + { + return UnalignedSpanIterator(m_index + n, m_span); + } + auto operator--() -> UnalignedSpanIterator& + { + m_index--; + return *this; + } + auto operator--(int) -> UnalignedSpanIterator + { + auto copy = *this; + m_index--; + return copy; + } + [[nodiscard]] auto operator-=(std::uint32_t n) -> UnalignedSpanIterator& + { + m_index -= n; + return *this; + } + [[nodiscard]] auto operator-(std::uint32_t n) const -> UnalignedSpanIterator + { + return UnalignedSpanIterator(m_index - n, m_span); + } + [[nodiscard]] auto operator-(UnalignedSpanIterator const& other) const -> std::int32_t + { + return static_cast(m_index) - static_cast(other.m_index); + } + [[nodiscard]] auto operator<(UnalignedSpanIterator const& other) const -> bool + { + return m_index < other.m_index; + } + [[nodiscard]] auto operator<=(UnalignedSpanIterator const& other) const -> bool + { + return m_index <= other.m_index; + } + [[nodiscard]] auto operator>(UnalignedSpanIterator const& other) const -> bool + { + return m_index > other.m_index; + } + [[nodiscard]] auto operator>=(UnalignedSpanIterator const& other) const -> bool + { + return m_index >= other.m_index; + } + + private: + std::uint32_t m_index; + UnalignedSpan const& m_span; +}; + +template +struct UnalignedSpan { + static_assert(std::is_trivially_copyable_v); + using value_type = T; + + constexpr UnalignedSpan() = default; + explicit constexpr UnalignedSpan(gsl::span bytes) : m_bytes(bytes) + { + if (m_bytes.size() % sizeof(value_type) != 0) { + throw std::logic_error("Number of bytes must be a multiplier of type size"); + } + } + constexpr UnalignedSpan(UnalignedSpan const&) = default; + constexpr UnalignedSpan(UnalignedSpan&&) noexcept = default; + constexpr UnalignedSpan& operator=(UnalignedSpan const&) = default; + constexpr UnalignedSpan& operator=(UnalignedSpan&&) noexcept = default; + ~UnalignedSpan() = default; + + using iterator = UnalignedSpanIterator; + + [[nodiscard]] auto operator[](std::uint32_t index) const -> value_type + { + return bit_cast( + m_bytes.subspan(index * sizeof(value_type), sizeof(value_type))); + } + + [[nodiscard]] auto front() const -> value_type + { + return bit_cast(m_bytes.subspan(0, sizeof(value_type))); + } + + [[nodiscard]] auto back() const -> value_type + { + return bit_cast( + m_bytes.subspan(m_bytes.size() - sizeof(value_type), sizeof(value_type))); + } + + [[nodiscard]] auto begin() const -> iterator { return iterator(0, *this); } + [[nodiscard]] auto end() const -> iterator { return iterator(size(), *this); } + + [[nodiscard]] auto size() const -> std::size_t { return m_bytes.size() / sizeof(value_type); } + [[nodiscard]] auto byte_size() const -> std::size_t { return m_bytes.size(); } + [[nodiscard]] auto bytes() const -> gsl::span { return m_bytes; } + [[nodiscard]] auto empty() const -> bool { return m_bytes.empty(); } + + private: + gsl::span m_bytes{}; +}; + +} // namespace pisa::v1 + +namespace std { + +template +struct iterator_traits<::pisa::v1::UnalignedSpanIterator> { + using size_type = std::uint32_t; + using difference_type = std::make_signed_t; + using value_type = T; + using pointer = T const*; + using reference = T const&; + using iterator_category = std::random_access_iterator_tag; +}; + +} // namespace std diff --git a/include/pisa/v1/unigram_union_lookup.hpp b/include/pisa/v1/unigram_union_lookup.hpp new file mode 100644 index 000000000..98cdfcada --- /dev/null +++ b/include/pisa/v1/unigram_union_lookup.hpp @@ -0,0 +1,101 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/reference.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/union_lookup_join.hpp" + +namespace pisa::v1 { + +/// Processes documents with the Union-Lookup method. +/// This is an optimized version that works **only on single-term posting lists**. +/// It will throw an exception if bigram selections are passed to it. +template +auto unigram_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + [[maybe_unused]] Inspect* inspect = nullptr) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + + auto const& selections = query.get_selections(); + runtime_assert(selections.bigrams.empty()).or_throw("This algorithm only supports unigrams"); + + topk.set_threshold(query.get_threshold()); + + auto non_essential_terms = + ranges::views::set_difference(term_ids, selections.unigrams) | ranges::to_vector; + + auto essential_cursors = index.max_scored_cursors(selections.unigrams, scorer); + auto lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); + ranges::sort(lookup_cursors, [](auto&& l, auto&& r) { return l.max_score() > r.max_score(); }); + + if constexpr (not std::is_void_v) { + inspect->essential(essential_cursors.size()); + } + + auto joined = join_union_lookup( + std::move(essential_cursors), + std::move(lookup_cursors), + payload_type{}, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectUnigramUnionLookup : Inspect { + + InspectUnigramUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp new file mode 100644 index 000000000..147248619 --- /dev/null +++ b/include/pisa/v1/union_lookup.hpp @@ -0,0 +1,783 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/lookup_transform.hpp" +#include "v1/cursor/reference.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/maxscore_union_lookup.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/unigram_union_lookup.hpp" +#include "v1/union_lookup_join.hpp" + +namespace pisa::v1 { + +template +auto filter_bigram_lookup_cursors( + Index const& index, Scorer&& scorer, LookupCursors&& lookup_cursors, TermId left, TermId right) +{ + return ranges::views::filter( + lookup_cursors, + [&](auto&& cursor) { return cursor.label() != left && cursor.label() != right; }) + //| ranges::views::transform([](auto&& cursor) { return ref(cursor); }) + | ranges::views::transform( + [&](auto&& cursor) { return index.max_scored_cursor(cursor.label(), scorer); }) + | ranges::to_vector; +} + +/// This algorithm... +template +auto lookup_union(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + InspectUnigram* inspect_unigram = nullptr, + InspectBigram* inspect_bigram = nullptr) +{ + using bigram_cursor_type = std::decay_t; + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); + auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; + + auto const& selections = query.get_selections(); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + if constexpr (not std::is_void_v) { + inspect_unigram->essential(essential_unigrams.size()); + } + if constexpr (not std::is_void_v) { + inspect_bigram->essential(essential_bigrams.size()); + } + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + auto lookup_cursors = index.cursors(non_essential_terms, [&](auto&& index, auto term) { + return label(index.max_scored_cursor(term, scorer), term); + }); + ranges::sort(lookup_cursors, std::greater{}, func::max_score{}); + auto unigram_cursor = + join_union_lookup(index.max_scored_cursors(gsl::make_span(essential_unigrams), scorer), + gsl::make_span(lookup_cursors), + 0.0F, + accumulators::Add{}, + is_above_threshold, + inspect_unigram); + + using lookup_transform_type = + LookupTransform; + using transform_payload_cursor_type = + TransformPayloadCursor; + + std::vector bigram_cursors; + + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + auto bigram_lookup_cursors = + filter_bigram_lookup_cursors(index, scorer, lookup_cursors, left, right); + auto lookup_cursors_upper_bound = + std::accumulate(bigram_lookup_cursors.begin(), + bigram_lookup_cursors.end(), + 0.0F, + [](auto acc, auto&& cursor) { return acc + cursor.max_score(); }); + bigram_cursors.emplace_back(std::move(*cursor.take()), + LookupTransform(std::move(bigram_lookup_cursors), + lookup_cursors_upper_bound, + is_above_threshold, + inspect_bigram)); + } + + auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { + return acc == 0 ? cursor.payload() : acc; + }; + auto bigram_cursor = union_merge( + std::move(bigram_cursors), 0.0F, [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { + if constexpr (not std::is_void_v) { + inspect_bigram->posting(); + } + return acc == 0 ? cursor.payload() : acc; + }); + auto merged = v1::variadic_union_merge( + 0.0F, + std::make_tuple(std::move(unigram_cursor), std::move(bigram_cursor)), + std::make_tuple(accumulate, accumulate)); + + v1::for_each(merged, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect_unigram->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +auto accumulate_cursor_to_heap(Cursor&& cursor, + std::size_t k, + float threshold = 0.0, + InspectInserts* inspect_inserts = nullptr, + InspectPostings* inspect_postings = nullptr) +{ + topk_queue heap(k); + heap.set_threshold(threshold); + v1::for_each(cursor, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + inspect_postings->posting(); + } + if constexpr (not std::is_void_v) { + if (heap.insert(cursor.payload(), cursor.value())) { + inspect_inserts->insert(); + } + } else { + heap.insert(cursor.payload(), cursor.value()); + } + }); + return heap; +} + +/// This algorithm... +template +auto lookup_union_eaat(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + InspectUnigram* inspect_unigram = nullptr, + InspectBigram* inspect_bigram = nullptr) +{ + using bigram_cursor_type = std::decay_t; + using lookup_cursor_type = std::decay_t; + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); + auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; + + auto const& selections = query.get_selections(); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + if constexpr (not std::is_void_v) { + inspect_unigram->essential(essential_unigrams.size()); + } + if constexpr (not std::is_void_v) { + inspect_bigram->essential(essential_bigrams.size()); + } + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + auto unigram_cursor = [&]() { + auto lookup_cursors = index.max_scored_cursors(gsl::make_span(non_essential_terms), scorer); + ranges::sort(lookup_cursors, + [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); + auto essential_cursors = + index.max_scored_cursors(gsl::make_span(essential_unigrams), scorer); + + return join_union_lookup(std::move(essential_cursors), + std::move(lookup_cursors), + 0.0F, + accumulators::Add{}, + is_above_threshold, + inspect_unigram); + }(); + + auto unigram_heap = + accumulate_cursor_to_heap(unigram_cursor, topk.size(), threshold, inspect_unigram); + + using lookup_transform_type = + LookupTransform; + using transform_payload_cursor_type = + TransformPayloadCursor; + + std::vector entries(unigram_heap.topk().begin(), + unigram_heap.topk().end()); + + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + std::vector essential_terms{left, right}; + auto lookup_terms = + ranges::views::set_difference(non_essential_terms, essential_terms) | ranges::to_vector; + + auto lookup_cursors = index.max_scored_cursors(lookup_terms, scorer); + ranges::sort(lookup_cursors, + [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); + + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + + auto heap = accumulate_cursor_to_heap( + transform_payload_cursor_type(std::move(*cursor.take()), + lookup_transform_type(std::move(lookup_cursors), + lookup_cursors_upper_bound, + is_above_threshold, + inspect_bigram)), + topk.size(), + threshold, + inspect_bigram, + inspect_bigram); + std::copy(heap.topk().begin(), heap.topk().end(), std::back_inserter(entries)); + } + std::sort(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + if (lhs.second == rhs.second) { + return lhs.first > rhs.first; + } + return lhs.second < rhs.second; + }); + auto end = std::unique(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + return lhs.second == rhs.second; + }); + entries.erase(end, entries.end()); + std::sort(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + return lhs.first > rhs.first; + }); + if (entries.size() > topk.size()) { + entries.erase(std::next(entries.begin(), topk.size()), entries.end()); + } + + for (auto entry : entries) { + topk.insert(entry.first, entry.second); + } + + return topk; +} + +template +auto union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + if (term_ids.size() > 8) { + throw std::invalid_argument( + "Generic version of union-Lookup supported only for queries of length <= 8"); + } + + auto threshold = query.get_threshold(); + auto const& selections = query.get_selections(); + + using bigram_cursor_type = std::decay_t; + + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + topk.set_threshold(threshold); + + std::array initial_payload{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + if constexpr (not std::is_void_v) { + inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + } + + std::vector essential_unigram_cursors; + std::transform(essential_unigrams.begin(), + essential_unigrams.end(), + std::back_inserter(essential_unigram_cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + + std::vector unigram_query_positions(essential_unigrams.size()); + for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); + unigram_position += 1) { + unigram_query_positions[unigram_position] = + query.sorted_position(essential_unigrams[unigram_position]); + } + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + acc[unigram_query_positions[term_idx]] = cursor.payload(); + return acc; + }); + + std::vector essential_bigram_cursors; + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back(cursor.take().value()); + } + + std::vector> bigram_query_positions( + essential_bigrams.size()); + for (std::size_t bigram_position = 0; bigram_position < essential_bigrams.size(); + bigram_position += 1) { + bigram_query_positions[bigram_position] = + std::make_pair(query.sorted_position(essential_bigrams[bigram_position].first), + query.sorted_position(essential_bigrams[bigram_position].second)); + } + auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto bigram_idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + auto payload = cursor.payload(); + auto query_positions = + bigram_query_positions[bigram_idx]; + acc[query_positions.first] = std::get<0>(payload); + acc[query_positions.second] = std::get<1>(payload); + return acc; + }); + + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + if (acc[idx] == 0) { + acc[idx] = payload[idx]; + } + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); + + auto lookup_cursors = [&]() { + std::vector> + lookup_cursors; + auto pos = term_ids.begin(); + for (auto non_essential_term : non_essential_terms) { + pos = std::find(pos, term_ids.end(), non_essential_term); + assert(pos != term_ids.end()); + auto idx = std::distance(term_ids.begin(), pos); + lookup_cursors.emplace_back(idx, index.max_scored_cursor(non_essential_term, scorer)); + } + return lookup_cursors; + }(); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.second.max_score() > rhs.second.max_score(); + }); + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.second.max_score(); + }); + + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + auto upper_bound = score + lookup_cursors_upper_bound; + for (auto& [idx, lookup_cursor] : lookup_cursors) { + if (not topk.would_enter(upper_bound)) { + return; + } + if (scores[idx] == 0) { + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + inspect->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + } + upper_bound -= lookup_cursor.max_score(); + } + if constexpr (not std::is_void_v) { + if (topk.insert(score, docid)) { + inspect->insert(); + } + } else { + topk.insert(score, docid); + } + }); + return topk; +} + +inline auto precompute_next_lookup(std::size_t essential_count, + std::size_t non_essential_count, + std::vector> const& essential_bigrams) +{ + runtime_assert(essential_count + non_essential_count <= 8).or_throw("Must be shorter than 9"); + std::uint32_t term_count = essential_count + non_essential_count; + std::vector next_lookup((term_count + 1) * (1U << term_count), -1); + auto unnecessary = [&](auto p, auto state) { + if (((1U << p) & state) > 0) { + return true; + } + for (auto k : essential_bigrams[p]) { + if (((1U << k) & state) > 0) { + return true; + } + } + return false; + }; + for (auto term_idx = essential_count; term_idx < term_count; term_idx += 1) { + for (std::uint32_t state = 0; state < (1U << term_count); state += 1) { + auto p = term_idx; + while (p < term_count && unnecessary(p, state)) { + ++p; + } + if (p == term_count) { + next_lookup[(term_idx << term_count) + state] = -1; + } else { + next_lookup[(term_idx << term_count) + state] = p; + } + } + } + return next_lookup; +} + +template +auto union_lookup_plus(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + using bigram_cursor_type = + LabeledCursor, + std::pair>; + + auto term_ids = gsl::make_span(query.get_term_ids()); + std::size_t term_count = term_ids.size(); + if (term_ids.empty()) { + return topk; + } + runtime_assert(term_ids.size() <= 8) + .or_throw("Generic version of union-Lookup supported only for queries of length <= 8"); + topk.set_threshold(query.get_threshold()); + auto const& selections = query.get_selections(); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + std::array initial_payload{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + if constexpr (not std::is_void_v) { + inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + } + + auto essential_unigram_cursors = + index.cursors(essential_unigrams, [&](auto&& index, auto term) { + return label(index.scored_cursor(term, scorer), term); + }); + + auto lookup_cursors = + index.cursors(gsl::make_span(non_essential_terms), [&](auto&& index, auto term) { + return label(index.max_scored_cursor(term, scorer), term); + }); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() > rhs.max_score(); + }); + + auto term_to_position = [&] { + std::unordered_map term_to_position; + std::uint32_t position = 0; + for (auto&& cursor : essential_unigram_cursors) { + term_to_position[cursor.label()] = position++; + } + for (auto&& cursor : lookup_cursors) { + term_to_position[cursor.label()] = position++; + } + return term_to_position; + }(); + + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + acc[idx] = cursor.payload(); + return acc; + }); + + std::vector essential_bigram_cursors; + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back( + label(cursor.take().value(), + std::make_pair(term_to_position[left], term_to_position[right]))); + } + + auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto /* bigram_idx */) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + auto payload = cursor.payload(); + acc[cursor.label().first] = std::get<0>(payload); + acc[cursor.label().second] = std::get<1>(payload); + return acc; + }); + + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + if (acc[idx] == 0.0F) { + acc[idx] = payload[idx]; + } + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); + + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + + auto next_lookup = + precompute_next_lookup(essential_unigrams.size(), lookup_cursors.size(), [&] { + std::vector> mapping(term_ids.size()); + for (auto&& cursor : essential_bigram_cursors) { + auto [left, right] = cursor.label(); + mapping[left].push_back(right); + mapping[right].push_back(left); + } + return mapping; + }()); + auto mus = [&] { + std::vector mus((term_count + 1) * (1U << term_count), 0.0); + for (auto term_idx = term_count; term_idx + 1 >= 1; term_idx -= 1) { + for (std::uint32_t j = (1U << term_count) - 1; j + 1 >= 1; j -= 1) { + auto state = (term_idx << term_count) + j; + auto nt = next_lookup[state]; + if (nt == -1) { + mus[state] = 0.0F; + } else { + auto a = lookup_cursors[nt - essential_unigrams.size()].max_score() + + mus[((nt + 1) << term_count) + (j | (1 << nt))]; + auto b = mus[((term_idx + 1) << term_count) + j]; + mus[state] = std::max(a, b); + } + } + } + return mus; + }(); + + auto const state_mask = (1U << term_count) - 1; + + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + + // auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + float score = 0.0F; + std::uint32_t state = essential_unigrams.size() << term_count; + for (auto pos = 0U; pos < term_count; pos += 1) { + if (scores[pos] > 0) { + score += scores[pos]; + state |= 1U << pos; + } + } + + assert(state >= 0 && state < next_lookup.size()); + auto next_idx = next_lookup[state]; + while (next_idx >= 0 && topk.would_enter(score + mus[state])) { + auto lookup_idx = next_idx - essential_unigrams.size(); + assert(lookup_idx >= 0 && lookup_idx < lookup_cursors.size()); + auto&& lookup_cursor = lookup_cursors[lookup_idx]; + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + inspect->lookup(); + } + if (lookup_cursor.value() == docid) { + score += lookup_cursor.payload(); + state |= (1U << next_idx); + } + state = (state & state_mask) + ((next_idx + 1) << term_count); + next_idx = next_lookup[state]; + } + if constexpr (not std::is_void_v) { + if (topk.insert(score, docid)) { + inspect->insert(); + } + } else { + topk.insert(score, docid); + } + }); + return topk; +} + +template +struct InspectUnionLookup : Inspect { + + InspectUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + union_lookup(query, index, std::move(topk), scorer, this); + } + } +}; + +template +struct InspectUnionLookupPlus : Inspect { + + InspectUnionLookupPlus(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + union_lookup_plus(query, index, std::move(topk), scorer, this); + } + } +}; + +using LookupUnionComponent = InspectMany; + +template +struct InspectLookupUnion : Inspect> { + + InspectLookupUnion(Index const& index, Scorer scorer) + : Inspect>(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first()); + } else { + lookup_union(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first(), + InspectPartitioned::second()); + } + } +}; + +template +struct InspectLookupUnionEaat : Inspect> { + + InspectLookupUnionEaat(Index const& index, Scorer scorer) + : Inspect>(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first()); + } else { + lookup_union_eaat(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first(), + InspectPartitioned::second()); + } + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup_join.hpp b/include/pisa/v1/union_lookup_join.hpp new file mode 100644 index 000000000..416df35a2 --- /dev/null +++ b/include/pisa/v1/union_lookup_join.hpp @@ -0,0 +1,277 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace pisa::v1 { + +namespace func { + + /// Calls `max_score()` method on any passed object. The default projection for + /// `maxscore_partition`. + struct max_score { + template + auto operator()(Cursor&& cursor) -> float + { + return cursor.max_score(); + } + }; + +} // namespace func + +/// Partitions a list of cursors into essential and non-essential parts as in MaxScore algorithm +/// first proposed by [Turtle and +/// Flood](https://www.sciencedirect.com/science/article/pii/030645739500020H). +/// +/// # Details +/// +/// This function takes a span of (max-score) cursors that participate in a query, and the current +/// threshold. By default, it retrieves the max scores from cursors by calling `max_score()` method. +/// However, a different can be used by passing a projection. For example, if instead of +/// partitioning actual posting list cursors, you want to partition a vector of pairs `(term, +/// max-score)`, then you may pass +/// `[](auto&& c) { return c.second; }` as the projection. +/// +/// # Complexity +/// +/// Note that this function **will** sort the cursors by their max scores to ensure correct +/// partitioning, and therefore it may not be suitable to update an existing partition. +template +auto maxscore_partition(gsl::span cursors, float threshold, P projection = func::max_score{}) + -> std::pair, gsl::span> +{ + ranges::sort(cursors.begin(), cursors.end(), std::less{}, projection); + float bound = 0; + auto mid = ranges::find_if_not(cursors, [&](auto&& cursor) { + bound += projection(cursor); + return bound < threshold; + }); + auto non_essential_count = std::distance(cursors.begin(), mid); + return std::make_pair(cursors.first(non_essential_count), cursors.subspan(non_essential_count)); +} + +/// This cursor operator takes a number of essential cursors (in an arbitrary order) +/// and a list of lookup cursors. The documents traversed will be in the DaaT order, +/// and the following documents will be skipped: +/// - documents that do not appear in any of the essential cursors, +/// - documents that at the moment of their traversal are irrelevant (see below). +/// +/// # Threshold +/// +/// This operator takes a callable object that returns `true` only if a given score +/// has a chance to be in the final result set. It is used to decide whether or not +/// to perform further lookups for the given document. The score passed to the function +/// is such that when it returns `false`, we know that it will return `false` for the +/// rest of the lookup cursors, and therefore we can skip that document. +/// Note that such document will never be returned by this cursor. Instead, we will +/// proceed to the next document to see if it can land in the final result set, and so on. +/// +/// # Accumulating Scores +/// +/// Another parameter taken by this operator is a callable that accumulates payloads +/// for one document ID. The function is very similar to what you would pass to +/// `std::accumulate`: it takes the accumulator (either by reference or value), +/// and a reference to the cursor. It must return an updated accumulator. +/// For example, a simple accumulator that simply sums all payloads for each document, +/// can be: `[](float score, auto&& cursor) { return score + cursor.payload(); }`. +/// Note that you can accumulate "heavier" objects by taking and returning a reference: +/// ``` +/// [](auto& acc, auto&& cursor) { +/// // Do something with acc +/// return acc; +/// } +/// ``` +/// Before the first call to the accumulating function, the accumulated payload will be +/// initialized to the value `init` passed in the constructor. This will also be the +/// type of the payload returned by this cursor. +/// +/// # Passing Cursors +/// +/// Both essential and lookup cursors are passed by value and moved into a member. +/// It is thus important to pass either a temporary, a view, or a moved object to the constructor. +/// It is recommended to pass the ownership through an rvalue, as the cursors will be consumed +/// either way. However, in rare cases when the cursors need to be read after use +/// (for example to get their size or max score) or if essential and lookup cursors are in one +/// container and you want to avoid moving them, you may pass a view such as `gsl::span`. +/// However, it is discouraged in general case due to potential lifetime issues and dangling +/// references. +template +struct UnionLookupJoin { + + using essential_cursor_type = typename EssentialCursors::value_type; + using lookup_cursor_type = typename LookupCursors::value_type; + + using payload_type = Payload; + using value_type = std::decay_t())>; + + using essential_iterator_category = + typename std::iterator_traits::iterator_category; + + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + UnionLookupJoin(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect = nullptr) + : m_essential_cursors(std::move(essential_cursors)), + m_lookup_cursors(std::move(lookup_cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_inspect(inspect) + { + if (m_essential_cursors.empty()) { + if (m_lookup_cursors.empty()) { + m_sentinel = std::numeric_limits::max(); + } else { + m_sentinel = min_sentinel(m_lookup_cursors); + } + m_current_value = m_sentinel; + m_current_payload = m_init; + return; + } + m_lookup_cumulative_upper_bound = std::accumulate( + m_lookup_cursors.begin(), m_lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + m_next_docid = min_value(m_essential_cursors); + m_sentinel = min_sentinel(m_essential_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + if constexpr (not std::is_void_v) { + m_inspect->document(); + } + + for (auto&& cursor : m_essential_cursors) { + if (cursor.value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_inspect->posting(); + } + m_current_payload = m_accumulate(m_current_payload, cursor); + cursor.advance(); + } + if (auto docid = cursor.value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + auto lookup_bound = m_lookup_cumulative_upper_bound; + for (auto&& cursor : m_lookup_cursors) { + //if (m_current_value == 2288) { + // std::cout << fmt::format( + // "[checking] doc: {}\tscore: {}\tbound: {} (is above = {})\tms = {}\n", + // m_current_value, + // m_current_payload, + // m_current_payload + lookup_bound, + // m_above_threshold(m_current_payload + lookup_bound), + // cursor.max_score()); + //} + if (not m_above_threshold(m_current_payload + lookup_bound)) { + exit = false; + break; + } + cursor.advance_to_geq(m_current_value); + //std::cout << fmt::format( + // "doc: {}\tbound: {}\n", m_current_value, m_current_payload + lookup_bound); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (cursor.value() == m_current_value) { + m_current_payload = m_accumulate(m_current_payload, cursor); + } + lookup_bound -= cursor.max_score(); + } + } + m_position += 1; + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + EssentialCursors m_essential_cursors; + LookupCursors m_lookup_cursors; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + payload_type m_previous_threshold{}; + payload_type m_lookup_cumulative_upper_bound{}; + std::size_t m_position = 0; + + Inspect* m_inspect; +}; + +/// Convenience function to construct a `UnionLookupJoin` cursor operator. +/// See the struct documentation for more information. +template +auto join_union_lookup(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect = nullptr) +{ + return UnionLookupJoin(std::move(essential_cursors), + std::move(lookup_cursors), + std::move(init), + std::move(accumulate), + std::move(threshold), + inspect); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/vector_lexicon.hpp b/include/pisa/v1/vector_lexicon.hpp new file mode 100644 index 000000000..9df28e1a8 --- /dev/null +++ b/include/pisa/v1/vector_lexicon.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include +#include + +#include "binary_freq_collection.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct VectorLexicon { + explicit VectorLexicon(binary_freq_collection const &collection) + { + m_offsets.push_back(0); + for (auto const &postings : collection) { + m_offsets.push_back(postings.docs.size()); + } + } + + [[nodiscard]] auto fetch(TermId term, gsl::span bytes) + -> gsl::span + { + Expects(term + 1 < m_offsets.size()); + return bytes.subspan(m_offsets[term], m_offsets[term + 1]); + } + + private: + std::vector m_offsets{}; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/wand.hpp b/include/pisa/v1/wand.hpp new file mode 100644 index 000000000..bbee82b56 --- /dev/null +++ b/include/pisa/v1/wand.hpp @@ -0,0 +1,478 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "util/compiler_attribute.hpp" +#include "v1/algorithm.hpp" +#include "v1/cursor/scoring_cursor.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +struct BlockMaxWandJoin; + +template +struct WandJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + friend BlockMaxWandJoin; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr WandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_cursor_pointers(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + initialize(); + } + + constexpr WandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) + : m_cursors(std::move(cursors)), + m_cursor_pointers(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_inspect(inspect) + { + initialize(); + } + + void initialize() + { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } + std::transform(m_cursors.begin(), + m_cursors.end(), + m_cursor_pointers.begin(), + [](auto&& cursor) { return &cursor; }); + + std::sort(m_cursor_pointers.begin(), m_cursor_pointers.end(), [](auto&& lhs, auto&& rhs) { + return lhs->value() < rhs->value(); + }); + + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + PISA_ALWAYSINLINE void advance() + { + bool exit = false; + while (not exit) { + auto upper_bound = 0.0F; + std::size_t pivot; + bool found_pivot = false; + for (pivot = 0; pivot < m_cursor_pointers.size(); ++pivot) { + if (m_cursor_pointers[pivot]->empty()) { + break; + } + upper_bound += m_cursor_pointers[pivot]->max_score(); + if (m_above_threshold(upper_bound)) { + found_pivot = true; + break; + } + } + // auto pivot = find_pivot(); + // if (PISA_UNLIKELY(not pivot)) { + // m_current_value = sentinel(); + // exit = true; + // break; + //} + if (not found_pivot) { + m_current_value = sentinel(); + exit = true; + break; + } + + // auto pivot_docid = (*pivot)->value(); + auto pivot_docid = m_cursor_pointers[pivot]->value(); + if (pivot_docid == m_cursor_pointers.front()->value()) { + m_current_value = pivot_docid; + m_current_payload = m_init; + + for (auto* cursor : m_cursor_pointers) { + if (cursor->value() != pivot_docid) { + break; + } + m_current_payload = m_accumulate(m_current_payload, *cursor); + cursor->advance(); + } + + auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); }; + std::sort(m_cursor_pointers.begin(), m_cursor_pointers.end(), by_docid); + exit = true; + } else { + auto next_list = pivot; + for (; m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { + } + m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + // bubble_down(next_list); + for (size_t idx = next_list + 1; idx < m_cursor_pointers.size(); idx += 1) { + if (m_cursor_pointers[idx]->value() < m_cursor_pointers[idx - 1]->value()) { + std::swap(m_cursor_pointers[idx], m_cursor_pointers[idx - 1]); + } else { + break; + } + } + } + } + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + PISA_ALWAYSINLINE void bubble_down(std::size_t list_idx) + { + for (size_t idx = list_idx + 1; idx < m_cursor_pointers.size(); idx += 1) { + if (m_cursor_pointers[idx]->value() < m_cursor_pointers[idx - 1]->value()) { + std::swap(m_cursor_pointers[idx], m_cursor_pointers[idx - 1]); + } else { + break; + } + } + } + + PISA_ALWAYSINLINE auto find_pivot() -> tl::optional + { + auto upper_bound = 0.0F; + std::size_t pivot; + for (pivot = 0; pivot < m_cursor_pointers.size(); ++pivot) { + if (m_cursor_pointers[pivot]->empty()) { + break; + } + upper_bound += m_cursor_pointers[pivot]->max_score(); + if (m_above_threshold(upper_bound)) { + return tl::make_optional(pivot); + } + } + return tl::nullopt; + // auto upper_bound = 0.0F; + // for (auto pivot = m_cursor_pointers.begin(); pivot != m_cursor_pointers.end(); ++pivot) { + // auto&& cursor = **pivot; + // if (cursor.empty()) { + // break; + // } + // upper_bound += cursor.max_score(); + // if (m_above_threshold(upper_bound)) { + // auto pivot_docid = (*pivot)->value(); + // while (std::next(pivot) != m_cursor_pointers.end()) { + // if ((*std::next(pivot))->value() != pivot_docid) { + // break; + // } + // pivot = std::next(pivot); + // } + // return pivot; + // } + //} + // return m_cursor_pointers.end(); + } + + CursorContainer m_cursors; + std::vector m_cursor_pointers; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + payload_type m_previous_threshold{}; + + Inspect* m_inspect; +}; + +template +struct BlockMaxWandJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr BlockMaxWandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_wand_join(std::move(cursors), init, std::move(accumulate), std::move(above_threshold)) + { + } + + constexpr BlockMaxWandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) + : WandJoin( + std::move(cursors), init, std::move(accumulate), std::move(above_threshold), inspect) + { + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_wand_join.value(); + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type + { + return m_wand_join.value(); + } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_wand_join.payload(); + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t + { + return m_wand_join.sentinel(); + } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_wand_join.empty(); } + + constexpr void advance() { m_wand_join.advance(); } + // constexpr void advance() + //{ + // while (true) { + // auto pivot = m_wand_join.find_pivot(); + // if (pivot == m_wand_join.m_cursor_pointers.end()) { + // m_wand_join.m_current_value = m_wand_join.sentinel(); + // return; + // } + + // auto pivot_docid = (*pivot)->value(); + // // auto block_upper_bound = std::accumulate( + // // m_wand_join.m_cursor_pointers.begin(), + // // std::next(pivot), + // // 0.0F, + // // [&](auto acc, auto* cursor) { return acc + cursor->block_max_score(pivot_docid); + // // }); + // // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // // block_max_advance(pivot, pivot_docid); + // // continue; + // //} + // if (pivot_docid == m_wand_join.m_cursor_pointers.front()->value()) { + // m_wand_join.m_current_value = pivot_docid; + // m_wand_join.m_current_payload = m_wand_join.m_init; + + // [&]() { + // auto iter = m_wand_join.m_cursor_pointers.begin(); + // for (; iter != m_wand_join.m_cursor_pointers.end(); ++iter) { + // auto* cursor = *iter; + // if (cursor->value() != pivot_docid) { + // break; + // } + // m_wand_join.m_current_payload = + // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + // cursor->advance(); + // } + // return iter; + // }(); + // // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // // if (cursor->value() != pivot_docid) { + // // break; + // // } + // // m_wand_join.m_current_payload = + // // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + // // block_upper_bound -= cursor->block_max_score() - cursor->payload(); + // // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // // break; + // // } + // //} + + // // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // // if (cursor->value() != pivot_docid) { + // // break; + // // } + // // cursor->advance(); + // //} + + // auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); + // }; std::sort(m_wand_join.m_cursor_pointers.begin(), + // m_wand_join.m_cursor_pointers.end(), + // by_docid); + // return; + // } + + // auto next_list = std::distance(m_wand_join.m_cursor_pointers.begin(), pivot); + // for (; m_wand_join.m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) + // { + // } + // m_wand_join.m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + // m_wand_join.bubble_down(next_list); + // } + //} + + private: + template + void block_max_advance(Iter pivot, DocId pivot_id) + { + auto next_list = std::max_element( + m_wand_join.m_cursor_pointers.begin(), std::next(pivot), [](auto* lhs, auto* rhs) { + return lhs->max_score() < rhs->max_score(); + }); + + auto next_docid = + (*std::min_element(m_wand_join.m_cursor_pointers.begin(), + std::next(pivot), + [](auto* lhs, auto* rhs) { + return lhs->block_max_docid() < rhs->block_max_docid(); + })) + ->value(); + next_docid += 1; + + if (auto iter = std::next(pivot); iter != m_wand_join.m_cursor_pointers.end()) { + if (auto docid = (*iter)->value(); docid < next_docid) { + next_docid = docid; + } + } + + if (next_docid <= pivot_id) { + next_docid = pivot_id + 1; + } + + (*next_list)->advance_to_geq(next_docid); + m_wand_join.bubble_down(std::distance(m_wand_join.m_cursor_pointers.begin(), next_list)); + } + + WandJoin m_wand_join; +}; + +template +auto join_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return WandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto join_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect) +{ + return WandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); +} + +template +auto join_block_max_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return BlockMaxWandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto join_block_max_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect) +{ + return BlockMaxWandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); +} + +template +auto wand(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); + } + auto joined = join_wand(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { + return topk.would_enter(score); + }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +template +auto bmw(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto cursors = index.block_max_scored_cursors(gsl::make_span(term_ids), scorer); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); + } + auto joined = join_block_max_wand(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { + return topk.would_enter(score); + }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/zip_cursor.hpp b/include/pisa/v1/zip_cursor.hpp new file mode 100644 index 000000000..0aa4f98b4 --- /dev/null +++ b/include/pisa/v1/zip_cursor.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +struct ZipCursor { + using Value = std::tuple())...>; + + explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors)...) + { + static_assert(std::tuple_size_v> == 2, + "Zip of more than two lists is currently not supported"); + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value + { + // auto deref = [](auto... cursors) { return std::make_tuple(cursors.value()...); }; + // return std::apply(deref, m_cursors); + return std::make_tuple(std::get<0>(m_cursors).value(), std::get<1>(m_cursors).value()); + } + constexpr void advance() + { + // TODO: Why generic doesn't work? + // auto advance_all = [](auto... cursors) { (cursors.advance(), ...); }; + // std::apply(advance_all, m_cursors); + std::get<0>(m_cursors).advance(); + std::get<1>(m_cursors).advance(); + } + constexpr void advance_to_position(std::size_t pos) + { + // TODO: Why generic doesn't work? + // auto advance_all = [pos](auto... cursors) { (cursors.advance_to_position(pos), ...); }; + // std::apply(advance_all, m_cursors); + std::get<0>(m_cursors).advance_to_position(pos); + std::get<1>(m_cursors).advance_to_position(pos); + } + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return std::get<0>(m_cursors).empty() || std::get<1>(m_cursors).empty(); + } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return std::get<0>(m_cursors).position(); + } + //[[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } + //[[nodiscard]] constexpr auto sentinel() const -> Document { return m_key_cursor.sentinel(); } + + private: + std::tuple m_cursors; +}; + +template +auto zip(Cursors... cursors) +{ + return ZipCursor(cursors...); +} + +} // namespace pisa::v1 diff --git a/script/cw09b-bp.sh b/script/cw09b-bp.sh new file mode 100644 index 000000000..03c9793ab --- /dev/null +++ b/script/cw09b-bp.sh @@ -0,0 +1,21 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv.bp" +FWD="/home/amallia/cw09b/CW09B.fwd.bp" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-bp/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-bp" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#PAIRS="/home/michal/real.aol.top50k.jl" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 + +# 134384 149026 1648 128376 4 # UL +# 149026 149026 1648 109040 4 # 15542 LU +# 2261867 2308897 1648 94280 3 # 44935 MS +# T = 12.985799789428713 diff --git a/script/cw09b-bpq.sh b/script/cw09b-bpq.sh new file mode 100644 index 000000000..03c9793ab --- /dev/null +++ b/script/cw09b-bpq.sh @@ -0,0 +1,21 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv.bp" +FWD="/home/amallia/cw09b/CW09B.fwd.bp" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-bp/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-bp" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#PAIRS="/home/michal/real.aol.top50k.jl" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 + +# 134384 149026 1648 128376 4 # UL +# 149026 149026 1648 109040 4 # 15542 LU +# 2261867 2308897 1648 94280 3 # 44935 MS +# T = 12.985799789428713 diff --git a/script/cw09b-est-pef.sh b/script/cw09b-est-pef.sh new file mode 100644 index 000000000..7bd029a2b --- /dev/null +++ b/script/cw09b-est-pef.sh @@ -0,0 +1,99 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +ENCODING="pef" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-pef" +THREADS=4 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est-pef" +#OUTPUT_DIR="/data/michal/intersect/cw09b-est-top20" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" + +set -e +set -x + +## Compress an inverted index in `binary_freq_collection` format. +${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 + +# Filter out queries witout existing terms. +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} + +# Extract intersections +${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand > ${OUTPUT_DIR}/bench.wand +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw > ${OUTPUT_DIR}/bench.bmw +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw \ + > ${OUTPUT_DIR}/bench.bmw-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/script/cw09b-est-val.sh b/script/cw09b-est-val.sh new file mode 100644 index 000000000..4b7868291 --- /dev/null +++ b/script/cw09b-est-val.sh @@ -0,0 +1,12 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.val" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est-val" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.val" diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh new file mode 100644 index 000000000..0bffb773d --- /dev/null +++ b/script/cw09b-est.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est-lm" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS="/home/michal/real.aol.top50k.jl" +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=100 diff --git a/script/cw09b-exact.sh b/script/cw09b-exact.sh new file mode 100644 index 000000000..630b2391c --- /dev/null +++ b/script/cw09b-exact.sh @@ -0,0 +1,92 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +INV="/home/amallia/cw09b/CW09B" +BASENAME="/data/michal/work/v1/cw09b/cw09b" +THREADS=4 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" + +set -e +set -x + +## Compress an inverted index in `binary_freq_collection` format. +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# Filter out queries witout existing terms. +#paste -d: ${QUERIES} ${THRESHOLDS} \ +# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ +# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} + +# Extract intersections +#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark \ + --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/script/cw09b-url-10.sh b/script/cw09b-url-10.sh new file mode 100644 index 000000000..bfda430ba --- /dev/null +++ b/script/cw09b-url-10.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=10 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-10-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top2.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-100.sh b/script/cw09b-url-100.sh new file mode 100644 index 000000000..43c4d7d94 --- /dev/null +++ b/script/cw09b-url-100.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=100 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-100-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top5.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-10000.sh b/script/cw09b-url-10000.sh new file mode 100644 index 000000000..c9f2613d8 --- /dev/null +++ b/script/cw09b-url-10000.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=10000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-10k-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top130.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-bi-trec06.sh b/script/cw09b-url-bi-trec06.sh new file mode 100644 index 000000000..ba686627d --- /dev/null +++ b/script/cw09b-url-bi-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-bi-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06-2" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-bi.sh b/script/cw09b-url-bi.sh new file mode 100644 index 000000000..39dab34c5 --- /dev/null +++ b/script/cw09b-url-bi.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS="/home/michal/real.aol.top100k.jl" +PAIR_INDEX_BASENAME="/data/michal/work/v1/cw09b-url/cw09b-simdbp-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw09b-url-trec06-2.sh b/script/cw09b-url-trec06-2.sh new file mode 100644 index 000000000..e9f39d84c --- /dev/null +++ b/script/cw09b-url-trec06-2.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-trec06-2" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06-2" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-trec06.sh b/script/cw09b-url-trec06.sh new file mode 100644 index 000000000..aa1491e25 --- /dev/null +++ b/script/cw09b-url-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw09b-url.sh b/script/cw09b-url.sh new file mode 100644 index 000000000..eac47f6e9 --- /dev/null +++ b/script/cw09b-url.sh @@ -0,0 +1,17 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +#QUERIES="/home/michal/biscorer/data/queries/topics.web.51-200" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +#THRESHOLDS="/home/michal/topics.web.51-200.thresholds" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh new file mode 100644 index 000000000..49a7cbcc5 --- /dev/null +++ b/script/cw09b.sh @@ -0,0 +1,151 @@ +# Fail if any variable is unset +set -u +set -e + +source $1 + +echo "Running experiment with the following environment:" +echo "" +echo " PISA_BIN = ${PISA_BIN}" +echo " INTERSECT_BIN = ${INTERSECT_BIN}" +echo " BINARY_FREQ_COLL = ${BINARY_FREQ_COLL}" +echo " FWD = ${FWD}" +echo " BASENAME = ${BASENAME}" +echo " THREADS = ${THREADS}" +echo " ENCODING = ${ENCODING}" +echo " OUTPUT_DIR = ${OUTPUT_DIR}" +echo " QUERIES = ${QUERIES}" +echo " FILTERED_QUERIES = ${FILTERED_QUERIES}" +echo " K = ${K}" +echo " THRESHOLDS = ${THRESHOLDS}" +echo " QUERY_LIMIT = ${QUERY_LIMIT}" +echo " PAIRS = ${PAIRS}" +echo " PAIR_INDEX_BASENAME = ${PAIR_INDEX_BASENAME}" +echo "" + +set -x +mkdir -p ${OUTPUT_DIR} + +## Compress an inverted index in `binary_freq_collection` format. +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var + +# Filter out queries witout existing terms. +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | head -${QUERY_LIMIT} \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 + +# Extract intersections +${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl +#${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ +# --in /home/michal/real.aol.top1m.tsv \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1.5 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time +#${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact + +# Run benchmarks +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw +#${PISA_BIN}/query -i "${BASENAME}-var.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.vbmw +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +# > ${OUTPUT_DIR}/bench.bmw-threshold +#${PISA_BIN}/query -i "${BASENAME}-var.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +# > ${OUTPUT_DIR}/bench.vbmw-threshold + +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe -k ${K} > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe -k ${K} \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe -k ${K} \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup-plus --safe \ + > ${OUTPUT_DIR}/bench.plus +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm union-lookup-plus -k ${K} \ + > ${OUTPUT_DIR}/bench.plus.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union -k ${K} \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union-eaat -k ${K} \ + > ${OUTPUT_DIR}/bench.lookup-union-eaat.scaled-2 + +# Analyze +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore -k ${K} > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore -k ${K} \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup -k ${K} \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup-plus \ + > ${OUTPUT_DIR}/stats.plus +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ + --inspect --algorithm union-lookup-plus \ + > ${OUTPUT_DIR}/stats.plus.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ + --inspect --algorithm lookup-union-eaat \ + > ${OUTPUT_DIR}/stats.lookup-union-eaat.scaled-2 + +# Evaluate +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ +# > "${OUTPUT_DIR}/eval.maxscore-threshold" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ +# > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ +# > "${OUTPUT_DIR}/eval.unigram-union-lookup" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ +# > "${OUTPUT_DIR}/eval.union-lookup" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ +# > "${OUTPUT_DIR}/eval.lookup-union" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ +# > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ +# --algorithm union-lookup-plus \ +# > ${OUTPUT_DIR}/eval.plus.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} --safe \ +# --algorithm lookup-union \ +# > ${OUTPUT_DIR}/eval.lookup-union.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} --safe \ +# --algorithm lookup-union-eaat \ +# > ${OUTPUT_DIR}/eval.lookup-union-eaat.scaled-2 diff --git a/script/cw12-url-bi-trec06.sh b/script/cw12-url-bi-trec06.sh new file mode 100644 index 000000000..23489dacd --- /dev/null +++ b/script/cw12-url-bi-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-bi-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw12-url-bi.sh b/script/cw12-url-bi.sh new file mode 100644 index 000000000..7604b025f --- /dev/null +++ b/script/cw12-url-bi.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS="/home/michal/real.aol.top100k.jl" +PAIR_INDEX_BASENAME="/data/michal/work/v1/cw12b-url/cw09b-simdbp-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw12-url-trec06.sh b/script/cw12-url-trec06.sh new file mode 100644 index 000000000..04eee7b25 --- /dev/null +++ b/script/cw12-url-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw12-url.sh b/script/cw12-url.sh new file mode 100644 index 000000000..e0a5e20fe --- /dev/null +++ b/script/cw12-url.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec05" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/robust.sh b/script/robust.sh new file mode 100644 index 000000000..4b7cd3ee1 --- /dev/null +++ b/script/robust.sh @@ -0,0 +1,63 @@ +PISA_BIN="/home/michal/work/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/mnt/michal/work/robust/inv" +FWD="/mnt/michal/work/robust/fwd" +INV="/mnt/michal/work/robust/inv" +BASENAME="/mnt/michal/work/v1/robust/robust" +THREADS=16 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +K=1000 +OUTPUT_DIR=`pwd` + +set -x +set -e + +# Compress an inverted index in `binary_freq_collection` format. +${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +## This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} +# +## This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# Filter out queries witout existing terms. +FILTERED_QUERIES="${OUTPUT_DIR}/filtered_queries" +${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ + > ${FILTERED_QUERIES} + +# Extract thresholds (TODO: estimates) +${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/thresholds +cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv + +# Extract intersections +${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.tsv + +# Select unigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 1 > ${OUTPUT_DIR}/selections.1 + +# Select unigrams and bigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 > ${OUTPUT_DIR}/selections.2 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + --thresholds ${OUTPUT_DIR}/thresholds +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 diff --git a/src/v1/bit_sequence_cursor.cpp b/src/v1/bit_sequence_cursor.cpp new file mode 100644 index 000000000..343b466c1 --- /dev/null +++ b/src/v1/bit_sequence_cursor.cpp @@ -0,0 +1 @@ +#include "v1/bit_sequence_cursor.hpp" diff --git a/src/v1/blocked_cursor.cpp b/src/v1/blocked_cursor.cpp new file mode 100644 index 000000000..3e2bbd8c6 --- /dev/null +++ b/src/v1/blocked_cursor.cpp @@ -0,0 +1,44 @@ +#include "v1/blocked_cursor.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto BaseBlockedCursor::block_offset(size_type block) const -> offset_type +{ + return block > 0U ? m_block_endpoints[block - 1] : static_cast(0U); +} +[[nodiscard]] auto BaseBlockedCursor::decoded_block() -> value_type* +{ + return m_decoded_block.data(); +} +[[nodiscard]] auto BaseBlockedCursor::encoded_block(offset_type offset) -> uint8_t const* +{ + return std::next(reinterpret_cast(m_encoded_blocks.data()), offset); +} +[[nodiscard]] auto BaseBlockedCursor::length() const -> size_type { return m_length; } +[[nodiscard]] auto BaseBlockedCursor::num_blocks() const -> size_type { return m_num_blocks; } + +[[nodiscard]] auto BaseBlockedCursor::operator*() const -> value_type { return m_current_value; } +[[nodiscard]] auto BaseBlockedCursor::value() const noexcept -> value_type { return *(*this); } +[[nodiscard]] auto BaseBlockedCursor::empty() const noexcept -> bool +{ + return value() >= sentinel(); +} +[[nodiscard]] auto BaseBlockedCursor::position() const noexcept -> std::size_t +{ + return m_current_block.number * m_block_length + m_current_block.offset; +} +[[nodiscard]] auto BaseBlockedCursor::size() const -> std::size_t { return m_length; } +[[nodiscard]] auto BaseBlockedCursor::sentinel() const -> value_type +{ + return std::numeric_limits::max(); +} +[[nodiscard]] auto BaseBlockedCursor::current_block() -> Block& { return m_current_block; } +[[nodiscard]] auto BaseBlockedCursor::decoded_value(size_type n) -> value_type +{ + return m_decoded_block[n] + 1U; +} + +void BaseBlockedCursor::update_current_value(value_type val) { m_current_value = val; } +void BaseBlockedCursor::increase_current_value(value_type val) { m_current_value += val; } + +} // namespace pisa::v1 diff --git a/src/v1/default_index_runner.cpp b/src/v1/default_index_runner.cpp new file mode 100644 index 000000000..649434d51 --- /dev/null +++ b/src/v1/default_index_runner.cpp @@ -0,0 +1,23 @@ +#include "v1/default_index_runner.hpp" + +namespace pisa::v1 { + +//[[nodiscard]] auto index_runner(IndexMetadata metadata) +//{ +// return index_runner( +// std::move(metadata), +// std::make_tuple(RawReader{}, +// DocumentBlockedReader<::pisa::simdbp_block>{}, +// DocumentBitSequenceReader>{}), +// std::make_tuple(RawReader{}, PayloadBlockedReader<::pisa::simdbp_block>{})); +//} +// +//[[nodiscard]] auto scored_index_runner(IndexMetadata metadata) +//{ +// return scored_index_runner( +// std::move(metadata), +// std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), +// std::make_tuple(RawReader{})); +//} + +} // namespace pisa::v1 diff --git a/src/v1/index.cpp b/src/v1/index.cpp new file mode 100644 index 000000000..7b02777e8 --- /dev/null +++ b/src/v1/index.cpp @@ -0,0 +1,173 @@ +#include + +#include + +#include "binary_collection.hpp" +#include "v1/index.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float +{ + auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); + return static_cast(sum) / lengths.size(); +} + +[[nodiscard]] auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool +{ + if (std::get<0>(lhs) < std::get<0>(rhs)) { + return true; + } + if (std::get<0>(lhs) > std::get<0>(rhs)) { + return false; + } + return std::get<1>(lhs) < std::get<1>(rhs); +} + +[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector +{ + binary_collection sizes(fmt::format("{}.sizes", basename).c_str()); + auto sequence = *sizes.begin(); + return std::vector(sequence.begin(), sequence.end()); +} + +[[nodiscard]] auto BaseIndex::num_terms() const -> std::size_t +{ + return m_documents.offsets.size() - 1; +} + +[[nodiscard]] auto BaseIndex::num_documents() const -> std::size_t +{ + return m_document_lengths.size(); +} + +[[nodiscard]] auto BaseIndex::num_pairs() const -> std::size_t +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + return m_bigrams->mapping.size(); +} + +[[nodiscard]] auto BaseIndex::document_length(DocId docid) const -> std::uint32_t +{ + return m_document_lengths[docid]; +} + +[[nodiscard]] auto BaseIndex::avg_document_length() const -> float { return m_avg_document_length; } + +[[nodiscard]] auto BaseIndex::normalized_document_length(DocId docid) const -> float +{ + return document_length(docid) / avg_document_length(); +} + +void BaseIndex::assert_term_in_bounds(TermId term) const +{ + if (term >= num_terms()) { + throw std::invalid_argument( + fmt::format("Requested term ID out of bounds [0-{}): {}", num_terms(), term)); + } +} +[[nodiscard]] auto BaseIndex::fetch_documents(TermId term) const -> gsl::span +{ + Expects(term + 1 < m_documents.offsets.size()); + return m_documents.postings.subspan(m_documents.offsets[term], + m_documents.offsets[term + 1] - m_documents.offsets[term]); +} +[[nodiscard]] auto BaseIndex::fetch_payloads(TermId term) const -> gsl::span +{ + Expects(term + 1 < m_payloads.offsets.size()); + return m_payloads.postings.subspan(m_payloads.offsets[term], + m_payloads.offsets[term + 1] - m_payloads.offsets[term]); +} +[[nodiscard]] auto BaseIndex::fetch_bigram_documents(TermId bigram) const + -> gsl::span +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + Expects(bigram + 1 < m_bigrams->documents.offsets.size()); + return m_bigrams->documents.postings.subspan( + m_bigrams->documents.offsets[bigram], + m_bigrams->documents.offsets[bigram + 1] - m_bigrams->documents.offsets[bigram]); +} + +template +[[nodiscard]] auto BaseIndex::fetch_bigram_payloads(TermId bigram) const + -> gsl::span +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + Expects(bigram + 1 < std::get(m_bigrams->payloads).offsets.size()); + return std::get(m_bigrams->payloads) + .postings.subspan(std::get(m_bigrams->payloads).offsets[bigram], + std::get(m_bigrams->payloads).offsets[bigram + 1] + - std::get(m_bigrams->payloads).offsets[bigram]); +} + +template auto BaseIndex::fetch_bigram_payloads<0>(TermId bigram) const + -> gsl::span; +template auto BaseIndex::fetch_bigram_payloads<1>(TermId bigram) const + -> gsl::span; + +[[nodiscard]] auto BaseIndex::fetch_bigram_payloads(TermId bigram) const + -> std::array, 2> +{ + return {fetch_bigram_payloads<0>(bigram), fetch_bigram_payloads<1>(bigram)}; +} + +[[nodiscard]] auto BaseIndex::bigram_id(TermId left_term, TermId right_term) const + -> tl::optional +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + if (right_term == left_term) { + throw std::logic_error("Requested bigram of two identical terms"); + } + auto bigram = std::array{left_term, right_term}; + if (right_term < left_term) { + std::swap(bigram[0], bigram[1]); + } + if (auto pos = std::lower_bound( + m_bigrams->mapping.begin(), m_bigrams->mapping.end(), bigram, compare_arrays); + pos != m_bigrams->mapping.end()) { + if (*pos == bigram) { + return tl::make_optional(std::distance(m_bigrams->mapping.begin(), pos)); + } + } + return tl::nullopt; +} + +[[nodiscard]] auto BaseIndex::max_score(std::size_t scorer_hash, TermId term) const -> float +{ + if (m_max_scores.empty()) { + throw std::logic_error("Missing max scores."); + } + return m_max_scores.at(scorer_hash)[term]; +} + +[[nodiscard]] auto BaseIndex::block_max_scores(std::size_t scorer_hash) const -> UnigramData const& +{ + if (auto pos = m_block_max_scores.find(scorer_hash); pos != m_block_max_scores.end()) { + return pos->second; + } + throw std::logic_error("Missing block-max scores."); +} + +[[nodiscard]] auto BaseIndex::quantized_max_score(TermId term) const -> std::uint8_t +{ + if (m_quantized_max_scores.empty()) { + throw std::logic_error("Missing quantized max scores."); + } + return m_quantized_max_scores.at(term); +} + +[[nodiscard]] auto BaseIndex::pairs() const -> tl::optional const>> +{ + return m_bigrams.map([](auto&& bigrams) { return bigrams.mapping; }); +} + +} // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp new file mode 100644 index 000000000..16954b80b --- /dev/null +++ b/src/v1/index_builder.cpp @@ -0,0 +1,604 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "codec/simdbp.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/accumulate.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/query.hpp" +#include "v1/scorer/bm25.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto make_temp() -> std::string +{ + return (boost::filesystem::temp_directory_path() / boost::filesystem::unique_path()).string(); +} + +template +void write_document_header(Index&& index, std::ostream& os) +{ + using index_type = std::decay_t; + using writer_type = typename CursorTraits::Writer; + PostingBuilder(writer_type{index.num_documents()}).write_header(os); +} + +template +void write_payload_header(Index&& index, std::ostream& os) +{ + using index_type = std::decay_t; + using writer_type = typename CursorTraits::Writer; + using value_type = typename CursorTraits::Value; + PostingBuilder(writer_type{index.num_documents()}).write_header(os); +} + +auto merge_into(PostingFilePaths const& batch, + std::ofstream& posting_sink, + std::ofstream& offsets_sink, + std::size_t shift) -> std::size_t +{ + std::ifstream batch_stream(batch.postings); + batch_stream.seekg(0, std::ios::end); + std::streamsize size = batch_stream.tellg(); + batch_stream.seekg(0, std::ios::beg); + posting_sink << batch_stream.rdbuf(); + auto offsets = load_vector(batch.offsets); + std::transform(offsets.begin(), offsets.end(), offsets.begin(), [shift](auto offset) { + return shift + offset; + }); + + // Because each batch has one superfluous offset indicating the end of data. + // Effectively, the last offset of batch `i` overlaps with the first offsets of batch `i + 1`. + std::size_t start = shift > 0 ? 8U : 0U; + + offsets_sink.write(reinterpret_cast(offsets.data()) + start, + offsets.size() * sizeof(std::size_t) - start); + return shift + size; +} + +template +void merge_postings(std::string const& message, + std::ofstream& posting_sink, + std::ofstream& offset_sink, + Rng&& batches) +{ + ProgressStatus status( + batches.size(), DefaultProgressCallback(message), std::chrono::milliseconds(500)); + std::size_t shift = 0; + for (auto&& batch : batches) { + shift = merge_into(batch, posting_sink, offset_sink, shift); + boost::filesystem::remove(batch.postings); + boost::filesystem::remove(batch.offsets); + status += 1; + }; +} + +/// Represents open-ended interval [begin, end). +template +struct Interval { + Interval(std::size_t begin, std::size_t end, gsl::span const& elements) + : m_begin(begin), m_end(end), m_elements(elements) + { + } + [[nodiscard]] auto span() const -> gsl::span + { + return m_elements.subspan(begin(), end() - begin()); + } + [[nodiscard]] auto begin() const -> std::size_t { return m_begin; } + [[nodiscard]] auto end() const -> std::size_t { return m_end; } + + private: + std::size_t m_begin; + std::size_t m_end; + gsl::span const& m_elements; +}; + +template +struct BatchConcurrentBuilder { + explicit BatchConcurrentBuilder(std::size_t num_batches, gsl::span elements) + : m_num_batches(num_batches), m_elements(elements) + { + runtime_assert(elements.size() >= num_batches).or_exit([&] { + return fmt::format( + "The number of elements ({}) must be at least the number of batches ({})", + elements.size(), + num_batches); + }); + } + + template + void execute_batch_jobs(BatchFn batch_job) + { + auto batch_size = m_elements.size() / m_num_batches; + tbb::task_group group; + for (auto batch = 0; batch < m_num_batches; batch += 1) { + group.run([batch_size, batch_job, batch, this] { + auto first_idx = batch * batch_size; + auto last_idx = batch < this->m_num_batches - 1 ? (batch + 1) * batch_size + : this->m_elements.size(); + batch_job(batch, Interval(first_idx, last_idx, m_elements)); + }); + } + group.wait(); + } + + private: + std::size_t m_num_batches; + gsl::span m_elements; +}; + +auto collect_unique_bigrams(std::vector const& queries, + std::function const& callback) + -> std::vector> +{ + std::vector> bigrams; + auto idx = 0; + for (auto const& query : queries) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + continue; + } + callback(); + for (auto left = 0; left < term_ids.size() - 1; left += 1) { + for (auto right = left + 1; right < term_ids.size(); right += 1) { + bigrams.emplace_back(term_ids[left], term_ids[right]); + } + } + } + std::sort(bigrams.begin(), bigrams.end()); + bigrams.erase(std::unique(bigrams.begin(), bigrams.end()), bigrams.end()); + return bigrams; +} + +auto verify_compressed_index(std::string const& input, std::string_view output) + -> std::vector +{ + std::vector errors; + pisa::binary_freq_collection const collection(input.c_str()); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", output)); + auto run = index_runner(meta); + ProgressStatus status( + collection.size(), DefaultProgressCallback("Verifying"), std::chrono::milliseconds(500)); + run([&](auto&& index) { + auto sequence_iter = collection.begin(); + for (auto term = 0; term < index.num_terms(); term += 1, ++sequence_iter) { + auto document_sequence = sequence_iter->docs; + auto frequency_sequence = sequence_iter->freqs; + auto cursor = index.cursor(term); + if (cursor.size() != document_sequence.size()) { + errors.push_back( + fmt::format("Posting list length mismatch for term {}: expected {} but is {}", + term, + document_sequence.size(), + cursor.size())); + continue; + } + auto dit = document_sequence.begin(); + auto fit = frequency_sequence.begin(); + auto pos = 0; + while (not cursor.empty()) { + if (cursor.value() != *dit) { + errors.push_back( + fmt::format("Document mismatch for term {} at position {}", term, pos)); + } + if (cursor.payload() != *fit) { + errors.push_back( + fmt::format("Frequency mismatch for term {} at position {}: {} != {}", + term, + pos, + cursor.payload(), + *fit)); + } + cursor.advance(); + ++dit; + ++fit; + ++pos; + } + status += 1; + } + }); + return errors; +} + +[[nodiscard]] auto build_scored_pair_index(IndexMetadata meta, + std::string const& index_basename, + std::vector> const& pairs, + std::size_t threads) + -> std::pair +{ + auto run = scored_index_runner(std::move(meta)); + + PostingFilePaths scores_0 = {.postings = fmt::format("{}.bigram_bm25_0", index_basename), + .offsets = + fmt::format("{}.bigram_bm25_offsets_0", index_basename)}; + PostingFilePaths scores_1 = {.postings = fmt::format("{}.bigram_bm25_1", index_basename), + .offsets = + fmt::format("{}.bigram_bm25_offsets_1", index_basename)}; + std::ofstream score_out_0(scores_0.postings); + std::ofstream score_out_1(scores_1.postings); + + ProgressStatus status(pairs.size(), + DefaultProgressCallback("Building scored pair index"), + std::chrono::milliseconds(500)); + + std::vector> batch_files(threads); + + BatchConcurrentBuilder batch_builder(threads, + gsl::span const>(pairs)); + run([&](auto&& index) { + using index_type = std::decay_t; + using score_writer_type = + typename CursorTraits::Writer; + + batch_builder.execute_batch_jobs( + [&status, &batch_files, &index](auto batch_idx, auto interval) { + auto scores_file_0 = make_temp(); + auto scores_file_1 = make_temp(); + auto score_offsets_file_0 = make_temp(); + auto score_offsets_file_1 = make_temp(); + std::ofstream score_out_0(scores_file_0); + std::ofstream score_out_1(scores_file_1); + + PostingBuilder score_builder_0(score_writer_type{}); + PostingBuilder score_builder_1(score_writer_type{}); + + for (auto [left_term, right_term] : interval.span()) { + auto intersection = intersect({index.scored_cursor(left_term, VoidScorer{}), + index.scored_cursor(right_term, VoidScorer{})}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + gsl::at(payload, list_idx) = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + for_each(intersection, [&](auto& cursor) { + auto payload = cursor.payload(); + score_builder_0.accumulate(std::get<0>(payload)); + score_builder_1.accumulate(std::get<1>(payload)); + }); + score_builder_0.flush_segment(score_out_0); + score_builder_1.flush_segment(score_out_1); + status += 1; + } + write_span(gsl::make_span(score_builder_0.offsets()), score_offsets_file_0); + write_span(gsl::make_span(score_builder_1.offsets()), score_offsets_file_1); + batch_files[batch_idx] = {PostingFilePaths{scores_file_0, score_offsets_file_0}, + PostingFilePaths{scores_file_1, score_offsets_file_1}}; + }); + + write_payload_header(index, score_out_0); + write_payload_header(index, score_out_1); + }); + + std::ofstream score_offsets_out_0(scores_0.offsets); + std::ofstream score_offsets_out_1(scores_1.offsets); + + merge_postings( + "Merging scores<0>", + score_out_0, + score_offsets_out_0, + ranges::views::transform(batch_files, [](auto&& files) { return std::get<0>(files); })); + + merge_postings( + "Merging scores<1>", + score_out_1, + score_offsets_out_1, + ranges::views::transform(batch_files, [](auto&& files) { return std::get<1>(files); })); + + return {scores_0, scores_1}; +} + +[[nodiscard]] auto build_scored_bigram_index(IndexMetadata meta, + std::string const& index_basename, + std::vector> const& bigrams) + -> std::pair +{ + auto run = scored_index_runner(std::move(meta)); + + auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); + auto scores_file_1 = fmt::format("{}.bigram_bm25_1", index_basename); + auto score_offsets_file_0 = fmt::format("{}.bigram_bm25_offsets_0", index_basename); + auto score_offsets_file_1 = fmt::format("{}.bigram_bm25_offsets_1", index_basename); + std::ofstream score_out_0(scores_file_0); + std::ofstream score_out_1(scores_file_1); + + run([&](auto&& index) { + ProgressStatus status(bigrams.size(), + DefaultProgressCallback("Building scored index"), + std::chrono::milliseconds(500)); + using index_type = std::decay_t; + using score_writer_type = + typename CursorTraits::Writer; + + PostingBuilder score_builder_0(score_writer_type{}); + PostingBuilder score_builder_1(score_writer_type{}); + PostingBuilder(score_writer_type{}); + + score_builder_0.write_header(score_out_0); + score_builder_1.write_header(score_out_1); + + for (auto [left_term, right_term] : bigrams) { + auto intersection = intersect({index.scored_cursor(left_term, VoidScorer{}), + index.scored_cursor(right_term, VoidScorer{})}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + gsl::at(payload, list_idx) = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + for_each(intersection, [&](auto& cursor) { + auto payload = cursor.payload(); + score_builder_0.accumulate(std::get<0>(payload)); + score_builder_1.accumulate(std::get<1>(payload)); + }); + score_builder_0.flush_segment(score_out_0); + score_builder_1.flush_segment(score_out_1); + status += 1; + } + write_span(gsl::make_span(score_builder_0.offsets()), score_offsets_file_0); + write_span(gsl::make_span(score_builder_1.offsets()), score_offsets_file_1); + }); + return {PostingFilePaths{scores_file_0, score_offsets_file_0}, + PostingFilePaths{scores_file_1, score_offsets_file_1}}; +} + +template > +struct HeapPriorityQueue { + using value_type = T; + + explicit HeapPriorityQueue(std::size_t capacity, Order order = Order()) + : m_capacity(capacity), m_order(std::move(order)) + { + m_elements.reserve(m_capacity + 1); + } + HeapPriorityQueue(HeapPriorityQueue const&) = default; + HeapPriorityQueue(HeapPriorityQueue&&) noexcept = default; + HeapPriorityQueue& operator=(HeapPriorityQueue const&) = default; + HeapPriorityQueue& operator=(HeapPriorityQueue&&) noexcept = default; + ~HeapPriorityQueue() = default; + + void push(value_type value) + { + m_elements.push_back(value); + std::push_heap(m_elements.begin(), m_elements.end(), m_order); + if (PISA_LIKELY(m_elements.size() > m_capacity)) { + std::pop_heap(m_elements.begin(), m_elements.end(), m_order); + m_elements.pop_back(); + } + } + + [[nodiscard]] auto size() const noexcept { return m_elements.size(); } + [[nodiscard]] auto capacity() const noexcept { return m_capacity; } + + [[nodiscard]] auto take() && -> std::vector + { + std::sort(m_elements.begin(), m_elements.end(), m_order); + return std::move(m_elements); + } + + private: + std::size_t m_capacity; + std::vector m_elements; + Order m_order; +}; + +[[nodiscard]] auto select_best_bigrams(IndexMetadata const& meta, + std::vector const& queries, + std::size_t num_bigrams_to_select) + -> std::vector> +{ + using Bigram = std::pair; + auto order = [](auto&& lhs, auto&& rhs) { return lhs.second > rhs.second; }; + auto top_bigrams = HeapPriorityQueue(num_bigrams_to_select, order); + + auto run = index_runner(meta); + run([&](auto&& index) { + auto bigram_gain = [&](Query const& bigram) -> float { + auto&& term_ids = bigram.get_term_ids(); + runtime_assert(term_ids.size() == 2) + .or_throw("Queries must be of exactly two unique terms"); + auto cursors = index.scored_cursors(term_ids, make_bm25(index)); + auto union_length = cursors[0].size() + cursors[1].size(); + auto intersection_length = + accumulate(intersect(std::move(cursors), + false, + []([[maybe_unused]] auto count, + [[maybe_unused]] auto&& cursor, + [[maybe_unused]] auto idx) { return true; }), + std::size_t{0}, + [](auto count, [[maybe_unused]] auto&& cursor) { return count + 1; }); + if (intersection_length == 0) { + return 0.0; + } + return static_cast(bigram.get_probability()) * static_cast(union_length) + / static_cast(intersection_length); + }; + for (auto&& query : queries) { + auto&& term_ids = query.get_term_ids(); + top_bigrams.push(std::make_pair(&query, bigram_gain(query))); + } + }); + auto top = std::move(top_bigrams).take(); + return ranges::views::transform(top, + [](auto&& elem) { + auto&& term_ids = elem.first->get_term_ids(); + return std::make_pair(term_ids[0], term_ids[1]); + }) + | ranges::to_vector; +} + +template +auto build_pair_batch(Index&& index, + gsl::span const> pairs, + ProgressStatus& status) + -> std::pair>> +{ + using index_type = std::decay_t; + using document_writer_type = + typename CursorTraits::Writer; + using frequency_writer_type = + typename CursorTraits::Writer; + + std::vector> pair_mapping; + + BigramMetadata batch_meta{.documents = {.postings = make_temp(), .offsets = make_temp()}, + .frequencies = {{.postings = make_temp(), .offsets = make_temp()}, + {.postings = make_temp(), .offsets = make_temp()}}}; + std::ofstream document_out(batch_meta.documents.postings); + std::ofstream frequency_out_0(batch_meta.frequencies.first.postings); + std::ofstream frequency_out_1(batch_meta.frequencies.second.postings); + + PostingBuilder document_builder(document_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_0(frequency_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_1(frequency_writer_type{index.num_documents()}); + + for (auto [left_term, right_term] : pairs) { + auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + gsl::at(payload, list_idx) = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + pair_mapping.push_back({left_term, right_term}); + for_each(intersection, [&](auto& cursor) { + document_builder.accumulate(*cursor); + auto payload = cursor.payload(); + frequency_builder_0.accumulate(std::get<0>(payload)); + frequency_builder_1.accumulate(std::get<1>(payload)); + }); + document_builder.flush_segment(document_out); + frequency_builder_0.flush_segment(frequency_out_0); + frequency_builder_1.flush_segment(frequency_out_1); + status += 1; + } + write_span(gsl::make_span(document_builder.offsets()), batch_meta.documents.offsets); + write_span(gsl::make_span(frequency_builder_0.offsets()), batch_meta.frequencies.first.offsets); + write_span(gsl::make_span(frequency_builder_1.offsets()), + batch_meta.frequencies.second.offsets); + return {std::move(batch_meta), std::move(pair_mapping)}; +} + +auto build_pair_index(IndexMetadata meta, + std::vector> const& pairs, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata +{ + runtime_assert(not pairs.empty()).or_throw("Pair index must contain pairs but none passed"); + std::string index_basename = clone_path.value_or(std::string(meta.get_basename())); + auto run = index_runner(meta); + + std::vector>> pair_mapping(threads); + std::vector batch_meta(threads); + + PostingFilePaths documents{.postings = fmt::format("{}.bigram_documents", index_basename), + .offsets = + fmt::format("{}.bigram_document_offsets", index_basename)}; + PostingFilePaths frequencies_0{ + .postings = fmt::format("{}.bigram_frequencies_0", index_basename), + .offsets = fmt::format("{}.bigram_frequency_offsets_0", index_basename)}; + PostingFilePaths frequencies_1{ + .postings = fmt::format("{}.bigram_frequencies_1", index_basename), + .offsets = fmt::format("{}.bigram_frequency_offsets_1", index_basename)}; + + std::ofstream document_out(documents.postings); + std::ofstream frequency_out_0(frequencies_0.postings); + std::ofstream frequency_out_1(frequencies_1.postings); + + std::ofstream document_offsets_out(documents.offsets); + std::ofstream frequency_offsets_out_0(frequencies_0.offsets); + std::ofstream frequency_offsets_out_1(frequencies_1.offsets); + + ProgressStatus status(pairs.size(), + DefaultProgressCallback("Building bigram index"), + std::chrono::milliseconds(500)); + + BatchConcurrentBuilder batch_builder(threads, + gsl::span const>(pairs)); + run([&](auto&& index) { + using index_type = std::decay_t; + using document_writer_type = + typename CursorTraits::Writer; + using frequency_writer_type = + typename CursorTraits::Writer; + + batch_builder.execute_batch_jobs([&](auto batch_idx, auto interval) { + auto res = build_pair_batch(index, interval.span(), status); + batch_meta[batch_idx] = res.first; + pair_mapping[batch_idx] = std::move(res.second); + }); + + write_document_header(index, document_out); + write_payload_header(index, frequency_out_0); + write_payload_header(index, frequency_out_1); + }); + + merge_postings( + "Merging documents", + document_out, + document_offsets_out, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.documents; })); + + merge_postings( + "Merging frequencies<0>", + frequency_out_0, + frequency_offsets_out_0, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.frequencies.first; })); + + merge_postings( + "Merging frequencies<1>", + frequency_out_1, + frequency_offsets_out_1, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.frequencies.second; })); + + auto count = std::accumulate(pair_mapping.begin(), + pair_mapping.end(), + std::size_t{0}, + [](auto acc, auto&& m) { return acc + m.size(); }); + BigramMetadata bigram_meta{.documents = documents, + .frequencies = {frequencies_0, frequencies_1}, + .scores = {}, + .mapping = fmt::format("{}.bigram_mapping", index_basename), + .count = count}; + + if (not meta.scores.empty()) { + bigram_meta.scores.push_back(build_scored_pair_index(meta, index_basename, pairs, threads)); + } + meta.bigrams = bigram_meta; + + std::cerr << "Writing metadata..."; + if (clone_path) { + meta.write(append_extension(clone_path.value())); + } else { + meta.update(); + } + std::cerr << " Done.\nWriting bigram mapping..."; + std::ofstream os(meta.bigrams->mapping); + for (auto mapping_batch : pair_mapping) { + write_span(gsl::make_span(mapping_batch), os); + } + std::cerr << " Done.\n"; + return meta; +} + +} // namespace pisa::v1 diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp new file mode 100644 index 000000000..a4c07982e --- /dev/null +++ b/src/v1/index_metadata.cpp @@ -0,0 +1,222 @@ +#include +#include +#include + +#include + +#include "query/queries.hpp" +#include "v1/index_metadata.hpp" + +namespace pisa::v1 { + +constexpr char const* DOCUMENTS = "documents"; +constexpr char const* FREQUENCIES = "frequencies"; +constexpr char const* SCORES = "scores"; +constexpr char const* POSTINGS = "file"; +constexpr char const* OFFSETS = "offsets"; +constexpr char const* STATS = "stats"; +constexpr char const* LEXICON = "lexicon"; +constexpr char const* TERMS = "terms"; +constexpr char const* BIGRAM = "bigram"; +constexpr char const* MAX_SCORES = "max_scores"; +constexpr char const* BLOCK_MAX_SCORES = "block_max_scores"; +constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; + +[[nodiscard]] auto has_extension(std::string_view file_path, std::string_view extension) -> bool +{ + if (file_path.size() <= 4) { + return false; + } + return std::string_view(file_path.data(), file_path.size() - 4) == extension; +} + +[[nodiscard]] auto append_extension(std::string file_path) -> std::string +{ + using namespace std::literals; + if (has_extension(file_path, ".yml"sv)) { + return file_path; + } + return fmt::format("{}{}", file_path, ".yml"sv); +} + +[[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string +{ + if (arg) { + return *arg; + } + throw std::runtime_error("Resolving .yml from the current folder not supported yet!"); +} + +[[nodiscard]] auto IndexMetadata::from_file(std::string const& file) -> IndexMetadata +{ + YAML::Node config = YAML::LoadFile(file); + std::vector scores; + if (config[SCORES]) { + // TODO(michal): Once switched to YAML, parse an array. + scores.push_back(PostingFilePaths{.postings = config[SCORES][POSTINGS].as(), + .offsets = config[SCORES][OFFSETS].as()}); + } + return IndexMetadata{ + .basename = file.substr(0, file.size() - 4), + .documents = PostingFilePaths{.postings = config[DOCUMENTS][POSTINGS].as(), + .offsets = config[DOCUMENTS][OFFSETS].as()}, + .frequencies = PostingFilePaths{.postings = config[FREQUENCIES][POSTINGS].as(), + .offsets = config[FREQUENCIES][OFFSETS].as()}, + .scores = std::move(scores), + .document_lengths_path = config[STATS]["document_lengths"].as(), + .avg_document_length = config[STATS]["avg_document_length"].as(), + .term_lexicon = [&]() -> tl::optional { + if (config[LEXICON][TERMS]) { + return config[LEXICON][TERMS].as(); + } + return tl::nullopt; + }(), + .document_lexicon = [&]() -> tl::optional { + if (config[LEXICON][DOCUMENTS]) { + return config[LEXICON][DOCUMENTS].as(); + } + return tl::nullopt; + }(), + .stemmer = [&]() -> tl::optional { + if (config[LEXICON]["stemmer"]) { + return config[LEXICON]["stemmer"].as(); + } + return tl::nullopt; + }(), + .bigrams = [&]() -> tl::optional { + if (config[BIGRAM]) { + std::vector> scores; + if (config[BIGRAM]["scores_0"]) { + scores = {{{.postings = config[BIGRAM]["scores_0"][POSTINGS].as(), + .offsets = config[BIGRAM]["scores_0"][OFFSETS].as()}, + {.postings = config[BIGRAM]["scores_1"][POSTINGS].as(), + .offsets = config[BIGRAM]["scores_1"][OFFSETS].as()}}}; + } + return BigramMetadata{ + .documents = {.postings = config[BIGRAM][DOCUMENTS][POSTINGS].as(), + .offsets = config[BIGRAM][DOCUMENTS][OFFSETS].as()}, + .frequencies = + {{.postings = config[BIGRAM]["frequencies_0"][POSTINGS].as(), + .offsets = config[BIGRAM]["frequencies_0"][OFFSETS].as()}, + {.postings = config[BIGRAM]["frequencies_1"][POSTINGS].as(), + .offsets = config[BIGRAM]["frequencies_1"][OFFSETS].as()}}, + .scores = std::move(scores), + .mapping = config[BIGRAM]["mapping"].as(), + .count = config[BIGRAM]["count"].as()}; + } + return tl::nullopt; + }(), + .max_scores = + [&]() { + if (config[MAX_SCORES]) { + return config[MAX_SCORES].as>(); + } + return std::map{}; + }(), + .block_max_scores = + [&]() { + if (config[BLOCK_MAX_SCORES]) { + return config[BLOCK_MAX_SCORES].as>(); + } + return std::map{}; + }(), + .quantized_max_scores = + [&]() { + if (config[QUANTIZED_MAX_SCORES]) { + return config[QUANTIZED_MAX_SCORES].as>(); + } + return std::map{}; + }()}; +} + +void IndexMetadata::update() const { write(fmt::format("{}.yml", get_basename())); } + +void IndexMetadata::write(std::string const& file) const +{ + YAML::Node root; + root[DOCUMENTS][POSTINGS] = documents.postings; + root[DOCUMENTS][OFFSETS] = documents.offsets; + root[FREQUENCIES][POSTINGS] = frequencies.postings; + root[FREQUENCIES][OFFSETS] = frequencies.offsets; + root[STATS]["avg_document_length"] = avg_document_length; + root[STATS]["document_lengths"] = document_lengths_path; + root[LEXICON]["stemmer"] = "porter2"; + if (not scores.empty()) { + root[SCORES][POSTINGS] = scores.front().postings; + root[SCORES][OFFSETS] = scores.front().offsets; + } + if (term_lexicon) { + root[LEXICON][TERMS] = *term_lexicon; + } + if (document_lexicon) { + root[LEXICON][DOCUMENTS] = *document_lexicon; + } + if (bigrams) { + root[BIGRAM][DOCUMENTS][POSTINGS] = bigrams->documents.postings; + root[BIGRAM][DOCUMENTS][OFFSETS] = bigrams->documents.offsets; + root[BIGRAM]["frequencies_0"][POSTINGS] = bigrams->frequencies.first.postings; + root[BIGRAM]["frequencies_0"][OFFSETS] = bigrams->frequencies.first.offsets; + root[BIGRAM]["frequencies_1"][POSTINGS] = bigrams->frequencies.second.postings; + root[BIGRAM]["frequencies_1"][OFFSETS] = bigrams->frequencies.second.offsets; + if (not bigrams->scores.empty()) { + root[BIGRAM]["scores_0"][POSTINGS] = bigrams->scores.front().first.postings; + root[BIGRAM]["scores_0"][OFFSETS] = bigrams->scores.front().first.offsets; + root[BIGRAM]["scores_1"][POSTINGS] = bigrams->scores.front().second.postings; + root[BIGRAM]["scores_1"][OFFSETS] = bigrams->scores.front().second.offsets; + } + root[BIGRAM]["mapping"] = bigrams->mapping; + root[BIGRAM]["count"] = bigrams->count; + } + if (not max_scores.empty()) { + for (auto [key, value] : max_scores) { + root[MAX_SCORES][key] = value; + } + } + if (not block_max_scores.empty()) { + for (auto [key, value] : block_max_scores) { + root[BLOCK_MAX_SCORES][key] = value; + } + } + if (not quantized_max_scores.empty()) { + for (auto [key, value] : quantized_max_scores) { + root[QUANTIZED_MAX_SCORES][key] = value; + } + } + std::ofstream fout(file); + fout << root; +} + +[[nodiscard]] auto IndexMetadata::query_parser(tl::optional const& stop_words) const + -> std::function +{ + if (term_lexicon) { + auto term_processor = ::pisa::TermProcessor( + *term_lexicon, + [&]() -> std::optional { + if (stop_words) { + return *stop_words; + } + return std::nullopt; + }(), + [&]() -> std::optional { + if (stemmer) { + return *stemmer; + } + return std::nullopt; + }()); + return [term_processor = std::move(term_processor)](Query& query) { + query.term_ids(parse_query_terms(query.get_raw(), term_processor).terms); + }; + } + throw std::runtime_error("Unable to parse query: undefined term lexicon"); +} + +[[nodiscard]] auto IndexMetadata::get_basename() const -> std::string const& +{ + if (not basename) { + throw std::runtime_error("Unable to resolve index basename"); + } + return basename.value(); +} + +} // namespace pisa::v1 diff --git a/src/v1/intersection.cpp b/src/v1/intersection.cpp new file mode 100644 index 000000000..d3c3b5d38 --- /dev/null +++ b/src/v1/intersection.cpp @@ -0,0 +1,77 @@ +#include +#include + +#include +#include +#include + +#include "io.hpp" +#include "v1/intersection.hpp" + +namespace pisa::v1 { + +auto read_intersections(std::string const& filename) -> std::vector>> +{ + std::ifstream is(filename); + return read_intersections(is); +} + +auto read_intersections(std::istream& is) -> std::vector>> +{ + std::vector>> intersections; + ::pisa::io::for_each_line(is, [&](auto const& query_line) { + intersections.emplace_back(); + std::istringstream iss(query_line); + std::transform( + std::istream_iterator(iss), + std::istream_iterator(), + std::back_inserter(intersections.back()), + [&](auto const& n) { + auto bits = std::bitset<64>(std::stoul(n)); + if (bits.count() > 2) { + spdlog::error("Intersections of more than 2 terms not supported yet!"); + std::exit(1); + } + return bits; + }); + }); + return intersections; +} + +[[nodiscard]] auto to_vector(std::bitset<64> const& bits) -> std::vector +{ + std::vector vec; + for (auto idx = 0; idx < bits.size(); idx += 1) { + if (bits.test(idx)) { + vec.push_back(idx); + } + } + return vec; +} + +[[nodiscard]] auto filter_unigrams(std::vector>> const& intersections) + -> std::vector> +{ + return intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(1)) + | ranges::views::transform([&](auto bits) { return to_vector(bits)[0]; }) + | ranges::to_vector; + }) + | ranges::to_vector; +} + +[[nodiscard]] auto filter_bigrams(std::vector>> const& intersections) + -> std::vector>> +{ + return intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(2)) + | ranges::views::transform([&](auto bits) { + auto vec = to_vector(bits); + return std::make_pair(vec[0], vec[1]); + }) + | ranges::to_vector; + }) + | ranges::to_vector; +} + +} // namespace pisa::v1 diff --git a/src/v1/io.cpp b/src/v1/io.cpp new file mode 100644 index 000000000..82c2db350 --- /dev/null +++ b/src/v1/io.cpp @@ -0,0 +1,21 @@ +#include + +#include "v1/io.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto load_bytes(std::string const& data_file) -> std::vector +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + data.resize(size); + if (not in.read(reinterpret_cast(data.data()), size)) { + throw std::runtime_error("Failed reading " + data_file); + } + return data; +} + +} // namespace pisa::v1 diff --git a/src/v1/progress_status.cpp b/src/v1/progress_status.cpp new file mode 100644 index 000000000..fa03c296e --- /dev/null +++ b/src/v1/progress_status.cpp @@ -0,0 +1,76 @@ +#include "v1/progress_status.hpp" + +#include +#include + +#include + +namespace pisa::v1 { + +using std::chrono::hours; +using std::chrono::minutes; +using std::chrono::seconds; + +[[nodiscard]] auto format_interval(std::chrono::seconds time) -> std::string +{ + hours h = std::chrono::duration_cast(time); + minutes m = std::chrono::duration_cast(time - h); + seconds s = std::chrono::duration_cast(time - h - m); + std::ostringstream os; + if (h.count() > 0) { + os << h.count() << "h "; + } + if (m.count() > 0) { + os << m.count() << "m "; + } + os << s.count() << "s"; + return os.str(); +} + +[[nodiscard]] auto estimate_remaining(seconds elapsed, Progress const& progress) -> seconds +{ + return seconds(elapsed.count() * (progress.target - progress.count) / progress.count); +} + +DefaultProgressCallback::DefaultProgressCallback(std::string caption) + : m_caption(std::move(caption)) +{ + if (not m_caption.empty()) { + m_caption.append(": "); + } +} + +void DefaultProgressCallback::operator()(Progress current_progress, + std::chrono::time_point start) +{ + std::size_t progress = (100 * current_progress.count) / current_progress.target; + m_previous = progress; + std::chrono::seconds elapsed = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + auto msg = fmt::format( + "{}{}\% [{}]{}", m_caption, progress, format_interval(elapsed), [&]() -> std::string { + if (current_progress.count > 0 && progress < 100) { + return fmt::format(" [<{}]", + format_interval(estimate_remaining(elapsed, current_progress))); + } + return ""; + }()); + std::cerr << '\r' << std::left << std::setfill(' ') << std::setw(m_prev_msg_len) << msg; + if (progress == 100) { + std::cerr << '\n'; + } + m_prev_msg_len = msg.size(); +} + +void ProgressStatus::close() +{ + if (m_open) { + m_count = m_target; + m_loop.join(); + m_open = false; + } +} + +ProgressStatus::~ProgressStatus() { close(); } + +} // namespace pisa::v1 diff --git a/src/v1/query.cpp b/src/v1/query.cpp new file mode 100644 index 000000000..93c91cba9 --- /dev/null +++ b/src/v1/query.cpp @@ -0,0 +1,331 @@ +#include + +#include +#include +#include + +#include "v1/query.hpp" + +namespace pisa::v1 { + +using json = nlohmann::json; + +[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector +{ + auto terms = query.get_term_ids(); + std::sort(terms.begin(), terms.end()); + terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); + return terms; +} + +void Query::add_selections(gsl::span const> selections) +{ + m_selections = ListSelection{}; + for (auto intersection : selections) { + if (intersection.count() > 2) { + throw std::invalid_argument("Intersections of more than 2 terms not supported yet!"); + } + auto positions = to_vector(intersection); + if (positions.size() == 1) { + m_selections->unigrams.push_back(resolve_term(positions.front())); + } else { + m_selections->bigrams.emplace_back(resolve_term(positions[0]), + resolve_term(positions[1])); + auto& added = m_selections->bigrams.back(); + if (added.first > added.second) { + std::swap(added.first, added.second); + } + } + } + ranges::sort(m_selections->unigrams); + ranges::sort(m_selections->bigrams); +} + +[[nodiscard]] auto Query::filtered_terms(std::bitset<64> selection) const -> std::vector +{ + auto&& term_ids = get_term_ids(); + std::vector terms; + std::vector weights; + for (std::size_t bitpos = 0; bitpos < term_ids.size(); ++bitpos) { + if (((1U << bitpos) & selection.to_ulong()) > 0) { + terms.push_back(term_ids.at(bitpos)); + } + } + return terms; +} + +auto Query::resolve_term(std::size_t pos) -> TermId +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->term_at_pos(pos); +} + +template +[[nodiscard]] auto get(json const& node, std::string_view field) -> tl::optional +{ + if (auto pos = node.find(field); pos != node.end()) { + return tl::make_optional(pos->get()); + } + return tl::optional{}; +} + +[[nodiscard]] auto Query::from_plain(std::string_view query_string) -> Query +{ + auto colon = std::find(query_string.begin(), query_string.end(), ':'); + tl::optional id; + if (colon != query_string.end()) { + id = std::string(query_string.begin(), colon); + } + auto pos = colon == query_string.end() ? query_string.begin() : std::next(colon); + return Query(std::string(&*pos, std::distance(pos, query_string.end())), std::move(id)); +} + +[[nodiscard]] auto Query::from_json(std::string_view json_string) -> Query +{ + try { + auto query_json = json::parse(json_string); + auto id = get(query_json, "id"); + auto raw_string = get(query_json, "query"); + auto term_ids = get>(query_json, "term_ids"); + Query query = [&]() { + if (raw_string) { + auto query = Query(*raw_string.take(), id); + if (term_ids) { + query.term_ids(*term_ids.take()); + } + return query; + } + if (term_ids) { + return Query(*term_ids.take(), id); + } + throw std::invalid_argument( + "Failed to parse query: must define either raw string or term IDs"); + }(); + if (auto threshold = get(query_json, "threshold"); threshold) { + query.threshold(*threshold); + } + if (auto probability = get(query_json, "probability"); probability) { + query.probability(*probability); + } + if (auto k = get(query_json, "k"); k) { + query.k(*k); + } + if (auto selections = get>(query_json, "selections"); selections) { + std::vector> bitsets; + std::transform(selections->begin(), + selections->end(), + std::back_inserter(bitsets), + [](auto selection) { return std::bitset<64>(selection); }); + query.selections(gsl::span>(bitsets)); + } + return query; + } catch (json::parse_error const& err) { + throw std::runtime_error(fmt::format("Failed to parse query: {}", err.what())); + } +} + +[[nodiscard]] auto Query::to_json() const -> std::unique_ptr +{ + auto query = std::make_unique(); + if (m_id) { + (*query)["id"] = *m_id; + } + if (m_raw_string) { + (*query)["query"] = *m_raw_string; + } + if (m_term_ids) { + (*query)["term_ids"] = m_term_ids->get(); + } + if (m_threshold) { + (*query)["threshold"] = *m_threshold; + } + if (m_probability) { + (*query)["probability"] = *m_probability; + } + // TODO(michal) + // tl::optional m_selections{}; + // int m_k = 1000; + return query; +} + +Query::Query(std::vector term_ids, tl::optional id) + : m_term_ids(std::move(term_ids)), m_id(std::move(id)) +{ +} + +Query::Query(std::string query, tl::optional id) + : m_raw_string(std::move(query)), m_id(std::move(id)) +{ +} + +auto Query::term_ids(std::vector term_ids) -> Query& +{ + m_term_ids = TermIdSet(std::move(term_ids)); + return *this; +} + +auto Query::id(std::string id) -> Query& +{ + m_id = std::move(id); + return *this; +} + +auto Query::k(int k) -> Query& +{ + m_k = k; + return *this; +} + +auto Query::selections(gsl::span const> selections) -> Query& +{ + add_selections(selections); + return *this; +} + +auto Query::selections(ListSelection selections) -> Query& +{ + m_selections = std::move(selections); + return *this; +} + +auto Query::threshold(float threshold) -> Query& +{ + m_threshold = threshold; + return *this; +} + +auto Query::probability(float probability) -> Query& +{ + m_probability = probability; + return *this; +} + +auto Query::with_term_ids(std::vector term_ids) && -> Query +{ + this->term_ids(std::move(term_ids)); + return std::move(*this); +} +auto Query::with_id(std::string id) && -> Query +{ + m_id = std::move(id); + return std::move(*this); +} +auto Query::with_k(int k) && -> Query +{ + m_k = k; + return std::move(*this); +} +auto Query::with_selections(gsl::span const> selections) && -> Query +{ + this->selections(selections); + return std::move(*this); +} +auto Query::with_selections(ListSelection selections) && -> Query +{ + this->selections(std::move(selections)); + return std::move(*this); +} +auto Query::with_threshold(float threshold) && -> Query +{ + m_threshold = threshold; + return std::move(*this); +} +auto Query::with_probability(float probability) && -> Query +{ + m_probability = probability; + return std::move(*this); +} + +/// Getters +auto Query::term_ids() const -> tl::optional const&> +{ + return m_term_ids.map( + [](auto const& terms) -> std::vector const& { return terms.get(); }); +} +auto Query::id() const -> tl::optional const& { return m_id; } +auto Query::k() const -> int { return m_k; } +auto Query::selections() const -> tl::optional +{ + if (m_selections) { + return *m_selections; + } + return tl::nullopt; +} +auto Query::threshold() const -> tl::optional { return m_threshold; } +auto Query::probability() const -> tl::optional { return m_probability; } +auto Query::raw() const -> tl::optional +{ + if (m_raw_string) { + return *m_raw_string; + } + return tl::nullopt; +} + +/// Throwing getters +auto Query::get_term_ids() const -> std::vector const& +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->get(); +} + +auto Query::get_id() const -> std::string const& +{ + if (not m_id) { + throw std::runtime_error("ID is not set"); + } + return *m_id; +} + +auto Query::get_selections() const -> ListSelection const& +{ + if (not m_selections) { + throw std::runtime_error("Selections are not set"); + } + return *m_selections; +} + +auto Query::get_threshold() const -> float +{ + if (not m_threshold) { + throw std::runtime_error("Threshold is not set"); + } + return *m_threshold; +} + +auto Query::get_probability() const -> float +{ + if (not m_probability) { + throw std::runtime_error("Probability is not set"); + } + return *m_probability; +} + +auto Query::get_raw() const -> std::string const& +{ + if (not m_raw_string) { + throw std::runtime_error("Raw query string is not set"); + } + return *m_raw_string; +} + +auto Query::sorted_position(TermId term) const -> std::size_t +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->sorted_position(term); +} + +auto Query::term_at_pos(std::size_t pos) const -> TermId +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->term_at_pos(pos); +} + +} // namespace pisa::v1 diff --git a/src/v1/raw_cursor.cpp b/src/v1/raw_cursor.cpp new file mode 100644 index 000000000..0e2dfaf04 --- /dev/null +++ b/src/v1/raw_cursor.cpp @@ -0,0 +1,18 @@ +#include "v1/raw_cursor.hpp" + +namespace pisa::v1 { + +template struct RawCursor; +template struct RawCursor; +template struct RawCursor; +template struct RawReader; +template struct RawReader; +template struct RawReader; +template struct RawWriter; +template struct RawWriter; +template struct RawWriter; +template struct CursorTraits>; +template struct CursorTraits>; +template struct CursorTraits>; + +} // namespace pisa::v1 diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp new file mode 100644 index 000000000..161d2d26e --- /dev/null +++ b/src/v1/score_index.cpp @@ -0,0 +1,166 @@ +#include + +#include + +#include "codec/simdbp.hpp" +#include "score_opt_partition.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/score_index.hpp" +#include "v1/scorer/bm25.hpp" + +using pisa::v1::DefaultProgressCallback; +using pisa::v1::IndexMetadata; +using pisa::v1::PostingFilePaths; +using pisa::v1::ProgressStatus; +using pisa::v1::RawWriter; +using pisa::v1::TermId; +using pisa::v1::write_span; + +namespace pisa::v1 { + +auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata +{ + auto run = index_runner(meta); + auto const& index_basename = meta.get_basename(); + auto postings_path = fmt::format("{}.bm25", index_basename); + auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); + auto max_scores_path = fmt::format("{}.bm25.maxf", index_basename); + auto quantized_max_scores_path = fmt::format("{}.bm25.maxq", index_basename); + run([&](auto&& index) { + ProgressStatus calc_max_status(index.num_terms(), + DefaultProgressCallback("Calculating max partial score"), + std::chrono::milliseconds(100)); + std::vector max_scores(index.num_terms(), 0.0F); + tbb::task_group group; // TODO(michal): Unused? + auto batch_size = index.num_terms() / threads; + for (auto thread_id = 0; thread_id < threads; thread_id += 1) { + auto first_term = thread_id * batch_size; + auto end_term = + thread_id < threads - 1 ? (thread_id + 1) * batch_size : index.num_terms(); + std::for_each(boost::counting_iterator(first_term), + boost::counting_iterator(end_term), + [&](auto term) { + for_each( + index.scoring_cursor(term, make_bm25(index)), [&](auto&& cursor) { + if (auto score = cursor.payload(); max_scores[term] < score) { + max_scores[term] = score; + } + }); + calc_max_status += 1; + }); + } + group.wait(); + calc_max_status.close(); + auto max_score = *std::max_element(max_scores.begin(), max_scores.end()); + std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.\n", + max_score, + std::numeric_limits::max()); + + auto quantizer = [&](float score) { + return static_cast(score * std::numeric_limits::max() + / max_score); + }; + std::vector quantized_max_scores; + std::transform(max_scores.begin(), + max_scores.end(), + std::back_inserter(quantized_max_scores), + quantizer); + + ProgressStatus status( + index.num_terms(), DefaultProgressCallback("Scoring"), std::chrono::milliseconds(100)); + std::ofstream score_file_stream(postings_path); + auto offsets = score_index(index, + score_file_stream, + RawWriter{}, + make_bm25(index), + quantizer, + [&]() { status += 1; }); + write_span(gsl::span(offsets), offsets_path); + write_span(gsl::span(max_scores), max_scores_path); + write_span(gsl::span(quantized_max_scores), quantized_max_scores_path); + }); + meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); + meta.max_scores["bm25"] = max_scores_path; + meta.quantized_max_scores["bm25"] = quantized_max_scores_path; + meta.update(); + return meta; +} + +// TODO: Use multiple threads +auto bm_score_index(IndexMetadata meta, + BlockType block_type, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata +{ + auto run = index_runner(meta); + std::string index_basename = clone_path.value_or(std::string(meta.get_basename())); + auto prefix = fmt::format("{}.bm25_block_max", index_basename); + UnigramFilePaths paths{ + .documents = PostingFilePaths{.postings = fmt::format("{}_documents", prefix), + .offsets = fmt::format("{}_document_offsets", prefix)}, + .payloads = PostingFilePaths{.postings = fmt::format("{}.scores", prefix), + .offsets = fmt::format("{}.score_offsets", prefix)}, + }; + run([&](auto&& index) { + auto scorer = make_bm25(index); + ProgressStatus status(index.num_terms(), + DefaultProgressCallback("Calculating max-blocks"), + std::chrono::milliseconds(100)); + std::ofstream document_out(paths.documents.postings); + std::ofstream score_out(paths.payloads.postings); + PostingBuilder document_builder(RawWriter{}); + PostingBuilder score_builder(RawWriter{}); + document_builder.write_header(document_out); + score_builder.write_header(score_out); + for (TermId term_id = 0; term_id < index.num_terms(); term_id += 1) { + auto cursor = index.scored_cursor(term_id, scorer); + if (std::holds_alternative(block_type)) { + auto block_size = std::get(block_type).size; + while (not cursor.empty()) { + auto max_score = 0.0F; + auto last_docid = 0; + for (auto idx = 0; idx < block_size && not cursor.empty(); ++idx) { + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + last_docid = cursor.value(); + cursor.advance(); + } + document_builder.accumulate(last_docid); + score_builder.accumulate(max_score); + } + document_builder.flush_segment(document_out); + score_builder.flush_segment(score_out); + } else { + auto lambda = std::get(block_type).lambda; + auto eps1 = configuration::get().eps1_wand; + auto eps2 = configuration::get().eps2_wand; + auto vec = collect_with_payload(cursor); + auto partition = + score_opt_partition(vec.begin(), 0, vec.size(), eps1, eps2, lambda); + document_builder.write_segment( + document_out, partition.docids.begin(), partition.docids.end()); + score_builder.write_segment( + score_out, partition.max_values.begin(), partition.max_values.end()); + } + status += 1; + } + write_span(gsl::make_span(document_builder.offsets()), paths.documents.offsets); + write_span(gsl::make_span(score_builder.offsets()), paths.payloads.offsets); + }); + meta.block_max_scores["bm25"] = paths; + if (clone_path) { + meta.write(append_extension(clone_path.value())); + } else { + meta.update(); + } + return meta; +} + +} // namespace pisa::v1 diff --git a/test/temporary_directory.hpp b/test/temporary_directory.hpp index d87e043cc..3a6b1f789 100644 --- a/test/temporary_directory.hpp +++ b/test/temporary_directory.hpp @@ -14,14 +14,15 @@ struct Temporary_Directory { boost::filesystem::create_directory(dir_); std::cerr << "Created a tmp dir " << dir_.c_str() << '\n'; } - ~Temporary_Directory() { + ~Temporary_Directory() + { if (boost::filesystem::exists(dir_)) { boost::filesystem::remove_all(dir_); } std::cerr << "Removed a tmp dir " << dir_.c_str() << '\n'; } - [[nodiscard]] auto path() -> boost::filesystem::path const & { return dir_; } + [[nodiscard]] auto path() const -> boost::filesystem::path const & { return dir_; } private: boost::filesystem::path dir_; diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index c6a213889..6a706c403 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -24,7 +24,7 @@ struct IndexData { static std::unordered_map> data; - IndexData(std::string const &scorer_name, std::unordered_set const &dropped_term_ids) + IndexData(std::string const& scorer_name, std::unordered_set const& dropped_term_ids) : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), wdata(document_sizes.begin()->begin(), @@ -36,7 +36,7 @@ struct IndexData { { typename Index::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( @@ -45,7 +45,7 @@ struct IndexData { builder.build(index); term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -58,8 +58,8 @@ struct IndexData { std::vector queries; WandTypePlain wdata; - [[nodiscard]] static auto get(std::string const &s_name, - std::unordered_set const &dropped_term_ids) + [[nodiscard]] static auto get(std::string const& s_name, + std::unordered_set const& dropped_term_ids) { if (IndexData::data.find(s_name) == IndexData::data.end()) { IndexData::data[s_name] = std::make_unique>(s_name, dropped_term_ids); @@ -72,7 +72,7 @@ template std::unordered_map>> IndexData::data = {}; template -auto test(Wand &wdata, std::string const &s_name) +auto test(Wand& wdata, std::string const& s_name) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); @@ -82,7 +82,7 @@ auto test(Wand &wdata, std::string const &s_name) wand_query wand_q(topk_2); auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { wand_q(make_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, wdata, *scorer, q), data->index.num_docs()); @@ -101,7 +101,7 @@ auto test(Wand &wdata, std::string const &s_name) TEST_CASE("block_max_wand", "[bmw][query][ranked][integration]", ) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); diff --git a/test/test_data/top10_selections b/test/test_data/top10_selections new file mode 100644 index 000000000..2c9284012 --- /dev/null +++ b/test/test_data/top10_selections @@ -0,0 +1,500 @@ +1 +1 +10 18 24 +1 2 4 +1 +1 +2 +16 32 192 256 +1 6 +2 +1 +1 2 +3 +1 6 +4 8 +1 +1 2 +1 +2 +4 +1 +1 +1 +1 +3 +1 2 +9 +2 +1 2 +1 12 +3 9 12 +5 9 12 33 40 +6 +1 6 +2 +1 +1 2 +1 +2 +2 9 40 72 96 136 160 192 +12 +3 5 6 +1 +1 2 +3 5 6 +2 5 +1 2 4 +1 2 4 +1 8 +1 6 +1 2 +2 +1 +1 10 66 72 +1 +4 +1 6 +1 +3 5 6 17 18 20 +2 +5 8 17 20 +1 2 +3 5 16 64 129 132 640 +1 2 +1 2 +1 2 4 +9 16 33 65 72 96 129 136 160 192 +5 17 20 32 +2 4 +5 6 9 +3 5 6 +5 9 12 +1 2 +1 2 +1 2 +1 +9 64 +2 5 +1 2 +8 32 +1 2 +2 4 +2 +1 4 +1 2 +1 32 64 +1 +1 4 +4 9 17 18 24 +8 +5 9 12 17 20 24 +1 2 +3 4 9 10 18 +1 10 34 40 +2 4 +10 +1 4 +1 6 +1 +3 6 +2 +1 2 +12 32 68 72 +2 4 +4 +1 +1 2 +36 64 +2 +1 8 +1 2 +3 +1 6 +1 2 +1 2 +5 +3 9 10 17 18 24 129 130 136 144 +4 10 +1 2 4 +3 5 12 +1 2 +1 10 18 24 +3 6 9 +1 2 4 +1 2 +1 2 +2 4 8 +4 64 136 +1 6 +1 +1 2 +2 9 17 +1 6 +4 +3 4 +2 4 33 +3 9 17 18 +2 +2 5 +2 +1 4 +2 9 +3 4 9 10 +9 +3 4 +4 8 +1 2 +3 6 +1 +1 2 +2 9 +1 +3 +1 12 +1 2 +3 5 +6 10 12 +1 2 +3 5 6 +5 6 +1 +1 4 +1 2 +5 6 8 +1 64 128 +1 +3 6 +1 2 12 +1 4 +2 8 +1 +1 2 +1 8 16 +1 +1 2 +1 2 +2 +5 +3 +2 +1 +1 +1 2 +4 16 +16 32 +1 12 +4 9 10 18 24 34 40 48 +1 12 +1 2 +3 +1 2 +1 8 +1 2 +1 +20 +1 +1 2 +1 2 +1 +1 2 +1 2 4 +1 2 +2 5 +3 5 6 +1 4 +1 6 +2 9 20 24 +1 10 18 24 34 40 48 +3 4 +1 2 +1 4 32 +1 2 +1 2 +2 +3 5 +10 12 +1 +1 2 +1 2 +3 +1 +2 4 40 +1 6 +1 6 +1 2 +3 9 10 +3 4 9 +3 5 6 16 +4 +3 4 +6 +1 2 +1 +3 5 6 +1 +3 5 6 33 34 36 +1 2 +1 +6 +3 4 9 10 33 34 40 +1 +1 +4 9 10 +1 +2 +1 2 +1 2 4 +2 +6 10 12 18 +1 +5 9 12 +2 +1 +2 +1 +1 +1 +5 +2 +2 4 +1 2 4 +1 +1 2 +48 +1 6 +3 +1 +1 2 +1 +1 2 +1 2 4 +1 +5 16 +2 +3 +3 5 6 8 +2 +1 2 +3 +1 +1 +3 +20 +1 8 20 36 +3 9 10 17 24 +3 4 +4 8 +1 2 4 +3 +1 8 +3 9 10 16 +4 +1 2 +1 2 +1 2 +3 +5 9 +3 8 +1 4 +1 2 12 +1 2 4 24 +2 +2 +1 2 +1 4 +2 9 12 20 33 36 40 +1 2 +3 5 6 +1 +4 9 17 24 +1 +3 +1 6 +6 8 +2 +1 +5 9 12 +3 5 6 +1 2 4 +1 8 18 +3 4 +1 2 +1 2 +6 10 12 18 20 24 +3 +3 +3 17 18 +2 +2 +1 +1 12 +1 2 +4 +9 16 65 72 129 132 136 192 +3 +3 5 +1 2 4 +1 6 8 18 20 +3 +2 +1 6 +2 4 +2 +1 2 4 +1 2 4 +1 2 +1 2 +1 +3 +1 +3 +1 2 +5 +1 +1 2 +5 +5 +1 2 +1 +1 2 +33 48 +3 9 10 +1 4 +1 2 +1 +3 8 17 18 +2 +1 2 +1 2 +1 +1 +1 +1 2 +3 5 6 +1 2 +1 +2 +1 +1 2 +1 +1 8 18 +6 10 12 +5 +3 4 +3 4 +1 +3 5 6 +5 9 12 17 20 24 +2 12 +1 20 64 +2 +2 20 24 +4 10 18 24 +1 2 +1 2 8 +4 10 18 24 +1 +3 9 10 +1 +12 20 24 36 40 48 68 72 80 96 +2 +1 2 4 +3 17 18 65 66 80 +1 +2 +1 12 20 24 68 72 80 +1 +1 +3 +1 +1 6 +2 +17 20 32 +4 9 33 65 72 96 +1 +1 2 +2 +3 +1 2 +1 +1 6 8 +2 +4 9 10 +5 8 65 68 +1 6 +5 17 18 20 33 36 48 +1 12 +1 4 +2 +1 +1 2 +4 +1 2 +2 4 8 +1 40 +1 2 +1 +1 2 +2 +3 5 6 +24 36 40 48 260 264 272 288 +6 +4 8 +5 9 12 16 +1 2 +1 24 +1 2 +1 2 +2 12 16 +1 2 +1 +1 2 12 40 48 +8 +3 9 10 +1 2 12 +1 +5 8 +3 4 9 10 +2 4 9 33 40 +1 2 +3 +1 4 +1 2 4 +4 9 +4 129 +1 20 +1 2 +1 6 +1 2 +1 2 +3 5 +3 5 6 +1 +1 2 +3 5 +2 +1 +8 16 32 +2 +1 +5 9 10 17 18 24 +1 2 +6 +5 +1 2 +3 +1 2 +1 +5 +1 6 +1 2 +1 2 +1 +1 +3 +1 +1 +6 10 12 +1 +1 diff --git a/test/test_data/top10_selections_unigram b/test/test_data/top10_selections_unigram new file mode 100644 index 000000000..b6551bb17 --- /dev/null +++ b/test/test_data/top10_selections_unigram @@ -0,0 +1,500 @@ +1 +1 +8 16 +1 2 4 +1 +1 +2 +16 32 64 256 +1 2 +2 +1 +1 2 +2 +1 4 +4 8 +1 +1 2 +1 +2 +4 +1 +1 +1 +1 +1 +1 2 +8 +2 +1 2 +1 4 +2 8 +1 8 +2 +1 2 +2 4 +1 +1 2 +1 +2 +2 8 32 128 +4 +1 4 +1 +1 2 +2 4 +2 4 +1 2 4 +1 2 4 +1 4 8 +1 4 +1 2 +2 +1 +1 8 64 +1 +4 +1 4 +1 +1 2 4 +2 +4 8 16 +1 2 +1 2 16 64 +1 2 +1 2 +1 2 4 +1 16 64 128 +1 16 32 +2 4 +4 8 +2 4 +4 8 +1 2 +1 2 +1 2 +1 +8 64 +1 2 +1 2 +8 32 +1 2 +2 4 +2 +1 4 +1 2 +1 32 64 +1 +1 4 +4 8 16 +8 +1 4 16 +1 2 +1 2 4 16 +1 2 8 +2 4 +2 +1 4 +1 2 +1 +1 2 +2 +1 2 +4 32 64 +2 4 +4 +1 +1 2 +4 64 +2 +1 8 +1 2 +1 +1 4 +1 2 +1 2 +1 +1 2 16 128 +4 8 +1 2 4 +2 4 +1 2 +1 2 8 +2 8 +1 2 4 +1 2 +1 2 +2 4 8 +4 64 128 +1 2 +1 +1 2 +1 2 +1 2 +4 +1 4 +1 2 4 +1 2 +2 +1 2 +2 +1 4 +2 8 +1 4 8 +1 +2 4 +4 8 +1 2 +1 4 +1 +1 2 +2 8 +1 +2 4 +1 4 +1 2 +1 2 +4 8 +1 2 +2 4 +2 4 +1 4 +1 4 +1 2 +4 8 +1 64 128 +1 +2 4 +1 2 4 +1 4 +2 8 +1 +1 2 +1 8 16 +1 +1 2 +1 2 +2 +1 2 +1 +2 +1 +1 +1 2 +4 16 +16 32 +1 4 +4 8 16 32 +1 8 +1 2 +1 +1 2 +1 8 +1 2 +1 +4 +1 +1 2 +1 2 +1 +1 2 +1 2 4 +1 2 +1 2 +1 2 +1 4 +1 4 +2 8 16 +1 2 8 16 +1 4 +1 2 +1 4 32 +1 2 +1 2 +2 +1 4 +8 +1 +1 2 +1 2 +1 +1 +2 4 8 +1 2 +1 2 +1 2 +1 8 +1 2 4 +1 2 16 +4 +1 4 +2 +1 2 +1 +1 2 +1 +1 4 32 +1 2 +1 +4 +2 4 8 32 +1 +1 +1 4 8 +1 +2 +1 2 +1 2 4 +2 +2 8 +1 +1 4 +2 +1 +2 +1 +1 +1 +4 +2 +2 4 +1 2 4 +1 +1 2 +16 +1 4 +2 +1 +1 2 +1 +1 2 +1 2 4 +1 +1 16 +2 +1 2 +1 4 8 +2 +1 2 +16 +1 +1 +2 +16 +1 4 8 +1 8 +1 4 +4 8 +1 2 4 +1 +1 8 +1 8 16 +4 +1 2 +1 2 +1 2 +1 +1 +1 2 8 +1 4 +1 2 8 +1 2 4 16 +2 +2 +1 2 +1 4 +1 2 4 8 +1 2 +1 4 +1 +1 4 16 +1 +1 +1 4 +2 8 +1 2 +1 +1 8 +2 4 +1 2 4 +1 2 8 +1 4 +1 2 +1 2 +2 4 16 +1 +1 +1 2 +2 +2 +1 +1 4 +1 2 +4 +8 16 64 128 +2 +1 4 +1 2 4 +1 2 8 +2 +2 +1 4 +2 4 +2 +1 2 4 +1 2 4 +1 2 +1 2 +1 +1 +1 +2 +1 2 +1 +1 +1 2 +1 +1 +1 2 +1 +1 2 +16 32 +1 8 +1 4 +1 2 +1 +2 8 16 +2 +1 2 +1 2 +1 +1 +1 +1 2 +1 2 +1 2 +1 +2 +1 +1 2 +1 +1 8 16 +2 8 +1 +1 4 +1 4 +1 +1 4 +4 8 16 +2 4 +1 4 64 +2 +2 16 +2 4 16 +1 2 +1 2 8 +2 4 8 +1 +1 2 +1 +4 8 16 64 +2 +1 2 4 +2 16 64 +1 +2 +1 4 16 64 +1 +1 +1 +1 +1 2 +2 +16 32 +1 4 64 +1 +1 2 +2 4 +2 +1 2 +1 +1 4 8 +2 +1 2 4 +4 8 64 +1 2 +1 16 32 +1 8 +1 4 +2 +1 +1 2 +4 +1 2 +2 4 8 +1 32 +1 2 +1 +1 2 +2 +2 4 +8 32 256 +1 +4 8 +4 8 16 +1 2 +1 16 +1 2 +1 2 +2 4 16 +1 2 +1 +1 2 4 16 32 +8 +2 8 +1 2 4 +1 +1 8 +4 8 +1 2 4 32 +1 2 +2 +1 4 +1 2 4 +4 8 +1 4 +1 4 +1 2 +1 2 +1 2 +1 2 +1 2 +2 4 +1 +1 2 +1 +2 +1 +8 16 32 +2 +1 +1 2 8 +1 2 +2 +4 +1 2 +2 +1 2 +1 +4 +1 4 +1 2 +1 2 +1 +1 +2 +1 +1 +1 2 8 +1 +1 diff --git a/test/test_data/top10_thresholds b/test/test_data/top10_thresholds new file mode 100644 index 000000000..e90c41493 --- /dev/null +++ b/test/test_data/top10_thresholds @@ -0,0 +1,500 @@ +8.1191 +6.40702 +8.53867 +7.33842 +4.62827 +7.43321 +6.07675 +9.76084 +8.25249 +6.33509 +5.94863 +8.06593 +6.77912 +8.00327 +8.06362 +6.78048 +7.16929 +6.22611 +4.86989 +8.08515 +5.00578 +5.72889 +3.3509 +6.03476 +7.0362 +7.55356 +10.4887 +7.63555 +5.8102 +9.09428 +9.71736 +10.1705 +7.12889 +6.97783 +6.79471 +5.62071 +0 +0 +4.61359 +16.0051 +7.7722 +10.6181 +7.65763 +2.58132 +10.6563 +9.43986 +9.1898 +6.66688 +6.36398 +5.66858 +8.04628 +7.27758 +8.27132 +16.886 +6.54159 +6.16754 +8.786 +0 +9.97472 +6.3738 +9.62928 +5.61858 +15.1814 +7.34612 +7.48248 +1.88119e-06 +9.75801 +7.72713 +8.63993 +9.11045 +11.1941 +12.6965 +9.2991 +6.23525 +6.18397 +7.16673 +7.3081 +9.34341 +7.56361 +8.35106 +6.32008 +9.17452 +7.44254 +3.29669 +4.81384 +13.2702 +2.66406 +8.50921 +9.88172 +7.06583 +11.3807 +2.72187 +11.1195 +8.54663 +6.3199 +9.95984 +2.11432 +9.89071 +0 +9.52282 +8.62435 +6.07234 +9.51326 +8.03819 +8.19029 +4.81047 +7.06047 +15.864 +6.73753 +6.07235 +7.07044 +9.949 +7.07569 +1.3461 +6.15922 +7.38229 +9.95886 +6.96017 +7.56859 +8.89754 +7.31538 +10.0964 +10.802 +5.87512 +4.65612 +7.31016 +8.06353 +14.7184 +5.10633 +6.54296 +7.29809 +8.5369 +8.21982 +6.63908 +7.93862 +8.48126 +9.1978 +6.47504 +9.09038 +6.65207 +7.87977 +8.91329 +8.21564 +7.90121 +7.21129 +14.6396 +7.7758 +11.0978 +6.53356 +5.61717 +8.16117 +7.21347 +9.12 +8.98664 +6.14115 +9.136 +8.8644 +4.25504 +9.91241 +8.2076 +8.59041 +6.81703 +6.4299 +9.45913 +11.1704 +7.23268 +7.89493 +8.27872 +6.45686 +10.1443 +7.79637 +5.45131 +6.81683 +3.88536 +7.02361 +7.18851 +7.42748 +8.95018 +11.2902 +8.81273 +6.72542 +8.52207 +6.91805 +7.30957 +7.42604 +8.68603 +12.3029 +9.5102 +6.58983 +9.3897 +0 +7.50808 +6.99158 +0 +6.10725 +8.43093 +6.20851 +0 +6.88064 +5.59466 +8.55461 +3.07427 +6.58394 +8.60744 +3.13404 +7.16225 +10.3509 +9.13109 +8.54642 +7.0724 +7.90035 +6.53838 +7.0606 +6.20906 +8.03361 +8.33321 +6.2458 +5.85291 +4.86942 +6.39161 +8.01369 +9.68012 +8.3957 +9.56297 +5.30648 +7.9272 +8.41178 +9.78377 +5.95837 +7.02013 +7.29275 +7.72767 +7.13853 +7.19798 +6.81257 +9.9947 +6.52238 +1.84112 +8.02057 +9.75506 +4.24206 +8.10603 +7.62294 +6.64202 +8.31228 +6.39894 +7.64641 +7.18985 +10.9307 +0 +9.35913 +7.14301 +2.62868 +7.82687 +7.78378 +4.47242 +5.99799 +8.01446 +7.29016 +6.96249 +7.31208 +7.26219 +1.02171 +7.90062 +8.52454 +7.79633 +0 +6.39334 +3.65948 +6.13462 +8.53574 +2.42495 +8.34661 +6.5509 +9.19673 +8.72243 +0.975942 +2.97355 +17.1217 +8.16323 +5.56541 +6.57861 +5.56972 +9.25555 +8.9732 +9.58803 +7.21451 +7.85729 +7.96213 +9.15531 +8.90842 +7.63755 +4.98901 +6.89196 +8.80033 +3.0826 +8.63325 +8.71573 +2.71681 +9.0253 +7.51434 +8.20914 +7.16011 +5.08073 +7.0777 +12.4818 +5.96502 +10.091 +0 +8.57939 +7.44738 +8.60938 +6.30986 +15.5856 +7.38011 +0 +10.2552 +8.81642 +2.74022 +7.3057 +8.29425 +5.7883 +6.793 +12.642 +7.0631 +7.46704 +13.6006 +8.07774 +6.36285 +0 +7.83312 +6.49256 +10.4431 +11.8368 +7.7008 +7.39802 +5.43342 +8.84415 +9.69069 +3.62001 +7.9487 +6.91921 +9.79507 +6.22026 +6.91993 +3.88733 +2.97617 +8.70682 +5.30657 +6.33269 +10.1033 +6.96088 +4.79799 +5.73728 +7.03829 +8.05473 +7.29462 +5.77349 +2.82017 +6.793 +14.9411 +9.10034 +6.73388 +6.46703 +7.13749 +9.87858 +7.78961 +6.41109 +6.69017 +0 +8.75548 +7.01149 +7.10874 +8.80585 +7.39948 +6.35082 +5.5849 +7.18396 +0.990798 +5.72014 +8.44618 +8.44117 +4.89387 +9.79468 +7.04739 +8.89953 +8.64338 +12.7184 +10.9665 +6.86001 +8.71301 +6.50483 +10.4117 +8.92076 +7.14701 +9.26171 +7.06198 +7.50738 +7.10056 +9.75535 +7.04357 +7.95306 +10.1646 +0 +5.94793 +7.94278 +6.07347 +6.6129 +5.74331 +6.27284 +7.83404 +5.58244 +8.4209 +8.93761 +0 +5.68759 +8.45663 +7.6864 +7.32927 +6.34879 +7.65884 +6.00451 +8.86703 +9.56031 +8.25177 +17.2111 +7.36363 +7.45831 +7.88611 +0 +7.44071 +5.64288 +5.34061 +6.47387 +8.15098 +7.30174 +5.3787 +7.22029 +6.87507 +6.78256 +13.5872 +9.50327 +7.79518 +12.4835 +2.97089 +7.6197 +6.54403 +8.7382 +8.07254 +8.09486 +6.58483 +10.2843 +6.67195 +9.55084 +6.63688 +0 +12.5787 +9.6835 +9.46256 +4.68514 +9.70017 +7.34846 +6.35295 +7.57669 +16.6089 +9.65398 +7.08425 +8.39853 +0 +7.6837 +8.62668 +10.1491 +0 +5.10292 +8.00893 +4.33105 +0 +7.55501 +8.08855 +5.86112 +11.9459 +6.54911 +9.99107 +8.97999 +7.07643 +7.32234 +5.08513 +6.66545 +11.1706 +8.37998 +7.17167 +6.5358 +5.87431 +8.89896 +6.3356 +7.97252 +6.52537 +9.99582 +6.36973 +0 diff --git a/test/test_forward_index_builder.cpp b/test/test_forward_index_builder.cpp index 212f54e09..cbd5ec1b1 100644 --- a/test/test_forward_index_builder.cpp +++ b/test/test_forward_index_builder.cpp @@ -6,9 +6,11 @@ #include #include #include +#include #include #include +#include "binary_collection.hpp" #include "filesystem.hpp" #include "forward_index_builder.hpp" #include "parsing/html.hpp" @@ -57,7 +59,7 @@ TEST_CASE("Write header", "[parsing][forward_index]") } } -[[nodiscard]] std::vector load_lines(std::istream &is) +[[nodiscard]] std::vector load_lines(std::istream& is) { std::string line; std::vector vec; @@ -67,22 +69,22 @@ TEST_CASE("Write header", "[parsing][forward_index]") return vec; } -[[nodiscard]] std::vector load_lines(std::string const &filename) +[[nodiscard]] std::vector load_lines(std::string const& filename) { std::ifstream is(filename); return load_lines(is); } template -void write_lines(std::ostream &os, gsl::span &&elements) +void write_lines(std::ostream& os, gsl::span&& elements) { - for (auto const &element : elements) { + for (auto const& element : elements) { os << element << '\n'; } } template -void write_lines(std::string const &filename, gsl::span &&elements) +void write_lines(std::string const& filename, gsl::span&& elements) { std::ofstream os(filename); write_lines(os, std::forward>(elements)); @@ -90,7 +92,7 @@ void write_lines(std::string const &filename, gsl::span &&elements) TEST_CASE("Build forward index batch", "[parsing][forward_index]") { - auto identity = [](std::string const &term) -> std::string { return term; }; + auto identity = [](std::string const& term) -> std::string { return term; }; GIVEN("a few test records") { @@ -148,10 +150,10 @@ TEST_CASE("Build forward index batch", "[parsing][forward_index]") } } -void write_batch(std::string const &basename, - std::vector const &documents, - std::vector const &terms, - std::vector> const &collection) +void write_batch(std::string const& basename, + std::vector const& documents, + std::vector const& terms, + std::vector> const& collection) { std::string document_file = basename + ".documents"; std::string term_file = basename + ".terms"; @@ -159,7 +161,7 @@ void write_batch(std::string const &basename, write_lines(term_file, gsl::make_span(terms)); std::ofstream os(basename); Forward_Index_Builder::write_header(os, collection.size()); - for (auto const &seq : collection) { + for (auto const& seq : collection) { Forward_Index_Builder::write_document(os, seq.begin(), seq.end()); } } @@ -270,7 +272,7 @@ TEST_CASE("Merge forward index batches", "[parsing][forward_index]") TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") { std::vector vec; - auto map_word = [&](std::string &&word) { vec.push_back(word); }; + auto map_word = [&](std::string&& word) { vec.push_back(word); }; SECTION("empty") { parse_html_content( @@ -300,7 +302,7 @@ TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") } } -[[nodiscard]] auto load_term_map(std::string const &basename) -> std::vector +[[nodiscard]] auto load_term_map(std::string const& basename) -> std::vector { std::vector map; std::ifstream is(basename + ".terms"); @@ -314,7 +316,7 @@ TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") TEST_CASE("Build forward index", "[parsing][forward_index][integration]") { tbb::task_scheduler_init init; - auto next_record = [](std::istream &in) -> std::optional { + auto next_record = [](std::istream& in) -> std::optional { Plaintext_Record record; if (in >> record) { return Document_Record(record.trecid(), record.content(), record.url()); @@ -340,7 +342,7 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") is, output, next_record, - [](std::string &&term) -> std::string { return std::forward(term); }, + [](std::string&& term) -> std::string { return std::forward(term); }, parse_plaintext_content, batch_size, thread_count); @@ -348,8 +350,8 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") THEN("The collection mapped to terms matches input") { auto term_map = load_term_map(output); - auto term_lexicon_buffer = Payload_Vector_Buffer::from_file(output + ".termlex"); - auto term_lexicon = Payload_Vector(term_lexicon_buffer); + mio::mmap_source m((output + ".termlex").c_str()); + auto term_lexicon = Payload_Vector<>::from(m); REQUIRE(std::vector(term_lexicon.begin(), term_lexicon.end()) == term_map); binary_collection coll((output).c_str()); @@ -372,7 +374,7 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") REQUIRE(produced_body == original_body); ++seq_iter; } - auto batch_files = ls(dir, [](auto const &filename) { + auto batch_files = ls(dir, [](auto const& filename) { return filename.find("batch") != std::string::npos; }); REQUIRE(batch_files.empty()); @@ -380,8 +382,8 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") AND_THEN("Document lexicon contains the same titles as text file") { auto documents = io::read_string_vector(output + ".documents"); - auto doc_lexicon_buffer = Payload_Vector_Buffer::from_file(output + ".doclex"); - auto doc_lexicon = Payload_Vector(doc_lexicon_buffer); + mio::mmap_source m((output + ".doclex").c_str()); + auto doc_lexicon = Payload_Vector<>::from(m); REQUIRE(std::vector(doc_lexicon.begin(), doc_lexicon.end()) == documents); } diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index 7cc546903..14eecaabf 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -21,7 +21,7 @@ struct IndexData { static std::unordered_map> data; - IndexData(std::string const &scorer_name, std::unordered_set const &dropped_term_ids) + IndexData(std::string const& scorer_name, std::unordered_set const& dropped_term_ids) : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), wdata(document_sizes.begin()->begin(), @@ -34,7 +34,7 @@ struct IndexData { { tbb::task_scheduler_init init; typename Index::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( @@ -44,7 +44,7 @@ struct IndexData { term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -52,8 +52,8 @@ struct IndexData { std::string t; } - [[nodiscard]] static auto get(std::string const &s_name, - std::unordered_set const &dropped_term_ids) + [[nodiscard]] static auto get(std::string const& s_name, + std::unordered_set const& dropped_term_ids) { if (IndexData::data.find(s_name) == IndexData::data.end()) { IndexData::data[s_name] = std::make_unique>(s_name, dropped_term_ids); @@ -78,7 +78,7 @@ class ranked_or_taat_query_acc : public ranked_or_taat_query { using ranked_or_taat_query::ranked_or_taat_query; template - void operator()(CursorRange &&cursors, uint64_t max_docid) + void operator()(CursorRange&& cursors, uint64_t max_docid) { Acc accumulator(max_docid); ranked_or_taat_query::operator()(cursors, max_docid, accumulator); @@ -91,7 +91,7 @@ class range_query_128 : public range_query { using range_query::range_query; template - void operator()(CursorRange &&cursors, uint64_t max_docid) + void operator()(CursorRange&& cursors, uint64_t max_docid) { range_query::operator()(cursors, max_docid, 128); } @@ -112,7 +112,7 @@ TEMPLATE_TEST_CASE("Ranked query test", range_query_128, range_query_128) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); topk_queue topk_1(10); @@ -121,7 +121,7 @@ TEMPLATE_TEST_CASE("Ranked query test", ranked_or_query or_q(topk_2); auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { or_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); @@ -142,7 +142,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", "[query][ranked][integration]", block_max_ranked_and_query) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); topk_queue topk_1(10); @@ -152,7 +152,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { and_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); @@ -171,7 +171,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", TEST_CASE("Top k") { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); topk_queue topk_1(10); @@ -181,7 +181,7 @@ TEST_CASE("Top k") auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { or_10(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); or_1(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); topk_1.finalize(); diff --git a/test/test_wand_data.cpp b/test/test_wand_data.cpp index 3601b13a6..ae559824f 100644 --- a/test/test_wand_data.cpp +++ b/test/test_wand_data.cpp @@ -14,6 +14,7 @@ #include "query/queries.hpp" #include "wand_data.hpp" #include "wand_data_range.hpp" +#include "wand_data_raw.hpp" #include "scorer/scorer.hpp" @@ -42,12 +43,12 @@ TEST_CASE("wand_data_range") SECTION("Precomputed block-max scores") { size_t term_id = 0; - for (auto const &seq : collection) { + for (auto const& seq : collection) { if (seq.docs.size() >= 1024) { auto max = wdata_range.max_term_weight(term_id); auto w = wdata_range.getenum(term_id); auto s = scorer->term_scorer(term_id); - for (auto &&[docid, freq] : ranges::views::zip(seq.docs, seq.freqs)) { + for (auto&& [docid, freq] : ranges::views::zip(seq.docs, seq.freqs)) { float score = s(docid, freq); w.next_geq(docid); CHECKED_ELSE(w.score() >= score) @@ -65,7 +66,7 @@ TEST_CASE("wand_data_range") index_type index; global_parameters params; index_type::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); @@ -75,15 +76,15 @@ TEST_CASE("wand_data_range") SECTION("Compute at run time") { size_t term_id = 0; - for (auto const &seq : collection) { + for (auto const& seq : collection) { auto list = index[term_id]; if (seq.docs.size() < 1024) { auto max = wdata_range.max_term_weight(term_id); - auto &w = wdata_range.get_block_wand(); + auto& w = wdata_range.get_block_wand(); auto s = scorer->term_scorer(term_id); const mapper::mappable_vector bm = w.compute_block_max_scores(list, s); WandTypeRange::enumerator we(0, bm); - for (auto &&[pos, docid, freq] : + for (auto&& [pos, docid, freq] : ranges::views::zip(ranges::views::iota(0), seq.docs, seq.freqs)) { float score = s(docid, freq); we.next_geq(docid); @@ -103,7 +104,7 @@ TEST_CASE("wand_data_range") { size_t i = 0; std::vector enums; - for (auto const &seq : collection) { + for (auto const& seq : collection) { if (seq.docs.size() >= 1024) { enums.push_back(wdata_range.getenum(i)); } @@ -116,7 +117,7 @@ TEST_CASE("wand_data_range") for (int i = 0; i < live_blocks.size(); ++i) { if (live_blocks[i] == 0) { - for (auto &&e : enums) { + for (auto&& e : enums) { e.next_geq(i * 64); REQUIRE(e.score() == 0); } diff --git a/test/v1/CMakeLists.txt b/test/v1/CMakeLists.txt new file mode 100644 index 000000000..0ae89c35d --- /dev/null +++ b/test/v1/CMakeLists.txt @@ -0,0 +1,28 @@ +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test) + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../external/Catch2/contrib) +include(CTest) +include(Catch) + +file(GLOB TEST_SOURCES test_*.cpp) +foreach(TEST_SRC ${TEST_SOURCES}) + get_filename_component (TEST_SRC_NAME ${TEST_SRC} NAME_WE) + add_executable(${TEST_SRC_NAME} ${TEST_SRC}) + target_link_libraries(${TEST_SRC_NAME} + pisa + Catch2 + rapidcheck + ) + catch_discover_tests(${TEST_SRC_NAME} TEST_PREFIX "${TEST_SRC_NAME}:") + + # enable code coverage + add_coverage(${TEST_SRC_NAME}) +endforeach(TEST_SRC) + +target_link_libraries(test_v1 + pisa + Catch2 + optional +) + +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../test_data DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp new file mode 100644 index 000000000..2e22e15a5 --- /dev/null +++ b/test/v1/index_fixture.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include + +#include "../temporary_directory.hpp" +#include "pisa_config.hpp" +#include "query/queries.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/intersection.hpp" +#include "v1/io.hpp" +#include "v1/query.hpp" +#include "v1/score_index.hpp" + +namespace v1 = pisa::v1; + +[[nodiscard]] inline auto test_queries() -> std::vector +{ + std::vector queries; + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + auto push_query = [&](std::string const& query_line) { + auto q = pisa::parse_query_ids(query_line); + v1::Query query(q.terms); + query.k(1000); + queries.push_back(std::move(query)); + }; + pisa::io::for_each_line(qfile, push_query); + return queries; +} + +[[nodiscard]] inline auto test_intersection_selections() +{ + auto intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + auto unigrams = pisa::v1::filter_unigrams(intersections); + auto bigrams = pisa::v1::filter_bigrams(intersections); + return std::make_pair(unigrams, bigrams); +} + +template +struct IndexFixture { + using DocumentWriter = typename v1::CursorTraits::Writer; + using FrequencyWriter = typename v1::CursorTraits::Writer; + using ScoreWriter = typename v1::CursorTraits::Writer; + + using DocumentReader = typename v1::CursorTraits::Reader; + using FrequencyReader = typename v1::CursorTraits::Reader; + using ScoreReader = typename v1::CursorTraits::Reader; + + explicit IndexFixture(bool verify = true, + bool score = true, + bool bm_score = true, + bool build_bigrams = true) + : m_tmpdir(std::make_unique()) + { + auto index_basename = (tmpdir().path() / "inv").string(); + v1::compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + PISA_SOURCE_DIR "/test/test_data/test_collection.fwd", + index_basename, + 1, + v1::make_writer(), + v1::make_writer()); + if (verify) { + auto errors = v1::verify_compressed_index( + PISA_SOURCE_DIR "/test/test_data/test_collection", index_basename); + for (auto&& error : errors) { + std::cerr << error << '\n'; + } + REQUIRE(errors.empty()); + } + auto yml = fmt::format("{}.yml", index_basename); + auto meta = v1::IndexMetadata::from_file(yml); + if (score) { + meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); + } + if (bm_score) { + meta = v1::bm_score_index(meta, pisa::v1::FixedBlock{5}, tl::nullopt, 1); + } + if (build_bigrams) { + v1::build_pair_index( + meta, collect_unique_bigrams(test_queries(), []() {}), tl::nullopt, 4); + } + } + + void rebuild_bm_scores(pisa::v1::BlockType block_type) + { + v1::bm_score_index(meta(), block_type, tl::nullopt, 1); + } + + [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } + [[nodiscard]] auto document_reader() const { return m_document_reader; } + [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } + [[nodiscard]] auto score_reader() const { return m_score_reader; } + [[nodiscard]] auto meta() const + { + auto index_basename = (tmpdir().path() / "inv").string(); + return v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + } + + private: + std::unique_ptr m_tmpdir; + DocumentReader m_document_reader{}; + FrequencyReader m_frequency_reader{}; + ScoreReader m_score_reader{}; +}; diff --git a/test/v1/test_union_lookup_join.cpp b/test/v1/test_union_lookup_join.cpp new file mode 100644 index 000000000..9cd5c97b7 --- /dev/null +++ b/test/v1/test_union_lookup_join.cpp @@ -0,0 +1,143 @@ +#include +#include + +template +std::ostream& operator<<(std::ostream& os, std::pair const& p) +{ + os << fmt::format("({}, {})", p.first, p.second); + return os; +} + +#define CATCH_CONFIG_MAIN +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "index_fixture.hpp" +#include "topk_queue.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index_metadata.hpp" +#include "v1/maxscore.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/union_lookup_join.hpp" + +using pisa::v1::collect_payloads; +using pisa::v1::collect_with_payload; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::join_union_lookup; +using pisa::v1::maxscore_partition; +using pisa::v1::RawCursor; +using pisa::v1::accumulators::Add; + +TEST_CASE("Maxscore partition", "[maxscore][v1][unit]") +{ + rc::check("paritition vector of max scores according to maxscore", []() { + auto max_scores = *rc::gen::nonEmpty>(); + ranges::sort(max_scores); + float total_sum = ranges::accumulate(max_scores, 0.0F); + auto threshold = + *rc::gen::suchThat([=](float x) { return x >= 0.0 && x < total_sum; }); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(max_scores), threshold, [](auto&& x) { return x; }); + auto non_essential_sum = ranges::accumulate(non_essential, 0.0F); + auto first_essential = essential.empty() ? 0.0 : essential[0]; + REQUIRE(non_essential_sum <= Approx(threshold)); + REQUIRE(non_essential_sum + first_essential >= threshold); + }); +} + +struct InspectMock { + std::size_t documents = 0; + std::size_t postings = 0; + std::size_t lookups = 0; + + void document() { documents += 1; } + void posting() { postings += 1; } + void lookup() { lookups += 1; } +}; + +TEST_CASE("UnionLookupJoin v Union", "[union-lookup][v1][unit]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture; + + auto result_order = [](auto&& lhs, auto&& rhs) { + if (lhs.second == rhs.second) { + return lhs.first > rhs.first; + } + return lhs.second > rhs.second; + }; + + auto add = [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }; + + index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader()))([&](auto&& index) { + auto queries = test_queries(); + for (auto&& [idx, q] : ranges::views::enumerate(queries)) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx); + + auto term_ids = gsl::make_span(q.get_term_ids()); + + auto union_results = collect_with_payload( + v1::union_merge(index.scored_cursors(term_ids, make_bm25(index)), 0.0F, add)); + std::sort(union_results.begin(), union_results.end(), result_order); + std::size_t num_results = std::min(union_results.size(), 10UL); + if (num_results == 0) { + continue; + } + float threshold = std::next(union_results.begin(), num_results - 1)->second; + + auto cursors = index.max_scored_cursors(term_ids, make_bm25(index)); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(cursors), threshold); + CAPTURE(non_essential.size()); + CAPTURE(essential.size()); + + InspectMock inspect; + auto ul_results = collect_with_payload(join_union_lookup( + essential, + non_essential, + 0.0F, + Add{}, + [=](auto score) { return score >= threshold; }, + &inspect)); + std::sort(ul_results.begin(), ul_results.end(), result_order); + REQUIRE(ul_results.size() >= num_results); + union_results.erase(std::next(union_results.begin(), num_results), union_results.end()); + ul_results.erase(std::next(ul_results.begin(), num_results), ul_results.end()); + for (auto pos = 0; pos < num_results; pos++) { + CAPTURE(pos); + REQUIRE(union_results[pos].first == ul_results[pos].first); + REQUIRE(union_results[pos].second == Approx(ul_results[pos].second)); + } + + auto essential_counts = [&] { + auto cursors = index.max_scored_cursors(term_ids, make_bm25(index)); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(cursors), threshold); + return collect_payloads(v1::union_merge( + essential, + 0, + [](auto count, [[maybe_unused]] auto&& cursor, [[maybe_unused]] auto idx) { + return count + 1; + })); + }(); + REQUIRE(essential_counts.size() == inspect.documents); + REQUIRE(ranges::accumulate(essential_counts, 0) == inspect.postings); + } + }); +} diff --git a/test/v1/test_v1.cpp b/test/v1/test_v1.cpp new file mode 100644 index 000000000..48e0cc9f9 --- /dev/null +++ b/test/v1/test_v1.cpp @@ -0,0 +1,326 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include +#include +#include + +#include "binary_freq_collection.hpp" +#include "io.hpp" +#include "pisa_config.hpp" +#include "v1/algorithm.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/io.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" +#include "v1/unaligned_span.hpp" + +using pisa::v1::Array; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::IndexRunner; +using pisa::v1::load_bytes; +using pisa::v1::next; +using pisa::v1::parse_type; +using pisa::v1::PostingBuilder; +using pisa::v1::PostingData; +using pisa::v1::PostingFormatHeader; +using pisa::v1::Primitive; +using pisa::v1::RawReader; +using pisa::v1::RawWriter; +using pisa::v1::read_sizes; +using pisa::v1::TermId; +using pisa::v1::Tuple; +using pisa::v1::UnalignedSpan; +using pisa::v1::Writer; + +template +std::ostream& operator<<(std::ostream& os, tl::optional const& val) +{ + if (val.has_value()) { + os << val.value(); + } else { + os << "nullopt"; + } + return os; +} + +TEST_CASE("partition_by_index", "[v1][unit]") +{ + auto expected = [](auto input_vec, auto right_indices) { + std::vector essential; + std::sort(right_indices.begin(), right_indices.end(), std::greater{}); + for (auto idx : right_indices) { + essential.push_back(input_vec[idx]); + input_vec.erase(std::next(input_vec.begin(), idx)); + } + std::sort(essential.begin(), essential.end()); + std::sort(input_vec.begin(), input_vec.end()); + input_vec.insert(input_vec.end(), essential.begin(), essential.end()); + return input_vec; + }; + + rc::check([&](std::vector input) { + CAPTURE(input); + std::vector all_indices(input.size()); + std::iota(all_indices.begin(), all_indices.end(), 0); + auto right_indices = + *rc::gen::unique>(rc::gen::elementOf(all_indices)); + CAPTURE(right_indices); + auto expected_output = expected(input, right_indices); + auto essential_count = right_indices.size(); + auto non_essential_count = input.size() - essential_count; + pisa::v1::partition_by_index(gsl::make_span(input), gsl::make_span(right_indices)); + std::sort(input.begin(), std::next(input.begin(), non_essential_count)); + std::sort(std::next(input.begin(), non_essential_count), input.end()); + REQUIRE(input == expected_output); + }); +} + +TEST_CASE("RawReader", "[v1][unit]") +{ + std::vector const mem{5, 0, 1, 2, 3, 4}; + RawReader reader; + auto cursor = reader.read(gsl::as_bytes(gsl::make_span(mem))); + REQUIRE(cursor.value() == mem[1]); + REQUIRE(next(cursor) == tl::make_optional(mem[2])); + REQUIRE(next(cursor) == tl::make_optional(mem[3])); + REQUIRE(next(cursor) == tl::make_optional(mem[4])); + REQUIRE(next(cursor) == tl::make_optional(mem[5])); + REQUIRE(next(cursor) == tl::nullopt); +} + +TEST_CASE("Test read header", "[v1][unit]") +{ + { + std::vector bytes{ + std::byte{0b00000000}, + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 0); + REQUIRE(header.version.minor == 1); + REQUIRE(header.version.patch == 0); + REQUIRE(std::get(header.type) == Primitive::Int); + REQUIRE(header.encoding == 0); + } + { + std::vector bytes{ + std::byte{0b00000001}, + std::byte{0b00000001}, + std::byte{0b00000011}, + std::byte{0b00000001}, + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 1); + REQUIRE(header.version.minor == 1); + REQUIRE(header.version.patch == 3); + REQUIRE(std::get(header.type) == Primitive::Float); + REQUIRE(header.encoding == 1); + } + { + std::vector bytes{ + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000011}, + std::byte{0b00000010}, + std::byte{0b00000011}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 1); + REQUIRE(header.version.minor == 0); + REQUIRE(header.version.patch == 3); + REQUIRE(std::get(header.type).type == Primitive::Int); + REQUIRE(header.encoding == 3); + } +} + +TEST_CASE("Value type", "[v1][unit]") +{ + REQUIRE(std::get(parse_type(std::byte{0b00000000})) == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b00000001})) == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00000010})).type == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b00000110})).type == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00101011})).type == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b01000111})).type == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00101011})).size == 5U); + REQUIRE(std::get(parse_type(std::byte{0b01000111})).size == 8U); +} + +TEST_CASE("Build raw document-frequency index", "[v1][unit]") +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + GIVEN("A test binary collection") + { + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + WHEN("Built posting files for documents and frequencies") + { + std::vector docbuf; + std::vector freqbuf; + + PostingBuilder document_builder(Writer(RawWriter{})); + PostingBuilder frequency_builder(Writer(RawWriter{})); + { + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; + + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); + + for (auto sequence : collection) { + document_builder.write_segment( + docstream, sequence.docs.begin(), sequence.docs.end()); + frequency_builder.write_segment( + freqstream, sequence.freqs.begin(), sequence.freqs.end()); + } + } + + auto document_offsets = document_builder.offsets(); + auto frequency_offsets = frequency_builder.offsets(); + + auto document_sizes = read_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection"); + + THEN("Bytes match with those of the collection") + { + auto document_bytes = + load_bytes(PISA_SOURCE_DIR "/test/test_data/test_collection.docs"); + auto frequency_bytes = + load_bytes(PISA_SOURCE_DIR "/test/test_data/test_collection.freqs"); + + // NOTE: the first 8 bytes of the document collection are different than those + // of the built document file. Also, the original frequency collection starts + // at byte 0 (no 8-byte "size vector" at the beginning), and thus is shorter. + CHECK(docbuf.size() == document_offsets.back() + 8); + CHECK(freqbuf.size() == frequency_offsets.back() + 8); + CHECK(docbuf.size() == document_bytes.size()); + CHECK(freqbuf.size() == frequency_bytes.size() + 8); + CHECK(gsl::make_span(docbuf.data(), docbuf.size()).subspan(8) + == gsl::make_span(document_bytes.data(), document_bytes.size()).subspan(8)); + CHECK(gsl::make_span(freqbuf.data(), freqbuf.size()).subspan(8) + == gsl::make_span(frequency_bytes.data(), frequency_bytes.size())); + } + + THEN("Index runner is correctly constructed") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, + {}, + document_sizes, + tl::nullopt, + {}, + {}, + {}, + std::move(source), + std::make_tuple(RawReader{}, + RawReader{}), // Repeat to test + // that it only + // executes once + std::make_tuple(RawReader{})); + int counter = 0; + runner([&](auto index) { + counter += 1; + TermId term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE( + std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.cursor(term_id))); + REQUIRE( + std::vector(sequence.freqs.begin(), sequence.freqs.end()) + == collect(index.cursor(term_id), + [](auto&& cursor) { return cursor.payload(); })); + term_id += 1; + } + }); + REQUIRE(counter == 1); + } + + THEN("Index runner fails when wrong type") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, + {}, + document_sizes, + tl::nullopt, + {}, + {}, + {}, + std::move(source), + std::make_tuple(RawReader{}), // Correct encoding but not + // type! + std::make_tuple()); + REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); + } + } + } +} + +TEST_CASE("UnalignedSpan", "[v1][unit]") +{ + std::vector bytes{std::byte{0b00000001}, + std::byte{0b00000010}, + std::byte{0b00000011}, + std::byte{0b00000100}, + std::byte{0b00000101}, + std::byte{0b00000110}, + std::byte{0b00000111}}; + SECTION("Bytes one-to-one") + { + auto span = UnalignedSpan(gsl::make_span(bytes)); + REQUIRE(std::vector(span.begin(), span.end()) == bytes); + } + SECTION("Bytes shifted by offset") + { + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(2)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector(bytes.begin() + 2, bytes.end())); + } + SECTION("u16") + { + REQUIRE_THROWS_AS(UnalignedSpan(gsl::make_span(bytes)), std::logic_error); + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(1)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector{ + 0b0000001100000010, 0b0000010100000100, 0b0000011100000110}); + } + SECTION("u32") + { + REQUIRE_THROWS_AS(UnalignedSpan(gsl::make_span(bytes)), std::logic_error); + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(1, 4)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector{0b00000101000001000000001100000010}); + } +} diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp new file mode 100644 index 000000000..712c0276d --- /dev/null +++ b/test/v1/test_v1_bigram_index.cpp @@ -0,0 +1,174 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "accumulator/lazy_accumulator.hpp" +#include "cursor/block_max_scored_cursor.hpp" +#include "cursor/max_scored_cursor.hpp" +#include "cursor/scored_cursor.hpp" +#include "index_fixture.hpp" +#include "index_types.hpp" +#include "io.hpp" +#include "pisa_config.hpp" +#include "query/queries.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/cursor_union.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/maxscore.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/query.hpp" +#include "v1/score_index.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/types.hpp" +#include "v1/union_lookup.hpp" + +namespace v1 = pisa::v1; + +static constexpr auto RELATIVE_ERROR = 0.1F; + +TEMPLATE_TEST_CASE("Bigram v intersection", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::PayloadBlockedCursor<::pisa::simdbp_block>, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx++); + + auto run = v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + std::vector results; + run([&](auto&& index) { + for (auto left = 0; left < q.get_term_ids().size(); left += 1) { + for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { + CAPTURE(q.get_term_ids()[left]); + CAPTURE(q.get_term_ids()[right]); + auto left_cursor = index.cursor(q.get_term_ids()[left]); + auto right_cursor = index.cursor(q.get_term_ids()[right]); + auto intersection = v1::intersect({left_cursor, right_cursor}, + std::array{0, 0}, + [](auto& acc, auto&& cursor, auto idx) { + gsl::at(acc, idx) = cursor.payload(); + return acc; + }); + if (not intersection.empty()) { + auto bigram_cursor = + index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]) + .value(); + std::vector bigram_documents; + std::vector bigram_frequencies_0; + std::vector bigram_frequencies_1; + v1::for_each(bigram_cursor, [&](auto&& cursor) { + bigram_documents.push_back(*cursor); + auto payload = cursor.payload(); + bigram_frequencies_0.push_back(std::get<0>(payload)); + bigram_frequencies_1.push_back(std::get<1>(payload)); + }); + std::vector intersection_documents; + std::vector intersection_frequencies_0; + std::vector intersection_frequencies_1; + v1::for_each(intersection, [&](auto&& cursor) { + intersection_documents.push_back(*cursor); + auto payload = cursor.payload(); + intersection_frequencies_0.push_back(std::get<0>(payload)); + intersection_frequencies_1.push_back(std::get<1>(payload)); + }); + CHECK(bigram_documents == intersection_documents); + CHECK(bigram_frequencies_0 == intersection_frequencies_0); + REQUIRE(bigram_frequencies_1 == intersection_frequencies_1); + } + } + } + }); + } +} + +TEMPLATE_TEST_CASE("Scored pair index v. intersection", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::PayloadBlockedCursor<::pisa::simdbp_block>, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx++); + + auto run = v1::scored_index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); + std::vector results; + run([&](auto&& index) { + for (auto left = 0; left < q.get_term_ids().size(); left += 1) { + for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { + CAPTURE(q.get_term_ids()[left]); + CAPTURE(q.get_term_ids()[right]); + auto left_cursor = index.cursor(q.get_term_ids()[left]); + auto right_cursor = index.cursor(q.get_term_ids()[right]); + auto intersection = v1::intersect({left_cursor, right_cursor}, + std::array{0.0F, 0.0F}, + [](auto& acc, auto&& cursor, auto idx) { + gsl::at(acc, idx) = cursor.payload(); + return acc; + }); + if (not intersection.empty()) { + auto bigram_cursor = + index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]) + .value(); + std::vector bigram_documents; + std::vector bigram_scores_0; + std::vector bigram_scores_1; + v1::for_each(bigram_cursor, [&](auto&& cursor) { + bigram_documents.push_back(*cursor); + auto payload = cursor.payload(); + bigram_scores_0.push_back(std::get<0>(payload)); + bigram_scores_1.push_back(std::get<1>(payload)); + }); + std::vector intersection_documents; + std::vector intersection_scores_0; + std::vector intersection_scores_1; + v1::for_each(intersection, [&](auto&& cursor) { + intersection_documents.push_back(*cursor); + auto payload = cursor.payload(); + intersection_scores_0.push_back(std::get<0>(payload)); + intersection_scores_1.push_back(std::get<1>(payload)); + }); + CHECK(bigram_documents == intersection_documents); + CHECK(bigram_scores_0 == intersection_scores_0); + REQUIRE(bigram_scores_1 == intersection_scores_1); + } + } + } + }); + } +} diff --git a/test/v1/test_v1_blocked_cursor.cpp b/test/v1/test_v1_blocked_cursor.cpp new file mode 100644 index 000000000..a44af3e36 --- /dev/null +++ b/test/v1/test_v1_blocked_cursor.cpp @@ -0,0 +1,196 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "pisa_config.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/posting_builder.hpp" +#include "v1/types.hpp" + +using pisa::v1::collect; +using pisa::v1::DocId; +using pisa::v1::DocumentBlockedReader; +using pisa::v1::DocumentBlockedWriter; +using pisa::v1::Frequency; +using pisa::v1::IndexRunner; +using pisa::v1::PayloadBlockedReader; +using pisa::v1::PayloadBlockedWriter; +using pisa::v1::PostingBuilder; +using pisa::v1::PostingData; +using pisa::v1::RawReader; +using pisa::v1::read_sizes; +using pisa::v1::TermId; + +TEST_CASE("Build single-block blocked document file", "[v1][unit]") +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + + std::vector docids{3, 4, 5, 6, 7, 8, 9, 10, 51, 115}; + std::vector docbuf; + auto document_offsets = [&]() { + PostingBuilder document_builder(DocumentBlockedWriter{}); + vector_stream_type docstream{sink_type{docbuf}}; + document_builder.write_header(docstream); + document_builder.write_segment(docstream, docids.begin(), docids.end()); + return document_builder.offsets(); + }(); + + auto documents = gsl::span(docbuf).subspan(8); + CHECK(docbuf.size() == document_offsets.back() + 8); + DocumentBlockedReader document_reader; + auto term = 0; + auto actual = collect(document_reader.read(documents.subspan( + document_offsets[term], document_offsets[term + 1] - document_offsets[term]))); + CHECK(actual.size() == docids.size()); + REQUIRE(actual == docids); +} + +TEST_CASE("Build blocked document-frequency index", "[v1][unit]") +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + GIVEN("A test binary collection") + { + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + WHEN("Built posting files for documents and frequencies") + { + std::vector docbuf; + std::vector freqbuf; + + PostingBuilder document_builder(DocumentBlockedWriter{}); + PostingBuilder frequency_builder(PayloadBlockedWriter{}); + { + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; + + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); + + for (auto sequence : collection) { + document_builder.write_segment( + docstream, sequence.docs.begin(), sequence.docs.end()); + frequency_builder.write_segment( + freqstream, sequence.freqs.begin(), sequence.freqs.end()); + } + } + + auto document_offsets = document_builder.offsets(); + auto frequency_offsets = frequency_builder.offsets(); + + auto document_sizes = read_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto documents = gsl::span(docbuf).subspan(8); + auto frequencies = gsl::span(freqbuf).subspan(8); + + THEN("The values read back are euqual to the binary collection's") + { + CHECK(docbuf.size() == document_offsets.back() + 8); + DocumentBlockedReader document_reader; + PayloadBlockedReader frequency_reader; + auto term = 0; + std::for_each(collection.begin(), collection.end(), [&](auto&& seq) { + std::vector expected_documents(seq.docs.begin(), seq.docs.end()); + auto actual_documents = collect(document_reader.read( + documents.subspan(document_offsets[term], + document_offsets[term + 1] - document_offsets[term]))); + CHECK(actual_documents.size() == expected_documents.size()); + REQUIRE(actual_documents == expected_documents); + + std::vector expected_frequencies(seq.freqs.begin(), seq.freqs.end()); + auto actual_frequencies = collect(frequency_reader.read(frequencies.subspan( + frequency_offsets[term], + frequency_offsets[term + 1] - frequency_offsets[term]))); + CHECK(actual_frequencies.size() == expected_frequencies.size()); + REQUIRE(actual_frequencies == expected_frequencies); + term += 1; + }); + } + + THEN("Index runner is correctly constructed") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, + {}, + document_sizes, + tl::nullopt, + {}, + {}, + {}, + std::move(source), + std::make_tuple(DocumentBlockedReader{}), + std::make_tuple(PayloadBlockedReader{})); + int counter = 0; + runner([&](auto index) { + counter += 1; + TermId term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE(sequence.docs.size() == index.cursor(term_id).size()); + REQUIRE( + std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.cursor(term_id))); + REQUIRE( + std::vector(sequence.freqs.begin(), sequence.freqs.end()) + == collect(index.cursor(term_id), + [](auto&& cursor) { return cursor.payload(); })); + { + auto cursor = index.cursor(term_id); + for (auto doc : sequence.docs) { + cursor.advance_to_geq(doc); + REQUIRE(cursor.value() == doc); + } + } + { + auto cursor = index.cursor(term_id); + for (auto doc : sequence.docs) { + REQUIRE(cursor.value() == doc); + cursor.advance_to_geq(doc + 1); + } + } + term_id += 1; + } + }); + REQUIRE(counter == 1); + } + + THEN("Index runner fails when wrong type") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, + {}, + document_sizes, + tl::nullopt, + {}, + {}, + {}, + std::move(source), + std::make_tuple(RawReader{}), + std::make_tuple()); + REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); + } + } + } +} diff --git a/test/v1/test_v1_document_payload_cursor.cpp b/test/v1/test_v1_document_payload_cursor.cpp new file mode 100644 index 000000000..220ccb315 --- /dev/null +++ b/test/v1/test_v1_document_payload_cursor.cpp @@ -0,0 +1,88 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "pisa_config.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/posting_builder.hpp" +#include "v1/types.hpp" + +using pisa::v1::DocId; +using pisa::v1::DocumentPayloadCursor; +using pisa::v1::Frequency; +using pisa::v1::RawCursor; +using pisa::v1::TermId; + +TEST_CASE("Document-payload cursor", "[v1][unit]") +{ + GIVEN("Document-payload cursor") + { + std::vector documents{4, 0, 1, 5, 7}; + std::vector frequencies{4, 2, 2, 1, 6}; + auto cursor = DocumentPayloadCursor, RawCursor>( + RawCursor(gsl::as_bytes(gsl::make_span(documents))), + RawCursor(gsl::as_bytes(gsl::make_span(frequencies)))); + + WHEN("Collected to document and frequency vectors") + { + std::vector collected_documents; + std::vector collected_frequencies; + for_each(cursor, [&](auto&& cursor) { + collected_documents.push_back(cursor.value()); + collected_frequencies.push_back(cursor.payload()); + }); + THEN("Vector equals to expected") + { + REQUIRE(collected_documents == std::vector{0, 1, 5, 7}); + REQUIRE(collected_frequencies == std::vector{2, 2, 1, 6}); + } + } + + WHEN("Stepped with advance_to_pos") + { + cursor.advance_to_position(0); + REQUIRE(cursor.value() == 0); + REQUIRE(cursor.payload() == 2); + cursor.advance_to_position(1); + REQUIRE(cursor.value() == 1); + REQUIRE(cursor.payload() == 2); + cursor.advance_to_position(2); + REQUIRE(cursor.value() == 5); + REQUIRE(cursor.payload() == 1); + cursor.advance_to_position(3); + REQUIRE(cursor.value() == 7); + REQUIRE(cursor.payload() == 6); + } + + WHEN("Advanced to 1") + { + cursor.advance_to_position(1); + REQUIRE(cursor.value() == 1); + REQUIRE(cursor.payload() == 2); + } + WHEN("Advanced to 2") + { + cursor.advance_to_position(2); + REQUIRE(cursor.value() == 5); + REQUIRE(cursor.payload() == 1); + } + WHEN("Advanced to 3") + { + cursor.advance_to_position(3); + REQUIRE(cursor.value() == 7); + REQUIRE(cursor.payload() == 6); + } + } +} diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp new file mode 100644 index 000000000..4d87f2875 --- /dev/null +++ b/test/v1/test_v1_index.cpp @@ -0,0 +1,246 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "binary_collection.hpp" +#include "codec/simdbp.hpp" +#include "index_fixture.hpp" +#include "pisa_config.hpp" +#include "v1/bit_sequence_cursor.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" +#include "v1/types.hpp" + +using pisa::binary_freq_collection; +using pisa::v1::BigramMetadata; +using pisa::v1::build_pair_index; +using pisa::v1::compress_binary_collection; +using pisa::v1::DocId; +using pisa::v1::DocumentBitSequenceCursor; +using pisa::v1::DocumentBlockedCursor; +using pisa::v1::DocumentBlockedWriter; +using pisa::v1::for_each; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::make_bm25; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceCursor; +using pisa::v1::PayloadBlockedCursor; +using pisa::v1::PayloadBlockedWriter; +using pisa::v1::PositiveSequence; +using pisa::v1::PostingFilePaths; +using pisa::v1::Query; +using pisa::v1::RawCursor; +using pisa::v1::RawWriter; +using pisa::v1::TermId; + +TEST_CASE("Binary collection index", "[v1][unit]") +{ + tbb::task_scheduler_init init(8); + Temporary_Directory tmpdir; + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + (tmpdir.path() / "fwd").string(), + (tmpdir.path() / "index").string(), + 8, + make_writer(RawWriter{}), + make_writer(RawWriter{})); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.yml").string()); + REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); + REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); + REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); + REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); + REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); + auto run = index_runner(meta); + run([&](auto index) { + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; + } + }); +} + +TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") +{ + tbb::task_scheduler_init init(8); + Temporary_Directory tmpdir; + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + (tmpdir.path() / "fwd").string(), + (tmpdir.path() / "index").string(), + 8, + make_writer(DocumentBlockedWriter<::pisa::simdbp_block>{}), + make_writer(PayloadBlockedWriter<::pisa::simdbp_block>{})); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.yml").string()); + REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); + REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); + REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); + REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); + REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); + auto run = index_runner(meta); + run([&](auto index) { + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; + } + }); +} + +TEMPLATE_TEST_CASE("Index", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>), + (IndexFixture, + PayloadBitSequenceCursor>, + RawCursor>), + (IndexFixture>, + PayloadBitSequenceCursor>, + RawCursor>)) +{ + tbb::task_scheduler_init init{1}; + TestType fixture(false, false, false, false); + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + auto run = pisa::v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + run([&](auto&& index) { + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; + } + }); +} + +TEST_CASE("Select best bigrams", "[v1][integration]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture( + false, false, false, false); + std::string index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + { + std::vector queries; + auto best_bigrams = select_best_bigrams(meta, queries, 10); + REQUIRE(best_bigrams.empty()); + } + { + std::vector queries = {Query::from_ids(0, 1).with_probability(0.1)}; + auto best_bigrams = select_best_bigrams(meta, queries, 10); + REQUIRE(best_bigrams == std::vector>{{0, 1}}); + } + { + std::vector queries = { + Query::from_ids(0, 1).with_probability(0.2), // u: 3758, i: 808 u/i: 4.650990099009901 + Query::from_ids(1, 2).with_probability(0.2), // u: 3961, i: 734 u/i: 5.3964577656675745 + Query::from_ids(2, 3).with_probability(0.2), // u: 2298, i: 61 u/i: 37.67213114754098 + Query::from_ids(3, 4).with_probability(0.2), // u: 90, i: 3 u/i: 30.0 + Query::from_ids(4, 5).with_probability(0.2), // u: 21, i: 1 u/i: 21.0 + Query::from_ids(5, 6).with_probability(0.2), // u: 8, i: 3 u/i: 2.6666666666666665 + Query::from_ids(6, 7).with_probability(0.2), // u: 4, i: 0 u/i: --- + Query::from_ids(7, 8).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(8, 9).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(9, 10).with_probability(0.2) // u: 2, i: 1 u/i: 2 + }; + auto best_bigrams = select_best_bigrams(meta, queries, 3); + REQUIRE(best_bigrams == std::vector>{{2, 3}, {3, 4}, {4, 5}}); + } + { + std::vector queries = { + Query::from_ids(0, 1).with_probability(0.2), // u: 3758, i: 808 u/i: 4.650990099009901 + Query::from_ids(1, 2).with_probability(0.2), // u: 3961, i: 734 u/i: 5.3964577656675745 + Query::from_ids(2, 3).with_probability(0.2), // u: 2298, i: 61 u/i: 37.67213114754098 + Query::from_ids(3, 4).with_probability(0.4), // u: 90, i: 3 u/i: 30.0 + Query::from_ids(4, 5).with_probability(0.01), // u: 21, i: 1 u/i: 21.0 + Query::from_ids(5, 6).with_probability(0.2), // u: 8, i: 3 u/i: 2.6666666666666665 + Query::from_ids(6, 7).with_probability(0.2), // u: 4, i: 0 u/i: --- + Query::from_ids(7, 8).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(8, 9).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(9, 10).with_probability(0.2) // u: 2, i: 1 u/i: 2 + }; + auto best_bigrams = select_best_bigrams(meta, queries, 3); + REQUIRE(best_bigrams == std::vector>{{3, 4}, {2, 3}, {1, 2}}); + } +} + +TEST_CASE("Build pair index", "[v1][integration]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture( + true, true, true, false); + std::string index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + SECTION("In place") + { + build_pair_index(meta, {{0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}}, tl::nullopt, 4); + auto run = index_runner(IndexMetadata::from_file(fmt::format("{}.yml", index_basename))); + run([&](auto&& index) { + REQUIRE(index.bigram_cursor(0, 1).has_value()); + REQUIRE(index.bigram_cursor(1, 0).has_value()); + REQUIRE(not index.bigram_cursor(1, 2).has_value()); + REQUIRE(not index.bigram_cursor(2, 1).has_value()); + }); + } + SECTION("Cloned") + { + auto cloned_basename = (fixture.tmpdir().path() / "cloned").string(); + build_pair_index( + meta, {{0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}}, tl::make_optional(cloned_basename), 4); + SECTION("Original index has no pairs") + { + auto run = + index_runner(IndexMetadata::from_file(fmt::format("{}.yml", index_basename))); + run([&](auto&& index) { + REQUIRE_THROWS_AS(index.bigram_cursor(0, 1), std::logic_error); + }); + } + SECTION("New index has pairs") + { + auto run = + index_runner(IndexMetadata::from_file(fmt::format("{}.yml", cloned_basename))); + run([&](auto&& index) { + REQUIRE(index.bigram_cursor(0, 1).has_value()); + REQUIRE(index.bigram_cursor(1, 0).has_value()); + REQUIRE(not index.bigram_cursor(1, 2).has_value()); + REQUIRE(not index.bigram_cursor(2, 1).has_value()); + }); + } + } +} diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp new file mode 100644 index 000000000..086e8ee1d --- /dev/null +++ b/test/v1/test_v1_maxscore_join.cpp @@ -0,0 +1,71 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "index_fixture.hpp" +#include "pisa_config.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/maxscore.hpp" +#include "v1/posting_builder.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/types.hpp" + +using pisa::v1::collect; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::join_maxscore; +using pisa::v1::accumulators::Add; + +TEMPLATE_TEST_CASE("Max score join", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + + SECTION("Zero threshold -- equivalent to union") + { + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx++); + + auto add = [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }; + auto run = v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + run([&](auto&& index) { + auto union_results = collect(v1::union_merge( + index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), + 0.0F, + add)); + auto maxscore_results = collect(v1::join_maxscore( + index.max_scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), + 0.0F, + Add{}, + [](auto /* score */) { return true; })); + REQUIRE(union_results == maxscore_results); + }); + } + } +} diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp new file mode 100644 index 000000000..5f62f5698 --- /dev/null +++ b/test/v1/test_v1_queries.cpp @@ -0,0 +1,327 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "accumulator/lazy_accumulator.hpp" +#include "cursor/block_max_scored_cursor.hpp" +#include "cursor/max_scored_cursor.hpp" +#include "cursor/scored_cursor.hpp" +#include "index_fixture.hpp" +#include "index_types.hpp" +#include "io.hpp" +#include "pisa_config.hpp" +#include "query/algorithm/ranked_or_query.hpp" +#include "query/queries.hpp" +#include "scorer/bm25.hpp" +#include "v1/bit_sequence_cursor.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/cursor_union.hpp" +#include "v1/daat_and.hpp" +#include "v1/daat_or.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/maxscore.hpp" +#include "v1/maxscore_union_lookup.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/query.hpp" +#include "v1/score_index.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" +#include "v1/taat_or.hpp" +#include "v1/types.hpp" +#include "v1/unigram_union_lookup.hpp" +#include "v1/union_lookup.hpp" +#include "v1/wand.hpp" + +using pisa::v1::DocId; +using pisa::v1::DocumentBitSequenceCursor; +using pisa::v1::DocumentBlockedCursor; +using pisa::v1::Frequency; +using pisa::v1::Index; +using pisa::v1::IndexMetadata; +using pisa::v1::ListSelection; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceCursor; +using pisa::v1::PayloadBlockedCursor; +using pisa::v1::PositiveSequence; +using pisa::v1::RawCursor; + +static constexpr auto RELATIVE_ERROR = 0.1F; + +template +struct IndexData { + + static std::unique_ptr data; + + IndexData() + : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), + document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), + wdata(document_sizes.begin()->begin(), + collection.num_docs(), + collection, + "bm25", + ::pisa::BlockSize(::pisa::FixedBlock()), + {}) + + { + typename v0_Index::builder builder(collection.num_docs(), params); + for (auto const& plist : collection) { + uint64_t freqs_sum = + std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); + builder.add_posting_list( + plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); + } + builder.build(v0_index); + + ::pisa::term_id_vec q; + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + auto push_query = [&](std::string const& query_line) { + queries.push_back(::pisa::parse_query_ids(query_line)); + }; + ::pisa::io::for_each_line(qfile, push_query); + + std::string t; + std::ifstream tin(PISA_SOURCE_DIR "/test/test_data/top5_thresholds"); + while (std::getline(tin, t)) { + thresholds.push_back(std::stof(t)); + } + } + + [[nodiscard]] static auto get() + { + if (IndexData::data == nullptr) { + IndexData::data = std::make_unique>(); + } + return IndexData::data.get(); + } + + ::pisa::global_parameters params; + ::pisa::binary_freq_collection collection; + ::pisa::binary_collection document_sizes; + v0_Index v0_index; + std::vector<::pisa::Query> queries; + std::vector thresholds; + ::pisa::wand_data<::pisa::wand_data_raw> wdata; +}; + +/// Inefficient, do not use in production code. +[[nodiscard]] auto approximate_order(std::pair lhs, + std::pair rhs) -> bool +{ + return std::make_pair(fmt::format("{:0.4f}", lhs.first), lhs.second) + < std::make_pair(fmt::format("{:0.4f}", rhs.first), rhs.second); +} + +/// Inefficient, do not use in production code. +[[nodiscard]] auto approximate_order_f(float lhs, float rhs) -> bool +{ + return fmt::format("{:0.4f}", lhs) < fmt::format("{:0.4f}", rhs); +} + +template +std::unique_ptr> + IndexData::data = nullptr; + +TEMPLATE_TEST_CASE("Query", + "[v1][integration]", + //(IndexFixture, RawCursor, + // RawCursor>), + //(IndexFixture, + // PayloadBlockedCursor<::pisa::simdbp_block>, + // RawCursor>), + (IndexFixture>, + PayloadBitSequenceCursor>, + RawCursor>)) +{ + tbb::task_scheduler_init init(1); + auto data = IndexData<::pisa::single_index, + Index, RawCursor>, + Index, RawCursor>>::get(); + TestType fixture; + auto input_data = GENERATE(table({ + //{"daat_or", false, false}, + {"maxscore", false, false}, + {"maxscore", true, false}, + //{"wand", false, false}, + //{"wand", true, false}, + //{"bmw", false, false}, + //{"bmw", true, false}, + //{"bmw", false, true}, + //{"bmw", true, true}, + {"maxscore_union_lookup", true, false}, + {"unigram_union_lookup", true, false}, + //{"union_lookup", true, false}, + //{"union_lookup_plus", true, false}, + {"lookup_union", true, false}, + {"lookup_union_eaat", true, false}, + })); + std::string algorithm = std::get<0>(input_data); + bool with_threshold = std::get<1>(input_data); + bool rebuild_with_variable_blocks = std::get<2>(input_data); + if (rebuild_with_variable_blocks) { + fixture.rebuild_bm_scores(pisa::v1::VariableBlock{12.0}); + } + CAPTURE(algorithm); + CAPTURE(with_threshold); + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + auto heap = ::pisa::topk_queue(10); + ::pisa::ranked_or_query or_q(heap); + auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { + if (name == "daat_or") { + return daat_or(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "maxscore") { + return maxscore(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "wand") { + return wand(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "bmw") { + return bmw(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "maxscore_union_lookup") { + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "unigram_union_lookup") { + query.selections(ListSelection{.unigrams = query.get_term_ids(), .bigrams = {}}); + return unigram_union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "union_lookup") { + if (query.get_term_ids().size() > 8) { + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + return union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "union_lookup_plus") { + if (query.get_term_ids().size() > 8) { + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + return union_lookup_plus(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "lookup_union") { + return lookup_union(query, index, ::pisa::topk_queue(10), scorer); + } + if (name == "lookup_union_eaat") { + return lookup_union_eaat(query, index, ::pisa::topk_queue(10), scorer); + } + std::abort(); + }; + int idx = 0; + auto const intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + for (auto& query : test_queries()) { + heap.clear(); + if (algorithm == "union_lookup" || algorithm == "union_lookup_plus" + || algorithm == "lookup_union" || algorithm == "lookup_union_eaat") { + query.selections(gsl::make_span(intersections[idx])); + } + + CAPTURE(query); + CAPTURE(idx); + CAPTURE(intersections[idx]); + + // if (query.get_selections().unigrams.empty() or idx < 8) { + // idx += 1; + // continue; + //} + + or_q( + make_scored_cursors(data->v0_index, + ::pisa::bm25<::pisa::wand_data<::pisa::wand_data_raw>>(data->wdata), + ::pisa::Query{{}, query.get_term_ids(), {}}), + data->v0_index.num_docs()); + heap.finalize(); + auto expected = or_q.topk(); + if (with_threshold) { + query.threshold(expected.back().first - 1.0F); + } + + auto on_the_fly = [&]() { + auto run = pisa::v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + std::vector results; + run([&](auto&& index) { + auto que = run_query(algorithm, query, index, make_bm25(index)); + que.finalize(); + results = que.topk(); + if (not results.empty()) { + results.erase(std::remove_if(results.begin(), + results.end(), + [last_score = results.back().first](auto&& entry) { + return entry.first <= last_score; + }), + results.end()); + std::sort(results.begin(), results.end(), approximate_order); + } + }); + return results; + }(); + expected.resize(on_the_fly.size()); + std::sort(expected.begin(), expected.end(), approximate_order); + + // if (algorithm == "union_lookup_plus") { + // for (size_t i = 0; i < on_the_fly.size(); ++i) { + // std::cerr << fmt::format("{} {} -- {} {}\n", + // on_the_fly[i].second, + // on_the_fly[i].first, + // expected[i].second, + // expected[i].first); + // } + // std::cerr << '\n'; + //} + + for (size_t i = 0; i < on_the_fly.size(); ++i) { + REQUIRE(on_the_fly[i].second == expected[i].second); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + } + + idx += 1; + + if (algorithm == "bmw") { + continue; + } + + auto precomputed = [&]() { + auto run = pisa::v1::scored_index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); + std::vector results; + run([&](auto&& index) { + auto que = run_query(algorithm, query, index, v1::VoidScorer{}); + que.finalize(); + results = que.topk(); + }); + if (not results.empty()) { + results.erase(std::remove_if(results.begin(), + results.end(), + [last_score = results.back().first](auto&& entry) { + return entry.first <= last_score; + }), + results.end()); + std::sort(results.begin(), results.end(), approximate_order); + } + return results; + }(); + + // TODO(michal): test the quantized results + // constexpr float max_partial_score = 16.5724F; + // auto quantizer = [&](float score) { + // return static_cast(score * std::numeric_limits::max() + // / max_partial_score); + //}; + } +} diff --git a/test/v1/test_v1_query.cpp b/test/v1/test_v1_query.cpp new file mode 100644 index 000000000..6dcbb1554 --- /dev/null +++ b/test/v1/test_v1_query.cpp @@ -0,0 +1,37 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include +#include +#include + +#include "v1/query.hpp" + +using pisa::v1::Query; +using pisa::v1::TermId; + +TEST_CASE("Parse query from JSON", "[v1][unit]") +{ + REQUIRE_THROWS(Query::from_json("{}")); + REQUIRE(Query::from_json("{\"query\": \"tell your dog I said hi\"}").get_raw() + == "tell your dog I said hi"); + REQUIRE(Query::from_json("{\"term_ids\": [0, 32, 4]}").get_term_ids() + == std::vector{0, 4, 32}); + REQUIRE(Query::from_json("{\"term_ids\": [0, 32, 4]}").k() == 1000); + auto query = Query::from_json( + R"({"id": "Q0", "query": "send dog pics", "term_ids": [0, 32, 4], "k": 15, )" + R"("threshold": 40.5, "selections": )" + R"([1, 4, 5, 6]})"); + REQUIRE(query.get_id() == "Q0"); + REQUIRE(query.k() == 15); + REQUIRE(query.get_term_ids() == std::vector{0, 4, 32}); + REQUIRE(query.get_threshold() == 40.5); + REQUIRE(query.get_raw() == "send dog pics"); + REQUIRE(query.get_selections().unigrams == std::vector{0, 4}); + REQUIRE(query.get_selections().bigrams + == std::vector>{{0, 4}, {4, 32}}); +} diff --git a/test/v1/test_v1_score_index.cpp b/test/v1/test_v1_score_index.cpp new file mode 100644 index 000000000..d040b389f --- /dev/null +++ b/test/v1/test_v1_score_index.cpp @@ -0,0 +1,130 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "index_fixture.hpp" +#include "pisa_config.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/types.hpp" + +using pisa::v1::DocId; +using pisa::v1::DocumentBlockedCursor; +using pisa::v1::DocumentBlockedReader; +using pisa::v1::DocumentBlockedWriter; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::make_bm25; +using pisa::v1::PayloadBlockedCursor; +using pisa::v1::PayloadBlockedReader; +using pisa::v1::PayloadBlockedWriter; +using pisa::v1::RawCursor; +using pisa::v1::TermId; + +TEMPLATE_TEST_CASE("Score index", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>)) +{ + tbb::task_scheduler_init init(1); + GIVEN("Index fixture (built and scored index)") + { + TestType fixture; + THEN("Float max scores are correct") + { + auto run = v1::index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.max_scored_cursor(term, make_bm25(index)); + auto precomputed_max = cursor.max_score(); + float calculated_max = 0.0F; + ::pisa::v1::for_each(cursor, [&](auto&& cursor) { + calculated_max = std::max(cursor.payload(), calculated_max); + }); + REQUIRE(precomputed_max == calculated_max); + } + }); + } + THEN("Quantized max scores are correct") + { + auto run = v1::scored_index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.max_scored_cursor(term, pisa::v1::VoidScorer{}); + auto precomputed_max = cursor.max_score(); + std::uint8_t calculated_max = 0; + ::pisa::v1::for_each(cursor, [&](auto&& cursor) { + if (cursor.payload() > calculated_max) { + calculated_max = cursor.payload(); + } + }); + REQUIRE(precomputed_max == calculated_max); + } + }); + } + } +} + +TEMPLATE_TEST_CASE("Construct max-score lists", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>)) +{ + tbb::task_scheduler_init init(1); + GIVEN("Index fixture (built and (max) scored index)") + { + TestType fixture; + THEN("Float max scores are correct") + { + auto run = v1::index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.block_max_scored_cursor(term, make_bm25(index)); + auto term_max_score = 0.0F; + while (not cursor.empty()) { + auto max_score = 0.0F; + auto block_max_score = cursor.block_max_score(*cursor); + for (auto idx = 0; idx < 5 && not cursor.empty(); ++idx) { + REQUIRE(cursor.block_max_score(*cursor) == block_max_score); + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + cursor.advance(); + } + if (max_score > term_max_score) { + term_max_score = max_score; + } + REQUIRE(max_score == block_max_score); + } + REQUIRE(term_max_score == cursor.max_score()); + } + }); + } + } +} diff --git a/test/v1/test_v1_union_lookup.cpp b/test/v1/test_v1_union_lookup.cpp new file mode 100644 index 000000000..47b85089c --- /dev/null +++ b/test/v1/test_v1_union_lookup.cpp @@ -0,0 +1,155 @@ +#define CATCH_CONFIG_MAIN +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "index_fixture.hpp" +#include "topk_queue.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index_metadata.hpp" +#include "v1/maxscore.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/union_lookup.hpp" + +using pisa::v1::BM25; +using pisa::v1::collect_payloads; +using pisa::v1::collect_with_payload; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::InspectLookupUnion; +using pisa::v1::InspectLookupUnionEaat; +using pisa::v1::InspectMaxScore; +using pisa::v1::InspectUnionLookup; +using pisa::v1::InspectUnionLookupPlus; +using pisa::v1::join_union_lookup; +using pisa::v1::lookup_union; +using pisa::v1::lookup_union_eaat; +using pisa::v1::maxscore_partition; +using pisa::v1::RawCursor; + +template +void test_write(T&& result) +{ + std::ostringstream os; + result.write(os); + REQUIRE(fmt::format("{}\t{}\t{}\t{}\t{}", + result.postings(), + result.documents(), + result.lookups(), + result.inserts(), + result.essentials()) + == os.str()); +} + +template +void test_write_partitioned(T&& result) +{ + std::ostringstream os; + result.write(os); + REQUIRE(fmt::format("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", + result.sum.postings(), + result.sum.documents(), + result.sum.lookups(), + result.sum.inserts(), + result.sum.essentials(), + result.first.postings(), + result.first.documents(), + result.first.lookups(), + result.first.inserts(), + result.first.essentials(), + result.second.postings(), + result.second.documents(), + result.second.lookups(), + result.second.inserts(), + result.second.essentials()) + == os.str()); +} + +TEST_CASE("UnionLookup statistics", "[union-lookup][v1][unit]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture; + index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader()))([&](auto&& index) { + auto union_lookup_inspect = InspectUnionLookup(index, make_bm25(index)); + auto union_lookup_plus_inspect = InspectUnionLookupPlus(index, make_bm25(index)); + auto lookup_union_inspect = InspectLookupUnion(index, make_bm25(index)); + auto lookup_union_eaat_inspect = InspectLookupUnionEaat(index, make_bm25(index)); + auto queries = test_queries(); + auto const intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + for (auto&& [idx, q] : ranges::views::enumerate(queries)) { + // if (idx == 230 || idx == 272 || idx == 380) { + // // Skipping these because of the false positives caused by floating point + // precision. continue; + //} + if (q.get_term_ids().size() > 8) { + continue; + } + auto heap = maxscore(q, index, ::pisa::topk_queue(10), make_bm25(index)); + q.selections(intersections[idx]); + q.threshold(heap.topk().back().first); + CAPTURE(q.get_term_ids()); + CAPTURE(intersections[idx]); + CAPTURE(q.get_threshold()); + CAPTURE(idx); + + auto ul = union_lookup_inspect(q); + auto ulp = union_lookup_plus_inspect(q); + auto lu = lookup_union_inspect(q); + auto lue = lookup_union_eaat_inspect(q); + test_write(ul); + test_write(ulp); + test_write_partitioned(lu); + test_write_partitioned(lue); + + CHECK(ul.documents() == ulp.documents()); + CHECK(ul.postings() == ulp.postings()); + + // +2 because of the false positives caused by floating point + CHECK(ul.lookups() + 2 >= ulp.lookups()); + + CAPTURE(ulp.lookups()); + CAPTURE(ul.lookups()); + CAPTURE(lu.first.lookups()); + CAPTURE(lu.second.lookups()); + + CAPTURE(ul.essentials()); + CAPTURE(lu.first.essentials()); + CAPTURE(lu.second.essentials()); + + CHECK(lu.first.lookups() + lu.second.lookups() == lu.sum.lookups()); + CHECK(lue.first.lookups() + lue.second.lookups() == lue.sum.lookups()); + CHECK(ul.postings() == lu.sum.postings()); + CHECK(ul.postings() == lue.sum.postings()); + + // +3 because of the false positives caused by floating point + CHECK(ulp.lookups() <= lu.sum.lookups() + 3); + CHECK(ulp.lookups() <= lue.sum.lookups() + 3); + } + auto ul = union_lookup_inspect.mean(); + auto ulp = union_lookup_plus_inspect.mean(); + auto lu = lookup_union_inspect.mean(); + auto lue = lookup_union_inspect.mean(); + CHECK(ul.documents() == ulp.documents()); + CHECK(ul.postings() == ulp.postings()); + CHECK(ul.lookups() >= ulp.lookups()); + CHECK(ul.postings() == lu.first.postings() + lu.second.postings()); + CHECK(ul.postings() == lue.first.postings() + lue.second.postings()); + CHECK(ulp.lookups() <= lu.first.lookups() + lu.second.lookups()); + CHECK(ulp.lookups() <= lue.first.lookups() + lue.second.lookups()); + CHECK(lu.first.lookups() + lu.second.lookups() == lu.sum.lookups()); + CHECK(lue.first.lookups() + lue.second.lookups() == lue.sum.lookups()); + }); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index e35c30c29..2d8b6aaca 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,69 +1,68 @@ -add_executable(create_freq_index create_freq_index.cpp) -target_link_libraries(create_freq_index - pisa - CLI11 -) - -add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) -target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) -target_link_libraries(optimal_hybrid_index - ${STXXL_LIBRARIES} - pisa -) -set_target_properties(optimal_hybrid_index PROPERTIES - CXX_STANDARD 14 -) - -add_executable(create_wand_data create_wand_data.cpp) -target_link_libraries(create_wand_data - pisa - CLI11 -) - -add_executable(queries queries.cpp) -target_link_libraries(queries - pisa - CLI11 -) - -add_executable(evaluate_queries evaluate_queries.cpp) -target_link_libraries(evaluate_queries - pisa - CLI11 -) - -add_executable(thresholds thresholds.cpp) -target_link_libraries(thresholds - pisa - CLI11 -) - -add_executable(profile_queries profile_queries.cpp) -target_link_libraries(profile_queries - pisa -) - -add_executable(profile_decoding profile_decoding.cpp) -target_link_libraries(profile_decoding - pisa -) - -add_executable(shuffle_docids shuffle_docids.cpp) -target_link_libraries(shuffle_docids - pisa - CLI11 -) - -add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) -target_link_libraries(recursive_graph_bisection - pisa - CLI11 -) - -add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) -target_link_libraries(evaluate_collection_ordering - pisa - ) +#add_executable(create_freq_index create_freq_index.cpp) +#target_link_libraries(create_freq_index +# pisa +# CLI11 +#) +# +#add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) +#target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) +#target_link_libraries(optimal_hybrid_index +# ${STXXL_LIBRARIES} +# pisa +#) +#set_target_properties(optimal_hybrid_index PROPERTIES +# CXX_STANDARD 14 +#) +# +#add_executable(create_wand_data create_wand_data.cpp) +#target_link_libraries(create_wand_data +# pisa +# CLI11 +#) + +#add_executable(queries queries.cpp) +#target_link_libraries(queries +# pisa +# CLI11 +#) + +#add_executable(evaluate_queries evaluate_queries.cpp) +#target_link_libraries(evaluate_queries +# pisa +# CLI11 +#) + +#add_executable(thresholds thresholds.cpp) +#target_link_libraries(thresholds +# pisa +# CLI11 +#) + +#add_executable(profile_queries profile_queries.cpp) +#target_link_libraries(profile_queries +# pisa +#) +# +#add_executable(profile_decoding profile_decoding.cpp) +#target_link_libraries(profile_decoding +# pisa +#) +# +#add_executable(shuffle_docids shuffle_docids.cpp) +#target_link_libraries(shuffle_docids +# pisa +#) +# +#add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) +#target_link_libraries(recursive_graph_bisection +# pisa +# CLI11 +#) +# +#add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) +#target_link_libraries(evaluate_collection_ordering +# pisa +# ) add_executable(parse_collection parse_collection.cpp) target_link_libraries(parse_collection @@ -78,23 +77,23 @@ target_link_libraries(invert pisa ) -add_executable(read_collection read_collection.cpp) -target_link_libraries(read_collection - pisa - CLI11 -) +#add_executable(read_collection read_collection.cpp) +#target_link_libraries(read_collection +# pisa +# CLI11 +#) +# +#add_executable(partition_fwd_index partition_fwd_index.cpp) +#target_link_libraries(partition_fwd_index +# pisa +# CLI11 +#) -add_executable(partition_fwd_index partition_fwd_index.cpp) -target_link_libraries(partition_fwd_index - pisa - CLI11 -) - -add_executable(compute_intersection compute_intersection.cpp) -target_link_libraries(compute_intersection - pisa - CLI11 -) +#add_executable(compute_intersection compute_intersection.cpp) +#target_link_libraries(compute_intersection +# pisa +# CLI11 +#) add_executable(lexicon lexicon.cpp) target_link_libraries(lexicon @@ -102,26 +101,26 @@ target_link_libraries(lexicon CLI11 ) -add_executable(extract_topics extract_topics.cpp) -target_link_libraries(extract_topics - pisa - CLI11 -) +#add_executable(extract_topics extract_topics.cpp) +#target_link_libraries(extract_topics +# pisa +# CLI11 +#) +# +#add_executable(sample_inverted_index sample_inverted_index.cpp) +#target_link_libraries(sample_inverted_index +# pisa +# CLI11 +#) -add_executable(sample_inverted_index sample_inverted_index.cpp) -target_link_libraries(sample_inverted_index - pisa - CLI11 -) +#add_executable(count_postings count_postings.cpp) +#target_link_libraries(count_postings +# pisa +# CLI11 +#) -add_executable(map_queries map_queries.cpp) -target_link_libraries(map_queries - pisa - CLI11 -) - -add_executable(count_postings count_postings.cpp) -target_link_libraries(count_postings - pisa - CLI11 -) +#add_executable(map_queries map_queries.cpp) +#target_link_libraries(map_queries +# pisa +# CLI11 +#) diff --git a/tools/compute_intersection.cpp b/tools/compute_intersection.cpp index 9baca9fea..1d5554ce8 100644 --- a/tools/compute_intersection.cpp +++ b/tools/compute_intersection.cpp @@ -20,10 +20,10 @@ using pisa::intersection::IntersectionType; using pisa::intersection::Mask; template -void intersect(std::string const &index_filename, - std::optional const &wand_data_filename, - std::vector const &queries, - std::string const &type, +void intersect(std::string const& index_filename, + std::optional const& wand_data_filename, + std::vector const& queries, + std::string const& type, IntersectionType intersection_type, std::optional max_term_count = std::nullopt) { @@ -44,18 +44,27 @@ void intersect(std::string const &index_filename, mapper::map(wdata, md, mapper::map_flags::warmup); } - std::size_t qid = 0u; + std::size_t qid = 0U; - auto print_intersection = [&](auto const &query, auto const &mask) { + auto print_intersection = [&](auto const& query, auto const& mask) { + // FIXME(michal): Quick workaround to not compute intersection (a, a) + auto filtered = intersection::filter(query, mask); + std::sort(filtered.terms.begin(), filtered.terms.end()); + if (std::unique(filtered.terms.begin(), filtered.terms.end()) != filtered.terms.end()) { + // Do not compute: contains duplicates + return; + } auto intersection = Intersection::compute(index, wdata, query, mask); - std::cout << fmt::format("{}\t{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), - mask.to_ulong(), - intersection.length, - intersection.max_score); + if (intersection.length > 0) { + std::cout << fmt::format("{}\t{}\t{}\t{}\n", + query.id ? *query.id : std::to_string(qid), + mask.to_ulong(), + intersection.length, + intersection.max_score); + } }; - for (auto const &query : queries) { + for (auto const& query : queries) { if (intersection_type == IntersectionType::Combinations) { for_all_subsets(query, max_term_count, print_intersection); } else { @@ -72,7 +81,7 @@ void intersect(std::string const &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { spdlog::drop(""); spdlog::set_default_logger(spdlog::stderr_color_mt("")); @@ -97,9 +106,9 @@ int main(int argc, const char **argv) app.add_option("-w,--wand", wand_data_filename, "Wand data filename"); app.add_option("-q,--query", query_filename, "Queries filename"); app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); - auto *terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); + auto* terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); - auto *combinations_flag = app.add_flag( + auto* combinations_flag = app.add_flag( "--combinations", combinations, "Compute intersections for combinations of terms in query"); app.add_option("--max-term-count,--mtc", max_term_count, diff --git a/tools/queries.cpp b/tools/queries.cpp index fb85c5606..17d0362e2 100644 --- a/tools/queries.cpp +++ b/tools/queries.cpp @@ -33,15 +33,15 @@ using ranges::views::enumerate; template void extract_times(Fn fn, - std::vector const &queries, - std::vector const &thresholds, - std::string const &index_type, - std::string const &query_type, + std::vector const& queries, + std::vector const& thresholds, + std::string const& index_type, + std::string const& query_type, size_t runs, - std::ostream &os) + std::ostream& os) { std::vector times(runs); - for (auto &&[qid, query] : enumerate(queries)) { + for (auto&& [qid, query] : enumerate(queries)) { do_not_optimize_away(fn(query, thresholds[qid])); std::generate(times.begin(), times.end(), [&fn, &q = query, &t = thresholds[qid]]() { return run_with_timer( @@ -56,10 +56,10 @@ void extract_times(Fn fn, template void op_perftest(Functor query_func, - std::vector const &queries, - std::vector const &thresholds, - std::string const &index_type, - std::string const &query_type, + std::vector const& queries, + std::vector const& thresholds, + std::string const& index_type, + std::string const& query_type, size_t runs) { @@ -67,7 +67,7 @@ void op_perftest(Functor query_func, for (size_t run = 0; run <= runs; ++run) { size_t idx = 0; - for (auto const &query : queries) { + for (auto const& query : queries) { auto usecs = run_with_timer([&]() { uint64_t result = query_func(query, thresholds[idx]); do_not_optimize_away(result); @@ -103,14 +103,14 @@ void op_perftest(Functor query_func, } template -void perftest(const std::string &index_filename, - const std::optional &wand_data_filename, - const std::vector &queries, - const std::optional &thresholds_filename, - std::string const &type, - std::string const &query_type, +void perftest(const std::string& index_filename, + const std::optional& wand_data_filename, + const std::vector& queries, + const std::optional& thresholds_filename, + std::string const& type, + std::string const& query_type, uint64_t k, - std::string const &scorer_name, + std::string const& scorer_name, bool extract) { IndexType index; @@ -120,7 +120,7 @@ void perftest(const std::string &index_filename, spdlog::info("Warming up posting lists"); std::unordered_set warmed_up; - for (auto const &q : queries) { + for (auto const& q : queries) { for (auto t : q.terms) { if (!warmed_up.count(t)) { index.warmup(t); @@ -163,7 +163,7 @@ void perftest(const std::string &index_filename, spdlog::info("Performing {} queries", type); spdlog::info("K: {}", k); - for (auto &&t : query_types) { + for (auto&& t : query_types) { spdlog::info("Query type: {}", t); std::function query_fun; if (t == "and") { @@ -284,7 +284,7 @@ void perftest(const std::string &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { std::string type; std::string query_type; @@ -312,7 +312,7 @@ int main(int argc, const char **argv) app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); app.add_option("-k", k, "k value"); app.add_option("-T,--thresholds", thresholds_filename, "k value"); - auto *terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); + auto* terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); app.add_option("--stopwords", stopwords_filename, "File containing stopwords to ignore") ->needs(terms_opt); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); diff --git a/tools/thresholds.cpp b/tools/thresholds.cpp index dc6472277..55e6e634d 100644 --- a/tools/thresholds.cpp +++ b/tools/thresholds.cpp @@ -25,12 +25,12 @@ using namespace pisa; template -void thresholds(const std::string &index_filename, - const std::optional &wand_data_filename, - const std::vector &queries, - const std::optional &thresholds_filename, - std::string const &type, - std::string const &scorer_name, +void thresholds(const std::string& index_filename, + const std::optional& wand_data_filename, + const std::vector& queries, + const std::optional& thresholds_filename, + std::string const& type, + std::string const& scorer_name, uint64_t k) { IndexType index; @@ -51,15 +51,16 @@ void thresholds(const std::string &index_filename, } mapper::map(wdata, md, mapper::map_flags::warmup); } + topk_queue topk(k); wand_query wand_q(topk); - for (auto const &query : queries) { + for (auto const& query : queries) { wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); topk.finalize(); auto results = topk.topk(); topk.clear(); float threshold = 0.0; - if (results.size() == k) { + if (not results.empty()) { threshold = results.back().first; } std::cout << threshold << '\n'; @@ -69,7 +70,7 @@ void thresholds(const std::string &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { spdlog::drop(""); spdlog::set_default_logger(spdlog::stderr_color_mt("")); @@ -95,7 +96,7 @@ int main(int argc, const char **argv) app.add_option("-s,--scorer", scorer_name, "Scorer function")->required(); app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); app.add_option("-k", k, "k value"); - auto *terms_opt = + auto* terms_opt = app.add_option("--terms", terms_file, "Text file with terms in separate lines"); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); CLI11_PARSE(app, argc, argv); @@ -131,7 +132,7 @@ int main(int argc, const char **argv) type, \ scorer_name, \ k); \ - } \ + } /**/ BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt new file mode 100644 index 000000000..6bb640447 --- /dev/null +++ b/v1/CMakeLists.txt @@ -0,0 +1,38 @@ +add_executable(compress compress.cpp) +target_link_libraries(compress pisa CLI11) + +add_executable(query query.cpp) +target_link_libraries(query pisa CLI11) + +add_executable(postings postings.cpp) +target_link_libraries(postings pisa CLI11) + +add_executable(score score.cpp) +target_link_libraries(score pisa CLI11) + +add_executable(bmscore bmscore.cpp) +target_link_libraries(bmscore pisa CLI11) + +add_executable(bigram-index bigram_index.cpp) +target_link_libraries(bigram-index pisa CLI11) + +add_executable(filter-queries filter_queries.cpp) +target_link_libraries(filter-queries pisa CLI11) + +add_executable(threshold threshold.cpp) +target_link_libraries(threshold pisa CLI11) + +add_executable(intersection intersection.cpp) +target_link_libraries(intersection pisa CLI11) + +add_executable(select-pairs select_pairs.cpp) +target_link_libraries(select-pairs pisa CLI11) + +add_executable(count-postings count_postings.cpp) +target_link_libraries(count-postings pisa CLI11) + +add_executable(stats stats.cpp) +target_link_libraries(stats pisa CLI11) + +add_executable(id-to-term id_to_term.cpp) +target_link_libraries(id-to-term pisa CLI11) diff --git a/v1/app.hpp b/v1/app.hpp new file mode 100644 index 000000000..0824fa76c --- /dev/null +++ b/v1/app.hpp @@ -0,0 +1,189 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "io.hpp" +#include "v1/index_metadata.hpp" +#include "v1/runtime_assert.hpp" + +namespace pisa { + +namespace arg { + + struct Index { + explicit Index(CLI::App* app) + { + app->add_option("-i,--index", + m_metadata_path, + "Path of .yml file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + } + + [[nodiscard]] auto index_metadata() const -> v1::IndexMetadata + { + return v1::IndexMetadata::from_file(v1::resolve_yml(m_metadata_path)); + } + + private: + tl::optional m_metadata_path; + }; + + enum class QueryMode : bool { Ranked, Unranked }; + + template + struct Query { + explicit Query(CLI::App* app) + { + app->add_option("-q,--query", m_query_file, "Path to file with queries", false); + app->add_option("--qf,--query-fmt", m_query_input_format, "Input file format", true); + if constexpr (Mode == QueryMode::Ranked) { + app->add_option("-k", m_k, "The number of top results to return", true); + } + app->add_flag("--force-parse", + m_force_parse, + "Force parsing of query string even ifterm IDs already available"); + app->add_option( + "--stopwords", m_stop_words, "List of blacklisted stop words to filter out"); + } + + [[nodiscard]] auto query_file() -> tl::optional + { + if (m_query_file) { + return m_query_file.value(); + } + return tl::nullopt; + } + + [[nodiscard]] auto queries(v1::IndexMetadata const& meta) const -> std::vector + { + std::vector queries; + auto parser = meta.query_parser(m_stop_words); + auto parse_line = [&](auto&& line) { + auto query = [&line, this]() { + if (m_query_input_format == "jl") { + return v1::Query::from_json(line); + } + return v1::Query::from_plain(line); + }(); + if (not query.term_ids() || m_force_parse) { + query.parse(parser); + } + if constexpr (Mode == QueryMode::Ranked) { + query.k(m_k); + } + queries.push_back(std::move(query)); + }; + if (m_query_file) { + std::ifstream is(*m_query_file); + pisa::io::for_each_line(is, parse_line); + } else { + pisa::io::for_each_line(std::cin, parse_line); + } + return queries; + } + + [[nodiscard]] auto query_range(v1::IndexMetadata const& meta) + { + auto lines = [&] { + if (m_query_file) { + m_query_file_handle = std::make_unique(*m_query_file); + return ranges::getlines(*m_query_file_handle); + } + return ranges::getlines(std::cin); + }(); + return ranges::views::transform(lines, + [force_parse = m_force_parse, + k = m_k, + parser = meta.query_parser(m_stop_words), + qfmt = m_query_input_format](auto&& line) { + auto query = [&]() { + if (qfmt == "jl") { + return v1::Query::from_json(line); + } + if (qfmt == "plain") { + return v1::Query::from_plain(line); + } + spdlog::error("Unknown query format: {}", + qfmt); + std::exit(1); + }(); + if (not query.term_ids() || force_parse) { + query.parse(parser); + } + // Not constexpr to silence unused k value warning. + // Performance is not a concern. + if (Mode == QueryMode::Ranked) { + query.k(k); + } + return query; + }); + } + + private: + std::unique_ptr m_query_file_handle = nullptr; + tl::optional m_query_file; + std::string m_query_input_format = "jl"; + int m_k = DefaultK; + bool m_force_parse{false}; + tl::optional m_stop_words{tl::nullopt}; + }; + + struct Benchmark { + explicit Benchmark(CLI::App* app) + { + app->add_flag("--benchmark", m_is_benchmark, "Run benchmark"); + } + + [[nodiscard]] auto is_benchmark() const -> bool { return m_is_benchmark; } + + private: + bool m_is_benchmark = false; + }; + + struct QuantizedScores { + explicit QuantizedScores(CLI::App* app) + { + app->add_flag("--quantized", m_use_quantized, "Use quantized scores"); + } + + [[nodiscard]] auto use_quantized() const -> bool { return m_use_quantized; } + + private: + bool m_use_quantized = false; + }; + + struct Threads { + explicit Threads(CLI::App* app) + { + app->add_option("-j,--threads", m_threads, "Number of threads"); + } + + [[nodiscard]] auto threads() const -> std::size_t { return m_threads; } + + private: + std::size_t m_threads = std::thread::hardware_concurrency(); + }; + +} // namespace arg + +template +struct App : public CLI::App, public Mixin... { + explicit App(std::string const& description) : CLI::App(description), Mixin(this)... {} +}; + +struct QueryApp : public App, + arg::Benchmark, + arg::QuantizedScores> { + explicit QueryApp(std::string const& description) : App(description) {} +}; + +} // namespace pisa diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp new file mode 100644 index 000000000..0355d8f35 --- /dev/null +++ b/v1/bigram_index.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "v1/index_builder.hpp" +#include "v1/progress_status.hpp" +#include "v1/types.hpp" + +using pisa::App; +using pisa::v1::build_pair_index; +using pisa::v1::collect_unique_bigrams; +using pisa::v1::DefaultProgressCallback; +using pisa::v1::ProgressStatus; + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + tl::optional clone_path{}; + + App, arg::Threads> app{ + "Creates a v1 bigram index."}; + app.add_option("--clone", clone_path, "Instead", false); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + + spdlog::info("Collecting queries..."); + auto queries = app.queries(meta); + + spdlog::info("Collected {} queries", queries.size()); + spdlog::info("Collecting bigrams..."); + ProgressStatus status( + queries.size(), DefaultProgressCallback{}, std::chrono::milliseconds(1000)); + auto bigrams = collect_unique_bigrams(queries, [&]() { status += 1; }); + status.close(); + spdlog::info("Collected {} bigrams", bigrams.size()); + build_pair_index(meta, bigrams, clone_path, app.threads()); + return 0; +} diff --git a/v1/bmscore.cpp b/v1/bmscore.cpp new file mode 100644 index 000000000..78e184969 --- /dev/null +++ b/v1/bmscore.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + +#include + +#include "app.hpp" +#include "v1/score_index.hpp" + +using pisa::App; +using pisa::v1::BlockType; +using pisa::v1::FixedBlock; +using pisa::v1::VariableBlock; +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + std::optional yml{}; + std::size_t block_size = 128; + std::size_t threads = std::thread::hardware_concurrency(); + std::optional lambda{}; + tl::optional clone_path{}; + + App app{"Constructs block-max score lists for v1 index."}; + app.add_option("--block-size", block_size, "The size of a block for max scores", true); + app.add_option("--variable-blocks", lambda, "The size of a block for max scores", false); + app.add_option( + "--clone", + clone_path, + "Clone .yml metadata to another path, and then score (won't affect the initial index)", + false); + CLI11_PARSE(app, argc, argv); + + auto block_type = [&]() -> BlockType { + if (lambda) { + return VariableBlock{*lambda}; + } + return FixedBlock{block_size}; + }(); + + pisa::v1::bm_score_index(app.index_metadata(), block_type, clone_path, app.threads()); + return 0; +} diff --git a/v1/compress.cpp b/v1/compress.cpp new file mode 100644 index 000000000..1d2521a1b --- /dev/null +++ b/v1/compress.cpp @@ -0,0 +1,110 @@ +#include + +#include +#include +#include + +#include "sequence/partitioned_sequence.hpp" +#include "v1/bit_sequence_cursor.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" +#include "v1/types.hpp" + +using std::literals::string_view_literals::operator""sv; + +using pisa::v1::compress_binary_collection; +using pisa::v1::DocumentBitSequenceWriter; +using pisa::v1::DocumentBlockedWriter; +using pisa::v1::EncodingId; +using pisa::v1::make_index_builder; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceWriter; +using pisa::v1::PayloadBlockedWriter; +using pisa::v1::PositiveSequence; +using pisa::v1::RawWriter; +using pisa::v1::verify_compressed_index; + +auto document_encoding(std::string_view name) -> std::uint32_t +{ + if (name == "raw"sv) { + return EncodingId::Raw; + } + if (name == "simdbp"sv) { + return EncodingId::BlockDelta | EncodingId::SimdBP; + } + if (name == "pef"sv) { + return EncodingId::BitSequence | EncodingId::PEF; + } + spdlog::error("Unknown encoding: {}", name); + std::exit(1); +} + +auto frequency_encoding(std::string_view name) -> std::uint32_t +{ + if (name == "raw"sv) { + return EncodingId::Raw; + } + if (name == "simdbp"sv) { + return EncodingId::Block | EncodingId::SimdBP; + } + if (name == "pef"sv) { + return EncodingId::BitSequence | EncodingId::PositiveSeq; + } + spdlog::error("Unknown encoding: {}", name); + std::exit(1); +} + +int main(int argc, char** argv) +{ + std::string input; + std::string fwd; + std::string output; + std::string encoding; + std::size_t threads = std::thread::hardware_concurrency(); + + CLI::App app{"Compresses a given binary collection to a v1 index."}; + app.add_option("-i,--inv", input, "Input collection basename")->required(); + // TODO(michal): Potentially, this would be removed once inv contains necessary info. + app.add_option("-f,--fwd", fwd, "Input forward index")->required(); + app.add_option("-o,--output", output, "Output basename")->required(); + app.add_option("-j,--threads", threads, "Number of threads"); + app.add_option("-e,--encoding", encoding, "Number of threads")->required(); + CLI11_PARSE(app, argc, argv); + + tbb::task_scheduler_init init(threads); + auto build = + make_index_builder(std::make_tuple(RawWriter{}, + DocumentBlockedWriter<::pisa::simdbp_block>{}, + DocumentBitSequenceWriter>{}), + std::make_tuple(RawWriter{}, + PayloadBlockedWriter<::pisa::simdbp_block>{}, + PayloadBitSequenceWriter>{})); + build(document_encoding(encoding), + frequency_encoding(encoding), + [&](auto document_writer, auto payload_writer) { + compress_binary_collection(input, + fwd, + output, + threads, + make_writer(std::move(document_writer)), + make_writer(std::move(payload_writer))); + }); + auto errors = verify_compressed_index(input, output); + if (not errors.empty()) { + if (errors.size() > 10) { + std::cerr << "Detected more than 10 errors, printing head:\n"; + errors.resize(10); + } + for (auto const& error : errors) { + std::cerr << error << '\n'; + } + return 1; + } + + std::cout << "Success."; + return 0; +} diff --git a/v1/count_postings.cpp b/v1/count_postings.cpp new file mode 100644 index 000000000..b52462ecb --- /dev/null +++ b/v1/count_postings.cpp @@ -0,0 +1,59 @@ +#include + +#include +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/progress_status.hpp" + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool pair_index = false; + bool term_by_term = false; + + pisa::App app("Simply counts all postings in the index"); + auto* pairs_opt = + app.add_flag("--pairs", pair_index, "Count postings in the pair index instead"); + app.add_flag("-t,--terms", term_by_term, "Print posting counts for each term in the index") + ->excludes(pairs_opt); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + std::size_t count{0}; + pisa::v1::index_runner(meta)([&](auto&& index) { + if (pair_index) { + pisa::v1::ProgressStatus status( + index.pairs()->size(), + pisa::v1::DefaultProgressCallback("Counting pair postings"), + std::chrono::milliseconds(500)); + for (auto term_pair : index.pairs().value()) { + count += + index.bigram_cursor(std::get<0>(term_pair), std::get<1>(term_pair))->size(); + status += 1; + } + } else if (term_by_term) { + for (pisa::v1::TermId id = 0; id < index.num_terms(); id += 1) { + std::cout << index.term_posting_count(id) << '\n'; + } + } else { + pisa::v1::ProgressStatus status( + index.num_terms(), + pisa::v1::DefaultProgressCallback("Counting term postings"), + std::chrono::milliseconds(500)); + for (pisa::v1::TermId id = 0; id < index.num_terms(); id += 1) { + count += index.term_posting_count(id); + status += 1; + } + } + }); + std::cout << count << '\n'; + return 0; +} diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp new file mode 100644 index 000000000..444ee43ad --- /dev/null +++ b/v1/filter_queries.cpp @@ -0,0 +1,33 @@ +#include + +#include +#include +#include +#include + +#include "app.hpp" + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + std::size_t min_query_len = 1; + std::size_t max_query_len = std::numeric_limits::max(); + + pisa::App> app( + "Filters out empty queries against a v1 index."); + app.add_option("--min", min_query_len, "Minimum query legth to consider"); + app.add_option("--max", max_query_len, "Maximum query legth to consider"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + for (auto&& query : app.query_range(meta)) { + if (auto len = query.get_term_ids().size(); len >= min_query_len && len <= max_query_len) { + std::cout << *query.to_json() << '\n'; + } + } + return 0; +} diff --git a/v1/id_to_term.cpp b/v1/id_to_term.cpp new file mode 100644 index 000000000..a13bf36e5 --- /dev/null +++ b/v1/id_to_term.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "payload_vector.hpp" + +namespace arg = pisa::arg; +using pisa::Payload_Vector; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Each ID from input translated to term"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + if (not meta.term_lexicon) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(meta.term_lexicon.value().c_str()); + auto lex = pisa::Payload_Vector<>::from(*source); + + pisa::io::for_each_line(std::cin, [&](auto&& line) { + std::istringstream is(line); + std::string next; + if (is >> next) { + std::cout << lex[std::stoi(next)]; + } + while (is >> next) { + std::cout << fmt::format(" {}", lex[std::stoi(next)]); + } + std::cout << '\n'; + }); + + return 0; +} diff --git a/v1/intersection.cpp b/v1/intersection.cpp new file mode 100644 index 000000000..27a80e89d --- /dev/null +++ b/v1/intersection.cpp @@ -0,0 +1,257 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "app.hpp" +#include "intersection.hpp" +#include "query/queries.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/accumulate.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" + +using pisa::App; +using pisa::Intersection; +using pisa::intersection::IntersectionType; +using pisa::intersection::Mask; +using pisa::v1::intersect; +using pisa::v1::make_bm25; +using pisa::v1::RawReader; +using pisa::v1::TermId; + +namespace arg = pisa::arg; +namespace v1 = pisa::v1; + +template +auto compute_intersection(Index const& index, + pisa::v1::Query const& query, + tl::optional> term_selection) -> Intersection +{ + auto const term_ids = + term_selection ? query.filtered_terms(*term_selection) : query.get_term_ids(); + if (term_ids.size() == 1) { + auto cursor = index.max_scored_cursor(term_ids[0], make_bm25(index)); + return Intersection{cursor.size(), cursor.max_score()}; + } + auto cursors = index.scored_cursors(gsl::make_span(term_ids), make_bm25(index)); + auto intersection = + intersect(cursors, 0.0F, [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + }); + return Intersection{postings, max_score}; +} + +/// Do `func` for all intersections in a query that have a given maximum number of terms. +/// `Fn` takes `Query` and `Mask`. +template +auto for_all_subsets(v1::Query const& query, tl::optional max_term_count, Fn func) +{ + auto&& term_ids = query.get_term_ids(); + auto subset_count = 1U << term_ids.size(); + for (auto subset = 1U; subset < subset_count; ++subset) { + auto mask = std::bitset<64>(subset); + if (!max_term_count || mask.count() <= *max_term_count) { + func(query, mask); + } + } +} + +template +void compute_intersections(Index const& index, + QRng queries, + IntersectionType intersection_type, + tl::optional max_term_count, + bool existing, + tl::optional>> const& in_set) +{ + for (auto const& query : queries) { + auto intersections = nlohmann::json::array(); + auto inter = [&](auto&& query, tl::optional> const& mask) { + auto intersection = compute_intersection(index, query, mask); + if (intersection.length > 0) { + intersections.push_back( + nlohmann::json{{"intersection", mask.value_or(0).to_ulong()}, + {"cost", intersection.length}, + {"max_score", intersection.max_score}}); + } + }; + if (intersection_type == IntersectionType::Combinations) { + if (in_set) { + std::uint64_t left_mask = 1; + auto term_ids = query.get_term_ids(); + for (auto left = 0; left < term_ids.size(); left += 1) { + auto cursor = index.max_scored_cursor(term_ids[left], make_bm25(index)); + intersections.push_back(nlohmann::json{{"intersection", left_mask}, + {"cost", cursor.size()}, + {"max_score", cursor.max_score()}}); + std::uint64_t right_mask = left_mask << 1U; + for (auto right = left + 1; right < term_ids.size(); right += 1) { + if (auto bid = index.bigram_id(term_ids[left], term_ids[right]); bid) { + if (in_set->find(std::make_pair(term_ids[left], term_ids[right])) + == in_set->end()) { + continue; + } + std::vector const terms{term_ids[left], term_ids[right]}; + auto cursors = + index.scored_cursors(gsl::make_span(terms), make_bm25(index)); + auto intersection = + intersect(cursors, + 0.0F, + [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + }); + intersections.push_back( + nlohmann::json{{"intersection", left_mask | right_mask}, + {"cost", postings}, + {"max_score", max_score}}); + } + right_mask <<= 1U; + } + left_mask <<= 1U; + } + } else if (existing) { + std::uint64_t left_mask = 1; + auto term_ids = query.get_term_ids(); + for (auto left = 0; left < term_ids.size(); left += 1) { + auto cursor = index.max_scored_cursor(term_ids[left], make_bm25(index)); + intersections.push_back(nlohmann::json{{"intersection", left_mask}, + {"cost", cursor.size()}, + {"max_score", cursor.max_score()}}); + std::uint64_t right_mask = left_mask << 1U; + for (auto right = left + 1; right < term_ids.size(); right += 1) { + if (auto bid = index.bigram_id(term_ids[left], term_ids[right]); bid) { + std::vector const terms{term_ids[left], term_ids[right]}; + auto cursors = + index.scored_cursors(gsl::make_span(terms), make_bm25(index)); + auto intersection = + intersect(cursors, + 0.0F, + [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + }); + intersections.push_back( + nlohmann::json{{"intersection", left_mask | right_mask}, + {"cost", postings}, + {"max_score", max_score}}); + } + // index + // .scored_bigram_cursor(term_ids[left], term_ids[right], + // make_bm25(index)) .map([&](auto&& cursor) { + // // TODO(michal): Do not traverse once max scores for bigrams are + // // implemented + // auto cost = cursor.size(); + // auto max_score = + // accumulate(cursor, 0.0F, [](float acc, auto&& cursor) { + // auto score = std::get<0>(cursor.payload()) + // + std::get<1>(cursor.payload()); + // return std::max(acc, score); + // }); + // intersections.push_back( + // nlohmann::json{{"intersection", left_mask | right_mask}, + // {"cost", cost}, + // {"max_score", max_score}}); + // }); + right_mask <<= 1U; + } + left_mask <<= 1U; + } + } else { + for_all_subsets(query, max_term_count, inter); + } + } else { + inter(query, tl::nullopt); + } + auto output = *query.to_json(); + output["intersections"] = intersections; + std::cout << output << '\n'; + } +} + +int main(int argc, const char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool combinations = false; + bool existing = false; + std::optional max_term_count; + std::optional in_set_path; + + pisa::App> app( + "Calculates intersections for a v1 index."); + auto* combinations_flag = app.add_flag( + "--combinations", combinations, "Compute intersections for combinations of terms in query"); + auto* mtc_flag = app.add_option("--max-term-count,--mtc", + max_term_count, + "Max number of terms when computing combinations"); + mtc_flag->needs(combinations_flag); + auto* existing_flag = app.add_flag("--existing", existing, "Use only existing bigrams") + ->needs(combinations_flag) + ->excludes(mtc_flag); + app.add_option("--in", in_set_path, "Use only bigrams from this list") + ->needs(combinations_flag) + ->excludes(mtc_flag) + ->excludes(existing_flag); + CLI11_PARSE(app, argc, argv); + auto mtc = max_term_count ? tl::make_optional(*max_term_count) : tl::optional{}; + tl::optional>> in_set{}; + if (in_set_path) { + in_set = std::set>{}; + std::ifstream is(*in_set_path); + std::string left, right; + while (is >> left >> right) { + in_set->emplace(std::stoi(left), std::stoi(right)); + } + } + + IntersectionType intersection_type = + combinations ? IntersectionType::Combinations : IntersectionType::Query; + + try { + auto meta = app.index_metadata(); + auto queries = app.query_range(meta); + + auto run = index_runner(meta); + run([&](auto&& index) { + compute_intersections(index, queries, intersection_type, mtc, existing, in_set); + }); + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); + } + return 0; +} diff --git a/v1/postings.cpp b/v1/postings.cpp new file mode 100644 index 000000000..221d65162 --- /dev/null +++ b/v1/postings.cpp @@ -0,0 +1,301 @@ +#include +#include +#include +#include + +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "query/queries.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::App; +using pisa::Query; +using pisa::resolve_query_parser; +using pisa::v1::collect_payloads; +using pisa::v1::index_runner; +using pisa::v1::runtime_assert; + +namespace arg = pisa::arg; + +[[nodiscard]] auto load_source(std::optional const& file) + -> std::shared_ptr +{ + if (file) { + return std::make_shared(file->c_str()); + } + return nullptr; +} + +[[nodiscard]] auto load_payload_vector(std::shared_ptr const& source) + -> std::optional> +{ + if (source) { + return pisa::Payload_Vector<>::from(*source); + } + return std::nullopt; +} + +/// Returns the first value (not nullopt), or nullopt if no optional contains a value. +template +[[nodiscard]] auto value(First&& first, Optional&&... candidtes) +{ + std::optional> val = std::nullopt; + auto has_value = [&](auto&& opt) -> bool { + if (not val.has_value() && opt) { + val = *opt; + return true; + } + return false; + }; + has_value(first) || (has_value(candidtes) || ...); + return val; +} + +void print_header(std::vector const& percentiles, std::vector const& cutoffs) +{ + std::cout << "length"; + for (auto cutoff : cutoffs) { + std::cout << '\t' << "top-" << cutoff + 1; + } + for (auto percentile : percentiles) { + std::cout << '\t' << "perc-" << percentile; + } + std::cout << '\n'; +} + +template +void calc_stats(Cursor&& cursor, + std::vector const& percentiles, + std::vector const& cutoffs) +{ + using payload_type = std::decay_t; + auto payloads = collect_payloads(cursor); + auto length = payloads.size(); + std::sort(payloads.begin(), payloads.end(), std::greater<>{}); + auto kth = [&](auto k) { + if (k < payloads.size()) { + return payloads[k]; + } + return payload_type{}; + }; + std::cout << length; + for (auto cutoff : cutoffs) { + std::cout << '\t' << kth(cutoff); + } + for (auto percentile : percentiles) { + std::cout << '\t' + << payloads[std::min(percentile * payloads.size() / 100, payloads.size() - 1)]; + } + std::cout << '\n'; +} + +template +auto print_postings(Cursor&& cursor, + Scorer&& scorer, + std::optional> const& docmap, + bool did, + bool print_frequencies, + bool print_scores) +{ + auto print = [&](auto&& cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if (print_frequencies) { + std::cout << " " << cursor.payload(); + } + if (print_scores) { + std::cout << " " << scorer(cursor.value(), cursor.payload()); + } + std::cout << '\n'; + }; + for_each(cursor, print); +}; + +template +auto print_precomputed_postings(Cursor&& cursor, + std::optional> const& docmap, + bool did) +{ + auto print = [&](auto&& cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if constexpr (std::is_same_v) { + std::cout << " " << static_cast(cursor.payload()) << '\n'; + } else { + std::cout << " " << cursor.payload() << '\n'; + } + }; + for_each(std::forward(cursor), print); +}; + +int main(int argc, char** argv) +{ + std::vector percentiles = {0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}; + std::vector cutoffs = { + 0, 9, 99, 999, 9'999, 99'999, 999'999, 9'999'999, 99'999'999}; + std::optional terms_file{}; + std::optional documents_file{}; + std::string query_input{}; + bool tid = false; + bool did = false; + bool print_frequencies = false; + bool print_scores = false; + bool precomputed = false; + bool stats = false; + bool header = false; + + App app{"Queries a v1 index."}; + app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); + app.add_option("--documents", + documents_file, + "Overrides document lexicon from .yml (if defined). Required otherwise."); + app.add_flag("--tid", tid, "Use term IDs instead of terms"); + app.add_flag("--did", did, "Print document IDs instead of titles"); + app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); + auto* scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); + app.add_flag("--precomputed", precomputed, "Use BM25 precomputed scores")->needs(scores_option); + + auto* stats_option = app.add_flag("--stats", stats, "Print stats instead of listing postings"); + app.add_option("--percentiles", percentiles, "Percentiles for stats", true) + ->needs(stats_option); + app.add_option("--cutoffs", cutoffs, "Cut-offs for stats", true)->needs(stats_option); + app.add_flag("--header", header, "Print stats header")->needs(stats_option); + + app.add_option("query", query_input, "List of terms", false)->required(); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (tid) { + terms_file = std::nullopt; + } else { + terms_file = value(meta.term_lexicon, terms_file); + } + documents_file = value(meta.document_lexicon, documents_file); + + if (not did and not documents_file) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + + if (not tid and not terms_file) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + + std::shared_ptr const source = load_source(documents_file); + std::optional> const docmap = load_payload_vector(source); + + auto const query = [&]() { + std::vector queries; + auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); + parse_query(query_input); + return queries[0]; + }(); + + if (header) { + print_header(percentiles, cutoffs); + } + + if (stats) { + for (auto percentile : percentiles) { + pisa::v1::runtime_assert(percentile >= 0 && percentile <= 100) + .or_exit("Percentiles must be in [0, 100]"); + } + } + + if (query.terms.size() == 1) { + if (precomputed) { + auto run = scored_index_runner(meta); + run([&](auto&& index) { + auto cursor = index.cursor(query.terms.front()); + if (stats) { + calc_stats(cursor, percentiles, cutoffs); + } else { + print_precomputed_postings(cursor, docmap, did); + } + }); + } else { + auto run = index_runner(meta); + run([&](auto&& index) { + auto bm25 = make_bm25(index); + if (stats) { + calc_stats( + index.scored_cursor(query.terms.front(), bm25), percentiles, cutoffs); + } else { + auto scorer = bm25.term_scorer(query.terms.front()); + print_postings(index.cursor(query.terms.front()), + scorer, + docmap, + did, + print_frequencies, + print_scores); + } + }); + } + } else { + if (precomputed) { + auto run = scored_index_runner(meta); + run([&](auto&& index) { + auto cursor = + ::pisa::v1::intersect(index.cursors(gsl::make_span(query.terms)), + 0.0, + [](auto acc, auto&& cursor, [[maybe_unused]] auto idx) { + return acc + cursor.payload(); + }); + if (stats) { + calc_stats(cursor, percentiles, cutoffs); + } else { + print_precomputed_postings(cursor, docmap, did); + } + }); + } else { + auto run = index_runner(meta); + run([&](auto&& index) { + auto bm25 = make_bm25(index); + if (stats) { + auto cursor = ::pisa::v1::intersect( + index.scored_cursors(gsl::make_span(query.terms), bm25), + 0.0, + [](auto acc, auto&& cursor, [[maybe_unused]] auto idx) { + return acc + cursor.payload(); + }); + calc_stats(cursor, percentiles, cutoffs); + } else { + runtime_assert(query.terms.size() == 1) + .or_exit("Printing scoring intersections not supported yet."); + auto scorer = bm25.term_scorer(query.terms.front()); + print_postings(index.cursor(query.terms.front()), + scorer, + docmap, + did, + print_frequencies, + print_scores); + } + }); + } + } + + return 0; +} diff --git a/v1/query.cpp b/v1/query.cpp new file mode 100644 index 000000000..90839f49a --- /dev/null +++ b/v1/query.cpp @@ -0,0 +1,385 @@ +#include +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "query/queries.hpp" +#include "timer.hpp" +#include "topk_queue.hpp" +#include "util/do_not_optimize_away.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/daat_or.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_metadata.hpp" +#include "v1/inspect_query.hpp" +#include "v1/maxscore.hpp" +#include "v1/maxscore_union_lookup.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" +#include "v1/unigram_union_lookup.hpp" +#include "v1/union_lookup.hpp" +#include "v1/wand.hpp" + +using pisa::v1::daat_or; +using pisa::v1::DocumentBlockedReader; +using pisa::v1::index_runner; +using pisa::v1::InspectDaatOr; +using pisa::v1::InspectLookupUnion; +using pisa::v1::InspectLookupUnionEaat; +using pisa::v1::InspectMaxScore; +using pisa::v1::InspectMaxScoreUnionLookup; +using pisa::v1::InspectUnigramUnionLookup; +using pisa::v1::InspectUnionLookup; +using pisa::v1::InspectUnionLookupPlus; +using pisa::v1::lookup_union; +using pisa::v1::maxscore_union_lookup; +using pisa::v1::PayloadBlockedReader; +using pisa::v1::Query; +using pisa::v1::QueryInspector; +using pisa::v1::RawReader; +using pisa::v1::unigram_union_lookup; +using pisa::v1::union_lookup; +using pisa::v1::union_lookup_plus; +using pisa::v1::VoidScorer; +using pisa::v1::wand; + +struct RetrievalAlgorithm { + template + explicit RetrievalAlgorithm(Fn fn, FallbackFn fallback, bool safe) + : m_retrieve(std::move(fn)), m_fallback(std::move(fallback)), m_safe(safe) + { + } + + [[nodiscard]] auto operator()(pisa::v1::Query const& query, ::pisa::topk_queue topk) const + -> ::pisa::topk_queue + { + topk = m_retrieve(query, topk); + if (m_safe && not topk.full()) { + spdlog::debug("Retrieved {} out of {} documents. Rerunning without threshold.", + topk.topk().size(), + topk.size()); + topk.clear(); + topk = m_fallback(query, topk); + } + return topk; + } + + private: + std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)> m_retrieve; + std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)> m_fallback; + bool m_safe; +}; + +template +auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& scorer, bool safe) + -> RetrievalAlgorithm +{ + auto fallback = [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + topk.clear(); + return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); + }; + if (name == "daat_or") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::daat_or( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "wand") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::wand(query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "bmw") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::bmw(query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "maxscore") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "maxscore-union-lookup") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::maxscore_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "unigram-union-lookup") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "union-lookup") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + if (query.get_term_ids().size() >= 8) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "union-lookup-plus") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + if (query.get_term_ids().size() > 8) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::union_lookup_plus( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "lookup-union") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + if (query.selections()->unigrams.empty()) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::lookup_union( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "lookup-union-eaat") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + if (query.selections()->unigrams.empty()) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::lookup_union_eaat( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + spdlog::error("Unknown algorithm: {}", name); + std::exit(1); +} + +template +auto resolve_inspect(std::string const& name, Index const& index, Scorer&& scorer) -> QueryInspector +{ + if (name == "daat_or") { + return QueryInspector(InspectDaatOr(index, std::forward(scorer))); + } + if (name == "maxscore") { + return QueryInspector(InspectMaxScore(index, std::forward(scorer))); + } + if (name == "maxscore-union-lookup") { + return QueryInspector( + InspectMaxScoreUnionLookup>(index, scorer)); + } + if (name == "unigram-union-lookup") { + return QueryInspector( + InspectUnigramUnionLookup>(index, scorer)); + } + if (name == "union-lookup") { + return QueryInspector(InspectUnionLookup>(index, scorer)); + } + if (name == "lookup-union") { + return QueryInspector(InspectLookupUnion>(index, scorer)); + } + if (name == "lookup-union-eaat") { + return QueryInspector(InspectLookupUnionEaat>(index, scorer)); + } + if (name == "union-lookup-plus") { + return QueryInspector(InspectUnionLookupPlus>(index, scorer)); + } + spdlog::error("Unknown algorithm: {}", name); + std::exit(1); +} + +void evaluate(std::vector const& queries, + pisa::Payload_Vector<> const& docmap, + RetrievalAlgorithm const& retrieve) +{ + auto query_idx = 0; + for (auto const& query : queries) { + auto que = retrieve(query, pisa::topk_queue(query.k())); + que.finalize(); + auto rank = 0; + for (auto result : que.topk()) { + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", + query.id().value_or(std::to_string(query_idx)), + "Q0", + docmap[result.second], + rank, + result.first, + "R0"); + rank += 1; + } + query_idx += 1; + } +} + +void benchmark(std::vector const& queries, RetrievalAlgorithm retrieve) + +{ + std::vector times(queries.size(), std::numeric_limits::max()); + for (auto run = 0; run < 5; run += 1) { + for (auto query = 0; query < queries.size(); query += 1) { + auto usecs = ::pisa::run_with_timer([&]() { + auto que = retrieve(queries[query], pisa::topk_queue(queries[query].k())); + que.finalize(); + do_not_optimize_away(que); + }); + times[query] = std::min(times[query], static_cast(usecs.count())); + } + } + for (auto time : times) { + std::cout << time << '\n'; + } + std::sort(times.begin(), times.end()); + double avg = std::accumulate(times.begin(), times.end(), double()) / times.size(); + double q50 = times[times.size() / 2]; + double q90 = times[90 * times.size() / 100]; + double q95 = times[95 * times.size() / 100]; + spdlog::info("Mean: {} us", avg); + spdlog::info("50% quantile: {} us", q50); + spdlog::info("90% quantile: {} us", q90); + spdlog::info("95% quantile: {} us", q95); +} + +void inspect_queries(std::vector const& queries, QueryInspector inspect) +{ + inspect.header(std::cout); + std::cout << '\n'; + for (auto query = 0; query < queries.size(); query += 1) { + inspect(queries[query]).write(std::cout); + std::cout << '\n'; + } + + std::cerr << "========== Avg ==========\n"; + inspect.header(std::cerr); + std::cerr << '\n'; + inspect.mean().write(std::cerr); + std::cerr << '\n'; + std::cerr << "=========================\n"; +} + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + std::string algorithm = "daat_or"; + bool inspect = false; + bool safe = false; + + pisa::QueryApp app("Queries a v1 index."); + app.add_option("--algorithm", algorithm, "Query retrieval algorithm", true); + app.add_flag("--inspect", inspect, "Analyze query execution and stats"); + app.add_flag("--safe", safe, "Repeats without threshold if it was overestimated"); + CLI11_PARSE(app, argc, argv); + + try { + auto meta = app.index_metadata(); + auto queries = app.queries(meta); + + if (not meta.document_lexicon) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(meta.document_lexicon.value().c_str()); + auto docmap = pisa::Payload_Vector<>::from(*source); + + if (app.use_quantized()) { + auto run = + scored_index_runner(meta, + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}), + std::make_tuple(RawReader{}, + PayloadBlockedReader<::pisa::simdbp_block>{})); + run([&](auto&& index) { + if (app.is_benchmark()) { + benchmark(queries, resolve_algorithm(algorithm, index, VoidScorer{}, safe)); + } else if (inspect) { + inspect_queries(queries, resolve_inspect(algorithm, index, VoidScorer{})); + } else { + evaluate( + queries, docmap, resolve_algorithm(algorithm, index, VoidScorer{}, safe)); + } + }); + } else { + auto run = index_runner(meta); + run([&](auto&& index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + if (app.is_benchmark()) { + benchmark(queries, resolve_algorithm(algorithm, index, scorer, safe)); + } else if (inspect) { + inspect_queries(queries, resolve_inspect(algorithm, index, scorer)); + } else { + evaluate( + queries, docmap, resolve_algorithm(algorithm, index, scorer, safe)); + } + }); + }); + } + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); + } + return 0; +} diff --git a/v1/score.cpp b/v1/score.cpp new file mode 100644 index 000000000..4d622bcdd --- /dev/null +++ b/v1/score.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include + +#include + +#include "app.hpp" +#include "v1/score_index.hpp" + +using pisa::App; +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + std::optional yml{}; + int bytes_per_score = 1; + std::size_t threads = std::thread::hardware_concurrency(); + + App app{"Scores v1 index."}; + // TODO(michal): enable + // app.add_option( + // "-b,--bytes-per-score", yml, "Quantize computed scores to this many bytes", true); + CLI11_PARSE(app, argc, argv); + pisa::v1::score_index(app.index_metadata(), app.threads()); + return 0; +} diff --git a/v1/select_pairs.cpp b/v1/select_pairs.cpp new file mode 100644 index 000000000..9db65b0f9 --- /dev/null +++ b/v1/select_pairs.cpp @@ -0,0 +1,41 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/types.hpp" + +using pisa::App; +using pisa::v1::bigram_gain; +using pisa::v1::index_runner; + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + std::optional terms_file{}; + //std::size_t num_pairs_to_select; + + App> app{"Creates a v1 bigram index."}; + // app.add_option("--count", num_pairs_to_select, "Number of pairs to select")->required(); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + + auto run = index_runner(meta); + run([&](auto&& index) { + for (auto&& query : app.query_range(meta)) { + auto term_ids = query.get_term_ids(); + std::cout + << fmt::format("{}\t{}\t{}\n", term_ids[0], term_ids[1], bigram_gain(index, query)); + } + }); + return 0; +} diff --git a/v1/stats.cpp b/v1/stats.cpp new file mode 100644 index 000000000..3376997bc --- /dev/null +++ b/v1/stats.cpp @@ -0,0 +1,29 @@ +#include + +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Simply counts all postings in the index"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + std::size_t count{0}; + pisa::v1::index_runner(meta)([&](auto&& index) { + std::cout << fmt::format("#terms: {}\n", index.num_terms()); + std::cout << fmt::format("#documents: {}\n", index.num_documents()); + std::cout << fmt::format("#pairs: {}\n", index.num_pairs()); + std::cout << fmt::format("avg. document length: {}\n", index.avg_document_length()); + }); + return 0; +} diff --git a/v1/term_to_id.cpp b/v1/term_to_id.cpp new file mode 100644 index 000000000..adb96924f --- /dev/null +++ b/v1/term_to_id.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "payload_vector.hpp" + +namespace arg = pisa::arg; +using pisa::Payload_Vector; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Each term from input translated to ID"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + if (not meta.term_lexicon) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(meta.term_lexicon.value().c_str()); + auto lex = pisa::Payload_Vector<>::from(*source); + + pisa::io::for_each_line(std::cin, [&](auto&& line) { + std::istringstream is(line); + std::string next; + if (is >> next) { + std::cout << lex[std::stoi(next)]; + } + while (is >> next) { + std::cout << fmt::format(" {}", lex[std::stoi(next)]); + } + std::cout << '\n'; + }); + + return 0; +} diff --git a/v1/threshold.cpp b/v1/threshold.cpp new file mode 100644 index 000000000..4b9ca52d0 --- /dev/null +++ b/v1/threshold.cpp @@ -0,0 +1,92 @@ +#include +#include + +#include +#include +#include +#include + +#include "app.hpp" +#include "query/queries.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/daat_or.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::v1::index_runner; +using pisa::v1::Query; +using pisa::v1::VoidScorer; + +template +void calculate_thresholds(Index&& index, + Scorer&& scorer, + std::vector& queries, + std::ostream& os) +{ + for (auto&& query : queries) { + auto results = pisa::v1::daat_or( + query, index, ::pisa::topk_queue(query.k()), std::forward(scorer)); + results.finalize(); + float threshold = 0.0; + if (not results.topk().empty()) { + threshold = results.topk().back().first; + } + query.threshold(threshold); + os << *query.to_json() << '\n'; + } +} + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool in_place = false; + pisa::QueryApp app("Calculates thresholds for a v1 index."); + app.add_flag("--in-place", in_place, "Edit the input file"); + CLI11_PARSE(app, argc, argv); + + if (in_place && not app.query_file()) { + spdlog::error("Cannot edit in place when no query file passed"); + std::exit(1); + } + + try { + auto meta = app.index_metadata(); + auto queries = app.queries(meta); + + if (app.use_quantized()) { + auto run = scored_index_runner(meta); + run([&](auto&& index) { + if (in_place) { + std::ofstream os(app.query_file().value()); + calculate_thresholds(index, VoidScorer{}, queries, os); + } else { + calculate_thresholds(index, VoidScorer{}, queries, std::cout); + } + }); + } else { + auto run = index_runner(meta); + run([&](auto&& index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + if (in_place) { + std::ofstream os(app.query_file().value()); + calculate_thresholds(index, scorer, queries, os); + } else { + calculate_thresholds(index, scorer, queries, std::cout); + } + }); + }); + } + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); + } + return 0; +}