From 73e1000a465032b56339245d1cac5340c6f6bfee Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 22 Oct 2019 15:15:17 -0400 Subject: [PATCH 01/56] Add tl::optional dependency --- .gitmodules | 3 +++ external/optional | 1 + 2 files changed, 4 insertions(+) create mode 160000 external/optional diff --git a/.gitmodules b/.gitmodules index 6f49b5e95..8caf14a27 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,3 +67,6 @@ [submodule "external/wapopp"] path = external/wapopp url = https://github.com/pisa-engine/wapopp.git +[submodule "external/optional"] + path = external/optional + url = https://github.com/TartanLlama/optional.git diff --git a/external/optional b/external/optional new file mode 160000 index 000000000..5c4876059 --- /dev/null +++ b/external/optional @@ -0,0 +1 @@ +Subproject commit 5c4876059c1168d5fa3c945bd8dd05ebbe61b6fe From 58dd9c808a586a1dcd9f1c1996f4f50a268cebb3 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 22 Oct 2019 15:44:55 -0400 Subject: [PATCH 02/56] Minimal partial example --- CMakeLists.txt | 1 + external/CMakeLists.txt | 4 ++ include/pisa/v1/index.hpp | 121 ++++++++++++++++++++++++++++++++++++++ test/CMakeLists.txt | 6 ++ test/test_v1.cpp | 19 ++++++ 5 files changed, 151 insertions(+) create mode 100644 include/pisa/v1/index.hpp create mode 100644 test/test_v1.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a76e366e9..84f95f8ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ target_link_libraries(pisa INTERFACE spdlog fmt::fmt range-v3 + optional ) target_include_directories(pisa INTERFACE external) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 75f130909..758aad91f 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -116,3 +116,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/spdlog) # Add range-v3 add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/range-v3) + +# Add tl::optional +set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip optional testing") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional) diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp new file mode 100644 index 000000000..60dc4b936 --- /dev/null +++ b/include/pisa/v1/index.hpp @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace pisa::v1 { + +template +constexpr auto bit_cast(gsl::span mem) -> T +{ + T dst; + std::memcpy(&dst, mem.data(), sizeof(T)); + return dst; +} + +using TermId = std::uint32_t; +using DocId = std::uint32_t; +using Frequency = std::uint32_t; + +template +auto payload(Cursor &cursor) -> typename std::tuple_element::type &; + +template +auto payload(Cursor const &cursor) -> + typename std::tuple_element::type const &; + +template +auto payload(Cursor &cursor) -> typename std::tuple_element<0, typename Cursor::Payload>::type & +{ + return payload<0>(cursor); +} + +template +auto payload(Cursor const &cursor) -> + typename std::tuple_element<0, typename Cursor::Payload>::type const & +{ + return payload<0>(cursor); +} + +template +struct RawCursor { + static_assert(std::is_trivial::value); + + explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes) + { + Expects(bytes.size() % sizeof(T) == 0); + } + + constexpr auto operator*() -> T + { + return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + } + constexpr auto next() -> tl::optional + { + step(); + return empty() ? tl::nullopt : tl::make_optional(operator*()()); + } + constexpr void step() { m_current += sizeof(T); } + constexpr auto empty() -> bool { return m_current == m_bytes.size(); } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } + + private: + std::size_t m_current = 0; + gsl::span m_bytes; +}; + +template +struct RawReader { + static_assert(std::is_trivial::value); + + [[nodiscard]] auto read(gsl::span bytes) const -> RawCursor + { + return RawCursor(bytes.subspan(sizeof(std::uint64_t))); + } +}; + +struct IndexFactory { +}; + +template +struct Index { + using Cursor = decltype(std::declval().read(std::declval>())); + static_assert(std::is_same_v()), DocId>); + static_assert(std::is_same_v())), Frequency>); + + [[nodiscard]] auto cursor(TermId term) -> Cursor { return m_reader.read(fetch(term)); } + + private: + [[nodiscard]] auto fetch(TermId term) -> gsl::span; + + Reader m_reader; +}; + +template +struct ZipCursor; + +template +struct Index2 { + using DocumentCursor = + decltype(std::declval().read(std::declval>())); + using FrequencyCursor = + decltype(std::declval().read(std::declval>())); + static_assert(std::is_same_v()), DocId>); + static_assert(std::is_same_v()), Frequency>); + + [[nodiscard]] auto cursor(TermId term) -> ZipCursor; + + private: + [[nodiscard]] auto fetch(TermId term) -> gsl::span; + + DocumentReader m_document_reader; + FrequencyReader m_frequency_reader; +}; + +} // namespace pisa::v1 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 17f65cefa..35f67b090 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,4 +18,10 @@ foreach(TEST_SRC ${TEST_SOURCES}) add_coverage(${TEST_SRC_NAME}) endforeach(TEST_SRC) +target_link_libraries(test_v1 + pisa + Catch2 + optional +) + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/test_data DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/test/test_v1.cpp b/test/test_v1.cpp new file mode 100644 index 000000000..83e1b5c19 --- /dev/null +++ b/test/test_v1.cpp @@ -0,0 +1,19 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include "v1/index.hpp" + +using pisa::v1::RawReader; + +TEST_CASE("RawReader", "[v1][unit]") +{ + std::vector const mem{0, 1, 2, 3, 4}; + RawReader reader; + auto cursor = reader.read(gsl::as_bytes(gsl::make_span(mem))); + REQUIRE(cursor.next() == tl::make_optional(mem[0])); + REQUIRE(cursor.next() == tl::make_optional(mem[1])); + REQUIRE(cursor.next() == tl::make_optional(mem[2])); + REQUIRE(cursor.next() == tl::make_optional(mem[3])); + REQUIRE(cursor.next() == tl::make_optional(mem[4])); + REQUIRE(cursor.next() == tl::nullopt); +} From f58a72aaf3b7507b8a466726f9bce19ac5d2138a Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 23 Oct 2019 07:44:11 -0400 Subject: [PATCH 03/56] ZipCursor --- include/pisa/v1/index.hpp | 41 ++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 60dc4b936..c69484d1f 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -97,25 +97,48 @@ struct Index { Reader m_reader; }; -template -struct ZipCursor; +template +struct ZipCursor { + using Key = decltype(*std::declval()); + using Payload = decltype(*std::declval()); -template -struct Index2 { + constexpr auto operator*() -> Key { return *m_key_cursor; } + constexpr auto next() -> tl::optional> + { + return m_key_cursor.next().and_then([&](Key key) { + return m_payload_cursor.next().map( + [key](Payload payload) { return std::make_pair(key, payload); }); + }); + } + constexpr void step() + { + m_key_cursor.step(); + m_payload_cursor.step(); + } + constexpr auto empty() -> bool { return m_key_cursor.empty(); } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } + + private: + KeyCursor m_key_cursor; + PayloadCursor m_payload_cursor; +}; + +template +struct ZippedIndex { using DocumentCursor = decltype(std::declval().read(std::declval>())); - using FrequencyCursor = - decltype(std::declval().read(std::declval>())); + using PayloadCursor = + decltype(std::declval().read(std::declval>())); static_assert(std::is_same_v()), DocId>); - static_assert(std::is_same_v()), Frequency>); + // static_assert(std::is_same_v()), Frequency>); - [[nodiscard]] auto cursor(TermId term) -> ZipCursor; + [[nodiscard]] auto cursor(TermId term) -> ZipCursor; private: [[nodiscard]] auto fetch(TermId term) -> gsl::span; DocumentReader m_document_reader; - FrequencyReader m_frequency_reader; + PayloadReader m_payload_reader; }; } // namespace pisa::v1 From f59af7797931edc30814e4ace1487c37b817b541 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 25 Oct 2019 13:49:14 -0400 Subject: [PATCH 04/56] Additional methods --- CMakeLists.txt | 2 +- include/pisa/v1/index.hpp | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84f95f8ae..a19526ee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,7 @@ list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") if (UNIX) # For hardware popcount and other special instructions - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fconcepts -march=native") # Extensive warnings set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index c69484d1f..1e9bc0a9a 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -23,6 +23,22 @@ using TermId = std::uint32_t; using DocId = std::uint32_t; using Frequency = std::uint32_t; +namespace concepts { + + // template + // concept bool CursorLike = requires(T cursor, DocId docid, Position pos) + //{ + // { cursor.reset() } -> void; + // { cursor.next() } -> void; + // { cursor.next_geq(docid) } -> void; + // { cursor.move(pos) } -> void; + // { cursor.docid() } -> DocId; + // { cursor.position() } -> Position; + // { cursor.size() } -> std::size_t; + //}; + +} + template auto payload(Cursor &cursor) -> typename std::tuple_element::type &; @@ -52,7 +68,7 @@ struct RawCursor { Expects(bytes.size() % sizeof(T) == 0); } - constexpr auto operator*() -> T + [[nodiscard]] constexpr auto operator*() const -> T { return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); } @@ -62,7 +78,8 @@ struct RawCursor { return empty() ? tl::nullopt : tl::make_optional(operator*()()); } constexpr void step() { m_current += sizeof(T); } - constexpr auto empty() -> bool { return m_current == m_bytes.size(); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_current == m_bytes.size(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_current; } [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } private: From 41e65498b750d7e450575f6091e93ebaaaeec16e Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sat, 26 Oct 2019 17:40:26 -0400 Subject: [PATCH 05/56] Intersections + union + bigram index --- CMakeLists.txt | 2 +- include/pisa/v1/bit_cast.hpp | 17 + include/pisa/v1/cursor_intersection.hpp | 131 ++++++++ include/pisa/v1/cursor_union.hpp | 91 ++++++ include/pisa/v1/document_payload_cursor.hpp | 114 +++++++ include/pisa/v1/index.hpp | 326 +++++++++++++------- include/pisa/v1/raw_cursor.hpp | 171 ++++++++++ include/pisa/v1/types.hpp | 11 + include/pisa/v1/vector_lexicon.hpp | 33 ++ test/test_v1.cpp | 192 +++++++++++- 10 files changed, 971 insertions(+), 117 deletions(-) create mode 100644 include/pisa/v1/bit_cast.hpp create mode 100644 include/pisa/v1/cursor_intersection.hpp create mode 100644 include/pisa/v1/cursor_union.hpp create mode 100644 include/pisa/v1/document_payload_cursor.hpp create mode 100644 include/pisa/v1/raw_cursor.hpp create mode 100644 include/pisa/v1/types.hpp create mode 100644 include/pisa/v1/vector_lexicon.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a19526ee1..84f95f8ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,7 @@ list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") if (UNIX) # For hardware popcount and other special instructions - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fconcepts -march=native") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") # Extensive warnings set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") diff --git a/include/pisa/v1/bit_cast.hpp b/include/pisa/v1/bit_cast.hpp new file mode 100644 index 000000000..54ab1aaf3 --- /dev/null +++ b/include/pisa/v1/bit_cast.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +constexpr auto bit_cast(gsl::span mem) -> T +{ + T dst; + std::memcpy(&dst, mem.data(), sizeof(T)); + return dst; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp new file mode 100644 index 000000000..23a382022 --- /dev/null +++ b/include/pisa/v1/cursor_intersection.hpp @@ -0,0 +1,131 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "util/likely.hpp" + +namespace pisa { + +/// Transforms a list of cursors into one cursor by lazily merging them together +/// into an intersection. +template +struct CursorIntersection { + using Cursor = typename CursorContainer::value_type; + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + using Value = std::decay_t())>; + + constexpr CursorIntersection(CursorContainer cursors, Payload init, AccumulateFn accumulate) + : m_unordered_cursors(std::move(cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_cursor_mapping(m_unordered_cursors.size()) + { + Expects(not m_unordered_cursors.empty()); + std::iota(m_cursor_mapping.begin(), m_cursor_mapping.end(), 0); + auto order = [&](auto lhs, auto rhs) { + return m_unordered_cursors[lhs].size() < m_unordered_cursors[rhs].size(); + }; + std::sort(m_cursor_mapping.begin(), m_cursor_mapping.end(), order); + std::transform(m_cursor_mapping.begin(), + m_cursor_mapping.end(), + std::back_inserter(m_cursors), + [&](auto idx) { return std::ref(m_unordered_cursors[idx]); }); + m_sentinel = std::min_element(m_unordered_cursors.begin(), + m_unordered_cursors.end(), + [](auto const &lhs, auto const &rhs) { + return lhs.sentinel() < rhs.sentinel(); + }) + ->sentinel(); + m_candidate = *m_cursors[0].get(); + next(); + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> tl::optional + { + if (PISA_LIKELY(m_candidate < sentinel())) { + return m_current_value; + } + return tl::nullopt; + } + + constexpr void step() + { + while (PISA_LIKELY(m_candidate < sentinel())) { + for (; m_next_cursor < m_cursors.size(); ++m_next_cursor) { + Cursor &cursor = m_cursors[m_next_cursor]; + cursor.step_to_geq(m_candidate); + if (*cursor != m_candidate) { + m_candidate = *cursor; + m_next_cursor = 0; + break; + } + } + if (m_next_cursor == m_cursors.size()) { + m_current_payload = m_init; + for (auto idx = 0; idx < m_cursors.size(); ++idx) { + //if (m_candidate == 116) { + // std::cout << *m_cursors[idx].get().payload() << ' '; + //} + m_current_payload = m_accumulate( + m_current_payload, m_cursors[idx].get(), m_cursor_mapping[idx]); + } + //if (m_candidate == 116) { + // std::cout << '\n'; + //} + m_cursors[0].get().step(); + m_current_value = std::exchange(m_candidate, *m_cursors[0].get()); + m_next_cursor = 1; + return; + } + } + m_current_value = sentinel(); + m_current_payload = m_init; + } + + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const & + { + return m_current_payload; + } + + constexpr void step_to_position(std::size_t pos); // TODO(michal) + constexpr void step_to_geq(Value value); // TODO(michal) + constexpr auto next() -> tl::optional + { + step(); + return value(); + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_candidate >= sentinel(); + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::size_t { return m_sentinel; } + [[nodiscard]] constexpr auto size() const -> std::size_t = delete; + + private: + CursorContainer m_unordered_cursors; + Payload m_init; + AccumulateFn m_accumulate; + std::vector m_cursor_mapping; + + std::vector> m_cursors; + Value m_current_value{}; + Value m_candidate{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_cursor = 1; +}; + +} // namespace pisa diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp new file mode 100644 index 000000000..02cf4d8bd --- /dev/null +++ b/include/pisa/v1/cursor_union.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +#include + +#include "util/likely.hpp" + +namespace pisa { + +/// Transforms a list of cursors into one cursor by lazily merging them together. +template +struct CursorUnion { + using Cursor = typename CursorContainer::value_type; + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + constexpr CursorUnion(CursorContainer cursors, + std::size_t max_docid, + Payload init, + AccumulateFn accumulate) + : m_cursors(std::move(cursors)), + m_init(init), + m_accumulate(std::move(accumulate)), + m_size(std::nullopt), + m_max_docid(max_docid) + { + Expects(not m_cursors.empty()); + auto order = [](auto const &lhs, auto const &rhs) { return lhs.docid() < rhs.docid(); }; + m_next_docid = [&]() { + auto pos = std::min_element(m_cursors.begin(), m_cursors.end(), order); + return pos->docid(); + }(); + next(); + } + + [[nodiscard]] constexpr auto size() const noexcept -> std::size_t + { + if (!m_size) { + m_size = std::accumulate(m_cursors.begin(), + m_cursors.end(), + std::size_t(0), + [](auto acc, auto const &elem) { return acc + elem.size(); }); + } + return *m_size; + } + [[nodiscard]] constexpr auto docid() const noexcept -> std::uint32_t { return m_current_docid; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const & + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_max_docid; } + constexpr void next() + { + if (PISA_UNLIKELY(m_next_docid == m_max_docid)) { + m_current_docid = m_max_docid; + m_current_payload = m_init; + } else { + m_current_payload = m_init; + m_current_docid = m_next_docid; + m_next_docid = m_max_docid; + std::size_t cursor_idx = 0; + for (auto &cursor : m_cursors) { + if (cursor.docid() == m_current_docid) { + m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); + cursor.next(); + } + if (cursor.docid() < m_next_docid) { + m_next_docid = cursor.docid(); + } + ++cursor_idx; + } + } + } + + private: + CursorContainer m_cursors; + Payload m_init; + AccumulateFn m_accumulate; + std::optional m_size; + std::uint32_t m_max_docid; + + std::uint32_t m_current_docid = 0; + Payload m_current_payload; + std::uint32_t m_next_docid; +}; + +} // namespace pisa diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp new file mode 100644 index 000000000..2fecb7f4f --- /dev/null +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -0,0 +1,114 @@ +#pragma once + +#include + +#include +#include + +template +struct DocumentPayloadCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(*std::declval()); + + explicit constexpr DocumentPayloadCursor(DocumentCursor key_cursor, + PayloadCursor payload_cursor); + [[nodiscard]] constexpr auto operator*() const -> Document; + [[nodiscard]] constexpr auto value() const noexcept -> tl::optional; + [[nodiscard]] constexpr auto payload() const noexcept -> tl::optional; + constexpr void step(); + constexpr void step_to_position(std::size_t pos); + constexpr void step_to_geq(Document value); + constexpr auto next() -> tl::optional; + [[nodiscard]] constexpr auto empty() const noexcept -> bool; + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; + [[nodiscard]] constexpr auto size() const -> std::size_t; + [[nodiscard]] constexpr auto sentinel() const -> Document; + + private: + DocumentCursor m_key_cursor; + PayloadCursor m_payload_cursor; +}; + +template +constexpr DocumentPayloadCursor::DocumentPayloadCursor( + DocumentCursor key_cursor, PayloadCursor payload_cursor) + : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) +{ +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::operator*() const + -> Document +{ + return *m_key_cursor; +} +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::sentinel() const + -> Document +{ + return m_key_cursor.sentinel(); +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::value() const + noexcept -> tl::optional +{ + return m_key_cursor.value(); +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::payload() const + noexcept -> tl::optional +{ + return m_payload_cursor.value(); +} + +template +constexpr void DocumentPayloadCursor::step() +{ + m_key_cursor.step(); + m_payload_cursor.step(); +} + +template +constexpr void DocumentPayloadCursor::step_to_position( + std::size_t pos) +{ + m_key_cursor.step_to_position(pos); + m_payload_cursor.step_to_position(pos); +} + +template +constexpr void DocumentPayloadCursor::step_to_geq(Document value) +{ + m_key_cursor.step_to_geq(value); + m_payload_cursor.step_to_position(m_key_cursor.position()); +} + +template +constexpr auto DocumentPayloadCursor::next() + -> tl::optional +{ + return m_key_cursor.next(); +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::empty() const + noexcept -> bool +{ + return m_key_cursor.empty(); +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::position() const + noexcept -> std::size_t +{ + return m_key_cursor.position(); +} + +template +[[nodiscard]] constexpr auto DocumentPayloadCursor::size() const + -> std::size_t +{ + return m_key_cursor.size(); +} diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 1e9bc0a9a..fefc55af8 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -7,155 +7,253 @@ #include #include +#include #include -namespace pisa::v1 { - -template -constexpr auto bit_cast(gsl::span mem) -> T -{ - T dst; - std::memcpy(&dst, mem.data(), sizeof(T)); - return dst; -} +#include "binary_freq_collection.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/document_payload_cursor.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/types.hpp" -using TermId = std::uint32_t; -using DocId = std::uint32_t; -using Frequency = std::uint32_t; - -namespace concepts { - - // template - // concept bool CursorLike = requires(T cursor, DocId docid, Position pos) - //{ - // { cursor.reset() } -> void; - // { cursor.next() } -> void; - // { cursor.next_geq(docid) } -> void; - // { cursor.move(pos) } -> void; - // { cursor.docid() } -> DocId; - // { cursor.position() } -> Position; - // { cursor.size() } -> std::size_t; - //}; - -} - -template -auto payload(Cursor &cursor) -> typename std::tuple_element::type &; - -template -auto payload(Cursor const &cursor) -> - typename std::tuple_element::type const &; +namespace pisa::v1 { -template -auto payload(Cursor &cursor) -> typename std::tuple_element<0, typename Cursor::Payload>::type & -{ - return payload<0>(cursor); -} +/// A generic type for an inverted index. +/// +/// \tparam DocumentReader Type of an object that reads document posting lists from bytes +/// It must read lists containing `DocId` objects. +/// \tparam PayloadReader Type of an object that reads payload posting lists from bytes. +/// It can read lists of arbitrary types, such as `Frequency`, +/// `Score`, or `std::pair` for a bigram scored index. +/// \tparam Source Can be used to store any owning data, like open `mmap`, since +/// index internally uses spans to manage encoded parts of memory. +template +struct Index { -template -auto payload(Cursor const &cursor) -> - typename std::tuple_element<0, typename Cursor::Payload>::type const & -{ - return payload<0>(cursor); -} + /// The type of cursor constructed by the document reader. Must read `DocId` values. + using DocumentCursor = + decltype(std::declval().read(std::declval>())); + static_assert(std::is_same_v()), DocId>); -template -struct RawCursor { - static_assert(std::is_trivial::value); + /// The type of cursor constructed by the payload reader. + using PayloadCursor = + decltype(std::declval().read(std::declval>())); - explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes) + /// Constructs the index. + /// + /// \param document_reader Reads document posting lists from bytes. + /// \param payload_reader Reads payload posting lists from bytes. + /// \param document_offsets Mapping from term ID to the position in memory of its + /// document posting list. + /// \param payload_offsets Mapping from term ID to the position in memory of its + /// payload posting list. + /// \param documents Encoded bytes for document postings. + /// \param payloads Encoded bytes for payload postings. + /// \param source This object (optionally) owns the raw data pointed at by + /// `documents` and `payloads` to ensure it is valid throughout + /// the lifetime of the index. + Index(DocumentReader document_reader, + PayloadReader payload_reader, + std::vector document_offsets, + std::vector payload_offsets, + gsl::span documents, + gsl::span payloads, + Source source) + : m_document_reader(std::move(document_reader)), + m_payload_reader(std::move(payload_reader)), + m_document_offsets(std::move(document_offsets)), + m_payload_offsets(std::move(payload_offsets)), + m_documents(documents), + m_payloads(payloads), + m_source(std::move(source)) { - Expects(bytes.size() % sizeof(T) == 0); } - [[nodiscard]] constexpr auto operator*() const -> T + /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). + [[nodiscard]] auto cursor(TermId term) { - return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + return DocumentPayloadCursor(documents(term), + payloads(term)); } - constexpr auto next() -> tl::optional + + /// Constructs a new document cursor. + [[nodiscard]] auto documents(TermId term) { - step(); - return empty() ? tl::nullopt : tl::make_optional(operator*()()); + return m_document_reader.read(fetch_documents(term).subspan(4)); } - constexpr void step() { m_current += sizeof(T); } - [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_current == m_bytes.size(); } - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_current; } - [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } - private: - std::size_t m_current = 0; - gsl::span m_bytes; -}; + /// Constructs a new payload cursor. + [[nodiscard]] auto payloads(TermId term) + { + return m_payload_reader.read(fetch_payloads(term).subspan(4)); + } -template -struct RawReader { - static_assert(std::is_trivial::value); + /// Constructs a new payload cursor. + [[nodiscard]] auto num_terms() -> std::uint32_t { return m_document_offsets.size() - 1; } - [[nodiscard]] auto read(gsl::span bytes) const -> RawCursor + private: + [[nodiscard]] auto fetch_documents(TermId term) -> gsl::span { - return RawCursor(bytes.subspan(sizeof(std::uint64_t))); + Expects(term + 1 < m_document_offsets.size()); + return m_documents.subspan(m_document_offsets[term], + m_document_offsets[term + 1] - m_document_offsets[term]); + } + [[nodiscard]] auto fetch_payloads(TermId term) -> gsl::span + { + Expects(term + 1 < m_payload_offsets.size()); + return m_payloads.subspan(m_payload_offsets[term], + m_payload_offsets[term + 1] - m_payload_offsets[term]); } -}; -struct IndexFactory { + DocumentReader m_document_reader; + PayloadReader m_payload_reader; + std::vector m_document_offsets; + std::vector m_payload_offsets; + gsl::span m_documents; + gsl::span m_payloads; + Source m_source; }; -template -struct Index { - using Cursor = decltype(std::declval().read(std::declval>())); - static_assert(std::is_same_v()), DocId>); - static_assert(std::is_same_v())), Frequency>); - - [[nodiscard]] auto cursor(TermId term) -> Cursor { return m_reader.read(fetch(term)); } - - private: - [[nodiscard]] auto fetch(TermId term) -> gsl::span; +/// Initializes a memory mapped source with a given file. +inline void open_source(mio::mmap_source &source, std::string const &filename) +{ + std::error_code error; + source.map(filename, error); + if (error) { + spdlog::error("Error mapping file {}: {}", filename, error.message()); + throw std::runtime_error("Error mapping file"); + } +} - Reader m_reader; -}; +[[nodiscard]] inline auto binary_collection_index(std::string const &basename) +{ + binary_freq_collection collection(basename.c_str()); + std::vector document_offsets; + std::vector frequency_offsets; + document_offsets.push_back(8); + frequency_offsets.push_back(0); + for (auto const &postings : collection) { + auto offset = (1 + postings.docs.size()) * sizeof(std::uint32_t); + document_offsets.push_back(document_offsets.back() + offset); + frequency_offsets.push_back(frequency_offsets.back() + offset); + } + auto source = std::make_unique>(); + open_source(source->first, basename + ".docs"); + open_source(source->second, basename + ".freqs"); + auto documents = gsl::make_span( + reinterpret_cast(source->first.data()), source->first.size()); + auto frequencies = gsl::make_span( + reinterpret_cast(source->second.data()), source->second.size()); + return Index, + RawReader, + std::unique_ptr>>( + {}, + {}, + std::move(document_offsets), + std::move(frequency_offsets), + documents, + frequencies, + std::move(source)); +} -template -struct ZipCursor { - using Key = decltype(*std::declval()); - using Payload = decltype(*std::declval()); +template +struct BigramIndex : public Index { + using PairMapping = std::vector>; - constexpr auto operator*() -> Key { return *m_key_cursor; } - constexpr auto next() -> tl::optional> + BigramIndex(Index index, PairMapping pair_mapping) + : Index(std::move(index)), m_pair_mapping(std::move(pair_mapping)) { - return m_key_cursor.next().and_then([&](Key key) { - return m_payload_cursor.next().map( - [key](Payload payload) { return std::make_pair(key, payload); }); - }); } - constexpr void step() + + [[nodiscard]] auto bigram_id(TermId left, TermId right) -> tl::optional { - m_key_cursor.step(); - m_payload_cursor.step(); + auto pos = + std::find(m_pair_mapping.begin(), m_pair_mapping.end(), std::make_pair(left, right)); + if (pos != m_pair_mapping.end()) { + return tl::make_optional(std::distance(m_pair_mapping.begin(), pos)); + } + return tl::nullopt; } - constexpr auto empty() -> bool { return m_key_cursor.empty(); } - [[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } private: - KeyCursor m_key_cursor; - PayloadCursor m_payload_cursor; + PairMapping m_pair_mapping; }; -template -struct ZippedIndex { - using DocumentCursor = - decltype(std::declval().read(std::declval>())); - using PayloadCursor = - decltype(std::declval().read(std::declval>())); - static_assert(std::is_same_v()), DocId>); - // static_assert(std::is_same_v()), Frequency>); +/// Creates, on the fly, a bigram index with all pairs of adjecent terms. +/// Disclaimer: for testing purposes. +[[nodiscard]] inline auto binary_collection_bigram_index(std::string const &basename) +{ + using payload_type = std::array; - [[nodiscard]] auto cursor(TermId term) -> ZipCursor; + auto unigram_index = binary_collection_index(basename); - private: - [[nodiscard]] auto fetch(TermId term) -> gsl::span; + std::vector> pair_mapping; + std::vector documents; + std::vector payloads; - DocumentReader m_document_reader; - PayloadReader m_payload_reader; -}; + std::vector document_offsets; + std::vector payload_offsets; + + { + // Hack to be backwards-compatible with binary_freq_collection (for now). + documents.insert(documents.begin(), 8, std::byte{0}); + } + + document_offsets.push_back(documents.size()); + payload_offsets.push_back(payloads.size()); + for (TermId left = 0; left < unigram_index.num_terms() - 1; left += 1) { + auto right = left + 1; + RawWriter document_writer; + RawWriter payload_writer; + auto inter = + CursorIntersection(std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, + payload_type{0, 0}, + [](payload_type &payload, auto &cursor, auto list_idx) { + payload[list_idx] = *cursor.payload(); + return payload; + }); + if (inter.empty()) { + // Include only non-empty intersections. + continue; + } + pair_mapping.emplace_back(left, right); + while (not inter.empty()) { + document_writer.push(*inter); + payload_writer.push(inter.payload()); + inter.step(); + } + document_writer.append(std::back_inserter(documents)); + payload_writer.append(std::back_inserter(payloads)); + document_offsets.push_back(documents.size()); + payload_offsets.push_back(payloads.size()); + } + + { + // Hack to be backwards-compatible with binary_freq_collection (for now). + auto one_bytes = std::array{}; + auto size_bytes = std::array{}; + auto num_bigrams = static_cast(document_offsets.size()); + std::uint32_t one = 1; + std::memcpy(&size_bytes, &num_bigrams, 4); + std::memcpy(&one_bytes, &one, 4); + std::copy(one_bytes.begin(), one_bytes.end(), documents.begin()); + std::copy(size_bytes.begin(), size_bytes.end(), std::next(documents.begin(), 4)); + } + + auto source = std::array, 2>{std::move(documents), std::move(payloads)}; + auto document_span = gsl::make_span(source[0]); + auto payload_span = gsl::make_span(source[1]); + auto index = + Index, RawReader, std::array, 2>>( + {}, + {}, + std::move(document_offsets), + std::move(payload_offsets), + document_span, + payload_span, + std::move(source)); + return BigramIndex(std::move(index), std::move(pair_mapping)); +} } // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp new file mode 100644 index 000000000..73d42a1e9 --- /dev/null +++ b/include/pisa/v1/raw_cursor.hpp @@ -0,0 +1,171 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include "util/likely.hpp" +#include "v1/bit_cast.hpp" + +namespace pisa::v1 { + +/// Uncompressed example of implementation of a single value cursor. +template +struct RawCursor { + static_assert(std::is_trivially_copyable_v); + + /// Creates a cursor from the encoded bytes. + explicit constexpr RawCursor(gsl::span bytes); + + /// Dereferences the current value. + /// It is an undefined behavior to call this when `empty() == true`. + [[nodiscard]] constexpr auto operator*() const -> T; + + /// Safely returns the current value, or returns `nullopt` if `empty() == true`. + [[nodiscard]] constexpr auto value() const noexcept -> tl::optional; + + /// Moves the cursor to the next position. + constexpr void step(); + + /// Moves the cursor to the position `pos`. + constexpr void step_to_position(std::size_t pos); + + /// Moves the cursor to the next value equal or greater than `value`. + constexpr void step_to_geq(T value); + + /// This is semantically equivalent to first calling `step()` and then `value()`. + constexpr auto next() -> tl::optional; + + /// Returns `true` if there is no elements left. + [[nodiscard]] constexpr auto empty() const noexcept -> bool; + + /// Returns the current position. + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; + + /// Returns the number of elements in the list. + [[nodiscard]] constexpr auto size() const -> std::size_t; + + /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. + [[nodiscard]] constexpr auto sentinel() const -> T; + + private: + std::size_t m_current = 0; + gsl::span m_bytes; +}; + +template +constexpr RawCursor::RawCursor(gsl::span bytes) : m_bytes(bytes) +{ + Expects(bytes.size() % sizeof(T) == 0); +} + +template +[[nodiscard]] constexpr auto RawCursor::operator*() const -> T +{ + if (PISA_UNLIKELY(empty())) { + return sentinel(); + } + return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); +} + +template +[[nodiscard]] constexpr auto RawCursor::sentinel() const -> T +{ + return std::numeric_limits::max(); +} + +template +[[nodiscard]] constexpr auto RawCursor::value() const noexcept -> tl::optional +{ + return empty() ? tl::nullopt : tl::make_optional(*(*this)); +} +template +constexpr auto RawCursor::next() -> tl::optional +{ + step(); + return value(); +} +template +constexpr void RawCursor::step() +{ + m_current += sizeof(T); +} + +template +[[nodiscard]] constexpr auto RawCursor::empty() const noexcept -> bool +{ + return m_current == m_bytes.size(); +} + +template +[[nodiscard]] constexpr auto RawCursor::position() const noexcept -> std::size_t +{ + return m_current; +} + +template +[[nodiscard]] constexpr auto RawCursor::size() const -> std::size_t +{ + return m_bytes.size() / sizeof(T); +} + +template +constexpr void RawCursor::step_to_position(std::size_t pos) +{ + m_current = pos; +} + +template +constexpr void RawCursor::step_to_geq(T value) +{ + while (not empty() && *(*this) < value) { + step(); + } +} + +template +struct RawReader { + static_assert(std::is_trivially_copyable::value); + + [[nodiscard]] auto read(gsl::span bytes) const -> RawCursor + { + return RawCursor(bytes); + } +}; + +template +struct RawWriter { + static_assert(std::is_trivially_copyable::value); + + void push(T const &posting) { m_postings.push_back(posting); } + void push(T &&posting) { m_postings.push_back(posting); } + + void write(std::ostream &os) const + { + assert(!m_postings.empty()); + auto memory = gsl::as_bytes(gsl::make_span(m_postings.data())); + os.write(memory, memory.size()); + } + + template + auto append(OutputByteIterator out) const -> OutputByteIterator + { + assert(!m_postings.empty()); + std::uint32_t length = m_postings.size(); + auto length_bytes = gsl::as_bytes(gsl::make_span(&length, 1)); + auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); + std::copy(length_bytes.begin(), length_bytes.end(), out); + std::copy(memory.begin(), memory.end(), out); + return out; + } + + private: + std::vector m_postings; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp new file mode 100644 index 000000000..e36e0ca1b --- /dev/null +++ b/include/pisa/v1/types.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace pisa::v1 { + +using TermId = std::uint32_t; +using DocId = std::uint32_t; +using Frequency = std::uint32_t; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/vector_lexicon.hpp b/include/pisa/v1/vector_lexicon.hpp new file mode 100644 index 000000000..9df28e1a8 --- /dev/null +++ b/include/pisa/v1/vector_lexicon.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include +#include + +#include "binary_freq_collection.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct VectorLexicon { + explicit VectorLexicon(binary_freq_collection const &collection) + { + m_offsets.push_back(0); + for (auto const &postings : collection) { + m_offsets.push_back(postings.docs.size()); + } + } + + [[nodiscard]] auto fetch(TermId term, gsl::span bytes) + -> gsl::span + { + Expects(term + 1 < m_offsets.size()); + return bytes.subspan(m_offsets[term], m_offsets[term + 1]); + } + + private: + std::vector m_offsets{}; +}; + +} // namespace pisa::v1 diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 83e1b5c19..6fa0c7b89 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -1,19 +1,207 @@ #define CATCH_CONFIG_MAIN #include "catch2/catch.hpp" +#include + +#include + +//#include "cursor/scored_cursor.hpp" +#include "pisa_config.hpp" +//#include "query/queries.hpp" +//#include "test_common.hpp" #include "v1/index.hpp" +#include "v1/types.hpp" +//#include "wand_utils.hpp" +using pisa::v1::DocId; +using pisa::v1::Frequency; using pisa::v1::RawReader; +using pisa::v1::TermId; +using namespace pisa; + +// template +// struct IndexData { +// +// static std::unique_ptr data; +// +// IndexData() +// : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), +// document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), +// wdata(document_sizes.begin()->begin(), +// collection.num_docs(), +// collection, +// BlockSize(FixedBlock())) +// +// { +// tbb::task_scheduler_init init; +// typename Index::builder builder(collection.num_docs(), params); +// for (auto const &plist : collection) { +// uint64_t freqs_sum = +// std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); +// builder.add_posting_list( +// plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); +// } +// builder.build(index); +// +// term_id_vec q; +// std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); +// auto push_query = [&](std::string const &query_line) { +// queries.push_back(parse_query_ids(query_line)); +// }; +// io::for_each_line(qfile, push_query); +// +// std::string t; +// std::ifstream tin(PISA_SOURCE_DIR "/test/test_data/top5_thresholds"); +// while (std::getline(tin, t)) { +// thresholds.push_back(std::stof(t)); +// } +// } +// +// [[nodiscard]] static auto get() +// { +// if (IndexData::data == nullptr) { +// IndexData::data = std::make_unique>(); +// } +// return IndexData::data.get(); +// } +// +// global_parameters params; +// binary_freq_collection collection; +// binary_collection document_sizes; +// Index index; +// std::vector queries; +// std::vector thresholds; +// wand_data wdata; +//}; +// +// template +// std::unique_ptr> IndexData::data = nullptr; + +template +std::ostream &operator<<(std::ostream &os, tl::optional const &val) +{ + if (val.has_value()) { + os << val.value(); + } else { + os << "nullopt"; + } + return os; +} TEST_CASE("RawReader", "[v1][unit]") { std::vector const mem{0, 1, 2, 3, 4}; - RawReader reader; + RawReader reader; auto cursor = reader.read(gsl::as_bytes(gsl::make_span(mem))); - REQUIRE(cursor.next() == tl::make_optional(mem[0])); + REQUIRE(cursor.value().value() == tl::make_optional(mem[0]).value()); REQUIRE(cursor.next() == tl::make_optional(mem[1])); REQUIRE(cursor.next() == tl::make_optional(mem[2])); REQUIRE(cursor.next() == tl::make_optional(mem[3])); REQUIRE(cursor.next() == tl::make_optional(mem[4])); REQUIRE(cursor.next() == tl::nullopt); } + +template +auto collect(Cursor &&cursor, Transform transform) +{ + std::vector> vec; + while (not cursor.empty()) { + vec.push_back(transform(cursor)); + cursor.step(); + } + return vec; +} + +template +auto collect(Cursor &&cursor) +{ + return collect(std::forward(cursor), [](auto &&cursor) { return *cursor; }); +} + +TEST_CASE("Binary collection index", "[v1][unit]") +{ + /* auto data = IndexData::get(); */ + /* ranked_or_query or_q(10); */ + + /* for (auto const &q : data->queries) { */ + /* or_q(make_scored_cursors(data->index, data->wdata, q), data->index.num_docs()); */ + /* auto results = or_q.topk(); */ + /* } */ + + binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto index = + pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.documents(term_id))); + REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) + == collect(index.payloads(term_id))); + term_id += 1; + } + term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.cursor(term_id))); + REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) + == collect(index.cursor(term_id), [](auto &&cursor) { return *cursor.payload(); })); + term_id += 1; + } +} + +TEST_CASE("Bigram collection index", "[v1][unit]") +{ + auto intersect = [](auto const &lhs, + auto const &rhs) -> std::vector> { + std::vector> intersection; + auto left = lhs.begin(); + auto right = rhs.begin(); + while (left != lhs.end() && right != rhs.end()) { + if (left->first == right->first) { + intersection.emplace_back(left->first, left->second, right->second); + ++right; + ++left; + } else if (left->first < right->first) { + ++left; + } else { + ++right; + } + } + return intersection; + }; + auto to_vec = [](auto const &seq) { + std::vector> vec; + std::transform(seq.docs.begin(), + seq.docs.end(), + seq.freqs.begin(), + std::back_inserter(vec), + [](auto doc, auto freq) { return std::make_pair(doc, freq); }); + return vec; + }; + + binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto index = + pisa::v1::binary_collection_bigram_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + + auto pos = collection.begin(); + auto prev = to_vec(*pos); + ++pos; + TermId term_id = 1; + for (; pos != collection.end(); ++pos, ++term_id) { + auto current = to_vec(*pos); + auto intersection = intersect(prev, current); + if (not intersection.empty()) { + auto id = index.bigram_id(term_id - 1, term_id); + REQUIRE(id.has_value()); + auto postings = collect(index.cursor(*id), [](auto &cursor) { + auto freqs = *cursor.payload(); + return std::make_tuple(*cursor, freqs[0], freqs[1]); + }); + REQUIRE(postings == intersection); + } + std::swap(prev, current); + break; + } +} From 2223a6e6b438ab6101244bc4e4f46898f29e138e Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 27 Oct 2019 10:46:20 -0400 Subject: [PATCH 06/56] Type-erase source and reader. --- include/pisa/v1/README.md | 64 +++++++++++++++++++++++++ include/pisa/v1/cursor_intersection.hpp | 15 +++--- include/pisa/v1/index.hpp | 64 ++++++++++++------------- include/pisa/v1/query.hpp | 42 ++++++++++++++++ include/pisa/v1/raw_cursor.hpp | 22 +++++++++ include/pisa/v1/types.hpp | 3 ++ 6 files changed, 171 insertions(+), 39 deletions(-) create mode 100644 include/pisa/v1/README.md create mode 100644 include/pisa/v1/query.hpp diff --git a/include/pisa/v1/README.md b/include/pisa/v1/README.md new file mode 100644 index 000000000..9fa95fff5 --- /dev/null +++ b/include/pisa/v1/README.md @@ -0,0 +1,64 @@ +> This document is a **work in progress**. + +# Introduction + +In our efforts to come up with the v1.0 of both PISA and our index format, +we should start a discussion about the shape of things from the point of view +of both the binary format and how we can use it in our library. + +## Index Format specification + +This document mainly discusses the binary file format of each index component, +as well as how these components come together to form a cohesive structure. + +## Reference Implementation + +Along with format description and discussion, this directory includes some +reference implementation of the discussed structures and some algorithms working on them. + +The goal of this is to show how things work on certain examples, +and find out what works and what doesn't and still needs to be thought through. + +# Posting Files + +Each _posting file_ contains a list of blocks of data, each related to a single term, +preceded by a header encoding information about the type of payload. + +``` +Posting File := Header, [Posting Block] +``` + +Each posting block encodes a list of homogeneous values, called _postings_. +Encoding is not fixed. + +## Header + +We should store the type of the postings in the file, as well as encoding used. +**This might be tricky because we'd like it to be an open set of values/encodings.** + +``` +Header := Type, Encoding +``` + +## Posting Types + +I think supporting these types will be sufficient to express about anything we +would want to, including single-value lists, document-frequency (or score) lists, +positional indexes, etc. + +``` +Type := int32 | float32 | List[Type] | Tuple[Type] +``` + +## Encodings + +We can identify encodings by either a name or ID/hash, or both. +I can imagine that an index reader could **register** new encodings, +and default to whatever we define in PISA. +We should then also verify that this encoding implement a `Encoding` "concept". +This is not the same as our "codecs". +This would be more like posting list reader. + +``` +Encoding := ?? +``` diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp index 23a382022..e41b27485 100644 --- a/include/pisa/v1/cursor_intersection.hpp +++ b/include/pisa/v1/cursor_intersection.hpp @@ -73,15 +73,9 @@ struct CursorIntersection { if (m_next_cursor == m_cursors.size()) { m_current_payload = m_init; for (auto idx = 0; idx < m_cursors.size(); ++idx) { - //if (m_candidate == 116) { - // std::cout << *m_cursors[idx].get().payload() << ' '; - //} m_current_payload = m_accumulate( m_current_payload, m_cursors[idx].get(), m_cursor_mapping[idx]); } - //if (m_candidate == 116) { - // std::cout << '\n'; - //} m_cursors[0].get().step(); m_current_value = std::exchange(m_candidate, *m_cursors[0].get()); m_next_cursor = 1; @@ -128,4 +122,13 @@ struct CursorIntersection { std::uint32_t m_next_cursor = 1; }; +template +[[nodiscard]] constexpr inline auto intersect(CursorContainer cursors, + Payload init, + AccumulateFn accumulate) +{ + return CursorIntersection( + std::move(cursors), std::move(init), std::move(accumulate)); +} + } // namespace pisa diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index fefc55af8..aa40ba9a9 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -26,19 +27,17 @@ namespace pisa::v1 { /// \tparam PayloadReader Type of an object that reads payload posting lists from bytes. /// It can read lists of arbitrary types, such as `Frequency`, /// `Score`, or `std::pair` for a bigram scored index. -/// \tparam Source Can be used to store any owning data, like open `mmap`, since -/// index internally uses spans to manage encoded parts of memory. -template +template struct Index { - /// The type of cursor constructed by the document reader. Must read `DocId` values. - using DocumentCursor = - decltype(std::declval().read(std::declval>())); - static_assert(std::is_same_v()), DocId>); + ///// The type of cursor constructed by the document reader. Must read `DocId` values. + // using DocumentCursor = + // decltype(std::declval().read(std::declval>())); + // static_assert(std::is_same_v()), DocId>); - /// The type of cursor constructed by the payload reader. - using PayloadCursor = - decltype(std::declval().read(std::declval>())); + ///// The type of cursor constructed by the payload reader. + // using PayloadCursor = + // decltype(std::declval().read(std::declval>())); /// Constructs the index. /// @@ -53,6 +52,10 @@ struct Index { /// \param source This object (optionally) owns the raw data pointed at by /// `documents` and `payloads` to ensure it is valid throughout /// the lifetime of the index. + /// + /// \tparam Source Can be used to store any owning data, like open `mmap`, since + /// index internally uses spans to manage encoded parts of memory. + template Index(DocumentReader document_reader, PayloadReader payload_reader, std::vector document_offsets, @@ -106,13 +109,13 @@ struct Index { m_payload_offsets[term + 1] - m_payload_offsets[term]); } - DocumentReader m_document_reader; - PayloadReader m_payload_reader; + Reader m_document_reader; + Reader m_payload_reader; std::vector m_document_offsets; std::vector m_payload_offsets; gsl::span m_documents; gsl::span m_payloads; - Source m_source; + std::any m_source; }; /// Initializes a memory mapped source with a given file. @@ -138,23 +141,20 @@ inline void open_source(mio::mmap_source &source, std::string const &filename) document_offsets.push_back(document_offsets.back() + offset); frequency_offsets.push_back(frequency_offsets.back() + offset); } - auto source = std::make_unique>(); + auto source = std::make_shared>(); open_source(source->first, basename + ".docs"); open_source(source->second, basename + ".freqs"); auto documents = gsl::make_span( reinterpret_cast(source->first.data()), source->first.size()); auto frequencies = gsl::make_span( reinterpret_cast(source->second.data()), source->second.size()); - return Index, - RawReader, - std::unique_ptr>>( - {}, - {}, - std::move(document_offsets), - std::move(frequency_offsets), - documents, - frequencies, - std::move(source)); + return Index, RawCursor>(RawReader{}, + RawReader{}, + std::move(document_offsets), + std::move(frequency_offsets), + documents, + frequencies, + std::move(source)); } template @@ -244,15 +244,13 @@ struct BigramIndex : public Index { auto source = std::array, 2>{std::move(documents), std::move(payloads)}; auto document_span = gsl::make_span(source[0]); auto payload_span = gsl::make_span(source[1]); - auto index = - Index, RawReader, std::array, 2>>( - {}, - {}, - std::move(document_offsets), - std::move(payload_offsets), - document_span, - payload_span, - std::move(source)); + auto index = Index, RawCursor>(RawReader{}, + RawReader{}, + std::move(document_offsets), + std::move(payload_offsets), + document_span, + payload_span, + std::move(source)); return BigramIndex(std::move(index), std::move(pair_mapping)); } diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp new file mode 100644 index 000000000..2f4964ef0 --- /dev/null +++ b/include/pisa/v1/query.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "topk_queue.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct Query { + std::vector terms; +}; + +template +using QueryProcessor = std::function; + +struct ExhaustiveConjunctiveProcessor { + template + auto operator()(Index const &index, Query const &query, topk_queue que) -> topk_queue + { + using Cursor = std::decay_t; + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&index](auto term_id) { return index.cursor(term_id); }); + auto intersection = + intersect(std::move(cursors), + 0.0F, + [](float score, auto &cursor, [[maybe_unused]] auto cursor_idx) { + return score + static_cast(cursor.payload()); + }); + while (not intersection.empty()) { + que.insert(intersection.payload(), *intersection); + } + return que; + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 73d42a1e9..963cee664 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -128,6 +128,28 @@ constexpr void RawCursor::step_to_geq(T value) } } +template +struct Reader { + using Value = std::decay_t())>; + static_assert(std::is_trivially_copyable::value); + + template + explicit constexpr Reader(ReaderImpl &&reader) + { + m_read = [reader = std::forward(reader)](gsl::span bytes) { + return reader.read(bytes); + }; + } + + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor + { + return m_read(bytes); + } + + private: + std::function)> m_read; +}; + template struct RawReader { static_assert(std::is_trivially_copyable::value); diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index e36e0ca1b..510bb92d5 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -1,11 +1,14 @@ #pragma once #include +#include namespace pisa::v1 { using TermId = std::uint32_t; using DocId = std::uint32_t; using Frequency = std::uint32_t; +using Score = float; +using Result = std::pair; } // namespace pisa::v1 From 49712b553372b159b473e9313385cd343f419267 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 27 Oct 2019 12:17:02 -0400 Subject: [PATCH 07/56] Add tl::expected --- .gitmodules | 3 +++ external/expected | 1 + include/pisa/v1/README.md | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 2 deletions(-) create mode 160000 external/expected diff --git a/.gitmodules b/.gitmodules index 8caf14a27..8e4d1212a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -70,3 +70,6 @@ [submodule "external/optional"] path = external/optional url = https://github.com/TartanLlama/optional.git +[submodule "external/expected"] + path = external/expected + url = https://github.com/TartanLlama/expected.git diff --git a/external/expected b/external/expected new file mode 160000 index 000000000..0ca73ee30 --- /dev/null +++ b/external/expected @@ -0,0 +1 @@ +Subproject commit 0ca73ee30e72a54b285ef1e26fd9ecc83c81e5ea diff --git a/include/pisa/v1/README.md b/include/pisa/v1/README.md index 9fa95fff5..603120d30 100644 --- a/include/pisa/v1/README.md +++ b/include/pisa/v1/README.md @@ -24,6 +24,11 @@ and find out what works and what doesn't and still needs to be thought through. Each _posting file_ contains a list of blocks of data, each related to a single term, preceded by a header encoding information about the type of payload. +> Do we need the header? I would say "yes" because even if we store the information +> somewhere else, then we might want to (1) verify that we are reading what we think +> we are reading, and (2) verify format version compatibility. +> The latter should be further discussed. + ``` Posting File := Header, [Posting Block] ``` @@ -31,13 +36,18 @@ Posting File := Header, [Posting Block] Each posting block encodes a list of homogeneous values, called _postings_. Encoding is not fixed. +> Note that _block_ here means the entire posting list area. +> We can work on the terminology. + ## Header We should store the type of the postings in the file, as well as encoding used. **This might be tricky because we'd like it to be an open set of values/encodings.** ``` -Header := Type, Encoding +Header := Version, Type, Encoding +Version := Major, Minor, Path +Type := ValueId, Count ``` ## Posting Types @@ -47,7 +57,8 @@ would want to, including single-value lists, document-frequency (or score) lists positional indexes, etc. ``` -Type := int32 | float32 | List[Type] | Tuple[Type] +Type := Primitive | List[Type] | Tuple[Type] +Primitive := int32 | float32 ``` ## Encodings From 6833d254d8ecdb5c2dd814482f9ddc022e47ddc2 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 27 Oct 2019 12:25:06 -0400 Subject: [PATCH 08/56] Add tl::expected --- CMakeLists.txt | 1 + external/CMakeLists.txt | 10 ++++++++-- external/expected | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84f95f8ae..6f138a65a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,7 @@ target_link_libraries(pisa INTERFACE fmt::fmt range-v3 optional + expected ) target_include_directories(pisa INTERFACE external) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 758aad91f..30c1c433f 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -118,5 +118,11 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/spdlog) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/range-v3) # Add tl::optional -set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip optional testing") -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional) +set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip tl::optional testing") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional EXCLUDE_FROM_ALL) + +# Add tl::expected +set(EXPECTED_BUILD_TESTS OFF CACHE BOOL "skip tl::expected testing") +set(EXPECTED_BUILD_PACKAGE OFF CACHE BOOL "skip tl::expected package") +set(EXPECTED_BUILD_PACKAGE_DEB OFF CACHE BOOL "skip tl::expected package deb") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/expected EXCLUDE_FROM_ALL) diff --git a/external/expected b/external/expected index 0ca73ee30..3d741708b 160000 --- a/external/expected +++ b/external/expected @@ -1 +1 @@ -Subproject commit 0ca73ee30e72a54b285ef1e26fd9ecc83c81e5ea +Subproject commit 3d741708b967b83ca1e2888239196c4a67f9f9b0 From b1892520a4489132ae0cd26d06bbc777f82fbf9d Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 27 Oct 2019 22:05:12 -0400 Subject: [PATCH 09/56] Posting header and builder --- CMakeLists.txt | 1 - external/CMakeLists.txt | 8 +- include/pisa/v1/bit_cast.hpp | 2 +- include/pisa/v1/index.hpp | 23 +-- include/pisa/v1/posting_builder.hpp | 54 +++++++ include/pisa/v1/posting_format_header.hpp | 168 ++++++++++++++++++++++ include/pisa/v1/raw_cursor.hpp | 40 ++---- include/pisa/v1/types.hpp | 86 +++++++++++ test/test_v1.cpp | 164 ++++++++++++++++++++- 9 files changed, 487 insertions(+), 59 deletions(-) create mode 100644 include/pisa/v1/posting_builder.hpp create mode 100644 include/pisa/v1/posting_format_header.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f138a65a..84f95f8ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,7 +96,6 @@ target_link_libraries(pisa INTERFACE fmt::fmt range-v3 optional - expected ) target_include_directories(pisa INTERFACE external) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 30c1c433f..f40eb3ee7 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -122,7 +122,7 @@ set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip tl::optional testing") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional EXCLUDE_FROM_ALL) # Add tl::expected -set(EXPECTED_BUILD_TESTS OFF CACHE BOOL "skip tl::expected testing") -set(EXPECTED_BUILD_PACKAGE OFF CACHE BOOL "skip tl::expected package") -set(EXPECTED_BUILD_PACKAGE_DEB OFF CACHE BOOL "skip tl::expected package deb") -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/expected EXCLUDE_FROM_ALL) +#set(EXPECTED_BUILD_TESTS OFF CACHE BOOL "skip tl::expected testing") +#set(EXPECTED_BUILD_PACKAGE OFF CACHE BOOL "skip tl::expected package") +#set(EXPECTED_BUILD_PACKAGE_DEB OFF CACHE BOOL "skip tl::expected package deb") +#add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/expected EXCLUDE_FROM_ALL) diff --git a/include/pisa/v1/bit_cast.hpp b/include/pisa/v1/bit_cast.hpp index 54ab1aaf3..dd7e87523 100644 --- a/include/pisa/v1/bit_cast.hpp +++ b/include/pisa/v1/bit_cast.hpp @@ -9,7 +9,7 @@ namespace pisa::v1 { template constexpr auto bit_cast(gsl::span mem) -> T { - T dst; + T dst{}; std::memcpy(&dst, mem.data(), sizeof(T)); return dst; } diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index aa40ba9a9..cebae0bd4 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -29,16 +29,6 @@ namespace pisa::v1 { /// `Score`, or `std::pair` for a bigram scored index. template struct Index { - - ///// The type of cursor constructed by the document reader. Must read `DocId` values. - // using DocumentCursor = - // decltype(std::declval().read(std::declval>())); - // static_assert(std::is_same_v()), DocId>); - - ///// The type of cursor constructed by the payload reader. - // using PayloadCursor = - // decltype(std::declval().read(std::declval>())); - /// Constructs the index. /// /// \param document_reader Reads document posting lists from bytes. @@ -51,10 +41,8 @@ struct Index { /// \param payloads Encoded bytes for payload postings. /// \param source This object (optionally) owns the raw data pointed at by /// `documents` and `payloads` to ensure it is valid throughout - /// the lifetime of the index. - /// - /// \tparam Source Can be used to store any owning data, like open `mmap`, since - /// index internally uses spans to manage encoded parts of memory. + /// the lifetime of the index. It should release any resources + /// in its destructor. template Index(DocumentReader document_reader, PayloadReader payload_reader, @@ -83,14 +71,11 @@ struct Index { /// Constructs a new document cursor. [[nodiscard]] auto documents(TermId term) { - return m_document_reader.read(fetch_documents(term).subspan(4)); + return m_document_reader.read(fetch_documents(term)); } /// Constructs a new payload cursor. - [[nodiscard]] auto payloads(TermId term) - { - return m_payload_reader.read(fetch_payloads(term).subspan(4)); - } + [[nodiscard]] auto payloads(TermId term) { return m_payload_reader.read(fetch_payloads(term)); } /// Constructs a new payload cursor. [[nodiscard]] auto num_terms() -> std::uint32_t { return m_document_offsets.size() - 1; } diff --git a/include/pisa/v1/posting_builder.hpp b/include/pisa/v1/posting_builder.hpp new file mode 100644 index 000000000..125687ef8 --- /dev/null +++ b/include/pisa/v1/posting_builder.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include + +#include "v1/posting_format_header.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +/// Builds a "posting file" from passed values. +/// +/// TODO: Probably the offsets should be part of the file along with the size. +template +struct PostingBuilder { + template + explicit PostingBuilder(WriterImpl writer) : m_writer(Writer(std::move(writer))) + { + m_offsets.push_back(0); + } + + void write_header(std::ostream &os) const + { + std::array header{}; + PostingFormatHeader{ + .version = FormatVersion::current(), .type = value_type(), .encoding = 0} + .write(gsl::make_span(header)); + os.write(reinterpret_cast(header.data()), header.size()); + } + + template + auto write_segment(std::ostream &os, ValueIterator first, ValueIterator last) -> std::ostream & + { + std::for_each(first, last, [&](auto &&value) { m_writer.push(value); }); + m_offsets.push_back(m_offsets.back() + m_writer.write(os)); + m_writer.reset(); + return os; + } + + [[nodiscard]] auto offsets() const -> gsl::span + { + return gsl::make_span(m_offsets); + } + + [[nodiscard]] auto offsets() -> std::vector && { return std::move(m_offsets); } + + private: + Writer m_writer; + std::vector m_offsets{}; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/posting_format_header.hpp b/include/pisa/v1/posting_format_header.hpp new file mode 100644 index 000000000..8ba6cafb3 --- /dev/null +++ b/include/pisa/v1/posting_format_header.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "v1/bit_cast.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +auto write_little_endian(Int number, gsl::span bytes) +{ + static_assert(std::is_integral_v); + Expects(bytes.size() == sizeof(Int)); + + Int mask{0xFF}; + for (unsigned int byte_num = 0; byte_num < sizeof(Int); byte_num += 1) { + auto byte_value = static_cast((number & mask) >> (8U * byte_num)); + mask <<= 8U; + } +} + +struct FormatVersion { + std::uint8_t major = 0; + std::uint8_t minor = 0; + std::uint8_t patch = 0; + + constexpr static auto parse(gsl::span bytes) -> FormatVersion + { + Expects(bytes.size() == 3); + return FormatVersion{ + bit_cast(bytes.first(1)), + bit_cast(bytes.subspan(1, 1)), + bit_cast(bytes.subspan(2, 1)), + }; + }; + + constexpr auto write(gsl::span bytes) -> void + { + bytes[0] = std::byte{major}; + bytes[1] = std::byte{minor}; + bytes[2] = std::byte{patch}; + }; + + constexpr static auto current() -> FormatVersion { return FormatVersion{0, 1, 0}; }; +}; + +enum class Primitive { Int = 0, Float = 1 }; + +struct Array { + Primitive type; +}; + +struct Tuple { + Primitive type; + std::uint8_t size; +}; + +using ValueType = std::variant; + +template +constexpr static auto value_type() -> ValueType +{ + if constexpr (std::is_integral_v) { + return Primitive::Int; + } else if constexpr (std::is_floating_point_v) { + return Primitive::Float; + } else { + // TODO(michal): array and tuple + throw std::runtime_error(""); + } +} + +constexpr auto parse_type(std::byte const byte) -> ValueType +{ + auto element_type = [byte]() { + switch (std::to_integer((byte & std::byte{0b00000100}) >> 2)) { + case 0U: + return Primitive::Int; + case 1U: + return Primitive::Float; + } + Unreachable(); + }; + switch (std::to_integer(byte & std::byte{0b00000011})) { + case 0U: + return Primitive::Int; + case 1U: + return Primitive::Float; + case 2U: + return Array{element_type()}; + case 3U: + return Tuple{element_type(), + std::to_integer((byte & std::byte{0b11111000}) >> 3)}; + } + Unreachable(); +}; + +constexpr auto to_byte(ValueType type) -> std::byte +{ + std::byte byte{}; + std::visit(overloaded{[&byte](Primitive primitive) { + switch (primitive) { + case Primitive::Int: + byte = std::byte{0b00000000}; + break; + case Primitive::Float: + byte = std::byte{0b00000001}; + break; + } + }, + [&byte](Array arr) { + switch (arr.type) { + case Primitive::Int: + byte = std::byte{0b00000010}; + break; + case Primitive::Float: + byte = std::byte{0b00000110}; + break; + } + }, + [&byte](Tuple tup) { + switch (tup.type) { + case Primitive::Int: + byte = std::byte{0b00000011}; + break; + case Primitive::Float: + byte = std::byte{0b00000111}; + break; + } + byte |= (std::byte{tup.size} << 3); + }}, + type); + return byte; +}; + +using Encoding = std::uint32_t; + +struct PostingFormatHeader { + FormatVersion version; + ValueType type; + Encoding encoding; + + constexpr static auto parse(gsl::span bytes) -> PostingFormatHeader + { + Expects(bytes.size() == 8); + auto version = FormatVersion::parse(bytes.first(3)); + auto type = parse_type(bytes[3]); + auto encoding = bit_cast(bytes.subspan(4)); + return {version, type, encoding}; + }; + + void write(gsl::span bytes) + { + Expects(bytes.size() == 8); + FormatVersion::current().write(bytes.first(3)); + bytes[3] = to_byte(type); + write_little_endian(encoding, bytes.subspan(4)); + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 963cee664..e5a66a6dc 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -11,6 +11,7 @@ #include "util/likely.hpp" #include "v1/bit_cast.hpp" +#include "v1/types.hpp" namespace pisa::v1 { @@ -59,9 +60,9 @@ struct RawCursor { }; template -constexpr RawCursor::RawCursor(gsl::span bytes) : m_bytes(bytes) +constexpr RawCursor::RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) { - Expects(bytes.size() % sizeof(T) == 0); + Expects(m_bytes.size() % sizeof(T) == 0); } template @@ -128,28 +129,6 @@ constexpr void RawCursor::step_to_geq(T value) } } -template -struct Reader { - using Value = std::decay_t())>; - static_assert(std::is_trivially_copyable::value); - - template - explicit constexpr Reader(ReaderImpl &&reader) - { - m_read = [reader = std::forward(reader)](gsl::span bytes) { - return reader.read(bytes); - }; - } - - [[nodiscard]] auto read(gsl::span bytes) const -> Cursor - { - return m_read(bytes); - } - - private: - std::function)> m_read; -}; - template struct RawReader { static_assert(std::is_trivially_copyable::value); @@ -158,6 +137,8 @@ struct RawReader { { return RawCursor(bytes); } + + constexpr static auto encoding() -> std::uint64_t { return EncodingId::Raw; } }; template @@ -167,11 +148,14 @@ struct RawWriter { void push(T const &posting) { m_postings.push_back(posting); } void push(T &&posting) { m_postings.push_back(posting); } - void write(std::ostream &os) const + [[nodiscard]] auto write(std::ostream &os) const -> std::size_t { assert(!m_postings.empty()); - auto memory = gsl::as_bytes(gsl::make_span(m_postings.data())); - os.write(memory, memory.size()); + std::uint32_t length = m_postings.size(); + os.write(reinterpret_cast(&length), sizeof(length)); + auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); + os.write(reinterpret_cast(memory.data()), memory.size()); + return sizeof(length) + memory.size(); } template @@ -185,6 +169,8 @@ struct RawWriter { std::copy(memory.begin(), memory.end(), out); return out; } + constexpr static auto encoding() -> std::uint64_t { return EncodingId::Raw; } + void reset() { m_postings.clear(); } private: std::vector m_postings; diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 510bb92d5..1bddda2cc 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -3,6 +3,10 @@ #include #include +#include + +#define Unreachable() std::abort(); + namespace pisa::v1 { using TermId = std::uint32_t; @@ -11,4 +15,86 @@ using Frequency = std::uint32_t; using Score = float; using Result = std::pair; +enum EncodingId { Raw = 0xda43 }; + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; + +template +overloaded(Ts...)->overloaded; + +template +struct Reader { + using Value = std::decay_t())>; + + template + explicit constexpr Reader(ReaderImpl &&reader) + { + m_read = [reader = std::forward(reader)](gsl::span bytes) { + return reader.read(bytes); + }; + } + + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor + { + return m_read(bytes); + } + + private: + std::function)> m_read; +}; + +template +struct Writer { + + template + explicit constexpr Writer(W writer) : m_internal_writer(std::make_unique>(writer)) + { + } + + void push(T const &posting) { m_internal_writer->push(posting); } + void push(T &&posting) { m_internal_writer->push(posting); } + auto write(std::ostream &os) const -> std::size_t { return m_internal_writer->write(os); } + auto encoding() -> std::int32_t { return m_internal_writer->encoding(); } + void reset() { return m_internal_writer->reset(); } + + struct WriterInterface { + WriterInterface() = default; + WriterInterface(WriterInterface const &) = default; + WriterInterface(WriterInterface &&) noexcept = default; + WriterInterface &operator=(WriterInterface const &) = default; + WriterInterface &operator=(WriterInterface &&) noexcept = default; + virtual ~WriterInterface() = default; + virtual void push(T const &posting) = 0; + virtual void push(T &&posting) = 0; + virtual auto write(std::ostream &os) const -> std::size_t = 0; + virtual auto encoding() -> std::uint32_t = 0; + virtual void reset() = 0; + }; + + template + struct WriterImpl : WriterInterface { + explicit WriterImpl(W writer) : m_writer(std::move(writer)) {} + WriterImpl() = default; + WriterImpl(WriterImpl const &) = default; + WriterImpl(WriterImpl &&) noexcept = default; + WriterImpl &operator=(WriterImpl const &) = default; + WriterImpl &operator=(WriterImpl &&) noexcept = default; + ~WriterImpl() = default; + void push(T const &posting) override { m_writer.push(posting); } + void push(T &&posting) override { m_writer.push(posting); } + auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } + auto encoding() -> std::uint32_t override { return W::encoding(); } + void reset() override { return m_writer.reset(); } + + private: + W m_writer; + }; + + private: + std::unique_ptr m_internal_writer; +}; + } // namespace pisa::v1 diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 6fa0c7b89..ab8b369d0 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -3,21 +3,35 @@ #include +#include +#include #include //#include "cursor/scored_cursor.hpp" +#include "io.hpp" #include "pisa_config.hpp" //#include "query/queries.hpp" //#include "test_common.hpp" #include "v1/index.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" #include "v1/types.hpp" //#include "wand_utils.hpp" +using pisa::v1::Array; using pisa::v1::DocId; using pisa::v1::Frequency; +using pisa::v1::Index; +using pisa::v1::parse_type; +using pisa::v1::PostingBuilder; +using pisa::v1::PostingFormatHeader; +using pisa::v1::Primitive; +using pisa::v1::RawCursor; using pisa::v1::RawReader; +using pisa::v1::RawWriter; using pisa::v1::TermId; -using namespace pisa; +using pisa::v1::Tuple; +using pisa::v1::Writer; // template // struct IndexData { @@ -90,14 +104,14 @@ std::ostream &operator<<(std::ostream &os, tl::optional const &val) TEST_CASE("RawReader", "[v1][unit]") { - std::vector const mem{0, 1, 2, 3, 4}; - RawReader reader; + std::vector const mem{5, 0, 1, 2, 3, 4}; + RawReader reader; auto cursor = reader.read(gsl::as_bytes(gsl::make_span(mem))); - REQUIRE(cursor.value().value() == tl::make_optional(mem[0]).value()); - REQUIRE(cursor.next() == tl::make_optional(mem[1])); + REQUIRE(cursor.value().value() == tl::make_optional(mem[1]).value()); REQUIRE(cursor.next() == tl::make_optional(mem[2])); REQUIRE(cursor.next() == tl::make_optional(mem[3])); REQUIRE(cursor.next() == tl::make_optional(mem[4])); + REQUIRE(cursor.next() == tl::make_optional(mem[5])); REQUIRE(cursor.next() == tl::nullopt); } @@ -128,7 +142,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") /* auto results = or_q.topk(); */ /* } */ - binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); auto index = pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); auto term_id = 0; @@ -181,7 +195,7 @@ TEST_CASE("Bigram collection index", "[v1][unit]") return vec; }; - binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); auto index = pisa::v1::binary_collection_bigram_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); @@ -205,3 +219,139 @@ TEST_CASE("Bigram collection index", "[v1][unit]") break; } } + +TEST_CASE("Test read header", "[v1][unit]") +{ + { + std::vector bytes{ + std::byte{0b00000000}, + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 0); + REQUIRE(header.version.minor == 1); + REQUIRE(header.version.patch == 0); + REQUIRE(std::get(header.type) == Primitive::Int); + REQUIRE(header.encoding == 0); + } + { + std::vector bytes{ + std::byte{0b00000001}, + std::byte{0b00000001}, + std::byte{0b00000011}, + std::byte{0b00000001}, + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 1); + REQUIRE(header.version.minor == 1); + REQUIRE(header.version.patch == 3); + REQUIRE(std::get(header.type) == Primitive::Float); + REQUIRE(header.encoding == 1); + } + { + std::vector bytes{ + std::byte{0b00000001}, + std::byte{0b00000000}, + std::byte{0b00000011}, + std::byte{0b00000010}, + std::byte{0b00000011}, + std::byte{0b00000000}, + std::byte{0b00000000}, + std::byte{0b00000000}, + }; + auto header = PostingFormatHeader::parse(gsl::span(bytes)); + REQUIRE(header.version.major == 1); + REQUIRE(header.version.minor == 0); + REQUIRE(header.version.patch == 3); + REQUIRE(std::get(header.type).type == Primitive::Int); + REQUIRE(header.encoding == 3); + } +} + +TEST_CASE("Value type", "[v1][unit]") +{ + REQUIRE(std::get(parse_type(std::byte{0b00000000})) == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b00000001})) == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00000010})).type == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b00000110})).type == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00101011})).type == Primitive::Int); + REQUIRE(std::get(parse_type(std::byte{0b01000111})).type == Primitive::Float); + REQUIRE(std::get(parse_type(std::byte{0b00101011})).size == 5U); + REQUIRE(std::get(parse_type(std::byte{0b01000111})).size == 8U); +} + +TEST_CASE("Build raw document-frequency index", "[v1][unit]") +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + GIVEN("A test binary collection") + { + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + WHEN("Built posting files for documents and frequencies") + { + std::vector docbuf; + std::vector freqbuf; + + PostingBuilder document_builder(Writer(RawWriter{})); + PostingBuilder frequency_builder(Writer(RawWriter{})); + { + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; + + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); + + for (auto sequence : collection) { + document_builder.write_segment( + docstream, sequence.docs.begin(), sequence.docs.end()); + frequency_builder.write_segment( + freqstream, sequence.freqs.begin(), sequence.freqs.end()); + } + } + + auto document_offsets = document_builder.offsets(); + auto frequency_offsets = frequency_builder.offsets(); + + THEN("Bytes match with those of the collection") + { + auto document_bytes = + pisa::io::load_data(PISA_SOURCE_DIR "/test/test_data/test_collection.docs"); + auto frequency_bytes = + pisa::io::load_data(PISA_SOURCE_DIR "/test/test_data/test_collection.freqs"); + + // NOTE: the first 8 bytes of the document collection are different than those + // of the built document file. Also, the original frequency collection starts + // at byte 0 (no 8-byte "size vector" at the beginning), and thus is shorter. + CHECK(docbuf.size() == document_offsets.back() + 8); + CHECK(freqbuf.size() == frequency_offsets.back() + 8); + CHECK(docbuf.size() == document_bytes.size()); + CHECK(freqbuf.size() == frequency_bytes.size() + 8); + CHECK(gsl::make_span(docbuf.data(), docbuf.size()).subspan(8) + == gsl::make_span(document_bytes.data(), document_bytes.size()).subspan(8)); + CHECK(gsl::make_span(freqbuf.data(), freqbuf.size()).subspan(8) + == gsl::make_span(frequency_bytes.data(), frequency_bytes.size())); + } + + // Index, RawCursor>( + // RawReader{}, + // RawReader{}, + // document_offsets, + // frequency_offsets, + // gsl::span(reinterpret_cast(docbuf.data()), + // docbuf.size()), + // gsl::span(reinterpret_cast(freqbuf.data()), + // freqbuf.size()), + // true); + } + } +} From 36ec48c15c21c94fbccd91686801e7b79ad2c3b0 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 28 Oct 2019 12:09:29 -0400 Subject: [PATCH 10/56] Index runner --- include/pisa/v1/cursor/collect.hpp | 24 +++ include/pisa/v1/cursor/for_each.hpp | 16 ++ include/pisa/v1/index.hpp | 234 +++++++++++++++------- include/pisa/v1/posting_builder.hpp | 12 +- include/pisa/v1/posting_format_header.hpp | 61 +++++- include/pisa/v1/raw_cursor.hpp | 8 +- include/pisa/v1/types.hpp | 6 +- test/test_v1.cpp | 150 +++++--------- 8 files changed, 333 insertions(+), 178 deletions(-) create mode 100644 include/pisa/v1/cursor/collect.hpp create mode 100644 include/pisa/v1/cursor/for_each.hpp diff --git a/include/pisa/v1/cursor/collect.hpp b/include/pisa/v1/cursor/collect.hpp new file mode 100644 index 000000000..83bf53738 --- /dev/null +++ b/include/pisa/v1/cursor/collect.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace pisa::v1 { + +template +auto collect(Cursor &&cursor, Transform transform) +{ + std::vector> vec; + while (not cursor.empty()) { + vec.push_back(transform(cursor)); + cursor.step(); + } + return vec; +} + +template +auto collect(Cursor &&cursor) +{ + return collect(std::forward(cursor), [](auto &&cursor) { return *cursor; }); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/for_each.hpp b/include/pisa/v1/cursor/for_each.hpp new file mode 100644 index 000000000..5fd2df07e --- /dev/null +++ b/include/pisa/v1/cursor/for_each.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace pisa::v1 { + +template +void for_each(Cursor &&cursor, UnaryOp op) +{ + while (not cursor.empty()) { + op(cursor); + cursor.step(); + } +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index cebae0bd4..8e045f038 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -6,6 +6,10 @@ #include #include +#include +#include +#include +#include #include #include #include @@ -13,8 +17,10 @@ #include "binary_freq_collection.hpp" #include "v1/bit_cast.hpp" +#include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/document_payload_cursor.hpp" +#include "v1/posting_builder.hpp" #include "v1/raw_cursor.hpp" #include "v1/types.hpp" @@ -103,6 +109,27 @@ struct Index { std::any m_source; }; +template +auto make_index(DocumentReader document_reader, + PayloadReader payload_reader, + std::vector document_offsets, + std::vector payload_offsets, + gsl::span documents, + gsl::span payloads, + Source source) +{ + using DocumentCursor = + decltype(document_reader.read(std::declval>())); + using PayloadCursor = decltype(payload_reader.read(std::declval>())); + return Index(std::move(document_reader), + std::move(payload_reader), + std::move(document_offsets), + std::move(payload_offsets), + documents, + payloads, + std::move(source)); +} + /// Initializes a memory mapped source with a given file. inline void open_source(mio::mmap_source &source, std::string const &filename) { @@ -116,32 +143,110 @@ inline void open_source(mio::mmap_source &source, std::string const &filename) [[nodiscard]] inline auto binary_collection_index(std::string const &basename) { + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + binary_freq_collection collection(basename.c_str()); - std::vector document_offsets; - std::vector frequency_offsets; - document_offsets.push_back(8); - frequency_offsets.push_back(0); - for (auto const &postings : collection) { - auto offset = (1 + postings.docs.size()) * sizeof(std::uint32_t); - document_offsets.push_back(document_offsets.back() + offset); - frequency_offsets.push_back(frequency_offsets.back() + offset); + std::vector docbuf; + std::vector freqbuf; + + PostingBuilder document_builder(Writer(RawWriter{})); + PostingBuilder frequency_builder(Writer(RawWriter{})); + { + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; + + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); + + for (auto sequence : collection) { + document_builder.write_segment(docstream, sequence.docs.begin(), sequence.docs.end()); + frequency_builder.write_segment( + freqstream, sequence.freqs.begin(), sequence.freqs.end()); + } } - auto source = std::make_shared>(); - open_source(source->first, basename + ".docs"); - open_source(source->second, basename + ".freqs"); - auto documents = gsl::make_span( + + auto document_offsets = document_builder.offsets(); + auto frequency_offsets = frequency_builder.offsets(); + auto source = std::make_shared, std::vector>>( + std::move(docbuf), std::move(freqbuf)); + auto documents = gsl::span( reinterpret_cast(source->first.data()), source->first.size()); - auto frequencies = gsl::make_span( + auto frequencies = gsl::span( reinterpret_cast(source->second.data()), source->second.size()); + return Index, RawCursor>(RawReader{}, RawReader{}, - std::move(document_offsets), - std::move(frequency_offsets), - documents, - frequencies, + document_offsets, + frequency_offsets, + documents.subspan(8), + frequencies.subspan(8), std::move(source)); } +template +struct IndexRunner { + template + IndexRunner(std::vector document_offsets, + std::vector payload_offsets, + gsl::span documents, + gsl::span payloads, + Source source, + Readers... readers) + : m_document_offsets(std::move(document_offsets)), + m_payload_offsets(std::move(payload_offsets)), + m_documents(documents), + m_payloads(payloads), + m_source(std::move(source)), + m_readers(readers...) + { + } + + template + void operator()(Fn fn) + { + auto dheader = PostingFormatHeader::parse(m_documents.first(8)); + auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); + auto run = [&](auto &&dreader, auto &&preader) -> bool { + if (std::decay_t::encoding() == dheader.encoding + && std::decay_t::encoding() == pheader.encoding + && is_type::value_type>(dheader.type) + && is_type::value_type>(pheader.type)) { + fn(make_index(std::forward(dreader), + std::forward(preader), + m_document_offsets, + m_payload_offsets, + m_documents.subspan(8), + m_payloads.subspan(8), + false)); + return true; + } + return false; + }; + bool success = std::apply( + [&](Readers... dreaders) { + auto with_document_reader = [&](auto dreader) { + return std::apply( + [&](Readers... preaders) { return (run(dreader, preaders) || ...); }, + m_readers); + }; + return (with_document_reader(dreaders) || ...); + }, + m_readers); + if (not success) { + throw std::domain_error("Unknown posting encoding"); + } + } + + private: + std::vector m_document_offsets; + std::vector m_payload_offsets; + gsl::span m_documents; + gsl::span m_payloads; + std::any m_source; + std::tuple m_readers; +}; + template struct BigramIndex : public Index { using PairMapping = std::vector>; @@ -170,71 +275,60 @@ struct BigramIndex : public Index { [[nodiscard]] inline auto binary_collection_bigram_index(std::string const &basename) { using payload_type = std::array; + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; auto unigram_index = binary_collection_index(basename); std::vector> pair_mapping; - std::vector documents; - std::vector payloads; - - std::vector document_offsets; - std::vector payload_offsets; + std::vector docbuf; + std::vector freqbuf; + PostingBuilder document_builder(RawWriter{}); + PostingBuilder frequency_builder(RawWriter{}); { - // Hack to be backwards-compatible with binary_freq_collection (for now). - documents.insert(documents.begin(), 8, std::byte{0}); - } + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; - document_offsets.push_back(documents.size()); - payload_offsets.push_back(payloads.size()); - for (TermId left = 0; left < unigram_index.num_terms() - 1; left += 1) { - auto right = left + 1; - RawWriter document_writer; - RawWriter payload_writer; - auto inter = - CursorIntersection(std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, - payload_type{0, 0}, - [](payload_type &payload, auto &cursor, auto list_idx) { - payload[list_idx] = *cursor.payload(); - return payload; - }); - if (inter.empty()) { - // Include only non-empty intersections. - continue; - } - pair_mapping.emplace_back(left, right); - while (not inter.empty()) { - document_writer.push(*inter); - payload_writer.push(inter.payload()); - inter.step(); - } - document_writer.append(std::back_inserter(documents)); - payload_writer.append(std::back_inserter(payloads)); - document_offsets.push_back(documents.size()); - payload_offsets.push_back(payloads.size()); - } + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); - { - // Hack to be backwards-compatible with binary_freq_collection (for now). - auto one_bytes = std::array{}; - auto size_bytes = std::array{}; - auto num_bigrams = static_cast(document_offsets.size()); - std::uint32_t one = 1; - std::memcpy(&size_bytes, &num_bigrams, 4); - std::memcpy(&one_bytes, &one, 4); - std::copy(one_bytes.begin(), one_bytes.end(), documents.begin()); - std::copy(size_bytes.begin(), size_bytes.end(), std::next(documents.begin(), 4)); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(unigram_index.num_terms() - 1), + [&](auto left) { + auto right = left + 1; + auto intersection = CursorIntersection( + std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, + payload_type{0, 0}, + [](payload_type &payload, auto &cursor, auto list_idx) { + payload[list_idx] = *cursor.payload(); + return payload; + }); + if (intersection.empty()) { + // Include only non-empty intersections. + return; + } + pair_mapping.emplace_back(left, right); + for_each(intersection, [&](auto &cursor) { + document_builder.accumulate(*cursor); + frequency_builder.accumulate(cursor.payload()); + }); + document_builder.flush_segment(docstream); + frequency_builder.flush_segment(freqstream); + }); } - auto source = std::array, 2>{std::move(documents), std::move(payloads)}; - auto document_span = gsl::make_span(source[0]); - auto payload_span = gsl::make_span(source[1]); + auto source = std::array, 2>{std::move(docbuf), std::move(freqbuf)}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); auto index = Index, RawCursor>(RawReader{}, RawReader{}, - std::move(document_offsets), - std::move(payload_offsets), - document_span, - payload_span, + document_builder.offsets(), + frequency_builder.offsets(), + document_span.subspan(8), + payload_span.subspan(8), std::move(source)); return BigramIndex(std::move(index), std::move(pair_mapping)); } diff --git a/include/pisa/v1/posting_builder.hpp b/include/pisa/v1/posting_builder.hpp index 125687ef8..5c7455379 100644 --- a/include/pisa/v1/posting_builder.hpp +++ b/include/pisa/v1/posting_builder.hpp @@ -24,8 +24,9 @@ struct PostingBuilder { void write_header(std::ostream &os) const { std::array header{}; - PostingFormatHeader{ - .version = FormatVersion::current(), .type = value_type(), .encoding = 0} + PostingFormatHeader{.version = FormatVersion::current(), + .type = value_type(), + .encoding = m_writer.encoding()} .write(gsl::make_span(header)); os.write(reinterpret_cast(header.data()), header.size()); } @@ -34,6 +35,13 @@ struct PostingBuilder { auto write_segment(std::ostream &os, ValueIterator first, ValueIterator last) -> std::ostream & { std::for_each(first, last, [&](auto &&value) { m_writer.push(value); }); + return flush_segment(os); + } + + void accumulate(Value value) { m_writer.push(value); } + + auto flush_segment(std::ostream &os) -> std::ostream & + { m_offsets.push_back(m_offsets.back() + m_writer.write(os)); m_writer.reset(); return os; diff --git a/include/pisa/v1/posting_format_header.hpp b/include/pisa/v1/posting_format_header.hpp index 8ba6cafb3..2eba95f2d 100644 --- a/include/pisa/v1/posting_format_header.hpp +++ b/include/pisa/v1/posting_format_header.hpp @@ -22,6 +22,7 @@ auto write_little_endian(Int number, gsl::span bytes) Int mask{0xFF}; for (unsigned int byte_num = 0; byte_num < sizeof(Int); byte_num += 1) { auto byte_value = static_cast((number & mask) >> (8U * byte_num)); + bytes[byte_num] = byte_value; mask <<= 8U; } } @@ -62,8 +63,29 @@ struct Tuple { std::uint8_t size; }; +[[nodiscard]] inline auto operator==(Tuple const &lhs, Tuple const &rhs) +{ + return lhs.type == rhs.type && lhs.size == rhs.size; +} + using ValueType = std::variant; +template +struct is_array : std::false_type { +}; + +template +struct is_array> : std::true_type { +}; + +template +struct array_length : public std::integral_constant { +}; + +template +struct array_length> : public std::integral_constant { +}; + template constexpr static auto value_type() -> ValueType { @@ -71,9 +93,44 @@ constexpr static auto value_type() -> ValueType return Primitive::Int; } else if constexpr (std::is_floating_point_v) { return Primitive::Float; + } else if constexpr (is_array::value) { + auto len = array_length::value; + if constexpr (std::is_integral_v) { + return Tuple{Primitive::Int, len}; + } else if constexpr (std::is_floating_point_v) { + return Tuple{Primitive::Float, len}; + } else { + throw std::domain_error("Unsupported type"); + } + } else { + // TODO(michal): array + throw std::domain_error("Unsupported type"); + } +} + +template +constexpr static auto is_type(ValueType type) +{ + if constexpr (std::is_integral_v) { + return std::holds_alternative(type) + && std::get(type) == Primitive::Int; + } else if constexpr (std::is_floating_point_v) { + return std::holds_alternative(type) + && std::get(type) == Primitive::Float; + } else if constexpr (is_array::value) { + auto len = array_length::value; + if constexpr (std::is_integral_v) { + return std::holds_alternative(type) + && std::get(type) == Tuple{Primitive::Int, len}; + } else if constexpr (std::is_floating_point_v) { + return std::holds_alternative(type) + && std::get(type) == Tuple{Primitive::Float, len}; + } else { + throw std::domain_error("Unsupported type"); + } } else { - // TODO(michal): array and tuple - throw std::runtime_error(""); + // TODO(michal): array + throw std::domain_error("Unsupported type"); } } diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index e5a66a6dc..af42fabb4 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -19,6 +19,7 @@ namespace pisa::v1 { template struct RawCursor { static_assert(std::is_trivially_copyable_v); + using value_type = T; /// Creates a cursor from the encoded bytes. explicit constexpr RawCursor(gsl::span bytes); @@ -132,18 +133,22 @@ constexpr void RawCursor::step_to_geq(T value) template struct RawReader { static_assert(std::is_trivially_copyable::value); + using value_type = T; [[nodiscard]] auto read(gsl::span bytes) const -> RawCursor { return RawCursor(bytes); } - constexpr static auto encoding() -> std::uint64_t { return EncodingId::Raw; } + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw; } }; template struct RawWriter { static_assert(std::is_trivially_copyable::value); + using value_type = T; + + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw; } void push(T const &posting) { m_postings.push_back(posting); } void push(T &&posting) { m_postings.push_back(posting); } @@ -169,7 +174,6 @@ struct RawWriter { std::copy(memory.begin(), memory.end(), out); return out; } - constexpr static auto encoding() -> std::uint64_t { return EncodingId::Raw; } void reset() { m_postings.clear(); } private: diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 1bddda2cc..2b3c52e03 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -57,7 +57,7 @@ struct Writer { void push(T const &posting) { m_internal_writer->push(posting); } void push(T &&posting) { m_internal_writer->push(posting); } auto write(std::ostream &os) const -> std::size_t { return m_internal_writer->write(os); } - auto encoding() -> std::int32_t { return m_internal_writer->encoding(); } + [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_writer->encoding(); } void reset() { return m_internal_writer->reset(); } struct WriterInterface { @@ -70,7 +70,7 @@ struct Writer { virtual void push(T const &posting) = 0; virtual void push(T &&posting) = 0; virtual auto write(std::ostream &os) const -> std::size_t = 0; - virtual auto encoding() -> std::uint32_t = 0; + [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; virtual void reset() = 0; }; @@ -86,7 +86,7 @@ struct Writer { void push(T const &posting) override { m_writer.push(posting); } void push(T &&posting) override { m_writer.push(posting); } auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } - auto encoding() -> std::uint32_t override { return W::encoding(); } + [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } void reset() override { return m_writer.reset(); } private: diff --git a/test/test_v1.cpp b/test/test_v1.cpp index ab8b369d0..219b32bfe 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -7,90 +7,28 @@ #include #include -//#include "cursor/scored_cursor.hpp" #include "io.hpp" #include "pisa_config.hpp" -//#include "query/queries.hpp" -//#include "test_common.hpp" +#include "v1/cursor/collect.hpp" #include "v1/index.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/types.hpp" -//#include "wand_utils.hpp" using pisa::v1::Array; using pisa::v1::DocId; using pisa::v1::Frequency; -using pisa::v1::Index; +using pisa::v1::IndexRunner; using pisa::v1::parse_type; using pisa::v1::PostingBuilder; using pisa::v1::PostingFormatHeader; using pisa::v1::Primitive; -using pisa::v1::RawCursor; using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::TermId; using pisa::v1::Tuple; using pisa::v1::Writer; -// template -// struct IndexData { -// -// static std::unique_ptr data; -// -// IndexData() -// : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), -// document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), -// wdata(document_sizes.begin()->begin(), -// collection.num_docs(), -// collection, -// BlockSize(FixedBlock())) -// -// { -// tbb::task_scheduler_init init; -// typename Index::builder builder(collection.num_docs(), params); -// for (auto const &plist : collection) { -// uint64_t freqs_sum = -// std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); -// builder.add_posting_list( -// plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); -// } -// builder.build(index); -// -// term_id_vec q; -// std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); -// auto push_query = [&](std::string const &query_line) { -// queries.push_back(parse_query_ids(query_line)); -// }; -// io::for_each_line(qfile, push_query); -// -// std::string t; -// std::ifstream tin(PISA_SOURCE_DIR "/test/test_data/top5_thresholds"); -// while (std::getline(tin, t)) { -// thresholds.push_back(std::stof(t)); -// } -// } -// -// [[nodiscard]] static auto get() -// { -// if (IndexData::data == nullptr) { -// IndexData::data = std::make_unique>(); -// } -// return IndexData::data.get(); -// } -// -// global_parameters params; -// binary_freq_collection collection; -// binary_collection document_sizes; -// Index index; -// std::vector queries; -// std::vector thresholds; -// wand_data wdata; -//}; -// -// template -// std::unique_ptr> IndexData::data = nullptr; - template std::ostream &operator<<(std::ostream &os, tl::optional const &val) { @@ -115,33 +53,8 @@ TEST_CASE("RawReader", "[v1][unit]") REQUIRE(cursor.next() == tl::nullopt); } -template -auto collect(Cursor &&cursor, Transform transform) -{ - std::vector> vec; - while (not cursor.empty()) { - vec.push_back(transform(cursor)); - cursor.step(); - } - return vec; -} - -template -auto collect(Cursor &&cursor) -{ - return collect(std::forward(cursor), [](auto &&cursor) { return *cursor; }); -} - TEST_CASE("Binary collection index", "[v1][unit]") { - /* auto data = IndexData::get(); */ - /* ranked_or_query or_q(10); */ - - /* for (auto const &q : data->queries) { */ - /* or_q(make_scored_cursors(data->index, data->wdata, q), data->index.num_docs()); */ - /* auto results = or_q.topk(); */ - /* } */ - pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); auto index = pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); @@ -342,16 +255,55 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") == gsl::make_span(frequency_bytes.data(), frequency_bytes.size())); } - // Index, RawCursor>( - // RawReader{}, - // RawReader{}, - // document_offsets, - // frequency_offsets, - // gsl::span(reinterpret_cast(docbuf.data()), - // docbuf.size()), - // gsl::span(reinterpret_cast(freqbuf.data()), - // freqbuf.size()), - // true); + THEN("Index runner is correctly constructed") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(document_offsets, + frequency_offsets, + document_span, + payload_span, + std::move(source), + RawReader{}, + RawReader{}); // Repeat to test that it only + // executes once + int counter = 0; + runner([&](auto index) { + counter += 1; + TermId term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE( + std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.cursor(term_id))); + REQUIRE( + std::vector(sequence.freqs.begin(), sequence.freqs.end()) + == collect(index.cursor(term_id), + [](auto &&cursor) { return *cursor.payload(); })); + term_id += 1; + } + }); + REQUIRE(counter == 1); + } + + THEN("Index runner fails when wrong type") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(document_offsets, + frequency_offsets, + document_span, + payload_span, + std::move(source), + RawReader{}); // Correct encoding but not type! + REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); + } } } } From 0b03bcc88bcaad5b45c099b025ca6456d876f016 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 28 Oct 2019 14:15:57 -0400 Subject: [PATCH 11/56] Update cursor API --- include/pisa/v1/README.md | 10 +- include/pisa/v1/cursor/collect.hpp | 2 +- include/pisa/v1/cursor/for_each.hpp | 2 +- include/pisa/v1/cursor_intersection.hpp | 17 +-- include/pisa/v1/document_payload_cursor.hpp | 130 +++++-------------- include/pisa/v1/index.hpp | 2 +- include/pisa/v1/raw_cursor.hpp | 136 ++++++-------------- test/test_v1.cpp | 19 +-- 8 files changed, 98 insertions(+), 220 deletions(-) diff --git a/include/pisa/v1/README.md b/include/pisa/v1/README.md index 603120d30..aa30112d7 100644 --- a/include/pisa/v1/README.md +++ b/include/pisa/v1/README.md @@ -19,8 +19,12 @@ reference implementation of the discussed structures and some algorithms working The goal of this is to show how things work on certain examples, and find out what works and what doesn't and still needs to be thought through. +> Look in `test/test_v1.cpp` for code examples. + # Posting Files +> Example: `v1/raw_cursor.hpp`. + Each _posting file_ contains a list of blocks of data, each related to a single term, preceded by a header encoding information about the type of payload. @@ -41,6 +45,8 @@ Encoding is not fixed. ## Header +> Example: `v1/posting_format_header.hpp`. + We should store the type of the postings in the file, as well as encoding used. **This might be tricky because we'd like it to be an open set of values/encodings.** @@ -70,6 +76,4 @@ We should then also verify that this encoding implement a `Encoding` "conc This is not the same as our "codecs". This would be more like posting list reader. -``` -Encoding := ?? -``` +> Example: `IndexRunner` in `v1/index.hpp`. diff --git a/include/pisa/v1/cursor/collect.hpp b/include/pisa/v1/cursor/collect.hpp index 83bf53738..099f6c7cb 100644 --- a/include/pisa/v1/cursor/collect.hpp +++ b/include/pisa/v1/cursor/collect.hpp @@ -10,7 +10,7 @@ auto collect(Cursor &&cursor, Transform transform) std::vector> vec; while (not cursor.empty()) { vec.push_back(transform(cursor)); - cursor.step(); + cursor.advance(); } return vec; } diff --git a/include/pisa/v1/cursor/for_each.hpp b/include/pisa/v1/cursor/for_each.hpp index 5fd2df07e..b05dec9f1 100644 --- a/include/pisa/v1/cursor/for_each.hpp +++ b/include/pisa/v1/cursor/for_each.hpp @@ -9,7 +9,7 @@ void for_each(Cursor &&cursor, UnaryOp op) { while (not cursor.empty()) { op(cursor); - cursor.step(); + cursor.advance(); } } diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp index e41b27485..ebbbe310c 100644 --- a/include/pisa/v1/cursor_intersection.hpp +++ b/include/pisa/v1/cursor_intersection.hpp @@ -46,7 +46,7 @@ struct CursorIntersection { }) ->sentinel(); m_candidate = *m_cursors[0].get(); - next(); + advance(); } [[nodiscard]] constexpr auto operator*() const -> Value { return m_current_value; } @@ -58,12 +58,12 @@ struct CursorIntersection { return tl::nullopt; } - constexpr void step() + constexpr void advance() { while (PISA_LIKELY(m_candidate < sentinel())) { for (; m_next_cursor < m_cursors.size(); ++m_next_cursor) { Cursor &cursor = m_cursors[m_next_cursor]; - cursor.step_to_geq(m_candidate); + cursor.advance_to_geq(m_candidate); if (*cursor != m_candidate) { m_candidate = *cursor; m_next_cursor = 0; @@ -76,7 +76,7 @@ struct CursorIntersection { m_current_payload = m_accumulate( m_current_payload, m_cursors[idx].get(), m_cursor_mapping[idx]); } - m_cursors[0].get().step(); + m_cursors[0].get().advance(); m_current_value = std::exchange(m_candidate, *m_cursors[0].get()); m_next_cursor = 1; return; @@ -91,13 +91,8 @@ struct CursorIntersection { return m_current_payload; } - constexpr void step_to_position(std::size_t pos); // TODO(michal) - constexpr void step_to_geq(Value value); // TODO(michal) - constexpr auto next() -> tl::optional - { - step(); - return value(); - } + constexpr void advance_to_position(std::size_t pos); // TODO(michal) + constexpr void advance_to_geq(Value value); // TODO(michal) [[nodiscard]] constexpr auto empty() const noexcept -> bool { diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index 2fecb7f4f..fa5988c84 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -3,7 +3,6 @@ #include #include -#include template struct DocumentPayloadCursor { @@ -11,104 +10,41 @@ struct DocumentPayloadCursor { using Payload = decltype(*std::declval()); explicit constexpr DocumentPayloadCursor(DocumentCursor key_cursor, - PayloadCursor payload_cursor); - [[nodiscard]] constexpr auto operator*() const -> Document; - [[nodiscard]] constexpr auto value() const noexcept -> tl::optional; - [[nodiscard]] constexpr auto payload() const noexcept -> tl::optional; - constexpr void step(); - constexpr void step_to_position(std::size_t pos); - constexpr void step_to_geq(Document value); - constexpr auto next() -> tl::optional; - [[nodiscard]] constexpr auto empty() const noexcept -> bool; - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; - [[nodiscard]] constexpr auto size() const -> std::size_t; - [[nodiscard]] constexpr auto sentinel() const -> Document; + PayloadCursor payload_cursor) + : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) + { + } + + [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document { return m_key_cursor.value(); } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload + { + return m_payload_cursor.value(); + } + constexpr void advance() + { + m_key_cursor.advance(); + m_payload_cursor.advance(); + } + constexpr void advance_to_position(std::size_t pos) + { + m_key_cursor.advance_to_position(pos); + m_payload_cursor.advance_to_position(pos); + } + constexpr void advance_to_geq(Document value) + { + m_key_cursor.advance_to_geq(value); + m_payload_cursor.advance_to_position(m_key_cursor.position()); + } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_key_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_key_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_key_cursor.sentinel(); } private: DocumentCursor m_key_cursor; PayloadCursor m_payload_cursor; }; - -template -constexpr DocumentPayloadCursor::DocumentPayloadCursor( - DocumentCursor key_cursor, PayloadCursor payload_cursor) - : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) -{ -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::operator*() const - -> Document -{ - return *m_key_cursor; -} -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::sentinel() const - -> Document -{ - return m_key_cursor.sentinel(); -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::value() const - noexcept -> tl::optional -{ - return m_key_cursor.value(); -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::payload() const - noexcept -> tl::optional -{ - return m_payload_cursor.value(); -} - -template -constexpr void DocumentPayloadCursor::step() -{ - m_key_cursor.step(); - m_payload_cursor.step(); -} - -template -constexpr void DocumentPayloadCursor::step_to_position( - std::size_t pos) -{ - m_key_cursor.step_to_position(pos); - m_payload_cursor.step_to_position(pos); -} - -template -constexpr void DocumentPayloadCursor::step_to_geq(Document value) -{ - m_key_cursor.step_to_geq(value); - m_payload_cursor.step_to_position(m_key_cursor.position()); -} - -template -constexpr auto DocumentPayloadCursor::next() - -> tl::optional -{ - return m_key_cursor.next(); -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::empty() const - noexcept -> bool -{ - return m_key_cursor.empty(); -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::position() const - noexcept -> std::size_t -{ - return m_key_cursor.position(); -} - -template -[[nodiscard]] constexpr auto DocumentPayloadCursor::size() const - -> std::size_t -{ - return m_key_cursor.size(); -} diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 8e045f038..bfc185e8b 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -301,7 +301,7 @@ struct BigramIndex : public Index { std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, payload_type{0, 0}, [](payload_type &payload, auto &cursor, auto list_idx) { - payload[list_idx] = *cursor.payload(); + payload[list_idx] = cursor.payload(); return payload; }); if (intersection.empty()) { diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index af42fabb4..2953327a9 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -15,6 +15,16 @@ namespace pisa::v1 { +template +[[nodiscard]] auto next(Cursor &&cursor) -> tl::optional::value_type> +{ + cursor.advance(); + if (cursor.empty()) { + return tl::nullopt; + } + return tl::make_optional(cursor.value()); +} + /// Uncompressed example of implementation of a single value cursor. template struct RawCursor { @@ -22,114 +32,57 @@ struct RawCursor { using value_type = T; /// Creates a cursor from the encoded bytes. - explicit constexpr RawCursor(gsl::span bytes); + explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) + { + Expects(m_bytes.size() % sizeof(T) == 0); + } /// Dereferences the current value. - /// It is an undefined behavior to call this when `empty() == true`. - [[nodiscard]] constexpr auto operator*() const -> T; + [[nodiscard]] constexpr auto operator*() const -> T + { + if (PISA_UNLIKELY(empty())) { + return sentinel(); + } + return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + } - /// Safely returns the current value, or returns `nullopt` if `empty() == true`. - [[nodiscard]] constexpr auto value() const noexcept -> tl::optional; + /// Alias for `operator*()`. + [[nodiscard]] constexpr auto value() const noexcept -> T { return *(*this); } - /// Moves the cursor to the next position. - constexpr void step(); + /// Advances the cursor to the next position. + constexpr void advance() { m_current += sizeof(T); } /// Moves the cursor to the position `pos`. - constexpr void step_to_position(std::size_t pos); + constexpr void advance_to_position(std::size_t pos) { m_current = pos; } /// Moves the cursor to the next value equal or greater than `value`. - constexpr void step_to_geq(T value); - - /// This is semantically equivalent to first calling `step()` and then `value()`. - constexpr auto next() -> tl::optional; + constexpr void advance_to_geq(T value) + { + while (not empty() && *(*this) < value) { + advance(); + } + } /// Returns `true` if there is no elements left. - [[nodiscard]] constexpr auto empty() const noexcept -> bool; + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current == m_bytes.size(); + } /// Returns the current position. - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_current; } /// Returns the number of elements in the list. - [[nodiscard]] constexpr auto size() const -> std::size_t; + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. - [[nodiscard]] constexpr auto sentinel() const -> T; + [[nodiscard]] constexpr auto sentinel() const -> T { return std::numeric_limits::max(); } private: std::size_t m_current = 0; gsl::span m_bytes; }; -template -constexpr RawCursor::RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) -{ - Expects(m_bytes.size() % sizeof(T) == 0); -} - -template -[[nodiscard]] constexpr auto RawCursor::operator*() const -> T -{ - if (PISA_UNLIKELY(empty())) { - return sentinel(); - } - return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); -} - -template -[[nodiscard]] constexpr auto RawCursor::sentinel() const -> T -{ - return std::numeric_limits::max(); -} - -template -[[nodiscard]] constexpr auto RawCursor::value() const noexcept -> tl::optional -{ - return empty() ? tl::nullopt : tl::make_optional(*(*this)); -} -template -constexpr auto RawCursor::next() -> tl::optional -{ - step(); - return value(); -} -template -constexpr void RawCursor::step() -{ - m_current += sizeof(T); -} - -template -[[nodiscard]] constexpr auto RawCursor::empty() const noexcept -> bool -{ - return m_current == m_bytes.size(); -} - -template -[[nodiscard]] constexpr auto RawCursor::position() const noexcept -> std::size_t -{ - return m_current; -} - -template -[[nodiscard]] constexpr auto RawCursor::size() const -> std::size_t -{ - return m_bytes.size() / sizeof(T); -} - -template -constexpr void RawCursor::step_to_position(std::size_t pos) -{ - m_current = pos; -} - -template -constexpr void RawCursor::step_to_geq(T value) -{ - while (not empty() && *(*this) < value) { - step(); - } -} - template struct RawReader { static_assert(std::is_trivially_copyable::value); @@ -163,17 +116,6 @@ struct RawWriter { return sizeof(length) + memory.size(); } - template - auto append(OutputByteIterator out) const -> OutputByteIterator - { - assert(!m_postings.empty()); - std::uint32_t length = m_postings.size(); - auto length_bytes = gsl::as_bytes(gsl::make_span(&length, 1)); - auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); - std::copy(length_bytes.begin(), length_bytes.end(), out); - std::copy(memory.begin(), memory.end(), out); - return out; - } void reset() { m_postings.clear(); } private: diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 219b32bfe..27a8e34c2 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -19,6 +19,7 @@ using pisa::v1::Array; using pisa::v1::DocId; using pisa::v1::Frequency; using pisa::v1::IndexRunner; +using pisa::v1::next; using pisa::v1::parse_type; using pisa::v1::PostingBuilder; using pisa::v1::PostingFormatHeader; @@ -45,12 +46,12 @@ TEST_CASE("RawReader", "[v1][unit]") std::vector const mem{5, 0, 1, 2, 3, 4}; RawReader reader; auto cursor = reader.read(gsl::as_bytes(gsl::make_span(mem))); - REQUIRE(cursor.value().value() == tl::make_optional(mem[1]).value()); - REQUIRE(cursor.next() == tl::make_optional(mem[2])); - REQUIRE(cursor.next() == tl::make_optional(mem[3])); - REQUIRE(cursor.next() == tl::make_optional(mem[4])); - REQUIRE(cursor.next() == tl::make_optional(mem[5])); - REQUIRE(cursor.next() == tl::nullopt); + REQUIRE(cursor.value() == mem[1]); + REQUIRE(next(cursor) == tl::make_optional(mem[2])); + REQUIRE(next(cursor) == tl::make_optional(mem[3])); + REQUIRE(next(cursor) == tl::make_optional(mem[4])); + REQUIRE(next(cursor) == tl::make_optional(mem[5])); + REQUIRE(next(cursor) == tl::nullopt); } TEST_CASE("Binary collection index", "[v1][unit]") @@ -73,7 +74,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) == collect(index.cursor(term_id))); REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) - == collect(index.cursor(term_id), [](auto &&cursor) { return *cursor.payload(); })); + == collect(index.cursor(term_id), [](auto &&cursor) { return cursor.payload(); })); term_id += 1; } } @@ -123,7 +124,7 @@ TEST_CASE("Bigram collection index", "[v1][unit]") auto id = index.bigram_id(term_id - 1, term_id); REQUIRE(id.has_value()); auto postings = collect(index.cursor(*id), [](auto &cursor) { - auto freqs = *cursor.payload(); + auto freqs = cursor.payload(); return std::make_tuple(*cursor, freqs[0], freqs[1]); }); REQUIRE(postings == intersection); @@ -282,7 +283,7 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") REQUIRE( std::vector(sequence.freqs.begin(), sequence.freqs.end()) == collect(index.cursor(term_id), - [](auto &&cursor) { return *cursor.payload(); })); + [](auto &&cursor) { return cursor.payload(); })); term_id += 1; } }); From 8f9faa8e4fe2acd8b99e45f56f1a380c19aa2454 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 29 Oct 2019 17:43:00 -0400 Subject: [PATCH 12/56] On-the-fly BM25 scoring --- include/pisa/v1/cursor/scoring_cursor.hpp | 47 ++++++ include/pisa/v1/cursor_intersection.hpp | 15 +- include/pisa/v1/cursor_union.hpp | 78 ++++++--- include/pisa/v1/document_payload_cursor.hpp | 4 + include/pisa/v1/index.hpp | 82 ++++++++- include/pisa/v1/raw_cursor.hpp | 3 +- include/pisa/v1/scorer/bm25.hpp | 49 ++++++ test/test_v1.cpp | 9 + test/test_v1_queries.cpp | 177 ++++++++++++++++++++ 9 files changed, 420 insertions(+), 44 deletions(-) create mode 100644 include/pisa/v1/cursor/scoring_cursor.hpp create mode 100644 include/pisa/v1/scorer/bm25.hpp create mode 100644 test/test_v1_queries.cpp diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp new file mode 100644 index 000000000..f9e3d9adf --- /dev/null +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include + +#include + +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct ScoringCursor { + using Document = decltype(*std::declval()); + using Payload = float; + static_assert(std::is_same_v); + + explicit constexpr ScoringCursor(BaseCursor base_cursor, TermScorer scorer) + : m_base_cursor(std::move(base_cursor)), m_scorer(std::move(scorer)) + { + } + + [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload + { + return m_scorer(m_base_cursor.value(), m_base_cursor.payload()); + } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + + private: + BaseCursor m_base_cursor; + TermScorer m_scorer; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp index ebbbe310c..de10198c3 100644 --- a/include/pisa/v1/cursor_intersection.hpp +++ b/include/pisa/v1/cursor_intersection.hpp @@ -5,12 +5,13 @@ #include #include +#include #include #include #include "util/likely.hpp" -namespace pisa { +namespace pisa::v1 { /// Transforms a list of cursors into one cursor by lazily merging them together /// into an intersection. @@ -50,13 +51,7 @@ struct CursorIntersection { } [[nodiscard]] constexpr auto operator*() const -> Value { return m_current_value; } - [[nodiscard]] constexpr auto value() const noexcept -> tl::optional - { - if (PISA_LIKELY(m_candidate < sentinel())) { - return m_current_value; - } - return tl::nullopt; - } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } constexpr void advance() { @@ -96,7 +91,7 @@ struct CursorIntersection { [[nodiscard]] constexpr auto empty() const noexcept -> bool { - return m_candidate >= sentinel(); + return m_current_value >= sentinel(); } [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) @@ -126,4 +121,4 @@ template std::move(cursors), std::move(init), std::move(accumulate)); } -} // namespace pisa +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index 02cf4d8bd..beb88c936 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -8,7 +8,7 @@ #include "util/likely.hpp" -namespace pisa { +namespace pisa::v1 { /// Transforms a list of cursors into one cursor by lazily merging them together. template @@ -18,23 +18,27 @@ struct CursorUnion { typename std::iterator_traits::iterator_category; static_assert(std::is_base_of(), "cursors must be stored in a random access container"); - constexpr CursorUnion(CursorContainer cursors, - std::size_t max_docid, - Payload init, - AccumulateFn accumulate) + using Value = std::decay_t())>; + + constexpr CursorUnion(CursorContainer cursors, Payload init, AccumulateFn accumulate) : m_cursors(std::move(cursors)), - m_init(init), + m_init(std::move(init)), m_accumulate(std::move(accumulate)), - m_size(std::nullopt), - m_max_docid(max_docid) + m_size(std::nullopt) { Expects(not m_cursors.empty()); - auto order = [](auto const &lhs, auto const &rhs) { return lhs.docid() < rhs.docid(); }; + auto order = [](auto const &lhs, auto const &rhs) { return lhs.value() < rhs.value(); }; m_next_docid = [&]() { auto pos = std::min_element(m_cursors.begin(), m_cursors.end(), order); - return pos->docid(); + return pos->value(); }(); - next(); + m_sentinel = std::min_element(m_cursors.begin(), + m_cursors.end(), + [](auto const &lhs, auto const &rhs) { + return lhs.sentinel() < rhs.sentinel(); + }) + ->sentinel(); + advance(); } [[nodiscard]] constexpr auto size() const noexcept -> std::size_t @@ -47,45 +51,65 @@ struct CursorUnion { } return *m_size; } - [[nodiscard]] constexpr auto docid() const noexcept -> std::uint32_t { return m_current_docid; } + [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } [[nodiscard]] constexpr auto payload() const noexcept -> Payload const & { return m_current_payload; } - [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_max_docid; } - constexpr void next() + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() { - if (PISA_UNLIKELY(m_next_docid == m_max_docid)) { - m_current_docid = m_max_docid; + if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { + m_current_value = m_sentinel; m_current_payload = m_init; } else { m_current_payload = m_init; - m_current_docid = m_next_docid; - m_next_docid = m_max_docid; + m_current_value = m_next_docid; + m_next_docid = m_sentinel; std::size_t cursor_idx = 0; for (auto &cursor : m_cursors) { - if (cursor.docid() == m_current_docid) { + if (cursor.value() == m_current_value) { m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); - cursor.next(); + cursor.advance(); } - if (cursor.docid() < m_next_docid) { - m_next_docid = cursor.docid(); + if (cursor.value() < m_next_docid) { + m_next_docid = cursor.value(); } ++cursor_idx; } } } + constexpr void advance_to_position(std::size_t pos); // TODO(michal) + constexpr void advance_to_geq(Value value); // TODO(michal) + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + private: CursorContainer m_cursors; Payload m_init; AccumulateFn m_accumulate; std::optional m_size; - std::uint32_t m_max_docid; - std::uint32_t m_current_docid = 0; - Payload m_current_payload; - std::uint32_t m_next_docid; + Value m_current_value{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_docid{}; }; -} // namespace pisa +template +[[nodiscard]] constexpr inline auto union_merge(CursorContainer cursors, + Payload init, + AccumulateFn accumulate) +{ + return CursorUnion( + std::move(cursors), std::move(init), std::move(accumulate)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index fa5988c84..0a04aa45b 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -4,6 +4,8 @@ #include +namespace pisa::v1 { + template struct DocumentPayloadCursor { using Document = decltype(*std::declval()); @@ -48,3 +50,5 @@ struct DocumentPayloadCursor { DocumentCursor m_key_cursor; PayloadCursor m_payload_cursor; }; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index bfc185e8b..d9aed492f 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -18,10 +18,12 @@ #include "binary_freq_collection.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor/for_each.hpp" +#include "v1/cursor/scoring_cursor.hpp" #include "v1/cursor_intersection.hpp" #include "v1/document_payload_cursor.hpp" #include "v1/posting_builder.hpp" #include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -56,6 +58,8 @@ struct Index { std::vector payload_offsets, gsl::span documents, gsl::span payloads, + std::vector document_lengths, + tl::optional avg_document_length, Source source) : m_document_reader(std::move(document_reader)), m_payload_reader(std::move(payload_reader)), @@ -63,42 +67,82 @@ struct Index { m_payload_offsets(std::move(payload_offsets)), m_documents(documents), m_payloads(payloads), + m_document_lengths(std::move(document_lengths)), + m_avg_document_length(avg_document_length.map_or_else( + [](auto &&self) { return self; }, + [&]() { return calc_avg_length(m_document_lengths); })), m_source(std::move(source)) { } /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). - [[nodiscard]] auto cursor(TermId term) + [[nodiscard]] auto cursor(TermId term) const { return DocumentPayloadCursor(documents(term), payloads(term)); } + /// Constructs a new document-score cursor. + template + [[nodiscard]] auto scoring_cursor(TermId term, Scorer &&scorer) const + { + return ScoringCursor(cursor(term), std::forward(scorer).term_scorer(term)); + } + /// Constructs a new document cursor. - [[nodiscard]] auto documents(TermId term) + [[nodiscard]] auto documents(TermId term) const { return m_document_reader.read(fetch_documents(term)); } /// Constructs a new payload cursor. - [[nodiscard]] auto payloads(TermId term) { return m_payload_reader.read(fetch_payloads(term)); } + [[nodiscard]] auto payloads(TermId term) const + { + return m_payload_reader.read(fetch_payloads(term)); + } /// Constructs a new payload cursor. - [[nodiscard]] auto num_terms() -> std::uint32_t { return m_document_offsets.size() - 1; } + [[nodiscard]] auto num_terms() const -> std::uint32_t { return m_document_offsets.size() - 1; } + + [[nodiscard]] auto num_documents() const -> std::uint32_t { return m_document_lengths.size(); } + + [[nodiscard]] auto term_posting_count(TermId term) const -> std::uint32_t + { + // TODO(michal): Should be done more efficiently. + return documents(term).size(); + } + + [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t + { + return m_document_lengths[docid]; + } + + [[nodiscard]] auto avg_document_length() const -> float { return m_avg_document_length; } + + [[nodiscard]] auto normalized_document_length(DocId docid) const -> float + { + return document_length(docid) / avg_document_length(); + } private: - [[nodiscard]] auto fetch_documents(TermId term) -> gsl::span + [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span { Expects(term + 1 < m_document_offsets.size()); return m_documents.subspan(m_document_offsets[term], m_document_offsets[term + 1] - m_document_offsets[term]); } - [[nodiscard]] auto fetch_payloads(TermId term) -> gsl::span + [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span { Expects(term + 1 < m_payload_offsets.size()); return m_payloads.subspan(m_payload_offsets[term], m_payload_offsets[term + 1] - m_payload_offsets[term]); } + [[nodiscard]] static auto calc_avg_length(std::vector const &lengths) + -> std::uint32_t + { + auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); + return static_cast(sum) / lengths.size(); + } Reader m_document_reader; Reader m_payload_reader; @@ -106,6 +150,8 @@ struct Index { std::vector m_payload_offsets; gsl::span m_documents; gsl::span m_payloads; + std::vector m_document_lengths; + std::uint32_t m_avg_document_length; std::any m_source; }; @@ -116,6 +162,8 @@ auto make_index(DocumentReader document_reader, std::vector payload_offsets, gsl::span documents, gsl::span payloads, + std::vector document_lengths, + tl::optional avg_document_length, Source source) { using DocumentCursor = @@ -127,6 +175,8 @@ auto make_index(DocumentReader document_reader, std::move(payload_offsets), documents, payloads, + std::move(document_lengths), + avg_document_length, std::move(source)); } @@ -141,6 +191,13 @@ inline void open_source(mio::mmap_source &source, std::string const &filename) } } +inline auto read_sizes(std::string_view basename) +{ + binary_collection sizes(fmt::format("{}.sizes", basename).c_str()); + auto sequence = *sizes.begin(); + return std::vector(sequence.begin(), sequence.end()); +} + [[nodiscard]] inline auto binary_collection_index(std::string const &basename) { using sink_type = boost::iostreams::back_insert_device>; @@ -181,6 +238,8 @@ inline void open_source(mio::mmap_source &source, std::string const &filename) frequency_offsets, documents.subspan(8), frequencies.subspan(8), + read_sizes(basename), + tl::nullopt, std::move(source)); } @@ -191,12 +250,16 @@ struct IndexRunner { std::vector payload_offsets, gsl::span documents, gsl::span payloads, + std::vector document_lengths, + tl::optional avg_document_length, Source source, Readers... readers) : m_document_offsets(std::move(document_offsets)), m_payload_offsets(std::move(payload_offsets)), m_documents(documents), m_payloads(payloads), + m_document_lengths(std::move(document_lengths)), + m_avg_document_length(avg_document_length), m_source(std::move(source)), m_readers(readers...) { @@ -218,6 +281,8 @@ struct IndexRunner { m_payload_offsets, m_documents.subspan(8), m_payloads.subspan(8), + m_document_lengths, + m_avg_document_length, false)); return true; } @@ -243,6 +308,8 @@ struct IndexRunner { std::vector m_payload_offsets; gsl::span m_documents; gsl::span m_payloads; + std::vector m_document_lengths; + tl::optional m_avg_document_length; std::any m_source; std::tuple m_readers; }; @@ -323,12 +390,15 @@ struct BigramIndex : public Index { reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); + auto index = Index, RawCursor>(RawReader{}, RawReader{}, document_builder.offsets(), frequency_builder.offsets(), document_span.subspan(8), payload_span.subspan(8), + read_sizes(basename), + tl::nullopt, std::move(source)); return BigramIndex(std::move(index), std::move(pair_mapping)); } diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 2953327a9..66949d98e 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -35,6 +35,7 @@ struct RawCursor { explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) { Expects(m_bytes.size() % sizeof(T) == 0); + Expects(not m_bytes.empty()); } /// Dereferences the current value. @@ -58,7 +59,7 @@ struct RawCursor { /// Moves the cursor to the next value equal or greater than `value`. constexpr void advance_to_geq(T value) { - while (not empty() && *(*this) < value) { + while (this->value() < value) { advance(); } } diff --git a/include/pisa/v1/scorer/bm25.hpp b/include/pisa/v1/scorer/bm25.hpp new file mode 100644 index 000000000..d1d315a62 --- /dev/null +++ b/include/pisa/v1/scorer/bm25.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct BM25 { + static constexpr float b = 0.4; + static constexpr float k1 = 0.9; + + explicit BM25(Index const &index) : m_index(index) {} + + static float doc_term_weight(uint64_t freq, float norm_len) + { + auto f = static_cast(freq); + return f / (f + k1 * (1.0F - b + b * norm_len)); + } + + static float query_term_weight(uint64_t df, uint64_t num_docs) + { + auto fdf = static_cast(df); + float idf = std::log((float(num_docs) - fdf + 0.5F) / (fdf + 0.5F)); + static const float epsilon_score = 1.0E-6; + return std::max(epsilon_score, idf) * (1.0F + k1); + } + + auto term_scorer(TermId term_id) const + { + auto term_weight = + query_term_weight(m_index.term_posting_count(term_id), m_index.num_documents()); + return [this, term_weight](uint32_t doc, uint32_t freq) { + return term_weight + * doc_term_weight(freq, this->m_index.normalized_document_length(doc)); + }; + } + + private: + Index const &m_index; +}; + +} // namespace pisa::v1 diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 27a8e34c2..5cd230e11 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -13,6 +13,7 @@ #include "v1/index.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" +#include "v1/scorer/bm25.hpp" #include "v1/types.hpp" using pisa::v1::Array; @@ -26,6 +27,7 @@ using pisa::v1::PostingFormatHeader; using pisa::v1::Primitive; using pisa::v1::RawReader; using pisa::v1::RawWriter; +using pisa::v1::read_sizes; using pisa::v1::TermId; using pisa::v1::Tuple; using pisa::v1::Writer; @@ -236,6 +238,8 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") auto document_offsets = document_builder.offsets(); auto frequency_offsets = frequency_builder.offsets(); + auto document_sizes = read_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection"); + THEN("Bytes match with those of the collection") { auto document_bytes = @@ -263,10 +267,13 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(document_offsets, frequency_offsets, document_span, payload_span, + document_sizes, + tl::nullopt, std::move(source), RawReader{}, RawReader{}); // Repeat to test that it only @@ -301,6 +308,8 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") frequency_offsets, document_span, payload_span, + document_sizes, + tl::nullopt, std::move(source), RawReader{}); // Correct encoding but not type! REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); diff --git a/test/test_v1_queries.cpp b/test/test_v1_queries.cpp new file mode 100644 index 000000000..8fd9019a8 --- /dev/null +++ b/test/test_v1_queries.cpp @@ -0,0 +1,177 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "accumulator/lazy_accumulator.hpp" +#include "cursor/block_max_scored_cursor.hpp" +#include "cursor/max_scored_cursor.hpp" +#include "cursor/scored_cursor.hpp" +#include "index_types.hpp" +#include "io.hpp" +#include "pisa_config.hpp" +#include "query/queries.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/cursor_union.hpp" +#include "v1/index.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/types.hpp" + +namespace v1 = pisa::v1; +using namespace pisa; + +template +struct IndexData { + + static std::unique_ptr data; + + IndexData() + : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), + document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), + v1_index( + pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection")), + wdata(document_sizes.begin()->begin(), + collection.num_docs(), + collection, + BlockSize(FixedBlock())) + + { + tbb::task_scheduler_init init; + typename v0_Index::builder builder(collection.num_docs(), params); + for (auto const &plist : collection) { + uint64_t freqs_sum = + std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); + builder.add_posting_list( + plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); + } + builder.build(v0_index); + + term_id_vec q; + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + auto push_query = [&](std::string const &query_line) { + queries.push_back(parse_query_ids(query_line)); + }; + io::for_each_line(qfile, push_query); + + std::string t; + std::ifstream tin(PISA_SOURCE_DIR "/test/test_data/top5_thresholds"); + while (std::getline(tin, t)) { + thresholds.push_back(std::stof(t)); + } + } + + [[nodiscard]] static auto get() + { + if (IndexData::data == nullptr) { + IndexData::data = std::make_unique>(); + } + return IndexData::data.get(); + } + + global_parameters params; + binary_freq_collection collection; + binary_collection document_sizes; + v0_Index v0_index; + v1_Index v1_index; + std::vector queries; + std::vector thresholds; + wand_data wdata; +}; + +template +std::unique_ptr> IndexData::data = nullptr; + +template +auto daat_and(Query const &query, Index const &index, topk_queue topk) +{ + v1::BM25 scorer(index); + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.scoring_cursor(term, scorer); }); + auto intersection = + v1::intersect(std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(intersection, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); + return topk; +} + +template +auto daat_or(Query const &query, Index const &index, topk_queue topk) +{ + v1::BM25 scorer(index); + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.scoring_cursor(term, scorer); }); + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(cunion, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); + return topk; +} + +TEST_CASE("DAAT AND", "[v1][integration]") +{ + auto data = IndexData, v1::RawCursor>>::get(); + ranked_and_query and_q(10); + int idx = 0; + for (auto const &q : data->queries) { + CAPTURE(q.terms); + CAPTURE(idx++); + and_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); + auto que = daat_and(q, data->v1_index, topk_queue(10)); + que.finalize(); + + auto expected = and_q.topk(); + std::sort(expected.begin(), expected.end(), std::greater{}); + auto actual = que.topk(); + std::sort(actual.begin(), actual.end(), std::greater{}); + + REQUIRE(expected.size() == actual.size()); + for (size_t i = 0; i < actual.size(); ++i) { + REQUIRE(actual[i].second == expected[i].second); + REQUIRE(actual[i].first == Approx(expected[i].first).epsilon(0.1)); + } + } +} + +TEST_CASE("DAAT OR", "[v1][integration]") +{ + auto data = IndexData, v1::RawCursor>>::get(); + ranked_or_query or_q(10); + int idx = 0; + for (auto const &q : data->queries) { + CAPTURE(q.terms); + CAPTURE(idx++); + or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); + auto que = daat_or(q, data->v1_index, topk_queue(10)); + que.finalize(); + + auto expected = or_q.topk(); + std::sort(expected.begin(), expected.end(), std::greater{}); + auto actual = que.topk(); + std::sort(actual.begin(), actual.end(), std::greater{}); + + REQUIRE(expected.size() == actual.size()); + for (size_t i = 0; i < actual.size(); ++i) { + REQUIRE(actual[i].second == expected[i].second); + REQUIRE(actual[i].first == Approx(expected[i].first).epsilon(0.1)); + } + } +} From 3c5ad7d44709624cd26222e4d43ae43a48cbee98 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 30 Oct 2019 11:29:05 -0400 Subject: [PATCH 13/56] Precomputed scores --- include/pisa/io.hpp | 14 ++ include/pisa/v1/index.hpp | 299 ++++++++++++++++++---------- include/pisa/v1/posting_builder.hpp | 13 +- include/pisa/v1/raw_cursor.hpp | 7 +- include/pisa/v1/scorer/bm25.hpp | 22 +- include/pisa/v1/scorer/runner.hpp | 43 ++++ include/pisa/v1/source.hpp | 18 ++ include/pisa/v1/types.hpp | 12 +- test/test_v1.cpp | 30 ++- test/test_v1_queries.cpp | 103 ++++++---- 10 files changed, 396 insertions(+), 165 deletions(-) create mode 100644 include/pisa/v1/scorer/runner.hpp create mode 100644 include/pisa/v1/source.hpp diff --git a/include/pisa/io.hpp b/include/pisa/io.hpp index e7c8bdc45..29446785d 100644 --- a/include/pisa/io.hpp +++ b/include/pisa/io.hpp @@ -61,4 +61,18 @@ void for_each_line(std::istream &is, Function fn) return data; } +[[nodiscard]] inline auto load_bytes(std::string const &data_file) +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + data.resize(size); + if (not in.read(data.data(), size)) { + throw std::runtime_error("Failed reading " + data_file); + } + return data; +} + } // namespace pisa::io diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index d9aed492f..f1ef690a1 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -24,6 +25,7 @@ #include "v1/posting_builder.hpp" #include "v1/raw_cursor.hpp" #include "v1/scorer/bm25.hpp" +#include "v1/source.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -54,20 +56,20 @@ struct Index { template Index(DocumentReader document_reader, PayloadReader payload_reader, - std::vector document_offsets, - std::vector payload_offsets, + gsl::span document_offsets, + gsl::span payload_offsets, gsl::span documents, gsl::span payloads, - std::vector document_lengths, + gsl::span document_lengths, tl::optional avg_document_length, Source source) : m_document_reader(std::move(document_reader)), m_payload_reader(std::move(payload_reader)), - m_document_offsets(std::move(document_offsets)), - m_payload_offsets(std::move(payload_offsets)), + m_document_offsets(document_offsets), + m_payload_offsets(payload_offsets), m_documents(documents), m_payloads(payloads), - m_document_lengths(std::move(document_lengths)), + m_document_lengths(document_lengths), m_avg_document_length(avg_document_length.map_or_else( [](auto &&self) { return self; }, [&]() { return calc_avg_length(m_document_lengths); })), @@ -89,6 +91,18 @@ struct Index { return ScoringCursor(cursor(term), std::forward(scorer).term_scorer(term)); } + /// This is equivalent to the `scoring_cursor` unless the scorer is of type `VoidScorer`, + /// in which case index payloads are treated as scores. + template + [[nodiscard]] auto scored_cursor(TermId term, Scorer &&scorer) const + { + if constexpr (std::is_convertible_v) { + return cursor(term); + } else { + return scoring_cursor(term, std::forward(scorer)); + } + } + /// Constructs a new document cursor. [[nodiscard]] auto documents(TermId term) const { @@ -137,7 +151,7 @@ struct Index { return m_payloads.subspan(m_payload_offsets[term], m_payload_offsets[term + 1] - m_payload_offsets[term]); } - [[nodiscard]] static auto calc_avg_length(std::vector const &lengths) + [[nodiscard]] static auto calc_avg_length(gsl::span const &lengths) -> std::uint32_t { auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); @@ -146,11 +160,11 @@ struct Index { Reader m_document_reader; Reader m_payload_reader; - std::vector m_document_offsets; - std::vector m_payload_offsets; + gsl::span m_document_offsets; + gsl::span m_payload_offsets; gsl::span m_documents; gsl::span m_payloads; - std::vector m_document_lengths; + gsl::span m_document_lengths; std::uint32_t m_avg_document_length; std::any m_source; }; @@ -158,11 +172,11 @@ struct Index { template auto make_index(DocumentReader document_reader, PayloadReader payload_reader, - std::vector document_offsets, - std::vector payload_offsets, + gsl::span document_offsets, + gsl::span payload_offsets, gsl::span documents, gsl::span payloads, - std::vector document_lengths, + gsl::span document_lengths, tl::optional avg_document_length, Source source) { @@ -171,15 +185,31 @@ auto make_index(DocumentReader document_reader, using PayloadCursor = decltype(payload_reader.read(std::declval>())); return Index(std::move(document_reader), std::move(payload_reader), - std::move(document_offsets), - std::move(payload_offsets), + document_offsets, + payload_offsets, documents, payloads, - std::move(document_lengths), + document_lengths, avg_document_length, std::move(source)); } +template +auto score_index(Index const &index, ByteOStream &os, Writer writer, Scorer scorer) + -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), + [&](auto &cursor) { score_builder.accumulate(cursor.payload()); }); + score_builder.flush_segment(os); + }); + return std::move(score_builder.offsets()); +} + /// Initializes a memory mapped source with a given file. inline void open_source(mio::mmap_source &source, std::string const &filename) { @@ -198,14 +228,15 @@ inline auto read_sizes(std::string_view basename) return std::vector(sequence.begin(), sequence.end()); } -[[nodiscard]] inline auto binary_collection_index(std::string const &basename) +[[nodiscard]] inline auto binary_collection_source(std::string const &basename) { - using sink_type = boost::iostreams::back_insert_device>; + using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; binary_freq_collection collection(basename.c_str()); - std::vector docbuf; - std::vector freqbuf; + VectorSource source{{{}, {}}, {{}, {}}, {read_sizes(basename)}}; + std::vector &docbuf = source.bytes[0]; + std::vector &freqbuf = source.bytes[1]; PostingBuilder document_builder(Writer(RawWriter{})); PostingBuilder frequency_builder(Writer(RawWriter{})); @@ -223,96 +254,70 @@ inline auto read_sizes(std::string_view basename) } } - auto document_offsets = document_builder.offsets(); - auto frequency_offsets = frequency_builder.offsets(); - auto source = std::make_shared, std::vector>>( - std::move(docbuf), std::move(freqbuf)); - auto documents = gsl::span( - reinterpret_cast(source->first.data()), source->first.size()); - auto frequencies = gsl::span( - reinterpret_cast(source->second.data()), source->second.size()); + source.offsets[0] = std::move(document_builder.offsets()); + source.offsets[1] = std::move(frequency_builder.offsets()); + return source; +} + +[[nodiscard]] inline auto binary_collection_index(std::string const &basename) +{ + auto source = binary_collection_source(basename); + auto documents = gsl::span(source.bytes[0]); + auto frequencies = gsl::span(source.bytes[1]); + auto document_offsets = gsl::span(source.offsets[0]); + auto frequency_offsets = gsl::span(source.offsets[1]); + auto sizes = gsl::span(source.sizes[0]); return Index, RawCursor>(RawReader{}, RawReader{}, document_offsets, frequency_offsets, documents.subspan(8), frequencies.subspan(8), - read_sizes(basename), + sizes, tl::nullopt, std::move(source)); } -template -struct IndexRunner { - template - IndexRunner(std::vector document_offsets, - std::vector payload_offsets, - gsl::span documents, - gsl::span payloads, - std::vector document_lengths, - tl::optional avg_document_length, - Source source, - Readers... readers) - : m_document_offsets(std::move(document_offsets)), - m_payload_offsets(std::move(payload_offsets)), - m_documents(documents), - m_payloads(payloads), - m_document_lengths(std::move(document_lengths)), - m_avg_document_length(avg_document_length), - m_source(std::move(source)), - m_readers(readers...) - { - } - - template - void operator()(Fn fn) - { - auto dheader = PostingFormatHeader::parse(m_documents.first(8)); - auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); - auto run = [&](auto &&dreader, auto &&preader) -> bool { - if (std::decay_t::encoding() == dheader.encoding - && std::decay_t::encoding() == pheader.encoding - && is_type::value_type>(dheader.type) - && is_type::value_type>(pheader.type)) { - fn(make_index(std::forward(dreader), - std::forward(preader), - m_document_offsets, - m_payload_offsets, - m_documents.subspan(8), - m_payloads.subspan(8), - m_document_lengths, - m_avg_document_length, - false)); - return true; - } - return false; - }; - bool success = std::apply( - [&](Readers... dreaders) { - auto with_document_reader = [&](auto dreader) { - return std::apply( - [&](Readers... preaders) { return (run(dreader, preaders) || ...); }, - m_readers); - }; - return (with_document_reader(dreaders) || ...); - }, - m_readers); - if (not success) { - throw std::domain_error("Unknown posting encoding"); - } - } +[[nodiscard]] inline auto binary_collection_scored_index(std::string const &basename) +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; - private: - std::vector m_document_offsets; - std::vector m_payload_offsets; - gsl::span m_documents; - gsl::span m_payloads; - std::vector m_document_lengths; - tl::optional m_avg_document_length; - std::any m_source; - std::tuple m_readers; -}; + auto source = binary_collection_source(basename); + auto documents = gsl::span(source.bytes[0]); + auto frequencies = gsl::span(source.bytes[1]); + auto sizes = gsl::span(source.sizes[0]); + auto document_offsets = gsl::span(source.offsets[0]); + auto frequency_offsets = gsl::span(source.offsets[1]); + auto freq_index = Index, RawCursor>(RawReader{}, + RawReader{}, + document_offsets, + frequency_offsets, + documents.subspan(8), + frequencies.subspan(8), + sizes, + tl::nullopt, + false); + + source.offsets.push_back([&freq_index, &source]() { + vector_stream_type score_stream{sink_type{source.bytes.emplace_back()}}; + return score_index(freq_index, score_stream, RawWriter{}, make_bm25(freq_index)); + }()); + auto scores = gsl::span(source.bytes.back()); + + document_offsets = gsl::span(source.offsets[0]); + auto score_offsets = gsl::span(source.offsets[2]); + return Index, RawCursor>(RawReader{}, + RawReader{}, + document_offsets, + score_offsets, + documents.subspan(8), + scores.subspan(8), + sizes, + tl::nullopt, + std::move(source)); +} template struct BigramIndex : public Index { @@ -342,14 +347,14 @@ struct BigramIndex : public Index { [[nodiscard]] inline auto binary_collection_bigram_index(std::string const &basename) { using payload_type = std::array; - using sink_type = boost::iostreams::back_insert_device>; + using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; auto unigram_index = binary_collection_index(basename); std::vector> pair_mapping; - std::vector docbuf; - std::vector freqbuf; + std::vector docbuf; + std::vector freqbuf; PostingBuilder document_builder(RawWriter{}); PostingBuilder frequency_builder(RawWriter{}); @@ -385,22 +390,96 @@ struct BigramIndex : public Index { }); } - auto source = std::array, 2>{std::move(docbuf), std::move(freqbuf)}; - auto document_span = gsl::span( - reinterpret_cast(source[0].data()), source[0].size()); - auto payload_span = gsl::span( - reinterpret_cast(source[1].data()), source[1].size()); - + VectorSource source{ + {std::move(docbuf), std::move(freqbuf)}, + {std::move(document_builder.offsets()), std::move(frequency_builder.offsets())}, + {read_sizes(basename)}}; + auto document_span = gsl::span(source.bytes[0]); + auto payload_span = gsl::span(source.bytes[1]); + auto document_offsets = gsl::span(source.offsets[0]); + auto frequency_offsets = gsl::span(source.offsets[1]); + auto sizes = gsl::span(source.sizes[0]); auto index = Index, RawCursor>(RawReader{}, RawReader{}, - document_builder.offsets(), - frequency_builder.offsets(), + document_offsets, + frequency_offsets, document_span.subspan(8), payload_span.subspan(8), - read_sizes(basename), + sizes, tl::nullopt, std::move(source)); return BigramIndex(std::move(index), std::move(pair_mapping)); } +template +struct IndexRunner { + template + IndexRunner(gsl::span document_offsets, + gsl::span payload_offsets, + gsl::span documents, + gsl::span payloads, + gsl::span document_lengths, + tl::optional avg_document_length, + Source source, + Readers... readers) + : m_document_offsets(document_offsets), + m_payload_offsets(payload_offsets), + m_documents(documents), + m_payloads(payloads), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length), + m_source(std::move(source)), + m_readers(readers...) + { + } + + template + void operator()(Fn fn) + { + auto dheader = PostingFormatHeader::parse(m_documents.first(8)); + auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); + auto run = [&](auto &&dreader, auto &&preader) -> bool { + if (std::decay_t::encoding() == dheader.encoding + && std::decay_t::encoding() == pheader.encoding + && is_type::value_type>(dheader.type) + && is_type::value_type>(pheader.type)) { + fn(make_index(std::forward(dreader), + std::forward(preader), + m_document_offsets, + m_payload_offsets, + m_documents.subspan(8), + m_payloads.subspan(8), + m_document_lengths, + m_avg_document_length, + false)); + return true; + } + return false; + }; + bool success = std::apply( + [&](Readers... dreaders) { + auto with_document_reader = [&](auto dreader) { + return std::apply( + [&](Readers... preaders) { return (run(dreader, preaders) || ...); }, + m_readers); + }; + return (with_document_reader(dreaders) || ...); + }, + m_readers); + if (not success) { + throw std::domain_error("Unknown posting encoding"); + } + } + + private: + gsl::span m_document_offsets; + gsl::span m_payload_offsets; + gsl::span m_documents; + gsl::span m_payloads; + gsl::span m_document_lengths; + tl::optional m_avg_document_length; + std::any m_source; + std::tuple m_readers; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/posting_builder.hpp b/include/pisa/v1/posting_builder.hpp index 5c7455379..699d486dd 100644 --- a/include/pisa/v1/posting_builder.hpp +++ b/include/pisa/v1/posting_builder.hpp @@ -21,18 +21,20 @@ struct PostingBuilder { m_offsets.push_back(0); } - void write_header(std::ostream &os) const + template + void write_header(std::basic_ostream &os) const { std::array header{}; PostingFormatHeader{.version = FormatVersion::current(), .type = value_type(), .encoding = m_writer.encoding()} .write(gsl::make_span(header)); - os.write(reinterpret_cast(header.data()), header.size()); + os.write(reinterpret_cast(header.data()), header.size()); } - template - auto write_segment(std::ostream &os, ValueIterator first, ValueIterator last) -> std::ostream & + template + auto write_segment(std::basic_ostream &os, ValueIterator first, ValueIterator last) + -> std::basic_ostream & { std::for_each(first, last, [&](auto &&value) { m_writer.push(value); }); return flush_segment(os); @@ -40,7 +42,8 @@ struct PostingBuilder { void accumulate(Value value) { m_writer.push(value); } - auto flush_segment(std::ostream &os) -> std::ostream & + template + auto flush_segment(std::basic_ostream &os) -> std::basic_ostream & { m_offsets.push_back(m_offsets.back() + m_writer.write(os)); m_writer.reset(); diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 66949d98e..155c3cbee 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -107,13 +107,14 @@ struct RawWriter { void push(T const &posting) { m_postings.push_back(posting); } void push(T &&posting) { m_postings.push_back(posting); } - [[nodiscard]] auto write(std::ostream &os) const -> std::size_t + template + [[nodiscard]] auto write(std::basic_ostream &os) const -> std::size_t { assert(!m_postings.empty()); std::uint32_t length = m_postings.size(); - os.write(reinterpret_cast(&length), sizeof(length)); + os.write(reinterpret_cast(&length), sizeof(length)); auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); - os.write(reinterpret_cast(memory.data()), memory.size()); + os.write(reinterpret_cast(memory.data()), memory.size()); return sizeof(length) + memory.size(); } diff --git a/include/pisa/v1/scorer/bm25.hpp b/include/pisa/v1/scorer/bm25.hpp index d1d315a62..1da1b049c 100644 --- a/include/pisa/v1/scorer/bm25.hpp +++ b/include/pisa/v1/scorer/bm25.hpp @@ -18,13 +18,13 @@ struct BM25 { explicit BM25(Index const &index) : m_index(index) {} - static float doc_term_weight(uint64_t freq, float norm_len) + [[nodiscard]] static float doc_term_weight(uint64_t freq, float norm_len) { auto f = static_cast(freq); return f / (f + k1 * (1.0F - b + b * norm_len)); } - static float query_term_weight(uint64_t df, uint64_t num_docs) + [[nodiscard]] static float query_term_weight(uint64_t df, uint64_t num_docs) { auto fdf = static_cast(df); float idf = std::log((float(num_docs) - fdf + 0.5F) / (fdf + 0.5F)); @@ -32,7 +32,7 @@ struct BM25 { return std::max(epsilon_score, idf) * (1.0F + k1); } - auto term_scorer(TermId term_id) const + [[nodiscard]] auto term_scorer(TermId term_id) const { auto term_weight = query_term_weight(m_index.term_posting_count(term_id), m_index.num_documents()); @@ -46,4 +46,20 @@ struct BM25 { Index const &m_index; }; +template +auto make_bm25(Index const &index) +{ + return BM25(index); +} + } // namespace pisa::v1 + +namespace std { +template +struct hash<::pisa::v1::BM25> { + std::size_t operator()(::pisa::v1::BM25 const & /* bm25 */) const noexcept + { + return std::hash{}("bm25"); + } +}; +} // namespace std diff --git a/include/pisa/v1/scorer/runner.hpp b/include/pisa/v1/scorer/runner.hpp new file mode 100644 index 000000000..55573edde --- /dev/null +++ b/include/pisa/v1/scorer/runner.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#include + +namespace pisa::v1 { + +/// This is similar to the `IndexRunner` and it's used for running tasks +/// that require on-the-fly scoring. +template +struct ScorerRunner { + explicit ScorerRunner(Index const &index, Scorers... scorers) + : m_index(index), m_scorers(std::move(scorers...)) + { + } + + template + void operator()(std::string_view scorer_name, Fn fn) + { + auto run = [&](auto &&scorer) -> bool { + if (std::hash>{}(scorer) + == std::hash{}(scorer_name)) { + fn(std::forward>(scorer)); + return true; + } + return false; + }; + bool success = + std::apply([&](Scorers... scorers) { return (run(scorers) || ...); }, m_scorers); + if (not success) { + throw std::domain_error(fmt::format("Unknown scorer: ", scorer_name)); + } + } + + private: + Index const &m_index; + std::tuple m_scorers; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/source.hpp b/include/pisa/v1/source.hpp new file mode 100644 index 000000000..f4a5fc166 --- /dev/null +++ b/include/pisa/v1/source.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include + +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct VectorSource { + std::vector> bytes{}; + std::vector> offsets{}; + std::vector> sizes{}; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 2b3c52e03..fc5b17ab5 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -14,6 +14,7 @@ using DocId = std::uint32_t; using Frequency = std::uint32_t; using Score = float; using Result = std::pair; +using ByteOStream = std::basic_ostream; enum EncodingId { Raw = 0xda43 }; @@ -56,7 +57,7 @@ struct Writer { void push(T const &posting) { m_internal_writer->push(posting); } void push(T &&posting) { m_internal_writer->push(posting); } - auto write(std::ostream &os) const -> std::size_t { return m_internal_writer->write(os); } + auto write(ByteOStream &os) const -> std::size_t { return m_internal_writer->write(os); } [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_writer->encoding(); } void reset() { return m_internal_writer->reset(); } @@ -69,7 +70,7 @@ struct Writer { virtual ~WriterInterface() = default; virtual void push(T const &posting) = 0; virtual void push(T &&posting) = 0; - virtual auto write(std::ostream &os) const -> std::size_t = 0; + virtual auto write(ByteOStream &os) const -> std::size_t = 0; [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; virtual void reset() = 0; }; @@ -85,7 +86,7 @@ struct Writer { ~WriterImpl() = default; void push(T const &posting) override { m_writer.push(posting); } void push(T &&posting) override { m_writer.push(posting); } - auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } + auto write(ByteOStream &os) const -> std::size_t override { return m_writer.write(os); } [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } void reset() override { return m_writer.reset(); } @@ -97,4 +98,9 @@ struct Writer { std::unique_ptr m_internal_writer; }; +/// Indicates that payloads should be treated as scores. +/// To be used with pre-computed scores, be it floats or quantized ints. +struct VoidScorer { +}; + } // namespace pisa::v1 diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 5cd230e11..5fa19db8c 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -14,6 +14,7 @@ #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" #include "v1/types.hpp" using pisa::v1::Array; @@ -28,6 +29,7 @@ using pisa::v1::Primitive; using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::read_sizes; +using pisa::v1::ScorerRunner; using pisa::v1::TermId; using pisa::v1::Tuple; using pisa::v1::Writer; @@ -206,17 +208,31 @@ TEST_CASE("Value type", "[v1][unit]") REQUIRE(std::get(parse_type(std::byte{0b01000111})).size == 8U); } +[[nodiscard]] inline auto load_bytes(std::string const &data_file) +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + data.resize(size); + if (not in.read(reinterpret_cast(data.data()), size)) { + throw std::runtime_error("Failed reading " + data_file); + } + return data; +} + TEST_CASE("Build raw document-frequency index", "[v1][unit]") { - using sink_type = boost::iostreams::back_insert_device>; + using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; GIVEN("A test binary collection") { pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); WHEN("Built posting files for documents and frequencies") { - std::vector docbuf; - std::vector freqbuf; + std::vector docbuf; + std::vector freqbuf; PostingBuilder document_builder(Writer(RawWriter{})); PostingBuilder frequency_builder(Writer(RawWriter{})); @@ -243,9 +259,9 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") THEN("Bytes match with those of the collection") { auto document_bytes = - pisa::io::load_data(PISA_SOURCE_DIR "/test/test_data/test_collection.docs"); + load_bytes(PISA_SOURCE_DIR "/test/test_data/test_collection.docs"); auto frequency_bytes = - pisa::io::load_data(PISA_SOURCE_DIR "/test/test_data/test_collection.freqs"); + load_bytes(PISA_SOURCE_DIR "/test/test_data/test_collection.freqs"); // NOTE: the first 8 bytes of the document collection are different than those // of the built document file. Also, the original frequency collection starts @@ -262,7 +278,7 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") THEN("Index runner is correctly constructed") { - auto source = std::array, 2>{docbuf, freqbuf}; + auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( @@ -299,7 +315,7 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") THEN("Index runner fails when wrong type") { - auto source = std::array, 2>{docbuf, freqbuf}; + auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( diff --git a/test/test_v1_queries.cpp b/test/test_v1_queries.cpp index 8fd9019a8..7e9d78cfb 100644 --- a/test/test_v1_queries.cpp +++ b/test/test_v1_queries.cpp @@ -27,7 +27,7 @@ namespace v1 = pisa::v1; using namespace pisa; -template +template struct IndexData { static std::unique_ptr data; @@ -37,6 +37,8 @@ struct IndexData { document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), v1_index( pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection")), + scored_index(pisa::v1::binary_collection_scored_index(PISA_SOURCE_DIR + "/test/test_data/test_collection")), wdata(document_sizes.begin()->begin(), collection.num_docs(), collection, @@ -70,7 +72,7 @@ struct IndexData { [[nodiscard]] static auto get() { if (IndexData::data == nullptr) { - IndexData::data = std::make_unique>(); + IndexData::data = std::make_unique>(); } return IndexData::data.get(); } @@ -80,23 +82,24 @@ struct IndexData { binary_collection document_sizes; v0_Index v0_index; v1_Index v1_index; + ScoredIndex scored_index; std::vector queries; std::vector thresholds; wand_data wdata; }; -template -std::unique_ptr> IndexData::data = nullptr; +template +std::unique_ptr> + IndexData::data = nullptr; -template -auto daat_and(Query const &query, Index const &index, topk_queue topk) +template +auto daat_and(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) { - v1::BM25 scorer(index); - std::vector cursors; + std::vector cursors; std::transform(query.terms.begin(), query.terms.end(), std::back_inserter(cursors), - [&](auto term) { return index.scoring_cursor(term, scorer); }); + [&](auto term) { return index.scored_cursor(term, scorer); }); auto intersection = v1::intersect(std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { score += cursor.payload(); @@ -106,15 +109,14 @@ auto daat_and(Query const &query, Index const &index, topk_queue topk) return topk; } -template -auto daat_or(Query const &query, Index const &index, topk_queue topk) +template +auto daat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) { - v1::BM25 scorer(index); - std::vector cursors; + std::vector cursors; std::transform(query.terms.begin(), query.terms.end(), std::back_inserter(cursors), - [&](auto term) { return index.scoring_cursor(term, scorer); }); + [&](auto term) { return index.scored_cursor(term, scorer); }); auto cunion = v1::union_merge( std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { score += cursor.payload(); @@ -127,25 +129,42 @@ auto daat_or(Query const &query, Index const &index, topk_queue topk) TEST_CASE("DAAT AND", "[v1][integration]") { auto data = IndexData, v1::RawCursor>>::get(); + v1::Index, v1::RawCursor>, + v1::Index, v1::RawCursor>>::get(); ranked_and_query and_q(10); int idx = 0; for (auto const &q : data->queries) { + CAPTURE(q.terms); CAPTURE(idx++); - and_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); - auto que = daat_and(q, data->v1_index, topk_queue(10)); - que.finalize(); + and_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); auto expected = and_q.topk(); std::sort(expected.begin(), expected.end(), std::greater{}); - auto actual = que.topk(); - std::sort(actual.begin(), actual.end(), std::greater{}); - REQUIRE(expected.size() == actual.size()); - for (size_t i = 0; i < actual.size(); ++i) { - REQUIRE(actual[i].second == expected[i].second); - REQUIRE(actual[i].first == Approx(expected[i].first).epsilon(0.1)); + auto on_the_fly = [&]() { + auto que = daat_and(q, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); + que.finalize(); + auto results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + return results; + }(); + + auto precomputed = [&]() { + auto que = daat_and(q, data->scored_index, topk_queue(10), v1::VoidScorer{}); + que.finalize(); + auto results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + return results; + }(); + + REQUIRE(expected.size() == on_the_fly.size()); + REQUIRE(expected.size() == precomputed.size()); + for (size_t i = 0; i < on_the_fly.size(); ++i) { + REQUIRE(on_the_fly[i].second == expected[i].second); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(0.1)); + REQUIRE(precomputed[i].second == expected[i].second); + REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(0.1)); } } } @@ -153,25 +172,41 @@ TEST_CASE("DAAT AND", "[v1][integration]") TEST_CASE("DAAT OR", "[v1][integration]") { auto data = IndexData, v1::RawCursor>>::get(); + v1::Index, v1::RawCursor>, + v1::Index, v1::RawCursor>>::get(); ranked_or_query or_q(10); int idx = 0; for (auto const &q : data->queries) { CAPTURE(q.terms); CAPTURE(idx++); - or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); - auto que = daat_or(q, data->v1_index, topk_queue(10)); - que.finalize(); + or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); auto expected = or_q.topk(); std::sort(expected.begin(), expected.end(), std::greater{}); - auto actual = que.topk(); - std::sort(actual.begin(), actual.end(), std::greater{}); - REQUIRE(expected.size() == actual.size()); - for (size_t i = 0; i < actual.size(); ++i) { - REQUIRE(actual[i].second == expected[i].second); - REQUIRE(actual[i].first == Approx(expected[i].first).epsilon(0.1)); + auto on_the_fly = [&]() { + auto que = daat_or(q, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); + que.finalize(); + auto results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + return results; + }(); + + auto precomputed = [&]() { + auto que = daat_or(q, data->scored_index, topk_queue(10), v1::VoidScorer{}); + que.finalize(); + auto results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + return results; + }(); + + REQUIRE(expected.size() == on_the_fly.size()); + REQUIRE(expected.size() == precomputed.size()); + for (size_t i = 0; i < on_the_fly.size(); ++i) { + REQUIRE(on_the_fly[i].second == expected[i].second); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(0.1)); + REQUIRE(precomputed[i].second == expected[i].second); + REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(0.1)); } } } From f76fdd76377663c13718090fbfc2cd403179bb4b Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 30 Oct 2019 20:33:32 -0400 Subject: [PATCH 14/56] Index building tool --- CMakeLists.txt | 2 +- include/pisa/v1/index.hpp | 13 +- include/pisa/v1/index_builder.hpp | 200 +++++++++++++++++++++++++++++ include/pisa/v1/index_metadata.hpp | 74 +++++++++++ include/pisa/v1/raw_cursor.hpp | 2 +- include/pisa/v1/scorer/runner.hpp | 2 +- include/pisa/v1/source.hpp | 11 ++ include/pisa/v1/types.hpp | 26 +++- test/test_v1_index.cpp | 54 ++++++++ v1/CMakeLists.txt | 120 +++++++++++++++++ v1/compress.cpp | 37 ++++++ 11 files changed, 530 insertions(+), 11 deletions(-) create mode 100644 include/pisa/v1/index_builder.hpp create mode 100644 include/pisa/v1/index_metadata.hpp create mode 100644 test/test_v1_index.cpp create mode 100644 v1/CMakeLists.txt create mode 100644 v1/compress.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 84f95f8ae..d3f78ab03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,7 +99,7 @@ target_link_libraries(pisa INTERFACE ) target_include_directories(pisa INTERFACE external) -add_subdirectory(src) +add_subdirectory(v1) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index f1ef690a1..92b7c0934 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -30,6 +30,13 @@ namespace pisa::v1 { +[[nodiscard]] inline auto calc_avg_length(gsl::span const &lengths) + -> std::uint32_t +{ + auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); + return static_cast(sum) / lengths.size(); +} + /// A generic type for an inverted index. /// /// \tparam DocumentReader Type of an object that reads document posting lists from bytes @@ -151,12 +158,6 @@ struct Index { return m_payloads.subspan(m_payload_offsets[term], m_payload_offsets[term + 1] - m_payload_offsets[term]); } - [[nodiscard]] static auto calc_avg_length(gsl::span const &lengths) - -> std::uint32_t - { - auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); - return static_cast(sum) / lengths.size(); - } Reader m_document_reader; Reader m_payload_reader; diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp new file mode 100644 index 000000000..e64f55796 --- /dev/null +++ b/include/pisa/v1/index_builder.hpp @@ -0,0 +1,200 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "util/progress.hpp" +#include "v1/index.hpp" + +namespace pisa::v1 { + +template +struct IndexBuilder { + explicit IndexBuilder(Writers... writers) : m_writers(std::move(writers...)) {} + + template + void operator()(Encoding encoding, Fn fn) + { + auto run = [&](auto &&writer) -> bool { + if (std::decay_t::encoding() == encoding) { + fn(writer); + return true; + } + return false; + }; + bool success = + std::apply([&](Writers... writers) { return (run(writers) || ...); }, m_writers); + if (not success) { + throw std::domain_error(fmt::format("Unknown writer")); + } + } + + private: + std::tuple m_writers; +}; + +using pisa::v1::ByteOStream; +using pisa::v1::calc_avg_length; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::PostingBuilder; +using pisa::v1::RawWriter; +using pisa::v1::read_sizes; + +template +auto compress_batch(CollectionIterator first, + CollectionIterator last, + std::ofstream &dout, + std::ofstream &fout, + Writer document_writer, + Writer frequency_writer, + tl::optional bar) + -> std::tuple, std::vector> +{ + PostingBuilder document_builder(std::move(document_writer)); + PostingBuilder frequency_builder(std::move(frequency_writer)); + for (auto pos = first; pos != last; ++pos) { + auto dseq = pos->docs; + auto fseq = pos->freqs; + for (auto doc : dseq) { + document_builder.accumulate(doc); + } + for (auto freq : fseq) { + frequency_builder.accumulate(freq); + } + document_builder.flush_segment(dout); + frequency_builder.flush_segment(fout); + bar->update(1); + } + return std::make_tuple(std::move(document_builder.offsets()), + std::move(frequency_builder.offsets())); +} + +template +void write_span(gsl::span offsets, std::string const &file) +{ + std::ofstream os(file); + auto bytes = gsl::as_bytes(offsets); + os.write(reinterpret_cast(bytes.data()), bytes.size()); +} + +inline void compress_binary_collection(std::string const &input, + std::string_view output, + std::size_t const threads, + Writer document_writer, + Writer frequency_writer) +{ + pisa::binary_freq_collection const collection(input.c_str()); + tbb::task_group group; + auto const num_terms = collection.size(); + std::vector> document_offsets(threads); + std::vector> frequency_offsets(threads); + std::vector document_paths; + std::vector frequency_paths; + std::vector document_streams; + std::vector frequency_streams; + auto for_each_batch = [threads](auto fn) { + for (auto thread_idx = 0; thread_idx < threads; thread_idx += 1) { + fn(thread_idx); + } + }; + for_each_batch([&](auto thread_idx) { + auto document_batch = + document_paths.emplace_back(fmt::format("{}.doc.batch.{}", output, thread_idx)); + auto frequency_batch = + frequency_paths.emplace_back(fmt::format("{}.freq.batch.{}", output, thread_idx)); + document_streams.emplace_back(document_batch); + frequency_streams.emplace_back(frequency_batch); + }); + progress bar("Compressing in parallel", collection.size()); + auto batch_size = num_terms / threads; + for_each_batch([&](auto thread_idx) { + group.run([thread_idx, + batch_size, + threads, + &collection, + &document_streams, + &frequency_streams, + &document_offsets, + &frequency_offsets, + &bar, + &document_writer, + &frequency_writer]() { + auto first = std::next(collection.begin(), thread_idx * batch_size); + auto last = [&]() { + if (thread_idx == threads - 1) { + return collection.end(); + } + return std::next(collection.begin(), (thread_idx + 1) * batch_size); + }(); + auto &dout = document_streams[thread_idx]; + auto &fout = frequency_streams[thread_idx]; + std::tie(document_offsets[thread_idx], frequency_offsets[thread_idx]) = + compress_batch(first, + last, + dout, + fout, + document_writer, + frequency_writer, + tl::make_optional(bar)); + }); + }); + group.wait(); + document_streams.clear(); + frequency_streams.clear(); + + std::vector all_document_offsets; + std::vector all_frequency_offsets; + all_document_offsets.reserve(num_terms + 1); + all_frequency_offsets.reserve(num_terms + 1); + all_document_offsets.push_back(0); + all_frequency_offsets.push_back(0); + auto documents_file = fmt::format("{}.documents", output); + auto frequencies_file = fmt::format("{}.frequencies", output); + std::ofstream document_out(documents_file); + std::ofstream frequency_out(frequencies_file); + + PostingBuilder(RawWriter{}).write_header(document_out); + PostingBuilder(RawWriter{}).write_header(frequency_out); + + for_each_batch([&](auto thread_idx) { + std::transform(std::next(document_offsets[thread_idx].begin()), + document_offsets[thread_idx].end(), + std::back_inserter(all_document_offsets), + [base = all_document_offsets.back()](auto offset) { return base + offset; }); + std::transform( + std::next(frequency_offsets[thread_idx].begin()), + frequency_offsets[thread_idx].end(), + std::back_inserter(all_frequency_offsets), + [base = all_frequency_offsets.back()](auto offset) { return base + offset; }); + std::ifstream docbatch(document_paths[thread_idx]); + std::ifstream freqbatch(frequency_paths[thread_idx]); + document_out << docbatch.rdbuf(); + frequency_out << freqbatch.rdbuf(); + }); + + auto doc_offset_file = fmt::format("{}.document_offsets", output); + auto freq_offset_file = fmt::format("{}.frequency_offsets", output); + write_span(gsl::span(all_document_offsets), doc_offset_file); + write_span(gsl::span(all_frequency_offsets), freq_offset_file); + + auto lengths = read_sizes(input); + auto document_lengths_file = fmt::format("{}.document_lengths", output); + write_span(gsl::span(lengths), document_lengths_file); + auto avg_len = calc_avg_length(gsl::span(lengths)); + + boost::property_tree::ptree pt; + pt.put("documents.file", documents_file); + pt.put("documents.offsets", doc_offset_file); + pt.put("frequencies.file", frequencies_file); + pt.put("frequencies.offsets", freq_offset_file); + pt.put("stats.avg_document_length", avg_len); + pt.put("stats.document_lengths", document_lengths_file); + boost::property_tree::write_ini(fmt::format("{}.ini", output), pt); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp new file mode 100644 index 000000000..631457097 --- /dev/null +++ b/include/pisa/v1/index_metadata.hpp @@ -0,0 +1,74 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "v1/index.hpp" +#include "v1/source.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct PostingFilePaths { + std::string postings; + std::string offsets; +}; + +struct IndexMetadata { + PostingFilePaths documents; + PostingFilePaths frequencies; + std::string document_lengths_path; + float avg_document_length; + + [[nodiscard]] static auto from_file(std::string const &file) + { + boost::property_tree::ptree pt; + boost::property_tree::ini_parser::read_ini(file, pt); + return IndexMetadata{ + .documents = PostingFilePaths{.postings = pt.get("documents.file"), + .offsets = pt.get("documents.offsets")}, + .frequencies = PostingFilePaths{.postings = pt.get("frequencies.file"), + .offsets = pt.get("frequencies.offsets")}, + .document_lengths_path = pt.get("stats.document_lengths"), + .avg_document_length = pt.get("stats.avg_document_length")}; + } +}; + +template +[[nodiscard]] auto to_span(mio::mmap_source const *mmap) +{ + static_assert(std::is_trivially_constructible_v); + return gsl::span(reinterpret_cast(mmap->data()), mmap->size() / sizeof(T)); +}; + +template +[[nodiscard]] auto source_span(MMapSource &source, std::string const &file) +{ + return to_span( + source.file_sources.emplace_back(std::make_shared(file)).get()); +}; + +template +[[nodiscard]] inline auto index_runner(IndexMetadata metadata, Readers... readers) +{ + MMapSource source; + auto documents = source_span(source, metadata.documents.postings); + auto frequencies = source_span(source, metadata.frequencies.postings); + auto document_offsets = source_span(source, metadata.documents.offsets); + auto frequency_offsets = source_span(source, metadata.frequencies.offsets); + auto document_lengths = source_span(source, metadata.document_lengths_path); + return IndexRunner(document_offsets, + frequency_offsets, + documents, + frequencies, + document_lengths, + tl::make_optional(metadata.avg_document_length), + std::move(source), + std::move(readers)...); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 155c3cbee..4e2c8e3f9 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -121,7 +121,7 @@ struct RawWriter { void reset() { m_postings.clear(); } private: - std::vector m_postings; + std::vector m_postings{}; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/scorer/runner.hpp b/include/pisa/v1/scorer/runner.hpp index 55573edde..d1034c053 100644 --- a/include/pisa/v1/scorer/runner.hpp +++ b/include/pisa/v1/scorer/runner.hpp @@ -31,7 +31,7 @@ struct ScorerRunner { bool success = std::apply([&](Scorers... scorers) { return (run(scorers) || ...); }, m_scorers); if (not success) { - throw std::domain_error(fmt::format("Unknown scorer: ", scorer_name)); + throw std::domain_error(fmt::format("Unknown scorer: {}", scorer_name)); } } diff --git a/include/pisa/v1/source.hpp b/include/pisa/v1/source.hpp index f4a5fc166..1335ae6a2 100644 --- a/include/pisa/v1/source.hpp +++ b/include/pisa/v1/source.hpp @@ -4,6 +4,7 @@ #include #include +#include #include "v1/types.hpp" @@ -15,4 +16,14 @@ struct VectorSource { std::vector> sizes{}; }; +struct MMapSource { + MMapSource() = default; + MMapSource(MMapSource &&) = default; + MMapSource(MMapSource const &) = default; + MMapSource &operator=(MMapSource &&) = default; + MMapSource &operator=(MMapSource const &) = default; + ~MMapSource() = default; + std::vector> file_sources{}; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index fc5b17ab5..8d8aaacaa 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -49,15 +49,23 @@ struct Reader { template struct Writer { + using Value = T; template explicit constexpr Writer(W writer) : m_internal_writer(std::make_unique>(writer)) { } + Writer() = default; + Writer(Writer const &other) : m_internal_writer(other.m_internal_writer->clone()) {} + Writer(Writer &&other) noexcept = default; + Writer &operator=(Writer const &other) = delete; + Writer &operator=(Writer &&other) noexcept = default; + ~Writer() = default; void push(T const &posting) { m_internal_writer->push(posting); } void push(T &&posting) { m_internal_writer->push(posting); } auto write(ByteOStream &os) const -> std::size_t { return m_internal_writer->write(os); } + auto write(std::ostream &os) const -> std::size_t { return m_internal_writer->write(os); } [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_writer->encoding(); } void reset() { return m_internal_writer->reset(); } @@ -71,8 +79,10 @@ struct Writer { virtual void push(T const &posting) = 0; virtual void push(T &&posting) = 0; virtual auto write(ByteOStream &os) const -> std::size_t = 0; - [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; + virtual auto write(std::ostream &os) const -> std::size_t = 0; virtual void reset() = 0; + [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; }; template @@ -87,8 +97,14 @@ struct Writer { void push(T const &posting) override { m_writer.push(posting); } void push(T &&posting) override { m_writer.push(posting); } auto write(ByteOStream &os) const -> std::size_t override { return m_writer.write(os); } - [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } + auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } void reset() override { return m_writer.reset(); } + [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } + [[nodiscard]] virtual auto clone() const -> std::unique_ptr + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } private: W m_writer; @@ -98,6 +114,12 @@ struct Writer { std::unique_ptr m_internal_writer; }; +template +[[nodiscard]] inline auto make_writer(W writer) +{ + return Writer(writer); +} + /// Indicates that payloads should be treated as scores. /// To be used with pre-computed scores, be it floats or quantized ints. struct VoidScorer { diff --git a/test/test_v1_index.cpp b/test/test_v1_index.cpp new file mode 100644 index 000000000..cc28c1a0d --- /dev/null +++ b/test/test_v1_index.cpp @@ -0,0 +1,54 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "pisa_config.hpp" +#include "temporary_directory.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/types.hpp" + +using pisa::v1::binary_collection_index; +using pisa::v1::compress_binary_collection; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::RawReader; +using pisa::v1::RawWriter; +using pisa::v1::TermId; + +TEST_CASE("Binary collection index", "[v1][unit]") +{ + tbb::task_scheduler_init init(8); + Temporary_Directory tmpdir; + auto bci = binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + (tmpdir.path() / "index").string(), + 8, + make_writer(RawWriter{}), + make_writer(RawWriter{})); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.ini").string()); + REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); + REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); + REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); + REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); + REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); + auto run = index_runner(meta, RawReader{}, RawReader{}); + run([&](auto index) { + REQUIRE(bci.avg_document_length() == index.avg_document_length()); + REQUIRE(bci.num_documents() == index.num_documents()); + REQUIRE(bci.num_terms() == index.num_terms()); + for (auto term = 0; term < bci.num_terms(); term += 1) { + REQUIRE(collect(bci.documents(term)) == collect(index.documents(term))); + REQUIRE(collect(bci.payloads(term)) == collect(index.payloads(term))); + } + }); +} diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt new file mode 100644 index 000000000..cadf5b7d1 --- /dev/null +++ b/v1/CMakeLists.txt @@ -0,0 +1,120 @@ +add_executable(compress compress.cpp) +target_link_libraries(compress + pisa + CLI11 +) + +#add_executable(create_freq_index create_freq_index.cpp) +#target_link_libraries(create_freq_index +# pisa +# CLI11 +#) +# +#add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) +#target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) +#target_link_libraries(optimal_hybrid_index +# ${STXXL_LIBRARIES} +# pisa +#) +#set_target_properties(optimal_hybrid_index PROPERTIES +# CXX_STANDARD 14 +#) +# +#add_executable(create_wand_data create_wand_data.cpp) +#target_link_libraries(create_wand_data +# pisa +# CLI11 +#) +# +#add_executable(queries queries.cpp) +#target_link_libraries(queries +# pisa +# CLI11 +#) +# +#add_executable(evaluate_queries evaluate_queries.cpp) +#target_link_libraries(evaluate_queries +# pisa +# CLI11 +#) +# +#add_executable(thresholds thresholds.cpp) +#target_link_libraries(thresholds +# pisa +# CLI11 +#) +# +#add_executable(profile_queries profile_queries.cpp) +#target_link_libraries(profile_queries +# pisa +#) +# +#add_executable(profile_decoding profile_decoding.cpp) +#target_link_libraries(profile_decoding +# pisa +#) +# +#add_executable(shuffle_docids shuffle_docids.cpp) +#target_link_libraries(shuffle_docids +# pisa +#) +# +#add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) +#target_link_libraries(recursive_graph_bisection +# pisa +# CLI11 +#) +# +#add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) +#target_link_libraries(evaluate_collection_ordering +# pisa +# ) +# +#add_executable(parse_collection parse_collection.cpp) +#target_link_libraries(parse_collection +# pisa +# CLI11 +# wapopp +#) +# +#add_executable(invert invert.cpp) +#target_link_libraries(invert +# CLI11 +# pisa +#) +# +#add_executable(read_collection read_collection.cpp) +#target_link_libraries(read_collection +# pisa +# CLI11 +#) +# +#add_executable(partition_fwd_index partition_fwd_index.cpp) +#target_link_libraries(partition_fwd_index +# pisa +# CLI11 +#) +# +#add_executable(compute_intersection compute_intersection.cpp) +#target_link_libraries(compute_intersection +# pisa +# CLI11 +#) +# +#add_executable(lexicon lexicon.cpp) +#target_link_libraries(lexicon +# pisa +# CLI11 +#) +# +#add_executable(extract_topics extract_topics.cpp) +#target_link_libraries(extract_topics +# pisa +# CLI11 +#) +# +#add_executable(sample_inverted_index sample_inverted_index.cpp) +#target_link_libraries(sample_inverted_index +# pisa +# CLI11 +#) diff --git a/v1/compress.cpp b/v1/compress.cpp new file mode 100644 index 000000000..3fcc4aa3a --- /dev/null +++ b/v1/compress.cpp @@ -0,0 +1,37 @@ +#include + +#include "binary_freq_collection.hpp" +#include "v1/index_builder.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/types.hpp" + +using pisa::v1::compress_binary_collection; +using pisa::v1::EncodingId; +using pisa::v1::IndexBuilder; +using pisa::v1::RawWriter; + +int main(int argc, char **argv) +{ + std::string input; + std::string output; + std::size_t threads = std::thread::hardware_concurrency(); + + CLI::App app{"Compresses a given binary collection to a v1 index."}; + app.add_option("-c,--collection", input, "Input collection basename")->required(); + app.add_option("-o,--output", output, "Output basename")->required(); + app.add_option("-j,--threads", threads, "Number of threads"); + CLI11_PARSE(app, argc, argv); + + tbb::task_scheduler_init init(threads); + IndexBuilder> build(RawWriter{}); + build(EncodingId::Raw, [&](auto writer) { + auto frequency_writer = writer; + compress_binary_collection(input, + output, + threads, + make_writer(std::move(writer)), + make_writer(std::move(frequency_writer))); + }); + + return 0; +} From d0994caf9ba898ce36614f37c94ff52c6e41f54c Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 4 Nov 2019 14:13:13 -0500 Subject: [PATCH 15/56] Query and postings tools --- CMakeLists.txt | 10 +- include/pisa/query/queries.hpp | 19 +--- include/pisa/query/query.hpp | 19 ++++ include/pisa/v1/index.hpp | 19 ++++ include/pisa/v1/index_builder.hpp | 23 +++-- include/pisa/v1/index_metadata.hpp | 50 +++++++--- include/pisa/v1/progress_status.hpp | 67 +++++++++++++ include/pisa/v1/query.hpp | 50 ++++++++++ include/pisa/v1/scorer/runner.hpp | 6 ++ src/v1/index_builder.cpp | 48 +++++++++ src/v1/index_metadata.cpp | 35 +++++++ src/v1/progress_status.cpp | 52 ++++++++++ test/test_v1_queries.cpp | 47 ++------- v1/CMakeLists.txt | 124 ++--------------------- v1/compress.cpp | 14 +++ v1/postings.cpp | 146 ++++++++++++++++++++++++++++ v1/query.cpp | 100 +++++++++++++++++++ 17 files changed, 635 insertions(+), 194 deletions(-) create mode 100644 include/pisa/query/query.hpp create mode 100644 include/pisa/v1/progress_status.hpp create mode 100644 src/v1/index_builder.cpp create mode 100644 src/v1/index_metadata.cpp create mode 100644 src/v1/progress_status.cpp create mode 100644 v1/postings.cpp create mode 100644 v1/query.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d3f78ab03..06d981e66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,12 +69,15 @@ endif() set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) +file(GLOB_RECURSE PISA_SRC_FILES FOLLOW_SYMLINKS "src/v1/*cpp") +list(SORT PISA_SRC_FILES) + include_directories(include) -add_library(pisa INTERFACE) -target_include_directories(pisa INTERFACE +add_library(pisa ${PISA_SRC_FILES}) +target_include_directories(pisa PUBLIC $ ) -target_link_libraries(pisa INTERFACE +target_link_libraries(pisa PUBLIC Threads::Threads Boost::boost QMX @@ -100,6 +103,7 @@ target_link_libraries(pisa INTERFACE target_include_directories(pisa INTERFACE external) add_subdirectory(v1) +add_subdirectory(src) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index 93fcfdf5d..a0e22582e 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -12,7 +12,7 @@ #include #include "index_types.hpp" -#include "query/queries.hpp" +#include "query/query.hpp" #include "scorer/score_function.hpp" #include "term_processor.hpp" #include "tokenizer.hpp" @@ -24,15 +24,6 @@ namespace pisa { -using term_id_type = uint32_t; -using term_id_vec = std::vector; - -struct Query { - std::optional id; - std::vector terms; - std::vector term_weights; -}; - [[nodiscard]] auto split_query_at_colon(std::string const &query_string) -> std::pair, std::string_view> { @@ -98,10 +89,10 @@ struct Query { { if (terms_file) { auto term_processor = TermProcessor(terms_file, stopwords_filename, stemmer_type); - return [&queries, term_processor = std::move(term_processor)]( - std::string const &query_line) { - queries.push_back(parse_query_terms(query_line, term_processor)); - }; + return + [&queries, term_processor = std::move(term_processor)](std::string const &query_line) { + queries.push_back(parse_query_terms(query_line, term_processor)); + }; } else { return [&queries](std::string const &query_line) { queries.push_back(parse_query_ids(query_line)); diff --git a/include/pisa/query/query.hpp b/include/pisa/query/query.hpp new file mode 100644 index 000000000..12559f957 --- /dev/null +++ b/include/pisa/query/query.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include + +namespace pisa { + +using term_id_type = std::uint32_t; +using term_id_vec = std::vector; + +struct Query { + std::optional id; + std::vector terms; + std::vector term_weights; +}; + +} // namespace pisa diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 92b7c0934..568bafc4e 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -433,6 +433,25 @@ struct IndexRunner { m_readers(readers...) { } + template + IndexRunner(gsl::span document_offsets, + gsl::span payload_offsets, + gsl::span documents, + gsl::span payloads, + gsl::span document_lengths, + tl::optional avg_document_length, + Source source, + std::tuple readers) + : m_document_offsets(document_offsets), + m_payload_offsets(payload_offsets), + m_documents(documents), + m_payloads(payloads), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length), + m_source(std::move(source)), + m_readers(std::move(readers)) + { + } template void operator()(Fn fn) diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index e64f55796..f929f8784 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -7,8 +9,9 @@ #include #include -#include "util/progress.hpp" #include "v1/index.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" namespace pisa::v1 { @@ -52,7 +55,7 @@ auto compress_batch(CollectionIterator first, std::ofstream &fout, Writer document_writer, Writer frequency_writer, - tl::optional bar) + tl::optional bar) -> std::tuple, std::vector> { PostingBuilder document_builder(std::move(document_writer)); @@ -68,7 +71,7 @@ auto compress_batch(CollectionIterator first, } document_builder.flush_segment(dout); frequency_builder.flush_segment(fout); - bar->update(1); + *bar += 1; } return std::make_tuple(std::move(document_builder.offsets()), std::move(frequency_builder.offsets())); @@ -110,7 +113,9 @@ inline void compress_binary_collection(std::string const &input, document_streams.emplace_back(document_batch); frequency_streams.emplace_back(frequency_batch); }); - progress bar("Compressing in parallel", collection.size()); + ProgressStatus status(collection.size(), + DefaultProgress("Compressing in parallel"), + std::chrono::milliseconds(500)); auto batch_size = num_terms / threads; for_each_batch([&](auto thread_idx) { group.run([thread_idx, @@ -121,7 +126,7 @@ inline void compress_binary_collection(std::string const &input, &frequency_streams, &document_offsets, &frequency_offsets, - &bar, + &status, &document_writer, &frequency_writer]() { auto first = std::next(collection.begin(), thread_idx * batch_size); @@ -140,7 +145,7 @@ inline void compress_binary_collection(std::string const &input, fout, document_writer, frequency_writer, - tl::make_optional(bar)); + tl::make_optional(status)); }); }); group.wait(); @@ -161,6 +166,8 @@ inline void compress_binary_collection(std::string const &input, PostingBuilder(RawWriter{}).write_header(document_out); PostingBuilder(RawWriter{}).write_header(frequency_out); + ProgressStatus merge_status( + threads, DefaultProgress("Merging files"), std::chrono::milliseconds(500)); for_each_batch([&](auto thread_idx) { std::transform(std::next(document_offsets[thread_idx].begin()), document_offsets[thread_idx].end(), @@ -175,6 +182,7 @@ inline void compress_binary_collection(std::string const &input, std::ifstream freqbatch(frequency_paths[thread_idx]); document_out << docbatch.rdbuf(); frequency_out << freqbatch.rdbuf(); + merge_status += 1; }); auto doc_offset_file = fmt::format("{}.document_offsets", output); @@ -197,4 +205,7 @@ inline void compress_binary_collection(std::string const &input, boost::property_tree::write_ini(fmt::format("{}.ini", output), pt); } +auto verify_compressed_index(std::string const &input, std::string_view output) + -> std::vector; + } // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 631457097..2fffedd50 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -3,9 +3,8 @@ #include #include -#include -#include #include +#include #include "v1/index.hpp" #include "v1/source.hpp" @@ -13,6 +12,29 @@ namespace pisa::v1 { +/// Return the passed file path if is not `nullopt`. +/// Otherwise, look for an `.ini` file in the current directory. +/// It will throw if no `.ini` file is found or there are multiple `.ini` files. +[[nodiscard]] auto resolve_ini(std::optional const &arg) -> std::string; + +template +[[nodiscard]] auto convert_optional(Optional opt) +{ + if (opt) { + return tl::make_optional(*opt); + } + return tl::optional>(); +} + +template +[[nodiscard]] auto to_std(tl::optional opt) -> std::optional +{ + if (opt) { + return std::make_optional(opt.take()); + } + return std::optional>(); +} + struct PostingFilePaths { std::string postings; std::string offsets; @@ -23,19 +45,11 @@ struct IndexMetadata { PostingFilePaths frequencies; std::string document_lengths_path; float avg_document_length; + tl::optional term_lexicon{}; + tl::optional document_lexicon{}; + tl::optional stemmer{}; - [[nodiscard]] static auto from_file(std::string const &file) - { - boost::property_tree::ptree pt; - boost::property_tree::ini_parser::read_ini(file, pt); - return IndexMetadata{ - .documents = PostingFilePaths{.postings = pt.get("documents.file"), - .offsets = pt.get("documents.offsets")}, - .frequencies = PostingFilePaths{.postings = pt.get("frequencies.file"), - .offsets = pt.get("frequencies.offsets")}, - .document_lengths_path = pt.get("stats.document_lengths"), - .avg_document_length = pt.get("stats.avg_document_length")}; - } + [[nodiscard]] static auto from_file(std::string const &file) -> IndexMetadata; }; template @@ -54,6 +68,12 @@ template template [[nodiscard]] inline auto index_runner(IndexMetadata metadata, Readers... readers) +{ + return index_runner(std::move(metadata), std::make_tuple(readers...)); +} + +template +[[nodiscard]] inline auto index_runner(IndexMetadata metadata, std::tuple readers) { MMapSource source; auto documents = source_span(source, metadata.documents.postings); @@ -68,7 +88,7 @@ template document_lengths, tl::make_optional(metadata.avg_document_length), std::move(source), - std::move(readers)...); + std::move(readers)); } } // namespace pisa::v1 diff --git a/include/pisa/v1/progress_status.hpp b/include/pisa/v1/progress_status.hpp new file mode 100644 index 000000000..6e662b0c8 --- /dev/null +++ b/include/pisa/v1/progress_status.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pisa::v1 { + +std::ostream &format_interval(std::ostream &out, std::chrono::seconds time); + +using CallbackFunction = std::function)>; + +struct DefaultProgress { + DefaultProgress() = default; + explicit DefaultProgress(std::string caption); + DefaultProgress(DefaultProgress const &) = default; + DefaultProgress(DefaultProgress &&) noexcept = default; + DefaultProgress &operator=(DefaultProgress const &) = default; + DefaultProgress &operator=(DefaultProgress &&) noexcept = default; + ~DefaultProgress() = default; + + void operator()(std::size_t count, + std::size_t goal, + std::chrono::time_point start); + + private: + std::size_t m_previous = 0; + std::string m_caption; +}; + +struct ProgressStatus { + template + explicit ProgressStatus(std::size_t count, Callback &&callback, Duration interval) + : m_goal(count), m_callback(std::forward(callback)) + { + m_loop = std::thread([this, interval]() { + this->m_callback(this->m_count.load(), this->m_goal, this->m_start); + while (this->m_count.load() < this->m_goal) { + std::this_thread::sleep_for(interval); + this->m_callback(this->m_count.load(), this->m_goal, this->m_start); + } + }); + } + ProgressStatus(ProgressStatus const &) = delete; + ProgressStatus(ProgressStatus &&) = delete; + ProgressStatus &operator=(ProgressStatus const &) = delete; + ProgressStatus &operator=(ProgressStatus &&) = delete; + ~ProgressStatus(); + void operator+=(std::size_t inc) { m_count += inc; } + void operator++() { m_count += 1; } + void operator++(int) { m_count += 1; } + + private: + std::size_t const m_goal; + std::function)> + m_callback; + std::atomic_size_t m_count = 0; + std::chrono::time_point m_start = std::chrono::steady_clock::now(); + std::thread m_loop; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 2f4964ef0..a21ad716b 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -4,7 +4,9 @@ #include #include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" +#include "v1/cursor_union.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -39,4 +41,52 @@ struct ExhaustiveConjunctiveProcessor { } }; +template +auto daat_and(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +{ + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + auto intersection = + v1::intersect(std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(intersection, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); + return topk; +} + +template +auto daat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +{ + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(cunion, [&](auto &cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +template +auto taat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +{ + std::vector accumulator(index.num_documents(), 0.0F); + for (auto term : query.terms) { + v1::for_each(index.scored_cursor(term, scorer), + [&accumulator](auto &&cursor) { accumulator[*cursor] += cursor.payload(); }); + } + for (auto document = 0; document < accumulator.size(); document += 1) { + topk.insert(accumulator[document], document); + } + return topk; +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/scorer/runner.hpp b/include/pisa/v1/scorer/runner.hpp index d1034c053..e104e3764 100644 --- a/include/pisa/v1/scorer/runner.hpp +++ b/include/pisa/v1/scorer/runner.hpp @@ -40,4 +40,10 @@ struct ScorerRunner { std::tuple m_scorers; }; +template +auto scorer_runner(Index const &index, Scorers... scorers) +{ + return ScorerRunner(index, std::move(scorers...)); +} + } // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp new file mode 100644 index 000000000..b6b0006fc --- /dev/null +++ b/src/v1/index_builder.cpp @@ -0,0 +1,48 @@ +#include "v1/index_builder.hpp" + +namespace pisa::v1 { + +auto verify_compressed_index(std::string const &input, std::string_view output) + -> std::vector +{ + std::vector errors; + pisa::binary_freq_collection const collection(input.c_str()); + auto meta = IndexMetadata::from_file(fmt::format("{}.ini", output)); + auto run = index_runner(meta, RawReader{}); + run([&](auto &&index) { + auto sequence_iter = collection.begin(); + for (auto term = 0; term < index.num_terms(); term += 1, ++sequence_iter) { + auto document_sequence = sequence_iter->docs; + auto frequency_sequence = sequence_iter->freqs; + auto cursor = index.cursor(term); + if (cursor.size() != document_sequence.size()) { + errors.push_back( + fmt::format("Posting list length mismatch for term {}: expected {} but is {}", + term, + document_sequence.size(), + cursor.size())); + continue; + } + auto dit = document_sequence.begin(); + auto fit = frequency_sequence.begin(); + auto pos = 0; + while (not cursor.empty()) { + if (cursor.value() != *dit) { + errors.push_back( + fmt::format("Document mismatch for term {} at position {}", term, pos)); + } + if (cursor.payload() != *fit) { + errors.push_back( + fmt::format("Frequency mismatch for term {} at position {}", term, pos)); + } + cursor.advance(); + ++dit; + ++fit; + ++pos; + } + } + }); + return errors; +} + +} // namespace pisa::v1 diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp new file mode 100644 index 000000000..666a5e5ac --- /dev/null +++ b/src/v1/index_metadata.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include +#include + +#include "v1/index_metadata.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto resolve_ini(std::optional const &arg) -> std::string +{ + if (arg) { + return *arg; + } + throw std::runtime_error("Resolving .ini from the current folder not supported yet!"); +} + +[[nodiscard]] auto IndexMetadata::from_file(std::string const &file) -> IndexMetadata +{ + boost::property_tree::ptree pt; + boost::property_tree::ini_parser::read_ini(file, pt); + return IndexMetadata{ + .documents = PostingFilePaths{.postings = pt.get("documents.file"), + .offsets = pt.get("documents.offsets")}, + .frequencies = PostingFilePaths{.postings = pt.get("frequencies.file"), + .offsets = pt.get("frequencies.offsets")}, + .document_lengths_path = pt.get("stats.document_lengths"), + .avg_document_length = pt.get("stats.avg_document_length"), + .term_lexicon = convert_optional(pt.get_optional("lexicon.terms")), + .document_lexicon = convert_optional(pt.get_optional("lexicon.documents")), + .stemmer = convert_optional(pt.get_optional("lexicon.stemmer"))}; +} + +} // namespace pisa::v1 diff --git a/src/v1/progress_status.cpp b/src/v1/progress_status.cpp new file mode 100644 index 000000000..bc60eb582 --- /dev/null +++ b/src/v1/progress_status.cpp @@ -0,0 +1,52 @@ +#include "v1/progress_status.hpp" + +namespace pisa::v1 { + +std::ostream &format_interval(std::ostream &out, std::chrono::seconds time) +{ + using std::chrono::hours; + using std::chrono::minutes; + using std::chrono::seconds; + hours h = std::chrono::duration_cast(time); + minutes m = std::chrono::duration_cast(time - h); + seconds s = std::chrono::duration_cast(time - h - m); + if (h.count() > 0) { + out << h.count() << "h "; + } + if (m.count() > 0) { + out << m.count() << "m "; + } + out << s.count() << "s"; + return out; +} + +DefaultProgress::DefaultProgress(std::string caption) : m_caption(std::move(caption)) +{ + if (not m_caption.empty()) { + m_caption.append(": "); + } +} + +void DefaultProgress::operator()(std::size_t count, + std::size_t goal, + std::chrono::time_point start) +{ + size_t progress = (100 * count) / goal; + if (progress == m_previous) { + return; + } + m_previous = progress; + std::chrono::seconds elapsed = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + std::cerr << m_caption << progress << "% ["; + format_interval(std::cerr, elapsed); + std::cerr << "]\n"; +} + +ProgressStatus::~ProgressStatus() +{ + m_count = m_goal; + m_loop.join(); +} + +} // namespace pisa::v1 diff --git a/test/test_v1_queries.cpp b/test/test_v1_queries.cpp index 7e9d78cfb..7bdfc5520 100644 --- a/test/test_v1_queries.cpp +++ b/test/test_v1_queries.cpp @@ -21,6 +21,7 @@ #include "v1/index.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" +#include "v1/query.hpp" #include "v1/scorer/bm25.hpp" #include "v1/types.hpp" @@ -92,40 +93,6 @@ template std::unique_ptr> IndexData::data = nullptr; -template -auto daat_and(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) -{ - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); - auto intersection = - v1::intersect(std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { - score += cursor.payload(); - return score; - }); - v1::for_each(intersection, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); - return topk; -} - -template -auto daat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) -{ - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); - auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { - score += cursor.payload(); - return score; - }); - v1::for_each(cunion, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); - return topk; -} - TEST_CASE("DAAT AND", "[v1][integration]") { auto data = IndexDatav1_index, topk_queue(10), make_bm25(data->v1_index)); + auto que = daat_and( + v1::Query{q.terms}, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); que.finalize(); auto results = que.topk(); std::sort(results.begin(), results.end(), std::greater{}); @@ -151,7 +119,8 @@ TEST_CASE("DAAT AND", "[v1][integration]") }(); auto precomputed = [&]() { - auto que = daat_and(q, data->scored_index, topk_queue(10), v1::VoidScorer{}); + auto que = + daat_and(v1::Query{q.terms}, data->scored_index, topk_queue(10), v1::VoidScorer{}); que.finalize(); auto results = que.topk(); std::sort(results.begin(), results.end(), std::greater{}); @@ -185,7 +154,8 @@ TEST_CASE("DAAT OR", "[v1][integration]") std::sort(expected.begin(), expected.end(), std::greater{}); auto on_the_fly = [&]() { - auto que = daat_or(q, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); + auto que = daat_or( + v1::Query{q.terms}, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); que.finalize(); auto results = que.topk(); std::sort(results.begin(), results.end(), std::greater{}); @@ -193,7 +163,8 @@ TEST_CASE("DAAT OR", "[v1][integration]") }(); auto precomputed = [&]() { - auto que = daat_or(q, data->scored_index, topk_queue(10), v1::VoidScorer{}); + auto que = + daat_or(v1::Query{q.terms}, data->scored_index, topk_queue(10), v1::VoidScorer{}); que.finalize(); auto results = que.topk(); std::sort(results.begin(), results.end(), std::greater{}); diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index cadf5b7d1..d3bda2b3d 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -1,120 +1,8 @@ add_executable(compress compress.cpp) -target_link_libraries(compress - pisa - CLI11 -) +target_link_libraries(compress pisa CLI11) -#add_executable(create_freq_index create_freq_index.cpp) -#target_link_libraries(create_freq_index -# pisa -# CLI11 -#) -# -#add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) -#target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) -#target_link_libraries(optimal_hybrid_index -# ${STXXL_LIBRARIES} -# pisa -#) -#set_target_properties(optimal_hybrid_index PROPERTIES -# CXX_STANDARD 14 -#) -# -#add_executable(create_wand_data create_wand_data.cpp) -#target_link_libraries(create_wand_data -# pisa -# CLI11 -#) -# -#add_executable(queries queries.cpp) -#target_link_libraries(queries -# pisa -# CLI11 -#) -# -#add_executable(evaluate_queries evaluate_queries.cpp) -#target_link_libraries(evaluate_queries -# pisa -# CLI11 -#) -# -#add_executable(thresholds thresholds.cpp) -#target_link_libraries(thresholds -# pisa -# CLI11 -#) -# -#add_executable(profile_queries profile_queries.cpp) -#target_link_libraries(profile_queries -# pisa -#) -# -#add_executable(profile_decoding profile_decoding.cpp) -#target_link_libraries(profile_decoding -# pisa -#) -# -#add_executable(shuffle_docids shuffle_docids.cpp) -#target_link_libraries(shuffle_docids -# pisa -#) -# -#add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) -#target_link_libraries(recursive_graph_bisection -# pisa -# CLI11 -#) -# -#add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) -#target_link_libraries(evaluate_collection_ordering -# pisa -# ) -# -#add_executable(parse_collection parse_collection.cpp) -#target_link_libraries(parse_collection -# pisa -# CLI11 -# wapopp -#) -# -#add_executable(invert invert.cpp) -#target_link_libraries(invert -# CLI11 -# pisa -#) -# -#add_executable(read_collection read_collection.cpp) -#target_link_libraries(read_collection -# pisa -# CLI11 -#) -# -#add_executable(partition_fwd_index partition_fwd_index.cpp) -#target_link_libraries(partition_fwd_index -# pisa -# CLI11 -#) -# -#add_executable(compute_intersection compute_intersection.cpp) -#target_link_libraries(compute_intersection -# pisa -# CLI11 -#) -# -#add_executable(lexicon lexicon.cpp) -#target_link_libraries(lexicon -# pisa -# CLI11 -#) -# -#add_executable(extract_topics extract_topics.cpp) -#target_link_libraries(extract_topics -# pisa -# CLI11 -#) -# -#add_executable(sample_inverted_index sample_inverted_index.cpp) -#target_link_libraries(sample_inverted_index -# pisa -# CLI11 -#) +add_executable(query query.cpp) +target_link_libraries(query pisa CLI11) + +add_executable(postings postings.cpp) +target_link_libraries(postings pisa CLI11) diff --git a/v1/compress.cpp b/v1/compress.cpp index 3fcc4aa3a..db0168af5 100644 --- a/v1/compress.cpp +++ b/v1/compress.cpp @@ -2,6 +2,7 @@ #include "binary_freq_collection.hpp" #include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" #include "v1/types.hpp" @@ -9,6 +10,7 @@ using pisa::v1::compress_binary_collection; using pisa::v1::EncodingId; using pisa::v1::IndexBuilder; using pisa::v1::RawWriter; +using pisa::v1::verify_compressed_index; int main(int argc, char **argv) { @@ -32,6 +34,18 @@ int main(int argc, char **argv) make_writer(std::move(writer)), make_writer(std::move(frequency_writer))); }); + auto errors = verify_compressed_index(input, output); + if (not errors.empty()) { + if (errors.size() > 10) { + std::cerr << "Detected more than 10 errors, printing head:\n"; + errors.resize(10); + } + for (auto const &error : errors) { + std::cerr << error << '\n'; + } + return 1; + } + std::cout << "Success."; return 0; } diff --git a/v1/postings.cpp b/v1/postings.cpp new file mode 100644 index 000000000..ca3ec3cde --- /dev/null +++ b/v1/postings.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include + +#include +#include + +#include "io.hpp" +#include "query/queries.hpp" +#include "topk_queue.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::Query; +using pisa::resolve_query_parser; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::RawReader; +using pisa::v1::resolve_ini; + +auto default_readers() { return std::make_tuple(RawReader{}, RawReader{}); } + +[[nodiscard]] auto load_source(std::optional const &file) + -> std::shared_ptr +{ + if (file) { + return std::make_shared(file->c_str()); + } + return nullptr; +} + +[[nodiscard]] auto load_payload_vector(std::shared_ptr const &source) + -> std::optional> +{ + if (source) { + return pisa::Payload_Vector<>::from(*source); + } + return std::nullopt; +} + +/// Returns the first value (not nullopt), or nullopt if no optional contains a value. +template +[[nodiscard]] auto value(First &&first, Optional &&... candidtes) +{ + std::optional> val = std::nullopt; + auto has_value = [&](auto &&opt) -> bool { + if (not val.has_value() && opt) { + val = *opt; + return true; + } + return false; + }; + has_value(first) || (has_value(candidtes) || ...); + return val; +} + +int main(int argc, char **argv) +{ + std::optional ini{}; + std::optional terms_file{}; + std::optional documents_file{}; + std::string query_input{}; + bool tid = false; + bool did = false; + bool print_frequencies = false; + bool print_scores = false; + + CLI::App app{"Queries a v1 index."}; + app.add_option("-i,--index", + ini, + "Path of .ini file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + app.add_option("--terms", terms_file, "Overrides document lexicon from .ini (if defined)."); + app.add_option("--documents", + documents_file, + "Overrides document lexicon from .ini (if defined). Required otherwise."); + app.add_flag("--tid", tid, "Use term IDs instead of terms"); + app.add_flag("--did", did, "Print document IDs instead of titles"); + app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); + app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); + app.add_option("query", query_input, "List of terms", false)->required(); + CLI11_PARSE(app, argc, argv); + + auto meta = IndexMetadata::from_file(resolve_ini(ini)); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (tid) { + terms_file = std::nullopt; + } else { + terms_file = value(meta.term_lexicon, terms_file); + } + documents_file = value(meta.document_lexicon, documents_file); + + if (not did and not documents_file) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + + if (not tid and not terms_file) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + + std::shared_ptr const source = load_source(documents_file); + std::optional> const docmap = load_payload_vector(source); + + auto const query = [&]() { + std::vector queries; + auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); + parse_query(query_input); + return queries[0]; + }(); + + if (query.terms.size() == 1) { + auto run = index_runner(meta, default_readers()); + run([&](auto &&index) { + auto bm25 = make_bm25(index); + auto scorer = bm25.term_scorer(query.terms.front()); + auto print = [&](auto &&cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if (print_frequencies) { + std::cout << " " << cursor.payload(); + } + if (print_scores) { + std::cout << " " << scorer(cursor.value(), cursor.payload()); + } + std::cout << '\n'; + }; + for_each(index.cursor(query.terms.front()), print); + }); + } else { + std::cerr << "Multiple terms unimplemented"; + std::exit(1); + } + + return 0; +} diff --git a/v1/query.cpp b/v1/query.cpp new file mode 100644 index 000000000..00779da1e --- /dev/null +++ b/v1/query.cpp @@ -0,0 +1,100 @@ +#include +#include +#include + +#include +#include + +#include "io.hpp" +#include "query/queries.hpp" +#include "topk_queue.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::Query; +using pisa::resolve_query_parser; +using pisa::v1::daat_or; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::RawReader; +using pisa::v1::resolve_ini; +using pisa::v1::taat_or; + +int main(int argc, char **argv) +{ + std::optional ini{}; + std::optional query_file{}; + std::optional terms_file{}; + std::optional documents_file{}; + int k = 1'000; + + CLI::App app{"Queries a v1 index."}; + app.add_option("-i,--index", + ini, + "Path of .ini file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + app.add_option("-q,--query", query_file, "Path to file with queries", false); + app.add_option("-k", k, "The number of top results to return", true); + app.add_option("--terms", terms_file, "Overrides document lexicon from .ini (if defined)."); + app.add_option("--documents", + documents_file, + "Overrides document lexicon from .ini (if defined). Required otherwise."); + CLI11_PARSE(app, argc, argv); + + auto meta = IndexMetadata::from_file(resolve_ini(ini)); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + terms_file = meta.term_lexicon.value(); + } + if (meta.document_lexicon) { + documents_file = meta.document_lexicon.value(); + } + + std::vector queries; + auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); + if (query_file) { + std::ifstream is(*query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + + if (not documents_file) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(documents_file.value().c_str()); + auto docmap = pisa::Payload_Vector<>::from(*source); + + auto run = index_runner(meta, RawReader{}); + run([&](auto &&index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + auto query_idx = 0; + for (auto const &query : queries) { + auto que = + taat_or(pisa::v1::Query{query.terms}, index, pisa::topk_queue(k), scorer); + que.finalize(); + auto rank = 0; + for (auto result : que.topk()) { + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", + query.id.value_or(std::to_string(query_idx)), + "Q0", + docmap[result.second], + rank, + result.first, + "R0"); + rank += 1; + } + query_idx += 1; + } + }); + }); + + return 0; +} From 27a62ba2e4aadba5622b82cb55016c5d919d0f62 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 11 Nov 2019 21:01:04 -0500 Subject: [PATCH 16/56] Blocked cursor + SIMDBP --- CMakeLists.txt | 4 +- include/pisa/codec/simdbp.hpp | 22 +- include/pisa/v1/bit_cast.hpp | 4 +- include/pisa/v1/blocked_cursor.hpp | 360 ++++++++++++++++++++ include/pisa/v1/cursor/collect.hpp | 13 + include/pisa/v1/cursor/scoring_cursor.hpp | 2 +- include/pisa/v1/cursor_traits.hpp | 10 + include/pisa/v1/document_payload_cursor.hpp | 23 +- include/pisa/v1/encoding_traits.hpp | 17 + include/pisa/v1/index.hpp | 46 +-- include/pisa/v1/index_builder.hpp | 96 +++--- include/pisa/v1/index_metadata.hpp | 28 ++ include/pisa/v1/io.hpp | 11 + include/pisa/v1/posting_format_header.hpp | 2 +- include/pisa/v1/raw_cursor.hpp | 14 +- include/pisa/v1/types.hpp | 24 +- include/pisa/v1/unaligned_span.hpp | 158 +++++++++ src/v1/index_builder.cpp | 11 +- src/v1/index_metadata.cpp | 30 ++ src/v1/io.cpp | 19 ++ src/v1/progress_status.cpp | 7 +- test/temporary_directory.hpp | 5 +- test/test_v1.cpp | 67 +++- test/test_v1_blocked_cursor.cpp | 195 +++++++++++ test/test_v1_document_payload_cursor.cpp | 98 ++++++ test/test_v1_index.cpp | 52 ++- test/test_v1_queries.cpp | 157 +++++---- v1/compress.cpp | 62 +++- v1/query.cpp | 84 ++++- 29 files changed, 1406 insertions(+), 215 deletions(-) create mode 100644 include/pisa/v1/blocked_cursor.hpp create mode 100644 include/pisa/v1/cursor_traits.hpp create mode 100644 include/pisa/v1/encoding_traits.hpp create mode 100644 include/pisa/v1/io.hpp create mode 100644 include/pisa/v1/unaligned_span.hpp create mode 100644 src/v1/io.cpp create mode 100644 test/test_v1_blocked_cursor.cpp create mode 100644 test/test_v1_document_payload_cursor.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 06d981e66..72807a1e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,10 +100,10 @@ target_link_libraries(pisa PUBLIC range-v3 optional ) -target_include_directories(pisa INTERFACE external) +target_include_directories(pisa PUBLIC external) add_subdirectory(v1) -add_subdirectory(src) +#add_subdirectory(src) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() diff --git a/include/pisa/codec/simdbp.hpp b/include/pisa/codec/simdbp.hpp index 23a694657..fa446eebe 100644 --- a/include/pisa/codec/simdbp.hpp +++ b/include/pisa/codec/simdbp.hpp @@ -1,8 +1,8 @@ #pragma once -#include -#include "util/util.hpp" #include "codec/block_codecs.hpp" +#include "util/util.hpp" +#include extern "C" { #include "simdcomp/include/simdbitpacking.h" @@ -14,7 +14,8 @@ struct simdbp_block { static void encode(uint32_t const *in, uint32_t sum_of_values, size_t n, - std::vector &out) { + std::vector &out) + { assert(n <= block_size); uint32_t *src = const_cast(in); @@ -23,23 +24,22 @@ struct simdbp_block { return; } uint32_t b = maxbits(in); - thread_local std::vector buf(8*n); - uint8_t * buf_ptr = buf.data(); + thread_local std::vector buf(8 * n); + uint8_t *buf_ptr = buf.data(); *buf_ptr++ = b; simdpackwithoutmask(src, (__m128i *)buf_ptr, b); out.insert(out.end(), buf.data(), buf.data() + b * sizeof(__m128i) + 1); } - static uint8_t const *decode(uint8_t const *in, - uint32_t *out, - uint32_t sum_of_values, - size_t n) { + static uint8_t const *decode(uint8_t const *in, uint32_t *out, uint32_t sum_of_values, size_t n) + { assert(n <= block_size); if (PISA_UNLIKELY(n < block_size)) { return interpolative_block::decode(in, out, sum_of_values, n); } uint32_t b = *in++; simdunpack((const __m128i *)in, out, b); - return in + b * sizeof(__m128i); + return in + b * sizeof(__m128i); } }; -} // namespace pisa \ No newline at end of file + +} // namespace pisa diff --git a/include/pisa/v1/bit_cast.hpp b/include/pisa/v1/bit_cast.hpp index dd7e87523..dd4fa6413 100644 --- a/include/pisa/v1/bit_cast.hpp +++ b/include/pisa/v1/bit_cast.hpp @@ -7,9 +7,9 @@ namespace pisa::v1 { template -constexpr auto bit_cast(gsl::span mem) -> T +constexpr auto bit_cast(gsl::span mem) -> std::remove_const_t { - T dst{}; + std::remove_const_t dst{}; std::memcpy(&dst, mem.data(), sizeof(T)); return dst; } diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp new file mode 100644 index 000000000..165253f2b --- /dev/null +++ b/include/pisa/v1/blocked_cursor.hpp @@ -0,0 +1,360 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "codec/block_codecs.hpp" +#include "util/likely.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/encoding_traits.hpp" +#include "v1/types.hpp" +#include "v1/unaligned_span.hpp" + +namespace pisa::v1 { + +/// Uncompressed example of implementation of a single value cursor. +template +struct BlockedCursor { + using value_type = std::uint32_t; + + /// Creates a cursor from the encoded bytes. + explicit constexpr BlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : m_encoded_blocks(encoded_blocks), + m_block_endpoints(block_endpoints), + m_block_last_values(block_last_values), + m_length(length), + m_num_blocks(num_blocks), + m_current_block( + {.number = 0, + .offset = 0, + .length = std::min(length, static_cast(Codec::block_size)), + .last_value = m_block_last_values[0]}) + { + static_assert(DeltaEncoded, + "Cannot initialize block_last_values for not delta-encoded list"); + m_decoded_block.resize(Codec::block_size); + reset(); + } + + /// Creates a cursor from the encoded bytes. + explicit constexpr BlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + std::uint32_t length, + std::uint32_t num_blocks) + : m_encoded_blocks(encoded_blocks), + m_block_endpoints(block_endpoints), + m_length(length), + m_num_blocks(num_blocks), + m_current_block( + {.number = 0, + .offset = 0, + .length = std::min(length, static_cast(Codec::block_size)), + .last_value = 0}) + { + static_assert(not DeltaEncoded, "Must initialize block_last_values for delta-encoded list"); + m_decoded_block.resize(Codec::block_size); + reset(); + } + + void reset() { decode_and_update_block(0); } + + /// Dereferences the current value. + [[nodiscard]] constexpr auto operator*() const -> value_type { return m_current_value; } + + /// Alias for `operator*()`. + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } + + /// Advances the cursor to the next position. + constexpr void advance() + { + m_current_block.offset += 1; + if (PISA_UNLIKELY(m_current_block.offset == m_current_block.length)) { + if (m_current_block.number + 1 == m_num_blocks) { + m_current_value = sentinel(); + return; + } + decode_and_update_block(m_current_block.number + 1); + } else { + if constexpr (DeltaEncoded) { + m_current_value += m_decoded_block[m_current_block.offset] + 1U; + } else { + m_current_value = m_decoded_block[m_current_block.offset] + 1U; + } + } + } + + /// Moves the cursor to the position `pos`. + constexpr void advance_to_position(std::uint32_t pos) + { + Expects(pos >= position()); + auto block = pos / Codec::block_size; + if (PISA_UNLIKELY(block != m_current_block.number)) { + decode_and_update_block(block); + } + while (position() < pos) { + if constexpr (DeltaEncoded) { + m_current_value += m_decoded_block[++m_current_block.offset] + 1U; + } else { + m_current_value = m_decoded_block[++m_current_block.offset] + 1U; + } + } + } + + /// Moves the cursor to the next value equal or greater than `value`. + constexpr void advance_to_geq(value_type value) + { + static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); + Expects(value >= m_current_value || position() == 0); + if (PISA_UNLIKELY(value > m_current_block.last_value)) { + if (value > m_block_last_values.back()) { + m_current_value = sentinel(); + return; + } + auto block = m_current_block.number + 1U; + while (m_block_last_values[block] < value) { + ++block; + } + decode_and_update_block(block); + } + + while (m_current_value < value) { + if constexpr (DeltaEncoded) { + m_current_value += m_decoded_block[m_current_block.offset] + 1U; + } else { + m_current_value = m_decoded_block[m_current_block.offset] + 1U; + } + Ensures(m_current_block.offset < m_current_block.length); + } + } + + ///// Returns `true` if there is no elements left. + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return position() == m_length; } + + /// Returns the current position. + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_current_block.number * Codec::block_size + m_current_block.offset; + } + + ///// Returns the number of elements in the list. + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_length; } + + /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return std::numeric_limits::max(); + } + + private: + struct Block { + std::uint32_t number = 0; + std::uint32_t offset = 0; + std::uint32_t length = 0; + value_type last_value = 0; + }; + + void decode_and_update_block(std::uint32_t block) + { + constexpr auto block_size = Codec::block_size; + auto endpoint = block > 0U ? m_block_endpoints[block - 1] : static_cast(0U); + std::uint8_t const *block_data = + std::next(reinterpret_cast(m_encoded_blocks.data()), endpoint); + m_current_block.length = + ((block + 1) * block_size <= size()) ? block_size : (size() % block_size); + + if constexpr (DeltaEncoded) { + std::uint32_t first_value = block > 0U ? m_block_last_values[block - 1] + 1U : 0U; + m_current_block.last_value = m_block_last_values[block]; + Codec::decode(block_data, + m_decoded_block.data(), + m_current_block.last_value - first_value - (m_current_block.length - 1), + m_current_block.length); + m_decoded_block[0] += first_value; + } else { + Codec::decode(block_data, + m_decoded_block.data(), + std::numeric_limits::max(), + m_current_block.length); + m_decoded_block[0] += 1; + } + + m_current_block.number = block; + m_current_block.offset = 0U; + m_current_value = m_decoded_block[0]; + } + + gsl::span m_encoded_blocks; + UnalignedSpan m_block_endpoints; + UnalignedSpan m_block_last_values{}; + std::vector m_decoded_block; + + std::uint32_t m_length; + std::uint32_t m_num_blocks; + Block m_current_block{}; + value_type m_current_value{}; +}; + +template +constexpr auto block_encoding_type() -> std::uint32_t +{ + if constexpr (DeltaEncoded) { + return EncodingId::BlockDelta; + } else { + return EncodingId::Block; + } +} + +template +struct BlockedReader { + using value_type = std::uint32_t; + + [[nodiscard]] auto read(gsl::span bytes) const + -> BlockedCursor + { + std::uint32_t length; + auto begin = reinterpret_cast(bytes.data()); + auto after_length_ptr = pisa::TightVariableByte::decode(begin, &length, 1); + auto length_byte_size = std::distance(begin, after_length_ptr); + auto num_blocks = ceil_div(length, Codec::block_size); + UnalignedSpan block_last_values; + if constexpr (DeltaEncoded) { + block_last_values = UnalignedSpan( + bytes.subspan(length_byte_size, num_blocks * sizeof(value_type))); + } + UnalignedSpan block_endpoints( + bytes.subspan(length_byte_size + block_last_values.byte_size(), + (num_blocks - 1) * sizeof(value_type))); + auto encoded_blocks = bytes.subspan(length_byte_size + block_last_values.byte_size() + + block_endpoints.byte_size()); + if constexpr (DeltaEncoded) { + return BlockedCursor( + encoded_blocks, block_endpoints, block_last_values, length, num_blocks); + } else { + return BlockedCursor( + encoded_blocks, block_endpoints, length, num_blocks); + } + } + + constexpr static auto encoding() -> std::uint32_t + { + return block_encoding_type() + | encoding_traits::encoding_tag::encoding(); + } +}; + +template +struct BlockedWriter { + using value_type = std::uint32_t; + + constexpr static auto encoding() -> std::uint32_t + { + return block_encoding_type() + | encoding_traits::encoding_tag::encoding(); + } + + void push(value_type const &posting) + { + if constexpr (DeltaEncoded) { + if (posting < m_last_value) { + throw std::runtime_error( + fmt::format("Delta-encoded sequences must be monotonic, but {} < {}", + posting, + m_last_value)); + } + } + m_postings.push_back(posting); + m_last_value = posting; + } + + template + [[nodiscard]] auto write(std::basic_ostream &os) const -> std::size_t + { + std::vector buffer; + std::uint32_t length = m_postings.size(); + TightVariableByte::encode_single(length, buffer); + auto block_size = Codec::block_size; + auto num_blocks = ceil_div(length, block_size); + auto begin_block_maxs = buffer.size(); + auto begin_block_endpoints = [&]() { + if constexpr (DeltaEncoded) { + return begin_block_maxs + 4U * num_blocks; + } else { + return begin_block_maxs; + } + }(); + auto begin_blocks = begin_block_endpoints + 4U * (num_blocks - 1); + buffer.resize(begin_blocks); + + auto iter = m_postings.begin(); + std::vector block_buffer(block_size); + std::uint32_t last_value(-1); + std::uint32_t block_base = 0; + for (auto block = 0; block < num_blocks; ++block) { + auto current_block_size = + ((block + 1) * block_size <= length) ? block_size : (length % block_size); + + std::for_each(block_buffer.begin(), + std::next(block_buffer.begin(), current_block_size), + [&](auto &&elem) { + if constexpr (DeltaEncoded) { + auto value = *iter++; + elem = value - (last_value + 1); + last_value = value; + } else { + elem = *iter++ - 1; + } + }); + + if constexpr (DeltaEncoded) { + std::memcpy( + &buffer[begin_block_maxs + 4U * block], &last_value, sizeof(last_value)); + auto size = buffer.size(); + Codec::encode(block_buffer.data(), + last_value - block_base - (current_block_size - 1), + current_block_size, + buffer); + } else { + Codec::encode(block_buffer.data(), std::uint32_t(-1), current_block_size, buffer); + } + if (block != num_blocks - 1) { + std::size_t endpoint = buffer.size() - begin_blocks; + std::memcpy( + &buffer[begin_block_endpoints + 4U * block], &endpoint, sizeof(last_value)); + } + block_base = last_value + 1; + } + os.write(reinterpret_cast(buffer.data()), buffer.size()); + return buffer.size(); + } + + void reset() + { + m_postings.clear(); + m_last_value = 0; + } + + private: + std::vector m_postings{}; + value_type m_last_value = 0U; +}; + +template +struct CursorTraits> { + using Writer = BlockedWriter; + using Reader = BlockedReader; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/collect.hpp b/include/pisa/v1/cursor/collect.hpp index 099f6c7cb..f2d68f66b 100644 --- a/include/pisa/v1/cursor/collect.hpp +++ b/include/pisa/v1/cursor/collect.hpp @@ -21,4 +21,17 @@ auto collect(Cursor &&cursor) return collect(std::forward(cursor), [](auto &&cursor) { return *cursor; }); } +template +auto collect_with_payload(Cursor &&cursor) +{ + return collect(std::forward(cursor), + [](auto &&cursor) { return std::make_pair(*cursor, cursor.payload()); }); +} + +template +auto collect_payloads(Cursor &&cursor) +{ + return collect(std::forward(cursor), [](auto &&cursor) { return cursor.payload(); }); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index f9e3d9adf..779b4b135 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -24,7 +24,7 @@ struct ScoringCursor { { return m_base_cursor.value(); } - [[nodiscard]] constexpr auto payload() const noexcept -> Payload + [[nodiscard]] constexpr auto payload() noexcept -> Payload { return m_scorer(m_base_cursor.value(), m_base_cursor.payload()); } diff --git a/include/pisa/v1/cursor_traits.hpp b/include/pisa/v1/cursor_traits.hpp new file mode 100644 index 000000000..3e32c69f2 --- /dev/null +++ b/include/pisa/v1/cursor_traits.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct CursorTraits; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index 0a04aa45b..8c075497b 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -19,25 +19,16 @@ struct DocumentPayloadCursor { [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Document { return m_key_cursor.value(); } - [[nodiscard]] constexpr auto payload() const noexcept -> Payload + [[nodiscard]] constexpr auto payload() noexcept -> Payload { + if (auto pos = m_key_cursor.position(); pos != m_payload_cursor.position()) { + m_payload_cursor.advance_to_position(m_key_cursor.position()); + } return m_payload_cursor.value(); } - constexpr void advance() - { - m_key_cursor.advance(); - m_payload_cursor.advance(); - } - constexpr void advance_to_position(std::size_t pos) - { - m_key_cursor.advance_to_position(pos); - m_payload_cursor.advance_to_position(pos); - } - constexpr void advance_to_geq(Document value) - { - m_key_cursor.advance_to_geq(value); - m_payload_cursor.advance_to_position(m_key_cursor.position()); - } + constexpr void advance() { m_key_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_key_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_key_cursor.advance_to_geq(value); } [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_key_cursor.empty(); } [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { diff --git a/include/pisa/v1/encoding_traits.hpp b/include/pisa/v1/encoding_traits.hpp new file mode 100644 index 000000000..6ec372118 --- /dev/null +++ b/include/pisa/v1/encoding_traits.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "codec/simdbp.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +struct SimdBPTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::SimdBP; } +}; + +template <> +struct encoding_traits<::pisa::simdbp_block> { + using encoding_tag = SimdBPTag; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 568bafc4e..562e5a3f9 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -30,8 +30,7 @@ namespace pisa::v1 { -[[nodiscard]] inline auto calc_avg_length(gsl::span const &lengths) - -> std::uint32_t +[[nodiscard]] inline auto calc_avg_length(gsl::span const &lengths) -> float { auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); return static_cast(sum) / lengths.size(); @@ -68,7 +67,7 @@ struct Index { gsl::span documents, gsl::span payloads, gsl::span document_lengths, - tl::optional avg_document_length, + tl::optional avg_document_length, Source source) : m_document_reader(std::move(document_reader)), m_payload_reader(std::move(payload_reader)), @@ -166,7 +165,7 @@ struct Index { gsl::span m_documents; gsl::span m_payloads; gsl::span m_document_lengths; - std::uint32_t m_avg_document_length; + float m_avg_document_length; std::any m_source; }; @@ -178,7 +177,7 @@ auto make_index(DocumentReader document_reader, gsl::span documents, gsl::span payloads, gsl::span document_lengths, - tl::optional avg_document_length, + tl::optional avg_document_length, Source source) { using DocumentCursor = @@ -195,8 +194,8 @@ auto make_index(DocumentReader document_reader, std::move(source)); } -template -auto score_index(Index const &index, ByteOStream &os, Writer writer, Scorer scorer) +template +auto score_index(Index const &index, std::basic_ostream &os, Writer writer, Scorer scorer) -> std::vector { PostingBuilder score_builder(writer); @@ -420,7 +419,7 @@ struct IndexRunner { gsl::span documents, gsl::span payloads, gsl::span document_lengths, - tl::optional avg_document_length, + tl::optional avg_document_length, Source source, Readers... readers) : m_document_offsets(document_offsets), @@ -439,7 +438,7 @@ struct IndexRunner { gsl::span documents, gsl::span payloads, gsl::span document_lengths, - tl::optional avg_document_length, + tl::optional avg_document_length, Source source, std::tuple readers) : m_document_offsets(document_offsets), @@ -454,29 +453,30 @@ struct IndexRunner { } template - void operator()(Fn fn) + auto operator()(Fn fn) { auto dheader = PostingFormatHeader::parse(m_documents.first(8)); auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); - auto run = [&](auto &&dreader, auto &&preader) -> bool { + auto run = [&](auto &&dreader, auto &&preader) { if (std::decay_t::encoding() == dheader.encoding && std::decay_t::encoding() == pheader.encoding && is_type::value_type>(dheader.type) && is_type::value_type>(pheader.type)) { - fn(make_index(std::forward(dreader), - std::forward(preader), - m_document_offsets, - m_payload_offsets, - m_documents.subspan(8), - m_payloads.subspan(8), - m_document_lengths, - m_avg_document_length, - false)); + auto index = make_index(std::forward(dreader), + std::forward(preader), + m_document_offsets, + m_payload_offsets, + m_documents.subspan(8), + m_payloads.subspan(8), + m_document_lengths, + m_avg_document_length, + false); + fn(index); return true; } return false; }; - bool success = std::apply( + auto result = std::apply( [&](Readers... dreaders) { auto with_document_reader = [&](auto dreader) { return std::apply( @@ -486,7 +486,7 @@ struct IndexRunner { return (with_document_reader(dreaders) || ...); }, m_readers); - if (not success) { + if (not result) { throw std::domain_error("Unknown posting encoding"); } } @@ -497,7 +497,7 @@ struct IndexRunner { gsl::span m_documents; gsl::span m_payloads; gsl::span m_document_lengths; - tl::optional m_avg_document_length; + tl::optional m_avg_document_length; std::any m_source; std::tuple m_readers; }; diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index f929f8784..14f458d7b 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -17,22 +17,31 @@ namespace pisa::v1 { template struct IndexBuilder { - explicit IndexBuilder(Writers... writers) : m_writers(std::move(writers...)) {} + explicit IndexBuilder(Writers... writers) : m_writers(std::move(writers)...) {} template - void operator()(Encoding encoding, Fn fn) + void operator()(Encoding document_encoding, Encoding payload_encoding, Fn fn) { - auto run = [&](auto &&writer) -> bool { - if (std::decay_t::encoding() == encoding) { - fn(writer); + auto run = [&](auto &&dwriter, auto &&pwriter) -> bool { + if (std::decay_t::encoding() == document_encoding + && std::decay_t::encoding() == payload_encoding) { + fn(dwriter, pwriter); return true; } return false; }; - bool success = - std::apply([&](Writers... writers) { return (run(writers) || ...); }, m_writers); + bool success = std::apply( + [&](Writers... dwriters) { + auto with_document_writer = [&](auto dwriter) { + return std::apply( + [&](Writers... pwriters) { return (run(dwriter, pwriters) || ...); }, + m_writers); + }; + return (with_document_writer(dwriters) || ...); + }, + m_writers); if (not success) { - throw std::domain_error(fmt::format("Unknown writer")); + throw std::domain_error("Unknown posting encoding"); } } @@ -40,13 +49,11 @@ struct IndexBuilder { std::tuple m_writers; }; -using pisa::v1::ByteOStream; -using pisa::v1::calc_avg_length; -using pisa::v1::DocId; -using pisa::v1::Frequency; -using pisa::v1::PostingBuilder; -using pisa::v1::RawWriter; -using pisa::v1::read_sizes; +template +auto make_index_builder(Writers... writers) +{ + return IndexBuilder(std::move(writers)...); +} template auto compress_batch(CollectionIterator first, @@ -86,12 +93,16 @@ void write_span(gsl::span offsets, std::string const &file) } inline void compress_binary_collection(std::string const &input, + std::string_view fwd, std::string_view output, std::size_t const threads, Writer document_writer, Writer frequency_writer) { pisa::binary_freq_collection const collection(input.c_str()); + ProgressStatus status(collection.size(), + DefaultProgress("Compressing in parallel"), + std::chrono::milliseconds(100)); tbb::task_group group; auto const num_terms = collection.size(); std::vector> document_offsets(threads); @@ -113,9 +124,6 @@ inline void compress_binary_collection(std::string const &input, document_streams.emplace_back(document_batch); frequency_streams.emplace_back(frequency_batch); }); - ProgressStatus status(collection.size(), - DefaultProgress("Compressing in parallel"), - std::chrono::milliseconds(500)); auto batch_size = num_terms / threads; for_each_batch([&](auto thread_idx) { group.run([thread_idx, @@ -163,37 +171,44 @@ inline void compress_binary_collection(std::string const &input, std::ofstream document_out(documents_file); std::ofstream frequency_out(frequencies_file); - PostingBuilder(RawWriter{}).write_header(document_out); - PostingBuilder(RawWriter{}).write_header(frequency_out); + PostingBuilder(document_writer).write_header(document_out); + PostingBuilder(frequency_writer).write_header(frequency_out); - ProgressStatus merge_status( - threads, DefaultProgress("Merging files"), std::chrono::milliseconds(500)); - for_each_batch([&](auto thread_idx) { - std::transform(std::next(document_offsets[thread_idx].begin()), - document_offsets[thread_idx].end(), - std::back_inserter(all_document_offsets), - [base = all_document_offsets.back()](auto offset) { return base + offset; }); - std::transform( - std::next(frequency_offsets[thread_idx].begin()), - frequency_offsets[thread_idx].end(), - std::back_inserter(all_frequency_offsets), - [base = all_frequency_offsets.back()](auto offset) { return base + offset; }); - std::ifstream docbatch(document_paths[thread_idx]); - std::ifstream freqbatch(frequency_paths[thread_idx]); - document_out << docbatch.rdbuf(); - frequency_out << freqbatch.rdbuf(); - merge_status += 1; - }); + { + ProgressStatus merge_status( + threads, DefaultProgress("Merging files"), std::chrono::milliseconds(500)); + for_each_batch([&](auto thread_idx) { + std::transform( + std::next(document_offsets[thread_idx].begin()), + document_offsets[thread_idx].end(), + std::back_inserter(all_document_offsets), + [base = all_document_offsets.back()](auto offset) { return base + offset; }); + std::transform( + std::next(frequency_offsets[thread_idx].begin()), + frequency_offsets[thread_idx].end(), + std::back_inserter(all_frequency_offsets), + [base = all_frequency_offsets.back()](auto offset) { return base + offset; }); + std::ifstream docbatch(document_paths[thread_idx]); + std::ifstream freqbatch(frequency_paths[thread_idx]); + document_out << docbatch.rdbuf(); + frequency_out << freqbatch.rdbuf(); + merge_status += 1; + }); + } + std::cerr << "Writing offsets..."; auto doc_offset_file = fmt::format("{}.document_offsets", output); auto freq_offset_file = fmt::format("{}.frequency_offsets", output); write_span(gsl::span(all_document_offsets), doc_offset_file); write_span(gsl::span(all_frequency_offsets), freq_offset_file); + std::cerr << " Done.\n"; + std::cerr << "Writing sizes..."; auto lengths = read_sizes(input); auto document_lengths_file = fmt::format("{}.document_lengths", output); write_span(gsl::span(lengths), document_lengths_file); - auto avg_len = calc_avg_length(gsl::span(lengths)); + float avg_len = calc_avg_length(gsl::span(lengths)); + std::cerr << " Done.\n"; boost::property_tree::ptree pt; pt.put("documents.file", documents_file); @@ -202,6 +217,9 @@ inline void compress_binary_collection(std::string const &input, pt.put("frequencies.offsets", freq_offset_file); pt.put("stats.avg_document_length", avg_len); pt.put("stats.document_lengths", document_lengths_file); + pt.put("lexicon.stemmer", "porter2"); // TODO(michal): Parametrize + pt.put("lexicon.terms", fmt::format("{}.termlex", fwd)); + pt.put("lexicon.documents", fmt::format("{}.doclex", fwd)); boost::property_tree::write_ini(fmt::format("{}.ini", output), pt); } diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 2fffedd50..e3c3e440d 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -43,12 +43,14 @@ struct PostingFilePaths { struct IndexMetadata { PostingFilePaths documents; PostingFilePaths frequencies; + std::vector scores{}; std::string document_lengths_path; float avg_document_length; tl::optional term_lexicon{}; tl::optional document_lexicon{}; tl::optional stemmer{}; + void write(std::string const &file); [[nodiscard]] static auto from_file(std::string const &file) -> IndexMetadata; }; @@ -91,4 +93,30 @@ template std::move(readers)); } +template +[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, Readers... readers) +{ + return scored_index_runner(std::move(metadata), std::make_tuple(readers...)); +} + +template +[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, + std::tuple readers) +{ + MMapSource source; + auto documents = source_span(source, metadata.documents.postings); + auto scores = source_span(source, metadata.scores.front().postings); + auto document_offsets = source_span(source, metadata.documents.offsets); + auto score_offsets = source_span(source, metadata.scores.front().offsets); + auto document_lengths = source_span(source, metadata.document_lengths_path); + return IndexRunner(document_offsets, + score_offsets, + documents, + scores, + document_lengths, + tl::make_optional(metadata.avg_document_length), + std::move(source), + std::move(readers)); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/io.hpp b/include/pisa/v1/io.hpp new file mode 100644 index 000000000..3b8b6bcb5 --- /dev/null +++ b/include/pisa/v1/io.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include +#include +#include + +namespace pisa::v1 { + +[[nodiscard]] auto load_bytes(std::string const &data_file) -> std::vector; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/posting_format_header.hpp b/include/pisa/v1/posting_format_header.hpp index 2eba95f2d..64f5678aa 100644 --- a/include/pisa/v1/posting_format_header.hpp +++ b/include/pisa/v1/posting_format_header.hpp @@ -204,7 +204,7 @@ struct PostingFormatHeader { ValueType type; Encoding encoding; - constexpr static auto parse(gsl::span bytes) -> PostingFormatHeader + static auto parse(gsl::span bytes) -> PostingFormatHeader { Expects(bytes.size() == 8); auto version = FormatVersion::parse(bytes.first(3)); diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 4e2c8e3f9..389a17f8d 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -11,6 +11,7 @@ #include "util/likely.hpp" #include "v1/bit_cast.hpp" +#include "v1/cursor_traits.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -54,7 +55,7 @@ struct RawCursor { constexpr void advance() { m_current += sizeof(T); } /// Moves the cursor to the position `pos`. - constexpr void advance_to_position(std::size_t pos) { m_current = pos; } + constexpr void advance_to_position(std::size_t pos) { m_current = pos * sizeof(T); } /// Moves the cursor to the next value equal or greater than `value`. constexpr void advance_to_geq(T value) @@ -71,7 +72,10 @@ struct RawCursor { } /// Returns the current position. - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_current; } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_current / sizeof(T); + } /// Returns the number of elements in the list. [[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); } @@ -124,4 +128,10 @@ struct RawWriter { std::vector m_postings{}; }; +template +struct CursorTraits> { + using Writer = RawWriter; + using Reader = RawReader; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 8d8aaacaa..92a06bc69 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -16,7 +17,13 @@ using Score = float; using Result = std::pair; using ByteOStream = std::basic_ostream; -enum EncodingId { Raw = 0xda43 }; +enum EncodingId { + Raw = 0xda43, + BlockDelta = 0xEF00, + Block = 0xFF00, + SimdBP = 0x0001, + Varbyte = 0x0002 +}; template struct overloaded : Ts... { @@ -100,7 +107,7 @@ struct Writer { auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } void reset() override { return m_writer.reset(); } [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } - [[nodiscard]] virtual auto clone() const -> std::unique_ptr + [[nodiscard]] auto clone() const -> std::unique_ptr override { auto copy = *this; return std::make_unique>(std::move(copy)); @@ -115,9 +122,15 @@ struct Writer { }; template -[[nodiscard]] inline auto make_writer(W writer) +[[nodiscard]] inline auto make_writer(W &&writer) { - return Writer(writer); + return Writer(std::forward(writer)); +} + +template +[[nodiscard]] inline auto make_writer() +{ + return Writer(W{}); } /// Indicates that payloads should be treated as scores. @@ -125,4 +138,7 @@ template struct VoidScorer { }; +template +struct encoding_traits; + } // namespace pisa::v1 diff --git a/include/pisa/v1/unaligned_span.hpp b/include/pisa/v1/unaligned_span.hpp new file mode 100644 index 000000000..8ea845823 --- /dev/null +++ b/include/pisa/v1/unaligned_span.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include +#include + +#include +#include + +#include "v1/bit_cast.hpp" + +namespace pisa::v1 { + +template +struct UnalignedSpan; + +template +struct UnalignedSpanIterator { + UnalignedSpanIterator(std::uint32_t index, UnalignedSpan const &span) + : m_index(index), m_span(span) + { + } + UnalignedSpanIterator(UnalignedSpanIterator const &) = default; + UnalignedSpanIterator(UnalignedSpanIterator &&) noexcept = default; + UnalignedSpanIterator &operator=(UnalignedSpanIterator const &) = default; + UnalignedSpanIterator &operator=(UnalignedSpanIterator &&) noexcept = default; + ~UnalignedSpanIterator() = default; + [[nodiscard]] auto operator==(UnalignedSpanIterator const &other) const + { + return m_span.bytes().data() == other.m_span.bytes().data() && m_index == other.m_index; + } + [[nodiscard]] auto operator!=(UnalignedSpanIterator const &other) const + { + return m_index != other.m_index || m_span.bytes().data() != other.m_span.bytes().data(); + } + [[nodiscard]] auto operator*() const { return m_span[m_index]; } + auto operator++() -> UnalignedSpanIterator & + { + m_index++; + return *this; + } + auto operator++(int) -> UnalignedSpanIterator + { + auto copy = *this; + m_index++; + return copy; + } + [[nodiscard]] auto operator+=(std::uint32_t n) -> UnalignedSpanIterator & + { + m_index += n; + return *this; + } + [[nodiscard]] auto operator+(std::uint32_t n) const -> UnalignedSpanIterator + { + return UnalignedSpanIterator(m_index + n, m_span); + } + auto operator--() -> UnalignedSpanIterator & + { + m_index--; + return *this; + } + auto operator--(int) -> UnalignedSpanIterator + { + auto copy = *this; + m_index--; + return copy; + } + [[nodiscard]] auto operator-=(std::uint32_t n) -> UnalignedSpanIterator & + { + m_index -= n; + return *this; + } + [[nodiscard]] auto operator-(std::uint32_t n) const -> UnalignedSpanIterator + { + return UnalignedSpanIterator(m_index - n, m_span); + } + [[nodiscard]] auto operator-(UnalignedSpanIterator const &other) const -> std::int32_t + { + return static_cast(m_index) - static_cast(other.m_index); + } + [[nodiscard]] auto operator<(UnalignedSpanIterator const &other) const -> bool + { + return m_index < other.m_index; + } + [[nodiscard]] auto operator<=(UnalignedSpanIterator const &other) const -> bool + { + return m_index <= other.m_index; + } + [[nodiscard]] auto operator>(UnalignedSpanIterator const &other) const -> bool + { + return m_index > other.m_index; + } + [[nodiscard]] auto operator>=(UnalignedSpanIterator const &other) const -> bool + { + return m_index >= other.m_index; + } + + private: + std::uint32_t m_index; + UnalignedSpan const &m_span; +}; + +template +struct UnalignedSpan { + static_assert(std::is_trivially_copyable_v); + using value_type = T; + + constexpr UnalignedSpan() = default; + explicit constexpr UnalignedSpan(gsl::span bytes) : m_bytes(bytes) + { + if (m_bytes.size() % sizeof(value_type) != 0) { + throw std::logic_error("Number of bytes must be a multiplier of type size"); + } + } + using iterator = UnalignedSpanIterator; + + [[nodiscard]] auto operator[](std::uint32_t index) const -> value_type + { + return bit_cast( + m_bytes.subspan(index * sizeof(value_type), sizeof(value_type))); + } + + [[nodiscard]] auto front() const -> value_type + { + return bit_cast(m_bytes.subspan(0, sizeof(value_type))); + } + + [[nodiscard]] auto back() const -> value_type + { + return bit_cast( + m_bytes.subspan(m_bytes.size() - sizeof(value_type), sizeof(value_type))); + } + + [[nodiscard]] auto begin() const -> iterator { return iterator(0, *this); } + [[nodiscard]] auto end() const -> iterator { return iterator(size(), *this); } + + [[nodiscard]] auto size() const -> std::size_t { return m_bytes.size() / sizeof(value_type); } + [[nodiscard]] auto byte_size() const -> std::size_t { return m_bytes.size(); } + [[nodiscard]] auto bytes() const -> gsl::span { return m_bytes; } + + private: + gsl::span m_bytes{}; +}; + +} // namespace pisa::v1 + +namespace std { + +template +struct iterator_traits<::pisa::v1::UnalignedSpanIterator> { + using size_type = std::uint32_t; + using difference_type = std::make_signed_t; + using value_type = T; + using pointer = T const *; + using reference = T const &; + using iterator_category = std::random_access_iterator_tag; +}; + +} // namespace std diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index b6b0006fc..f3b804491 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -1,4 +1,7 @@ #include "v1/index_builder.hpp" +#include "codec/simdbp.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/raw_cursor.hpp" namespace pisa::v1 { @@ -8,7 +11,12 @@ auto verify_compressed_index(std::string const &input, std::string_view output) std::vector errors; pisa::binary_freq_collection const collection(input.c_str()); auto meta = IndexMetadata::from_file(fmt::format("{}.ini", output)); - auto run = index_runner(meta, RawReader{}); + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + ProgressStatus status( + collection.size(), DefaultProgress("Verifying"), std::chrono::milliseconds(100)); run([&](auto &&index) { auto sequence_iter = collection.begin(); for (auto term = 0; term < index.num_terms(); term += 1, ++sequence_iter) { @@ -40,6 +48,7 @@ auto verify_compressed_index(std::string const &input, std::string_view output) ++fit; ++pos; } + status += 1; } }); return errors; diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 666a5e5ac..cdcd626c4 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -20,11 +20,18 @@ namespace pisa::v1 { { boost::property_tree::ptree pt; boost::property_tree::ini_parser::read_ini(file, pt); + std::vector scores; + if (pt.count("scores") > 0U) { + scores.push_back(PostingFilePaths{.postings = pt.get("scores.file"), + .offsets = pt.get("scores.offsets")}); + } return IndexMetadata{ .documents = PostingFilePaths{.postings = pt.get("documents.file"), .offsets = pt.get("documents.offsets")}, .frequencies = PostingFilePaths{.postings = pt.get("frequencies.file"), .offsets = pt.get("frequencies.offsets")}, + // TODO(michal): Once switched to YAML, parse an array. + .scores = std::move(scores), .document_lengths_path = pt.get("stats.document_lengths"), .avg_document_length = pt.get("stats.avg_document_length"), .term_lexicon = convert_optional(pt.get_optional("lexicon.terms")), @@ -32,4 +39,27 @@ namespace pisa::v1 { .stemmer = convert_optional(pt.get_optional("lexicon.stemmer"))}; } +void IndexMetadata::write(std::string const &file) +{ + boost::property_tree::ptree pt; + pt.put("documents.file", documents.postings); + pt.put("documents.offsets", documents.offsets); + pt.put("frequencies.file", frequencies.postings); + pt.put("frequencies.offsets", frequencies.offsets); + pt.put("stats.avg_document_length", avg_document_length); + pt.put("stats.document_lengths", document_lengths_path); + pt.put("lexicon.stemmer", "porter2"); // TODO(michal): Parametrize + if (not scores.empty()) { + pt.put("scores.file", scores.front().postings); + pt.put("scores.offsets", scores.front().offsets); + } + if (term_lexicon) { + pt.put("lexicon.terms", *term_lexicon); + } + if (document_lexicon) { + pt.put("lexicon.documents", *document_lexicon); + } + boost::property_tree::write_ini(file, pt); +} + } // namespace pisa::v1 diff --git a/src/v1/io.cpp b/src/v1/io.cpp new file mode 100644 index 000000000..aec0b7492 --- /dev/null +++ b/src/v1/io.cpp @@ -0,0 +1,19 @@ +#include "v1/io.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto load_bytes(std::string const &data_file) -> std::vector +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + data.resize(size); + if (not in.read(reinterpret_cast(data.data()), size)) { + throw std::runtime_error("Failed reading " + data_file); + } + return data; +} + +} // namespace pisa::v1 diff --git a/src/v1/progress_status.cpp b/src/v1/progress_status.cpp index bc60eb582..10ad0e0ab 100644 --- a/src/v1/progress_status.cpp +++ b/src/v1/progress_status.cpp @@ -38,9 +38,12 @@ void DefaultProgress::operator()(std::size_t count, m_previous = progress; std::chrono::seconds elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); - std::cerr << m_caption << progress << "% ["; + std::cerr << '\r' << m_caption << progress << "% ["; format_interval(std::cerr, elapsed); - std::cerr << "]\n"; + std::cerr << "]"; + if (progress == 100) { + std::cerr << '\n'; + } } ProgressStatus::~ProgressStatus() diff --git a/test/temporary_directory.hpp b/test/temporary_directory.hpp index d87e043cc..3a6b1f789 100644 --- a/test/temporary_directory.hpp +++ b/test/temporary_directory.hpp @@ -14,14 +14,15 @@ struct Temporary_Directory { boost::filesystem::create_directory(dir_); std::cerr << "Created a tmp dir " << dir_.c_str() << '\n'; } - ~Temporary_Directory() { + ~Temporary_Directory() + { if (boost::filesystem::exists(dir_)) { boost::filesystem::remove_all(dir_); } std::cerr << "Removed a tmp dir " << dir_.c_str() << '\n'; } - [[nodiscard]] auto path() -> boost::filesystem::path const & { return dir_; } + [[nodiscard]] auto path() const -> boost::filesystem::path const & { return dir_; } private: boost::filesystem::path dir_; diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 5fa19db8c..754b2e4d1 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -11,16 +11,19 @@ #include "pisa_config.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" +#include "v1/io.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/scorer/bm25.hpp" #include "v1/scorer/runner.hpp" #include "v1/types.hpp" +#include "v1/unaligned_span.hpp" using pisa::v1::Array; using pisa::v1::DocId; using pisa::v1::Frequency; using pisa::v1::IndexRunner; +using pisa::v1::load_bytes; using pisa::v1::next; using pisa::v1::parse_type; using pisa::v1::PostingBuilder; @@ -29,9 +32,9 @@ using pisa::v1::Primitive; using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::read_sizes; -using pisa::v1::ScorerRunner; using pisa::v1::TermId; using pisa::v1::Tuple; +using pisa::v1::UnalignedSpan; using pisa::v1::Writer; template @@ -58,7 +61,7 @@ TEST_CASE("RawReader", "[v1][unit]") REQUIRE(next(cursor) == tl::nullopt); } -TEST_CASE("Binary collection index", "[v1][unit]") +TEST_CASE("Binary collection index", "[.][v1][unit]") { pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); auto index = @@ -83,7 +86,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") } } -TEST_CASE("Bigram collection index", "[v1][unit]") +TEST_CASE("Bigram collection index", "[.][v1][unit]") { auto intersect = [](auto const &lhs, auto const &rhs) -> std::vector> { @@ -122,6 +125,7 @@ TEST_CASE("Bigram collection index", "[v1][unit]") ++pos; TermId term_id = 1; for (; pos != collection.end(); ++pos, ++term_id) { + CAPTURE(term_id); auto current = to_vec(*pos); auto intersection = intersect(prev, current); if (not intersection.empty()) { @@ -131,6 +135,12 @@ TEST_CASE("Bigram collection index", "[v1][unit]") auto freqs = cursor.payload(); return std::make_tuple(*cursor, freqs[0], freqs[1]); }); + for (auto idx = 0; idx < 10; idx++) { + std::cout << std::get<1>(postings[idx]) << " " << std::get<1>(intersection[idx]) + << '\n'; + std::cout << std::get<2>(postings[idx]) << " " << std::get<2>(intersection[idx]) + << "\n---\n"; + } REQUIRE(postings == intersection); } std::swap(prev, current); @@ -208,20 +218,6 @@ TEST_CASE("Value type", "[v1][unit]") REQUIRE(std::get(parse_type(std::byte{0b01000111})).size == 8U); } -[[nodiscard]] inline auto load_bytes(std::string const &data_file) -{ - std::vector data; - std::basic_ifstream in(data_file.c_str(), std::ios::binary); - in.seekg(0, std::ios::end); - std::streamsize size = in.tellg(); - in.seekg(0, std::ios::beg); - data.resize(size); - if (not in.read(reinterpret_cast(data.data()), size)) { - throw std::runtime_error("Failed reading " + data_file); - } - return data; -} - TEST_CASE("Build raw document-frequency index", "[v1][unit]") { using sink_type = boost::iostreams::back_insert_device>; @@ -333,3 +329,40 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") } } } + +TEST_CASE("UnalignedSpan", "[v1][unit]") +{ + std::vector bytes{std::byte{0b00000001}, + std::byte{0b00000010}, + std::byte{0b00000011}, + std::byte{0b00000100}, + std::byte{0b00000101}, + std::byte{0b00000110}, + std::byte{0b00000111}}; + SECTION("Bytes one-to-one") + { + auto span = UnalignedSpan(gsl::make_span(bytes)); + REQUIRE(std::vector(span.begin(), span.end()) == bytes); + } + SECTION("Bytes shifted by offset") + { + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(2)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector(bytes.begin() + 2, bytes.end())); + } + SECTION("u16") + { + REQUIRE_THROWS_AS(UnalignedSpan(gsl::make_span(bytes)), std::logic_error); + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(1)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector{ + 0b0000001100000010, 0b0000010100000100, 0b0000011100000110}); + } + SECTION("u32") + { + REQUIRE_THROWS_AS(UnalignedSpan(gsl::make_span(bytes)), std::logic_error); + auto span = UnalignedSpan(gsl::make_span(bytes).subspan(1, 4)); + REQUIRE(std::vector(span.begin(), span.end()) + == std::vector{0b00000101000001000000001100000010}); + } +} diff --git a/test/test_v1_blocked_cursor.cpp b/test/test_v1_blocked_cursor.cpp new file mode 100644 index 000000000..f4f8cd71a --- /dev/null +++ b/test/test_v1_blocked_cursor.cpp @@ -0,0 +1,195 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "codec/simdbp.hpp" +#include "pisa_config.hpp" +#include "temporary_directory.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/posting_builder.hpp" +#include "v1/types.hpp" + +using pisa::v1::BlockedReader; +using pisa::v1::BlockedWriter; +using pisa::v1::collect; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::IndexRunner; +using pisa::v1::PostingBuilder; +using pisa::v1::RawReader; +using pisa::v1::read_sizes; +using pisa::v1::TermId; + +TEST_CASE("Build single-block blocked document file", "[v1][unit]") +{ + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + + std::vector docids{3, 4, 5, 6, 7, 8, 9, 10, 51, 115}; + std::vector docbuf; + auto document_offsets = [&]() { + PostingBuilder document_builder(BlockedWriter{}); + vector_stream_type docstream{sink_type{docbuf}}; + document_builder.write_header(docstream); + document_builder.write_segment(docstream, docids.begin(), docids.end()); + return document_builder.offsets(); + }(); + + auto documents = gsl::span(docbuf).subspan(8); + CHECK(docbuf.size() == document_offsets.back() + 8); + BlockedReader document_reader; + auto term = 0; + auto actual = collect(document_reader.read(documents.subspan( + document_offsets[term], document_offsets[term + 1] - document_offsets[term]))); + CHECK(actual.size() == docids.size()); + REQUIRE(actual == docids); +} + +TEST_CASE("Build blocked document-frequency index", "[v1][unit]") +{ + // Temporary_Directory tmpdir; + + //{ + // std::vector document_data{1, 1, 4, 1, 3, 6, 11}; + // std::vector frequency_data{4, 5, 4, 3, 2}; + // std::ofstream dos((tmpdir.path() / "x.docs").string()); + // std::ofstream fos((tmpdir.path() / "x.freqs").string()); + // dos.write(reinterpret_cast(document_data.data()), document_data.size() * 4); + // fos.write(reinterpret_cast(frequency_data.data()), frequency_data.size() * + // 4); + //} + + using sink_type = boost::iostreams::back_insert_device>; + using vector_stream_type = boost::iostreams::stream; + GIVEN("A test binary collection") + { + pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + // pisa::binary_freq_collection collection((tmpdir.path() / "x").string().c_str()); + WHEN("Built posting files for documents and frequencies") + { + std::vector docbuf; + std::vector freqbuf; + + PostingBuilder document_builder(BlockedWriter{}); + PostingBuilder frequency_builder(BlockedWriter{}); + { + vector_stream_type docstream{sink_type{docbuf}}; + vector_stream_type freqstream{sink_type{freqbuf}}; + + document_builder.write_header(docstream); + frequency_builder.write_header(freqstream); + + for (auto sequence : collection) { + document_builder.write_segment( + docstream, sequence.docs.begin(), sequence.docs.end()); + frequency_builder.write_segment( + freqstream, sequence.freqs.begin(), sequence.freqs.end()); + } + } + + auto document_offsets = document_builder.offsets(); + auto frequency_offsets = frequency_builder.offsets(); + + auto document_sizes = read_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto documents = gsl::span(docbuf).subspan(8); + auto frequencies = gsl::span(freqbuf).subspan(8); + + //{ + // std::ofstream dos("/home/elshize/test.documents"); + // std::ofstream fos("/home/elshize/test.frequencies"); + // dos.write(reinterpret_cast(docbuf.data()), docbuf.size()); + // fos.write(reinterpret_cast(freqbuf.data()), freqbuf.size()); + //} + + THEN("The values read back are euqual to the binary collection's") + { + CHECK(docbuf.size() == document_offsets.back() + 8); + BlockedReader document_reader; + BlockedReader frequency_reader; + auto term = 0; + std::for_each(collection.begin(), collection.end(), [&](auto &&seq) { + std::vector expected_documents(seq.docs.begin(), seq.docs.end()); + auto actual_documents = collect(document_reader.read( + documents.subspan(document_offsets[term], + document_offsets[term + 1] - document_offsets[term]))); + CHECK(actual_documents.size() == expected_documents.size()); + REQUIRE(actual_documents == expected_documents); + + std::vector expected_frequencies(seq.freqs.begin(), seq.freqs.end()); + auto actual_frequencies = collect(frequency_reader.read(frequencies.subspan( + frequency_offsets[term], + frequency_offsets[term + 1] - frequency_offsets[term]))); + CHECK(actual_frequencies.size() == expected_frequencies.size()); + REQUIRE(actual_frequencies == expected_frequencies); + term += 1; + }); + } + + THEN("Index runner is correctly constructed") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + + IndexRunner runner(document_offsets, + frequency_offsets, + document_span, + payload_span, + document_sizes, + tl::nullopt, + std::move(source), + BlockedReader{}, + BlockedReader{}); + int counter = 0; + runner([&](auto index) { + counter += 1; + TermId term_id = 0; + for (auto sequence : collection) { + CAPTURE(term_id); + REQUIRE(sequence.docs.size() == index.cursor(term_id).size()); + REQUIRE( + std::vector(sequence.docs.begin(), sequence.docs.end()) + == collect(index.cursor(term_id))); + REQUIRE( + std::vector(sequence.freqs.begin(), sequence.freqs.end()) + //== collect(index.payloads(term_id))); + == collect(index.cursor(term_id), + [](auto &&cursor) { return cursor.payload(); })); + term_id += 1; + } + }); + REQUIRE(counter == 1); + } + + THEN("Index runner fails when wrong type") + { + auto source = std::array, 2>{docbuf, freqbuf}; + auto document_span = gsl::span( + reinterpret_cast(source[0].data()), source[0].size()); + auto payload_span = gsl::span( + reinterpret_cast(source[1].data()), source[1].size()); + IndexRunner runner(document_offsets, + frequency_offsets, + document_span, + payload_span, + document_sizes, + tl::nullopt, + std::move(source), + RawReader{}); // Correct encoding but not type! + REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); + } + } + } +} diff --git a/test/test_v1_document_payload_cursor.cpp b/test/test_v1_document_payload_cursor.cpp new file mode 100644 index 000000000..cc71baae3 --- /dev/null +++ b/test/test_v1_document_payload_cursor.cpp @@ -0,0 +1,98 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "codec/simdbp.hpp" +#include "pisa_config.hpp" +#include "temporary_directory.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/posting_builder.hpp" +#include "v1/types.hpp" + +using pisa::v1::BlockedReader; +using pisa::v1::BlockedWriter; +using pisa::v1::collect; +using pisa::v1::compress_binary_collection; +using pisa::v1::DocId; +using pisa::v1::DocumentPayloadCursor; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::load_bytes; +using pisa::v1::PostingBuilder; +using pisa::v1::RawCursor; +using pisa::v1::RawReader; +using pisa::v1::read_sizes; +using pisa::v1::TermId; + +TEST_CASE("Document-payload cursor", "[v1][unit]") +{ + GIVEN("Document-payload cursor") + { + std::vector documents{4, 0, 1, 5, 7}; + std::vector frequencies{4, 2, 2, 1, 6}; + auto cursor = DocumentPayloadCursor, RawCursor>( + RawCursor(gsl::as_bytes(gsl::make_span(documents))), + RawCursor(gsl::as_bytes(gsl::make_span(frequencies)))); + + WHEN("Collected to document and frequency vectors") + { + std::vector collected_documents; + std::vector collected_frequencies; + for_each(cursor, [&](auto &&cursor) { + collected_documents.push_back(cursor.value()); + collected_frequencies.push_back(cursor.payload()); + }); + THEN("Vector equals to expected") + { + REQUIRE(collected_documents == std::vector{0, 1, 5, 7}); + REQUIRE(collected_frequencies == std::vector{2, 2, 1, 6}); + } + } + + WHEN("Stepped with advance_to_pos") + { + cursor.advance_to_position(0); + REQUIRE(cursor.value() == 0); + REQUIRE(cursor.payload() == 2); + cursor.advance_to_position(1); + REQUIRE(cursor.value() == 1); + REQUIRE(cursor.payload() == 2); + cursor.advance_to_position(2); + REQUIRE(cursor.value() == 5); + REQUIRE(cursor.payload() == 1); + cursor.advance_to_position(3); + REQUIRE(cursor.value() == 7); + REQUIRE(cursor.payload() == 6); + } + + WHEN("Advanced to 1") + { + cursor.advance_to_position(1); + REQUIRE(cursor.value() == 1); + REQUIRE(cursor.payload() == 2); + } + WHEN("Advanced to 2") + { + cursor.advance_to_position(2); + REQUIRE(cursor.value() == 5); + REQUIRE(cursor.payload() == 1); + } + WHEN("Advanced to 3") + { + cursor.advance_to_position(3); + REQUIRE(cursor.value() == 7); + REQUIRE(cursor.payload() == 6); + } + } +} diff --git a/test/test_v1_index.cpp b/test/test_v1_index.cpp index cc28c1a0d..6903630d3 100644 --- a/test/test_v1_index.cpp +++ b/test/test_v1_index.cpp @@ -7,8 +7,10 @@ #include #include +#include "codec/simdbp.hpp" #include "pisa_config.hpp" #include "temporary_directory.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" @@ -16,6 +18,8 @@ #include "v1/types.hpp" using pisa::v1::binary_collection_index; +using pisa::v1::BlockedReader; +using pisa::v1::BlockedWriter; using pisa::v1::compress_binary_collection; using pisa::v1::DocId; using pisa::v1::Frequency; @@ -31,21 +35,63 @@ TEST_CASE("Binary collection index", "[v1][unit]") Temporary_Directory tmpdir; auto bci = binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + (tmpdir.path() / "fwd").string(), (tmpdir.path() / "index").string(), 8, - make_writer(RawWriter{}), - make_writer(RawWriter{})); + make_writer(RawWriter{}), + make_writer(RawWriter{})); auto meta = IndexMetadata::from_file((tmpdir.path() / "index.ini").string()); REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); - auto run = index_runner(meta, RawReader{}, RawReader{}); + auto run = index_runner(meta, + RawReader{}, + BlockedReader{}, + BlockedReader{}); run([&](auto index) { + REQUIRE(bci.num_documents() == index.num_documents()); + REQUIRE(bci.num_terms() == index.num_terms()); REQUIRE(bci.avg_document_length() == index.avg_document_length()); + for (auto doc = 0; doc < bci.num_documents(); doc += 1) { + REQUIRE(bci.document_length(doc) == index.document_length(doc)); + } + for (auto term = 0; term < bci.num_terms(); term += 1) { + REQUIRE(collect(bci.documents(term)) == collect(index.documents(term))); + REQUIRE(collect(bci.payloads(term)) == collect(index.payloads(term))); + } + }); +} + +TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") +{ + tbb::task_scheduler_init init(8); + Temporary_Directory tmpdir; + auto bci = binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + (tmpdir.path() / "fwd").string(), + (tmpdir.path() / "index").string(), + 8, + make_writer(BlockedWriter<::pisa::simdbp_block, true>{}), + make_writer(BlockedWriter<::pisa::simdbp_block, false>{})); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.ini").string()); + REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); + REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); + REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); + REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); + REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); + auto run = index_runner(meta, + RawReader{}, + BlockedReader{}, + BlockedReader{}); + run([&](auto index) { REQUIRE(bci.num_documents() == index.num_documents()); REQUIRE(bci.num_terms() == index.num_terms()); + REQUIRE(bci.avg_document_length() == index.avg_document_length()); + for (auto doc = 0; doc < bci.num_documents(); doc += 1) { + REQUIRE(bci.document_length(doc) == index.document_length(doc)); + } for (auto term = 0; term < bci.num_terms(); term += 1) { REQUIRE(collect(bci.documents(term)) == collect(index.documents(term))); REQUIRE(collect(bci.payloads(term)) == collect(index.payloads(term))); diff --git a/test/test_v1_queries.cpp b/test/test_v1_queries.cpp index 7bdfc5520..3f15034ef 100644 --- a/test/test_v1_queries.cpp +++ b/test/test_v1_queries.cpp @@ -15,10 +15,14 @@ #include "io.hpp" #include "pisa_config.hpp" #include "query/queries.hpp" +#include "temporary_directory.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/cursor_intersection.hpp" +#include "v1/cursor_traits.hpp" #include "v1/cursor_union.hpp" #include "v1/index.hpp" +#include "v1/index_builder.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/query.hpp" @@ -28,6 +32,67 @@ namespace v1 = pisa::v1; using namespace pisa; +static constexpr auto RELATIVE_ERROR = 0.1F; + +template +struct IndexFixture { + using DocumentWriter = typename v1::CursorTraits::Writer; + using FrequencyWriter = typename v1::CursorTraits::Writer; + using ScoreWriter = typename v1::CursorTraits::Writer; + + using DocumentReader = typename v1::CursorTraits::Reader; + using FrequencyReader = typename v1::CursorTraits::Reader; + using ScoreReader = typename v1::CursorTraits::Reader; + + IndexFixture() : m_tmpdir(std::make_unique()) + { + auto index_basename = (tmpdir().path() / "inv").string(); + v1::compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + PISA_SOURCE_DIR "/test/test_data/test_collection.fwd", + index_basename, + 2, + v1::make_writer(), + v1::make_writer()); + auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", + index_basename); + REQUIRE(errors.empty()); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.ini", index_basename)); + auto run = v1::index_runner(meta, document_reader(), frequency_reader()); + auto postings_path = fmt::format("{}.bm25", index_basename); + auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); + run([&](auto &&index) { + std::ofstream score_file_stream(postings_path); + auto offsets = score_index(index, score_file_stream, ScoreWriter{}, make_bm25(index)); + v1::write_span(gsl::span(offsets), offsets_path); + }); + meta.scores.push_back( + v1::PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); + meta.write(fmt::format("{}.ini", index_basename)); + } + + [[nodiscard]] auto const &tmpdir() const { return *m_tmpdir; } + [[nodiscard]] auto document_reader() const { return m_document_reader; } + [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } + [[nodiscard]] auto score_reader() const { return m_score_reader; } + + private: + std::unique_ptr m_tmpdir; + DocumentReader m_document_reader{}; + FrequencyReader m_frequency_reader{}; + ScoreReader m_score_reader{}; +}; + +[[nodiscard]] auto test_queries() -> std::vector +{ + std::vector queries; + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + auto push_query = [&](std::string const &query_line) { + queries.push_back(parse_query_ids(query_line)); + }; + io::for_each_line(qfile, push_query); + return queries; +} + template struct IndexData { @@ -46,7 +111,6 @@ struct IndexData { BlockSize(FixedBlock())) { - tbb::task_scheduler_init init; typename v0_Index::builder builder(collection.num_docs(), params); for (auto const &plist : collection) { uint64_t freqs_sum = @@ -93,59 +157,24 @@ template std::unique_ptr> IndexData::data = nullptr; -TEST_CASE("DAAT AND", "[v1][integration]") -{ - auto data = IndexData, v1::RawCursor>, - v1::Index, v1::RawCursor>>::get(); - ranked_and_query and_q(10); - int idx = 0; - for (auto const &q : data->queries) { - - CAPTURE(q.terms); - CAPTURE(idx++); - - and_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); - auto expected = and_q.topk(); - std::sort(expected.begin(), expected.end(), std::greater{}); - - auto on_the_fly = [&]() { - auto que = daat_and( - v1::Query{q.terms}, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); - que.finalize(); - auto results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); - return results; - }(); - - auto precomputed = [&]() { - auto que = - daat_and(v1::Query{q.terms}, data->scored_index, topk_queue(10), v1::VoidScorer{}); - que.finalize(); - auto results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); - return results; - }(); - - REQUIRE(expected.size() == on_the_fly.size()); - REQUIRE(expected.size() == precomputed.size()); - for (size_t i = 0; i < on_the_fly.size(); ++i) { - REQUIRE(on_the_fly[i].second == expected[i].second); - REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(0.1)); - REQUIRE(precomputed[i].second == expected[i].second); - REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(0.1)); - } - } -} - -TEST_CASE("DAAT OR", "[v1][integration]") +TEMPLATE_TEST_CASE( + "DAAT OR2", + "[v1][integration]", + (IndexFixture, v1::RawCursor, v1::RawCursor>), + (IndexFixture, + v1::BlockedCursor<::pisa::simdbp_block, false>, + v1::RawCursor>)) { + tbb::task_scheduler_init init(1); auto data = IndexData, v1::RawCursor>, v1::Index, v1::RawCursor>>::get(); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.ini", index_basename)); ranked_or_query or_q(10); int idx = 0; - for (auto const &q : data->queries) { + for (auto const &q : test_queries()) { CAPTURE(q.terms); CAPTURE(idx++); @@ -154,20 +183,28 @@ TEST_CASE("DAAT OR", "[v1][integration]") std::sort(expected.begin(), expected.end(), std::greater{}); auto on_the_fly = [&]() { - auto que = daat_or( - v1::Query{q.terms}, data->v1_index, topk_queue(10), make_bm25(data->v1_index)); - que.finalize(); - auto results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); + auto run = + v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + std::vector results; + run([&](auto &&index) { + auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), make_bm25(index)); + que.finalize(); + results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + }); return results; }(); auto precomputed = [&]() { - auto que = - daat_or(v1::Query{q.terms}, data->scored_index, topk_queue(10), v1::VoidScorer{}); - que.finalize(); - auto results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); + auto run = + v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); + std::vector results; + run([&](auto &&index) { + auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), v1::VoidScorer{}); + que.finalize(); + results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + }); return results; }(); @@ -175,9 +212,9 @@ TEST_CASE("DAAT OR", "[v1][integration]") REQUIRE(expected.size() == precomputed.size()); for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); - REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(0.1)); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); REQUIRE(precomputed[i].second == expected[i].second); - REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(0.1)); + REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); } } } diff --git a/v1/compress.cpp b/v1/compress.cpp index db0168af5..4f7a6531d 100644 --- a/v1/compress.cpp +++ b/v1/compress.cpp @@ -1,39 +1,79 @@ +#include + #include +#include #include "binary_freq_collection.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" #include "v1/types.hpp" +using std::literals::string_view_literals::operator""sv; + +using pisa::v1::BlockedWriter; using pisa::v1::compress_binary_collection; using pisa::v1::EncodingId; -using pisa::v1::IndexBuilder; +using pisa::v1::make_index_builder; using pisa::v1::RawWriter; using pisa::v1::verify_compressed_index; +auto document_encoding(std::string_view name) -> std::uint32_t +{ + if (name == "raw"sv) { + return EncodingId::Raw; + } + if (name == "simdbp"sv) { + return EncodingId::BlockDelta | EncodingId::SimdBP; + } + spdlog::error("Unknown encoding: {}", name); + std::exit(1); +} + +auto frequency_encoding(std::string_view name) -> std::uint32_t +{ + if (name == "raw"sv) { + return EncodingId::Raw; + } + if (name == "simdbp"sv) { + return EncodingId::Block | EncodingId::SimdBP; + } + spdlog::error("Unknown encoding: {}", name); + std::exit(1); +} + int main(int argc, char **argv) { std::string input; + std::string fwd; std::string output; + std::string encoding; std::size_t threads = std::thread::hardware_concurrency(); CLI::App app{"Compresses a given binary collection to a v1 index."}; - app.add_option("-c,--collection", input, "Input collection basename")->required(); + app.add_option("-i,--inv", input, "Input collection basename")->required(); + // TODO(michal): Potentially, this would be removed once inv contains necessary info. + app.add_option("-f,--fwd", fwd, "Input forward index")->required(); app.add_option("-o,--output", output, "Output basename")->required(); app.add_option("-j,--threads", threads, "Number of threads"); + app.add_option("-e,--encoding", encoding, "Number of threads")->required(); CLI11_PARSE(app, argc, argv); tbb::task_scheduler_init init(threads); - IndexBuilder> build(RawWriter{}); - build(EncodingId::Raw, [&](auto writer) { - auto frequency_writer = writer; - compress_binary_collection(input, - output, - threads, - make_writer(std::move(writer)), - make_writer(std::move(frequency_writer))); - }); + auto build = make_index_builder(RawWriter{}, + BlockedWriter<::pisa::simdbp_block, true>{}, + BlockedWriter<::pisa::simdbp_block, false>{}); + build(document_encoding(encoding), + frequency_encoding(encoding), + [&](auto document_writer, auto payload_writer) { + compress_binary_collection(input, + fwd, + output, + threads, + make_writer(std::move(document_writer)), + make_writer(std::move(payload_writer))); + }); auto errors = verify_compressed_index(input, output); if (not errors.empty()) { if (errors.size() > 10) { diff --git a/v1/query.cpp b/v1/query.cpp index 00779da1e..ed5bf7968 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -7,7 +7,9 @@ #include "io.hpp" #include "query/queries.hpp" +#include "timer.hpp" #include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -17,6 +19,7 @@ using pisa::Query; using pisa::resolve_query_parser; +using pisa::v1::BlockedReader; using pisa::v1::daat_or; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; @@ -24,6 +27,59 @@ using pisa::v1::RawReader; using pisa::v1::resolve_ini; using pisa::v1::taat_or; +template +void evaluate(std::vector const &queries, + Index &&index, + Scorer &&scorer, + int k, + pisa::Payload_Vector<> const &docmap) +{ + auto query_idx = 0; + for (auto const &query : queries) { + auto que = daat_or(pisa::v1::Query{query.terms}, index, pisa::topk_queue(k), scorer); + que.finalize(); + auto rank = 0; + for (auto result : que.topk()) { + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", + query.id.value_or(std::to_string(query_idx)), + "Q0", + docmap[result.second], + rank, + result.first, + "R0"); + rank += 1; + } + query_idx += 1; + } +} + +template +void benchmark(std::vector const &queries, Index &&index, Scorer &&scorer, int k) + +{ + std::vector times(queries.size(), std::numeric_limits::max()); + for (auto run = 0; run < 5; run += 1) { + for (auto query = 0; query < queries.size(); query += 1) { + auto usecs = ::pisa::run_with_timer([&]() { + auto que = daat_or( + pisa::v1::Query{queries[query].terms}, index, pisa::topk_queue(k), scorer); + que.finalize(); + do_not_optimize_away(que); + }); + times[query] = std::min(times[query], static_cast(usecs.count())); + } + } + std::sort(times.begin(), times.end()); + double avg = std::accumulate(times.begin(), times.end(), double()) / times.size(); + double q50 = times[times.size() / 2]; + double q90 = times[90 * times.size() / 100]; + double q95 = times[95 * times.size() / 100]; + spdlog::info("Mean: {}", avg); + spdlog::info("50% quantile: {}", q50); + spdlog::info("90% quantile: {}", q90); + spdlog::info("95% quantile: {}", q95); +} + int main(int argc, char **argv) { std::optional ini{}; @@ -31,6 +87,7 @@ int main(int argc, char **argv) std::optional terms_file{}; std::optional documents_file{}; int k = 1'000; + bool is_benchmark = false; CLI::App app{"Queries a v1 index."}; app.add_option("-i,--index", @@ -44,6 +101,7 @@ int main(int argc, char **argv) app.add_option("--documents", documents_file, "Overrides document lexicon from .ini (if defined). Required otherwise."); + app.add_flag("--benchmark", is_benchmark, "Run benchmark"); CLI11_PARSE(app, argc, argv); auto meta = IndexMetadata::from_file(resolve_ini(ini)); @@ -71,27 +129,17 @@ int main(int argc, char **argv) auto source = std::make_shared(documents_file.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); - auto run = index_runner(meta, RawReader{}); + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto &&index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { - auto query_idx = 0; - for (auto const &query : queries) { - auto que = - taat_or(pisa::v1::Query{query.terms}, index, pisa::topk_queue(k), scorer); - que.finalize(); - auto rank = 0; - for (auto result : que.topk()) { - std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", - query.id.value_or(std::to_string(query_idx)), - "Q0", - docmap[result.second], - rank, - result.first, - "R0"); - rank += 1; - } - query_idx += 1; + if (is_benchmark) { + benchmark(queries, index, scorer, k); + } else { + evaluate(queries, index, scorer, k, docmap); } }); }); From 242223b38ee975036a8d2693b4725e4ad173e670 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 11 Nov 2019 21:49:07 -0500 Subject: [PATCH 17/56] Precomputed scores --- include/pisa/v1/index.hpp | 20 +++++++++++++ v1/CMakeLists.txt | 3 ++ v1/query.cpp | 40 ++++++++++++++++++-------- v1/score.cpp | 59 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 12 deletions(-) create mode 100644 v1/score.cpp diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 562e5a3f9..0ee1010dc 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -194,6 +194,26 @@ auto make_index(DocumentReader document_reader, std::move(source)); } +template +auto score_index(Index const &index, + std::basic_ostream &os, + Writer writer, + Scorer scorer, + Callback &&callback) -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), + [&](auto &cursor) { score_builder.accumulate(cursor.payload()); }); + score_builder.flush_segment(os); + callback(); + }); + return std::move(score_builder.offsets()); +} + template auto score_index(Index const &index, std::basic_ostream &os, Writer writer, Scorer scorer) -> std::vector diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index d3bda2b3d..2812c9e4c 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -6,3 +6,6 @@ target_link_libraries(query pisa CLI11) add_executable(postings postings.cpp) target_link_libraries(postings pisa CLI11) + +add_executable(score score.cpp) +target_link_libraries(score pisa CLI11) diff --git a/v1/query.cpp b/v1/query.cpp index ed5bf7968..36fee7f80 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -25,7 +25,7 @@ using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::RawReader; using pisa::v1::resolve_ini; -using pisa::v1::taat_or; +using pisa::v1::VoidScorer; template void evaluate(std::vector const &queries, @@ -88,6 +88,7 @@ int main(int argc, char **argv) std::optional documents_file{}; int k = 1'000; bool is_benchmark = false; + bool precomputed = false; CLI::App app{"Queries a v1 index."}; app.add_option("-i,--index", @@ -102,6 +103,7 @@ int main(int argc, char **argv) documents_file, "Overrides document lexicon from .ini (if defined). Required otherwise."); app.add_flag("--benchmark", is_benchmark, "Run benchmark"); + app.add_flag("--precomputed", precomputed, "Use precomputed scores"); CLI11_PARSE(app, argc, argv); auto meta = IndexMetadata::from_file(resolve_ini(ini)); @@ -129,20 +131,34 @@ int main(int argc, char **argv) auto source = std::make_shared(documents_file.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto &&index) { - auto with_scorer = scorer_runner(index, make_bm25(index)); - with_scorer("bm25", [&](auto scorer) { + if (precomputed) { + auto run = scored_index_runner(meta, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto &&index) { if (is_benchmark) { - benchmark(queries, index, scorer, k); + benchmark(queries, index, VoidScorer{}, k); } else { - evaluate(queries, index, scorer, k, docmap); + evaluate(queries, index, VoidScorer{}, k, docmap); } }); - }); - + } else { + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto &&index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + if (is_benchmark) { + benchmark(queries, index, scorer, k); + } else { + evaluate(queries, index, scorer, k, docmap); + } + }); + }); + } return 0; } diff --git a/v1/score.cpp b/v1/score.cpp new file mode 100644 index 000000000..a1f9a8038 --- /dev/null +++ b/v1/score.cpp @@ -0,0 +1,59 @@ +#include + +#include +#include + +#include "binary_freq_collection.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/types.hpp" + +using pisa::v1::BlockedReader; +using pisa::v1::DefaultProgress; +using pisa::v1::IndexMetadata; +using pisa::v1::PostingFilePaths; +using pisa::v1::ProgressStatus; +using pisa::v1::RawReader; +using pisa::v1::RawWriter; +using pisa::v1::resolve_ini; +using pisa::v1::write_span; + +int main(int argc, char **argv) +{ + std::optional ini{}; + + CLI::App app{"Scores v1 index."}; + app.add_option("-i,--index", + ini, + "Path of .ini file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + CLI11_PARSE(app, argc, argv); + + auto resolved_ini = resolve_ini(ini); + auto meta = IndexMetadata::from_file(resolved_ini); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + auto index_basename = resolved_ini.substr(0, resolved_ini.size() - 4); + auto postings_path = fmt::format("{}.bm25", index_basename); + auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); + run([&](auto &&index) { + ProgressStatus status( + index.num_terms(), DefaultProgress("Scoring"), std::chrono::milliseconds(100)); + std::ofstream score_file_stream(postings_path); + auto offsets = score_index( + index, score_file_stream, RawWriter{}, make_bm25(index), [&]() { status += 1; }); + write_span(gsl::span(offsets), offsets_path); + }); + meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); + meta.write(resolved_ini); + + return 0; +} From 4a40e137a7b4719fc88ac32c8f9476bf564b1636 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 12 Nov 2019 10:24:05 -0500 Subject: [PATCH 18/56] Quantized scores --- include/pisa/v1/index.hpp | 17 +++++--- include/pisa/v1/io.hpp | 1 + include/pisa/v1/raw_cursor.hpp | 21 ++++++++-- v1/postings.cpp | 74 ++++++++++++++++++++++++---------- v1/query.cpp | 2 +- v1/score.cpp | 43 +++++++++++++++++++- 6 files changed, 125 insertions(+), 33 deletions(-) diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 0ee1010dc..7006d5666 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -194,20 +194,27 @@ auto make_index(DocumentReader document_reader, std::move(source)); } -template +template auto score_index(Index const &index, std::basic_ostream &os, Writer writer, Scorer scorer, + Quantizer &&quantizer, Callback &&callback) -> std::vector { - PostingBuilder score_builder(writer); + PostingBuilder score_builder(writer); score_builder.write_header(os); std::for_each(boost::counting_iterator(0), boost::counting_iterator(index.num_terms()), [&](auto term) { - for_each(index.scoring_cursor(term, scorer), - [&](auto &cursor) { score_builder.accumulate(cursor.payload()); }); + for_each(index.scoring_cursor(term, scorer), [&](auto &cursor) { + score_builder.accumulate(quantizer(cursor.payload())); + }); score_builder.flush_segment(os); callback(); }); @@ -218,7 +225,7 @@ template auto score_index(Index const &index, std::basic_ostream &os, Writer writer, Scorer scorer) -> std::vector { - PostingBuilder score_builder(writer); + PostingBuilder score_builder(writer); score_builder.write_header(os); std::for_each(boost::counting_iterator(0), boost::counting_iterator(index.num_terms()), diff --git a/include/pisa/v1/io.hpp b/include/pisa/v1/io.hpp index 3b8b6bcb5..1409714d7 100644 --- a/include/pisa/v1/io.hpp +++ b/include/pisa/v1/io.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 389a17f8d..53b1cdd02 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -26,6 +27,14 @@ template return tl::make_optional(cursor.value()); } +template +inline void contract(bool condition, std::string const &message, Args &&... args) +{ + if (not condition) { + throw std::logic_error(fmt::format(message, std::forward(args)...)); + } +} + /// Uncompressed example of implementation of a single value cursor. template struct RawCursor { @@ -35,8 +44,11 @@ struct RawCursor { /// Creates a cursor from the encoded bytes. explicit constexpr RawCursor(gsl::span bytes) : m_bytes(bytes.subspan(4)) { - Expects(m_bytes.size() % sizeof(T) == 0); - Expects(not m_bytes.empty()); + contract(m_bytes.size() % sizeof(T) == 0, + "Raw cursor memory size must be multiplier of element size ({}) but is {}", + sizeof(T), + m_bytes.size()); + contract(not m_bytes.empty(), "Raw cursor memory must not be empty"); } /// Dereferences the current value. @@ -98,7 +110,7 @@ struct RawReader { return RawCursor(bytes); } - constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw; } + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } }; template @@ -106,7 +118,7 @@ struct RawWriter { static_assert(std::is_trivially_copyable::value); using value_type = T; - constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw; } + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } void push(T const &posting) { m_postings.push_back(posting); } void push(T &&posting) { m_postings.push_back(posting); } @@ -132,6 +144,7 @@ template struct CursorTraits> { using Writer = RawWriter; using Reader = RawReader; + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } }; } // namespace pisa::v1 diff --git a/v1/postings.cpp b/v1/postings.cpp index ca3ec3cde..13fa4e3ff 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -9,6 +9,7 @@ #include "io.hpp" #include "query/queries.hpp" #include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -18,12 +19,20 @@ using pisa::Query; using pisa::resolve_query_parser; +using pisa::v1::BlockedReader; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::RawReader; using pisa::v1::resolve_ini; -auto default_readers() { return std::make_tuple(RawReader{}, RawReader{}); } +auto default_readers() +{ + return std::make_tuple(RawReader{}, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); +} [[nodiscard]] auto load_source(std::optional const &file) -> std::shared_ptr @@ -69,6 +78,7 @@ int main(int argc, char **argv) bool did = false; bool print_frequencies = false; bool print_scores = false; + bool precomputed = false; CLI::App app{"Queries a v1 index."}; app.add_option("-i,--index", @@ -83,7 +93,8 @@ int main(int argc, char **argv) app.add_flag("--tid", tid, "Use term IDs instead of terms"); app.add_flag("--did", did, "Print document IDs instead of titles"); app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); - app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); + auto *scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); + app.add_flag("--precomputed", precomputed, "Use BM25 precomputed scores")->needs(scores_option); app.add_option("query", query_input, "List of terms", false)->required(); CLI11_PARSE(app, argc, argv); @@ -117,26 +128,45 @@ int main(int argc, char **argv) }(); if (query.terms.size() == 1) { - auto run = index_runner(meta, default_readers()); - run([&](auto &&index) { - auto bm25 = make_bm25(index); - auto scorer = bm25.term_scorer(query.terms.front()); - auto print = [&](auto &&cursor) { - if (did) { - std::cout << *cursor; - } else { - std::cout << docmap.value()[*cursor]; - } - if (print_frequencies) { - std::cout << " " << cursor.payload(); - } - if (print_scores) { - std::cout << " " << scorer(cursor.value(), cursor.payload()); - } - std::cout << '\n'; - }; - for_each(index.cursor(query.terms.front()), print); - }); + if (precomputed) { + auto run = scored_index_runner(meta, default_readers()); + run([&](auto &&index) { + auto print = [&](auto &&cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if constexpr (std::is_same_v) { + std::cout << " " << static_cast(cursor.payload()) << '\n'; + } else { + std::cout << " " << cursor.payload() << '\n'; + } + }; + for_each(index.cursor(query.terms.front()), print); + }); + } else { + auto run = index_runner(meta, default_readers()); + run([&](auto &&index) { + auto bm25 = make_bm25(index); + auto scorer = bm25.term_scorer(query.terms.front()); + auto print = [&](auto &&cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if (print_frequencies) { + std::cout << " " << cursor.payload(); + } + if (print_scores) { + std::cout << " " << scorer(cursor.value(), cursor.payload()); + } + std::cout << '\n'; + }; + for_each(index.cursor(query.terms.front()), print); + }); + } } else { std::cerr << "Multiple terms unimplemented"; std::exit(1); diff --git a/v1/query.cpp b/v1/query.cpp index 36fee7f80..4e2d50855 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -134,7 +134,7 @@ int main(int argc, char **argv) if (precomputed) { auto run = scored_index_runner(meta, RawReader{}, - RawReader{}, + RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto &&index) { diff --git a/v1/score.cpp b/v1/score.cpp index a1f9a8038..02ab5c8d1 100644 --- a/v1/score.cpp +++ b/v1/score.cpp @@ -19,11 +19,14 @@ using pisa::v1::ProgressStatus; using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::resolve_ini; +using pisa::v1::TermId; using pisa::v1::write_span; int main(int argc, char **argv) { std::optional ini{}; + int bytes_per_score = 1; + std::size_t threads = std::thread::hardware_concurrency(); CLI::App app{"Scores v1 index."}; app.add_option("-i,--index", @@ -31,6 +34,10 @@ int main(int argc, char **argv) "Path of .ini file of an index " "(if not provided, it will be looked for in the current directory)", false); + app.add_option("-j,--threads", threads, "Number of threads"); + // TODO(michal): enable + // app.add_option( + // "-b,--bytes-per-score", ini, "Quantize computed scores to this many bytes", true); CLI11_PARSE(app, argc, argv); auto resolved_ini = resolve_ini(ini); @@ -45,11 +52,45 @@ int main(int argc, char **argv) auto postings_path = fmt::format("{}.bm25", index_basename); auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); run([&](auto &&index) { + ProgressStatus calc_max_status(index.num_terms(), + DefaultProgress("Calculating max partial score"), + std::chrono::milliseconds(100)); + std::vector max_scores(threads, 0.0F); + tbb::task_group group; + auto batch_size = index.num_terms() / threads; + for (auto thread_id = 0; thread_id < threads; thread_id += 1) { + auto first_term = thread_id * batch_size; + auto end_term = + thread_id < threads - 1 ? (thread_id + 1) * batch_size : index.num_terms(); + std::for_each( + boost::counting_iterator(first_term), + boost::counting_iterator(end_term), + [&](auto term) { + for_each(index.scoring_cursor(term, make_bm25(index)), [&](auto &cursor) { + max_scores[thread_id] = std::max(max_scores[thread_id], cursor.payload()); + }); + calc_max_status += 1; + }); + } + group.wait(); + auto max_score = *std::max_element(max_scores.begin(), max_scores.end()); + std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.", + max_score, + std::numeric_limits::max()); + ProgressStatus status( index.num_terms(), DefaultProgress("Scoring"), std::chrono::milliseconds(100)); std::ofstream score_file_stream(postings_path); auto offsets = score_index( - index, score_file_stream, RawWriter{}, make_bm25(index), [&]() { status += 1; }); + index, + score_file_stream, + RawWriter{}, + make_bm25(index), + [&](float score) { + return static_cast(score * std::numeric_limits::max() + / max_score); + }, + [&]() { status += 1; }); write_span(gsl::span(offsets), offsets_path); }); meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); From c4e3e35a444b3dff9a51dca797024c0dea366a8e Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 13 Nov 2019 09:51:19 -0500 Subject: [PATCH 19/56] Add yaml-cpp dependency --- .gitmodules | 3 +++ external/yaml-cpp | 1 + 2 files changed, 4 insertions(+) create mode 160000 external/yaml-cpp diff --git a/.gitmodules b/.gitmodules index 8e4d1212a..14cc7eac9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -73,3 +73,6 @@ [submodule "external/expected"] path = external/expected url = https://github.com/TartanLlama/expected.git +[submodule "external/yaml-cpp"] + path = external/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git diff --git a/external/yaml-cpp b/external/yaml-cpp new file mode 160000 index 000000000..72f699f5c --- /dev/null +++ b/external/yaml-cpp @@ -0,0 +1 @@ +Subproject commit 72f699f5ce2d22a8f70cd59d73ce2fbb42dd960e From 66b3dd447a978e52261ddb5be134a340d7b5a8ec Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 13 Nov 2019 14:04:50 -0500 Subject: [PATCH 20/56] Creating bigram index from query terms --- .clang-format | 2 +- CMakeLists.txt | 1 + external/CMakeLists.txt | 5 + include/pisa/v1/blocked_cursor.hpp | 6 +- include/pisa/v1/cursor_intersection.hpp | 10 ++ include/pisa/v1/index.hpp | 8 +- include/pisa/v1/index_builder.hpp | 47 ++++--- include/pisa/v1/index_metadata.hpp | 24 ++-- include/pisa/v1/zip_cursor.hpp | 43 ++++++ src/v1/index_builder.cpp | 6 +- src/v1/index_metadata.cpp | 118 +++++++++++----- src/v1/progress_status.cpp | 6 +- v1/CMakeLists.txt | 3 + v1/bigram_index.cpp | 171 ++++++++++++++++++++++++ v1/postings.cpp | 34 ++--- v1/query.cpp | 32 ++--- v1/score.cpp | 24 ++-- 17 files changed, 419 insertions(+), 121 deletions(-) create mode 100644 include/pisa/v1/zip_cursor.hpp create mode 100644 v1/bigram_index.cpp diff --git a/.clang-format b/.clang-format index 90a245f4c..21f3a956f 100644 --- a/.clang-format +++ b/.clang-format @@ -34,7 +34,7 @@ PenaltyBreakFirstLessLess: 20 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Right +PointerAlignment: Left SpaceAfterControlStatementKeyword: true SpaceBeforeAssignmentOperators: true SpaceInEmptyParentheses: false diff --git a/CMakeLists.txt b/CMakeLists.txt index 72807a1e1..9faabf7de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,7 @@ target_link_libraries(pisa PUBLIC fmt::fmt range-v3 optional + yaml-cpp ) target_include_directories(pisa PUBLIC external) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index f40eb3ee7..865ec1a94 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -121,6 +121,11 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/range-v3) set(OPTIONAL_ENABLE_TESTS OFF CACHE BOOL "skip tl::optional testing") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/optional EXCLUDE_FROM_ALL) +# Add yaml-cpp +set(YAML_CPP_BUILD_TOOLS OFF CACHE BOOL "skip building YAML tools") +set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "skip building YAML tests") +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/yaml-cpp EXCLUDE_FROM_ALL) + # Add tl::expected #set(EXPECTED_BUILD_TESTS OFF CACHE BOOL "skip tl::expected testing") #set(EXPECTED_BUILD_PACKAGE OFF CACHE BOOL "skip tl::expected package") diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index 165253f2b..f1dd96479 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -116,7 +116,11 @@ struct BlockedCursor { /// Moves the cursor to the next value equal or greater than `value`. constexpr void advance_to_geq(value_type value) { - static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); + //static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); + // TODO(michal): This should be `static_assert` like above. But currently, + // it would not compile. What needs to be done is separating document + // and payload readers for the index runner. + assert(DeltaEncoded); Expects(value >= m_current_value || position() == 0); if (PISA_UNLIKELY(value > m_current_block.last_value)) { if (value > m_block_last_values.back()) { diff --git a/include/pisa/v1/cursor_intersection.hpp b/include/pisa/v1/cursor_intersection.hpp index de10198c3..ccb8812c7 100644 --- a/include/pisa/v1/cursor_intersection.hpp +++ b/include/pisa/v1/cursor_intersection.hpp @@ -121,4 +121,14 @@ template std::move(cursors), std::move(init), std::move(accumulate)); } +template +[[nodiscard]] constexpr inline auto intersect(std::initializer_list cursors, + Payload init, + AccumulateFn accumulate) +{ + std::vector cursor_container(cursors); + return CursorIntersection, Payload, AccumulateFn>( + std::move(cursor_container), std::move(init), std::move(accumulate)); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 7006d5666..33cc4fccf 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -45,6 +45,10 @@ namespace pisa::v1 { /// `Score`, or `std::pair` for a bigram scored index. template struct Index { + + using document_cursor_type = DocumentCursor; + using payload_cursor_type = PayloadCursor; + /// Constructs the index. /// /// \param document_reader Reads document posting lists from bytes. @@ -122,9 +126,9 @@ struct Index { } /// Constructs a new payload cursor. - [[nodiscard]] auto num_terms() const -> std::uint32_t { return m_document_offsets.size() - 1; } + [[nodiscard]] auto num_terms() const -> std::size_t { return m_document_offsets.size() - 1; } - [[nodiscard]] auto num_documents() const -> std::uint32_t { return m_document_lengths.size(); } + [[nodiscard]] auto num_documents() const -> std::size_t { return m_document_lengths.size(); } [[nodiscard]] auto term_posting_count(TermId term) const -> std::uint32_t { diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 14f458d7b..4821d9e29 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -1,13 +1,13 @@ #pragma once #include +#include -#include -#include #include #include #include #include +#include #include "v1/index.hpp" #include "v1/index_metadata.hpp" @@ -22,7 +22,7 @@ struct IndexBuilder { template void operator()(Encoding document_encoding, Encoding payload_encoding, Fn fn) { - auto run = [&](auto &&dwriter, auto &&pwriter) -> bool { + auto run = [&](auto&& dwriter, auto&& pwriter) -> bool { if (std::decay_t::encoding() == document_encoding && std::decay_t::encoding() == payload_encoding) { fn(dwriter, pwriter); @@ -58,11 +58,11 @@ auto make_index_builder(Writers... writers) template auto compress_batch(CollectionIterator first, CollectionIterator last, - std::ofstream &dout, - std::ofstream &fout, + std::ofstream& dout, + std::ofstream& fout, Writer document_writer, Writer frequency_writer, - tl::optional bar) + tl::optional bar) -> std::tuple, std::vector> { PostingBuilder document_builder(std::move(document_writer)); @@ -85,14 +85,14 @@ auto compress_batch(CollectionIterator first, } template -void write_span(gsl::span offsets, std::string const &file) +void write_span(gsl::span offsets, std::string const& file) { std::ofstream os(file); auto bytes = gsl::as_bytes(offsets); - os.write(reinterpret_cast(bytes.data()), bytes.size()); + os.write(reinterpret_cast(bytes.data()), bytes.size()); } -inline void compress_binary_collection(std::string const &input, +inline void compress_binary_collection(std::string const& input, std::string_view fwd, std::string_view output, std::size_t const threads, @@ -144,8 +144,8 @@ inline void compress_binary_collection(std::string const &input, } return std::next(collection.begin(), (thread_idx + 1) * batch_size); }(); - auto &dout = document_streams[thread_idx]; - auto &fout = frequency_streams[thread_idx]; + auto& dout = document_streams[thread_idx]; + auto& fout = frequency_streams[thread_idx]; std::tie(document_offsets[thread_idx], frequency_offsets[thread_idx]) = compress_batch(first, last, @@ -153,7 +153,7 @@ inline void compress_binary_collection(std::string const &input, fout, document_writer, frequency_writer, - tl::make_optional(status)); + tl::make_optional(status)); }); }); group.wait(); @@ -210,20 +210,19 @@ inline void compress_binary_collection(std::string const &input, float avg_len = calc_avg_length(gsl::span(lengths)); std::cerr << " Done.\n"; - boost::property_tree::ptree pt; - pt.put("documents.file", documents_file); - pt.put("documents.offsets", doc_offset_file); - pt.put("frequencies.file", frequencies_file); - pt.put("frequencies.offsets", freq_offset_file); - pt.put("stats.avg_document_length", avg_len); - pt.put("stats.document_lengths", document_lengths_file); - pt.put("lexicon.stemmer", "porter2"); // TODO(michal): Parametrize - pt.put("lexicon.terms", fmt::format("{}.termlex", fwd)); - pt.put("lexicon.documents", fmt::format("{}.doclex", fwd)); - boost::property_tree::write_ini(fmt::format("{}.ini", output), pt); + IndexMetadata{ + .documents = PostingFilePaths{.postings = documents_file, .offsets = doc_offset_file}, + .frequencies = PostingFilePaths{.postings = frequencies_file, .offsets = freq_offset_file}, + .scores = {}, + .document_lengths_path = document_lengths_file, + .avg_document_length = avg_len, + .term_lexicon = tl::make_optional(fmt::format("{}.termlex", fwd)), + .document_lexicon = tl::make_optional(fmt::format("{}.doclex", fwd)), + .stemmer = tl::make_optional("porter2")} + .write(fmt::format("{}.yml", output)); } -auto verify_compressed_index(std::string const &input, std::string_view output) +auto verify_compressed_index(std::string const& input, std::string_view output) -> std::vector; } // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index e3c3e440d..2568008d4 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -13,9 +13,9 @@ namespace pisa::v1 { /// Return the passed file path if is not `nullopt`. -/// Otherwise, look for an `.ini` file in the current directory. -/// It will throw if no `.ini` file is found or there are multiple `.ini` files. -[[nodiscard]] auto resolve_ini(std::optional const &arg) -> std::string; +/// Otherwise, look for an `.yml` file in the current directory. +/// It will throw if no `.yml` file is found or there are multiple `.yml` files. +[[nodiscard]] auto resolve_yml(std::optional const& arg) -> std::string; template [[nodiscard]] auto convert_optional(Optional opt) @@ -40,6 +40,13 @@ struct PostingFilePaths { std::string offsets; }; +struct BigramMetadata { + PostingFilePaths documents; + std::pair frequencies; + std::string mapping; + std::size_t count; +}; + struct IndexMetadata { PostingFilePaths documents; PostingFilePaths frequencies; @@ -49,20 +56,21 @@ struct IndexMetadata { tl::optional term_lexicon{}; tl::optional document_lexicon{}; tl::optional stemmer{}; + tl::optional bigrams{}; - void write(std::string const &file); - [[nodiscard]] static auto from_file(std::string const &file) -> IndexMetadata; + void write(std::string const& file); + [[nodiscard]] static auto from_file(std::string const& file) -> IndexMetadata; }; template -[[nodiscard]] auto to_span(mio::mmap_source const *mmap) +[[nodiscard]] auto to_span(mio::mmap_source const* mmap) { static_assert(std::is_trivially_constructible_v); - return gsl::span(reinterpret_cast(mmap->data()), mmap->size() / sizeof(T)); + return gsl::span(reinterpret_cast(mmap->data()), mmap->size() / sizeof(T)); }; template -[[nodiscard]] auto source_span(MMapSource &source, std::string const &file) +[[nodiscard]] auto source_span(MMapSource& source, std::string const& file) { return to_span( source.file_sources.emplace_back(std::make_shared(file)).get()); diff --git a/include/pisa/v1/zip_cursor.hpp b/include/pisa/v1/zip_cursor.hpp new file mode 100644 index 000000000..e9e04b511 --- /dev/null +++ b/include/pisa/v1/zip_cursor.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +struct ZipCursor { + using Value = std::tuple())>; + + explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors...)) {} + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value + { + auto deref = [](auto... cursors) { return std::make_tuple(cursors.value()...); }; + return std::apply(deref, m_cursors); + } + constexpr void advance() + { + auto advance_all = [](auto... cursors) { (cursors.advance(), ...); }; + std::apply(advance_all, m_cursors); + } + constexpr void advance_to_position(std::size_t pos) + { + auto advance_all = [pos](auto... cursors) { (cursors.advance_to_position(pos), ...); }; + std::apply(advance_all, m_cursors); + } + //[[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_key_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return std::get<0>(m_cursors).position(); + } + //[[nodiscard]] constexpr auto size() const -> std::size_t { return m_key_cursor.size(); } + //[[nodiscard]] constexpr auto sentinel() const -> Document { return m_key_cursor.sentinel(); } + + private: + std::tuple m_cursors; +}; + +} // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index f3b804491..3bb1d051c 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -5,19 +5,19 @@ namespace pisa::v1 { -auto verify_compressed_index(std::string const &input, std::string_view output) +auto verify_compressed_index(std::string const& input, std::string_view output) -> std::vector { std::vector errors; pisa::binary_freq_collection const collection(input.c_str()); - auto meta = IndexMetadata::from_file(fmt::format("{}.ini", output)); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", output)); auto run = index_runner(meta, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); ProgressStatus status( collection.size(), DefaultProgress("Verifying"), std::chrono::milliseconds(100)); - run([&](auto &&index) { + run([&](auto&& index) { auto sequence_iter = collection.begin(); for (auto term = 0; term < index.num_terms(); term += 1, ++sequence_iter) { auto document_sequence = sequence_iter->docs; diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index cdcd626c4..754bc9b04 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -1,65 +1,115 @@ +#include #include #include -#include -#include +#include #include "v1/index_metadata.hpp" namespace pisa::v1 { -[[nodiscard]] auto resolve_ini(std::optional const &arg) -> std::string +constexpr char const* DOCUMENTS = "documents"; +constexpr char const* FREQUENCIES = "frequencies"; +constexpr char const* SCORES = "scores"; +constexpr char const* POSTINGS = "file"; +constexpr char const* OFFSETS = "offsets"; +constexpr char const* STATS = "stats"; +constexpr char const* LEXICON = "lexicon"; +constexpr char const* TERMS = "terms"; +constexpr char const* BIGRAM = "bigram"; + +[[nodiscard]] auto resolve_yml(std::optional const& arg) -> std::string { if (arg) { return *arg; } - throw std::runtime_error("Resolving .ini from the current folder not supported yet!"); + throw std::runtime_error("Resolving .yml from the current folder not supported yet!"); } -[[nodiscard]] auto IndexMetadata::from_file(std::string const &file) -> IndexMetadata +[[nodiscard]] auto IndexMetadata::from_file(std::string const& file) -> IndexMetadata { - boost::property_tree::ptree pt; - boost::property_tree::ini_parser::read_ini(file, pt); + YAML::Node config = YAML::LoadFile(file); std::vector scores; - if (pt.count("scores") > 0U) { - scores.push_back(PostingFilePaths{.postings = pt.get("scores.file"), - .offsets = pt.get("scores.offsets")}); + if (config[SCORES]) { + // TODO(michal): Once switched to YAML, parse an array. + scores.push_back(PostingFilePaths{.postings = config[SCORES][POSTINGS].as(), + .offsets = config[SCORES][OFFSETS].as()}); } return IndexMetadata{ - .documents = PostingFilePaths{.postings = pt.get("documents.file"), - .offsets = pt.get("documents.offsets")}, - .frequencies = PostingFilePaths{.postings = pt.get("frequencies.file"), - .offsets = pt.get("frequencies.offsets")}, - // TODO(michal): Once switched to YAML, parse an array. + .documents = PostingFilePaths{.postings = config[DOCUMENTS][POSTINGS].as(), + .offsets = config[DOCUMENTS][OFFSETS].as()}, + .frequencies = PostingFilePaths{.postings = config[FREQUENCIES][POSTINGS].as(), + .offsets = config[FREQUENCIES][OFFSETS].as()}, .scores = std::move(scores), - .document_lengths_path = pt.get("stats.document_lengths"), - .avg_document_length = pt.get("stats.avg_document_length"), - .term_lexicon = convert_optional(pt.get_optional("lexicon.terms")), - .document_lexicon = convert_optional(pt.get_optional("lexicon.documents")), - .stemmer = convert_optional(pt.get_optional("lexicon.stemmer"))}; + .document_lengths_path = config[STATS]["document_lengths"].as(), + .avg_document_length = config[STATS]["avg_document_length"].as(), + .term_lexicon = [&]() -> tl::optional { + if (config[LEXICON][TERMS]) { + return config[LEXICON][TERMS].as(); + } + return tl::nullopt; + }(), + .document_lexicon = [&]() -> tl::optional { + if (config[LEXICON][DOCUMENTS]) { + return config[LEXICON][DOCUMENTS].as(); + } + return tl::nullopt; + }(), + .stemmer = [&]() -> tl::optional { + if (config[LEXICON]["stemmer"]) { + return config[LEXICON]["stemmer"].as(); + } + return tl::nullopt; + }(), + .bigrams = [&]() -> tl::optional { + if (config[BIGRAM]) { + return BigramMetadata{ + .documents = {.postings = config[DOCUMENTS][POSTINGS].as(), + .offsets = config[DOCUMENTS][OFFSETS].as()}, + .frequencies = + {{.postings = config["frequencies_0"][POSTINGS].as(), + .offsets = config["frequencies_0"][OFFSETS].as()}, + {.postings = config["frequencies_1"][POSTINGS].as(), + .offsets = config["frequencies_1"][OFFSETS].as()}}, + .mapping = config[BIGRAM]["mapping"].as(), + .count = config[BIGRAM]["count"].as()}; + } + return tl::nullopt; + }()}; } -void IndexMetadata::write(std::string const &file) +void IndexMetadata::write(std::string const& file) { - boost::property_tree::ptree pt; - pt.put("documents.file", documents.postings); - pt.put("documents.offsets", documents.offsets); - pt.put("frequencies.file", frequencies.postings); - pt.put("frequencies.offsets", frequencies.offsets); - pt.put("stats.avg_document_length", avg_document_length); - pt.put("stats.document_lengths", document_lengths_path); - pt.put("lexicon.stemmer", "porter2"); // TODO(michal): Parametrize + YAML::Node root; + root[DOCUMENTS][POSTINGS] = documents.postings; + root[DOCUMENTS][OFFSETS] = documents.offsets; + root[FREQUENCIES][POSTINGS] = frequencies.postings; + root[FREQUENCIES][OFFSETS] = frequencies.offsets; + root[STATS]["avg_document_length"] = avg_document_length; + root[STATS]["document_lengths"] = document_lengths_path; + root[LEXICON]["stemmer"] = "porter2"; if (not scores.empty()) { - pt.put("scores.file", scores.front().postings); - pt.put("scores.offsets", scores.front().offsets); + root[SCORES][POSTINGS] = scores.front().postings; + root[SCORES][OFFSETS] = scores.front().offsets; } if (term_lexicon) { - pt.put("lexicon.terms", *term_lexicon); + root[LEXICON][TERMS] = *term_lexicon; } if (document_lexicon) { - pt.put("lexicon.documents", *document_lexicon); + root[LEXICON][DOCUMENTS] = *document_lexicon; + } + if (bigrams) { + root[BIGRAM][DOCUMENTS][POSTINGS] = bigrams->documents.postings; + root[BIGRAM][DOCUMENTS][OFFSETS] = bigrams->documents.offsets; + root[BIGRAM]["frequencies_0"][POSTINGS] = bigrams->frequencies.first.postings; + root[BIGRAM]["frequencies_0"][OFFSETS] = bigrams->frequencies.first.offsets; + root[BIGRAM]["frequencies_1"][POSTINGS] = bigrams->frequencies.second.postings; + root[BIGRAM]["frequencies_1"][OFFSETS] = bigrams->frequencies.second.offsets; + root[BIGRAM]["mapping"] = bigrams->mapping; + root[BIGRAM]["count"] = bigrams->count; } - boost::property_tree::write_ini(file, pt); + std::ofstream fout(file); + fout << root; } } // namespace pisa::v1 diff --git a/src/v1/progress_status.cpp b/src/v1/progress_status.cpp index 10ad0e0ab..5bef98191 100644 --- a/src/v1/progress_status.cpp +++ b/src/v1/progress_status.cpp @@ -32,9 +32,9 @@ void DefaultProgress::operator()(std::size_t count, std::chrono::time_point start) { size_t progress = (100 * count) / goal; - if (progress == m_previous) { - return; - } + // if (progress == m_previous) { + // return; + //} m_previous = progress; std::chrono::seconds elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index 2812c9e4c..4723e9cf3 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -9,3 +9,6 @@ target_link_libraries(postings pisa CLI11) add_executable(score score.cpp) target_link_libraries(score pisa CLI11) + +add_executable(bigram_index bigram_index.cpp) +target_link_libraries(bigram_index pisa CLI11) diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp new file mode 100644 index 000000000..beffda3c7 --- /dev/null +++ b/v1/bigram_index.cpp @@ -0,0 +1,171 @@ +#include +#include +#include + +#include +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query/queries.hpp" +#include "timer.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::Payload_Vector_Buffer; +using pisa::v1::BigramMetadata; +using pisa::v1::BlockedReader; +using pisa::v1::CursorTraits; +using pisa::v1::DefaultProgress; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::intersect; +using pisa::v1::PostingBuilder; +using pisa::v1::ProgressStatus; +using pisa::v1::RawReader; +using pisa::v1::resolve_yml; +using pisa::v1::TermId; +using pisa::v1::write_span; + +auto collect_unique_bigrams(std::vector const& queries) + -> std::vector> +{ + std::vector> bigrams; + for (auto&& query : queries) { + for (auto left = 0; left < query.terms.size(); left += 1) { + auto right = left + 1; + bigrams.emplace_back(query.terms[left], query.terms[right]); + } + } + std::sort(bigrams.begin(), bigrams.end()); + bigrams.erase(std::unique(bigrams.begin(), bigrams.end()), bigrams.end()); + return bigrams; +} + +int main(int argc, char** argv) +{ + std::optional yml{}; + std::optional query_file{}; + std::optional terms_file{}; + + CLI::App app{"Creates a v1 bigram index."}; + app.add_option("-i,--index", + yml, + "Path of .yml file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + app.add_option("-q,--query", query_file, "Path to file with queries", false); + app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); + CLI11_PARSE(app, argc, argv); + + auto resolved_yml = resolve_yml(yml); + auto meta = IndexMetadata::from_file(resolved_yml); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + terms_file = meta.term_lexicon.value(); + } + + std::vector queries; + auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); + if (query_file) { + std::ifstream is(*query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + + auto bigrams = collect_unique_bigrams(queries); + + auto index_basename = resolved_yml.substr(0, resolved_yml.size() - 4); + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + + std::vector> pair_mapping; + auto documents_file = fmt::format("{}.bigram_documents", index_basename); + auto frequencies_file_0 = fmt::format("{}.bigram_frequencies_0", index_basename); + auto frequencies_file_1 = fmt::format("{}.bigram_frequencies_1", index_basename); + auto document_offsets_file = fmt::format("{}.bigram_document_offsets", index_basename); + auto frequency_offsets_file_0 = fmt::format("{}.bigram_frequency_offsets_0", index_basename); + auto frequency_offsets_file_1 = fmt::format("{}.bigram_frequency_offsets_1", index_basename); + std::ofstream document_out(documents_file); + std::ofstream frequency_out_0(frequencies_file_0); + std::ofstream frequency_out_1(frequencies_file_1); + + run([&](auto&& index) { + ProgressStatus status(bigrams.size(), + DefaultProgress("Building bigram index"), + std::chrono::milliseconds(100)); + using index_type = std::decay_t; + using document_writer_type = + typename CursorTraits::Writer; + using frequency_writer_type = + typename CursorTraits::Writer; + + PostingBuilder document_builder(document_writer_type{}); + PostingBuilder frequency_builder_0(frequency_writer_type{}); + PostingBuilder frequency_builder_1(frequency_writer_type{}); + + document_builder.write_header(document_out); + frequency_builder_0.write_header(frequency_out_0); + frequency_builder_1.write_header(frequency_out_1); + + for (auto [left_term, right_term] : bigrams) { + auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + payload[list_idx] = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + pair_mapping.push_back({left_term, right_term}); + for_each(intersection, [&](auto& cursor) { + document_builder.accumulate(*cursor); + auto payload = cursor.payload(); + frequency_builder_0.accumulate(std::get<0>(payload)); + frequency_builder_1.accumulate(std::get<1>(payload)); + }); + document_builder.flush_segment(document_out); + frequency_builder_0.flush_segment(frequency_out_0); + frequency_builder_1.flush_segment(frequency_out_1); + status += 1; + } + std::cerr << "Writing offsets..."; + write_span(gsl::make_span(document_builder.offsets()), document_offsets_file); + write_span(gsl::make_span(frequency_builder_0.offsets()), frequency_offsets_file_0); + write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); + std::cerr << " Done.\n"; + }); + std::cerr << "Writing metadata..."; + meta.bigrams = BigramMetadata{ + .documents = {.postings = documents_file, .offsets = document_offsets_file}, + .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, + {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, + .mapping = fmt::format("{}.bigram_mapping", index_basename), + .count = pair_mapping.size()}; + meta.write(resolved_yml); + std::cerr << " Done.\nWriting bigram mapping..."; + Payload_Vector_Buffer::make(pair_mapping.begin(), + pair_mapping.end(), + [](auto&& terms, auto out) { + auto bytes = gsl::as_bytes(gsl::make_span(terms)); + std::copy(bytes.begin(), bytes.end(), out); + }) + .to_file(meta.bigrams->mapping); + std::cerr << " Done.\n"; + return 0; +} diff --git a/v1/postings.cpp b/v1/postings.cpp index 13fa4e3ff..647bd1471 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -23,7 +23,7 @@ using pisa::v1::BlockedReader; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::RawReader; -using pisa::v1::resolve_ini; +using pisa::v1::resolve_yml; auto default_readers() { @@ -34,7 +34,7 @@ auto default_readers() BlockedReader<::pisa::simdbp_block, false>{}); } -[[nodiscard]] auto load_source(std::optional const &file) +[[nodiscard]] auto load_source(std::optional const& file) -> std::shared_ptr { if (file) { @@ -43,7 +43,7 @@ auto default_readers() return nullptr; } -[[nodiscard]] auto load_payload_vector(std::shared_ptr const &source) +[[nodiscard]] auto load_payload_vector(std::shared_ptr const& source) -> std::optional> { if (source) { @@ -54,10 +54,10 @@ auto default_readers() /// Returns the first value (not nullopt), or nullopt if no optional contains a value. template -[[nodiscard]] auto value(First &&first, Optional &&... candidtes) +[[nodiscard]] auto value(First&& first, Optional&&... candidtes) { std::optional> val = std::nullopt; - auto has_value = [&](auto &&opt) -> bool { + auto has_value = [&](auto&& opt) -> bool { if (not val.has_value() && opt) { val = *opt; return true; @@ -68,9 +68,9 @@ template return val; } -int main(int argc, char **argv) +int main(int argc, char** argv) { - std::optional ini{}; + std::optional yml{}; std::optional terms_file{}; std::optional documents_file{}; std::string query_input{}; @@ -82,23 +82,23 @@ int main(int argc, char **argv) CLI::App app{"Queries a v1 index."}; app.add_option("-i,--index", - ini, - "Path of .ini file of an index " + yml, + "Path of .yml file of an index " "(if not provided, it will be looked for in the current directory)", false); - app.add_option("--terms", terms_file, "Overrides document lexicon from .ini (if defined)."); + app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); app.add_option("--documents", documents_file, - "Overrides document lexicon from .ini (if defined). Required otherwise."); + "Overrides document lexicon from .yml (if defined). Required otherwise."); app.add_flag("--tid", tid, "Use term IDs instead of terms"); app.add_flag("--did", did, "Print document IDs instead of titles"); app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); - auto *scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); + auto* scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); app.add_flag("--precomputed", precomputed, "Use BM25 precomputed scores")->needs(scores_option); app.add_option("query", query_input, "List of terms", false)->required(); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_ini(ini)); + auto meta = IndexMetadata::from_file(resolve_yml(yml)); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; if (tid) { terms_file = std::nullopt; @@ -130,8 +130,8 @@ int main(int argc, char **argv) if (query.terms.size() == 1) { if (precomputed) { auto run = scored_index_runner(meta, default_readers()); - run([&](auto &&index) { - auto print = [&](auto &&cursor) { + run([&](auto&& index) { + auto print = [&](auto&& cursor) { if (did) { std::cout << *cursor; } else { @@ -147,10 +147,10 @@ int main(int argc, char **argv) }); } else { auto run = index_runner(meta, default_readers()); - run([&](auto &&index) { + run([&](auto&& index) { auto bm25 = make_bm25(index); auto scorer = bm25.term_scorer(query.terms.front()); - auto print = [&](auto &&cursor) { + auto print = [&](auto&& cursor) { if (did) { std::cout << *cursor; } else { diff --git a/v1/query.cpp b/v1/query.cpp index 4e2d50855..03dd9f87c 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -24,18 +24,18 @@ using pisa::v1::daat_or; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::RawReader; -using pisa::v1::resolve_ini; +using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; template -void evaluate(std::vector const &queries, - Index &&index, - Scorer &&scorer, +void evaluate(std::vector const& queries, + Index&& index, + Scorer&& scorer, int k, - pisa::Payload_Vector<> const &docmap) + pisa::Payload_Vector<> const& docmap) { auto query_idx = 0; - for (auto const &query : queries) { + for (auto const& query : queries) { auto que = daat_or(pisa::v1::Query{query.terms}, index, pisa::topk_queue(k), scorer); que.finalize(); auto rank = 0; @@ -54,7 +54,7 @@ void evaluate(std::vector const &queries, } template -void benchmark(std::vector const &queries, Index &&index, Scorer &&scorer, int k) +void benchmark(std::vector const& queries, Index&& index, Scorer&& scorer, int k) { std::vector times(queries.size(), std::numeric_limits::max()); @@ -80,9 +80,9 @@ void benchmark(std::vector const &queries, Index &&index, Scorer && spdlog::info("95% quantile: {}", q95); } -int main(int argc, char **argv) +int main(int argc, char** argv) { - std::optional ini{}; + std::optional yml{}; std::optional query_file{}; std::optional terms_file{}; std::optional documents_file{}; @@ -92,21 +92,21 @@ int main(int argc, char **argv) CLI::App app{"Queries a v1 index."}; app.add_option("-i,--index", - ini, - "Path of .ini file of an index " + yml, + "Path of .yml file of an index " "(if not provided, it will be looked for in the current directory)", false); app.add_option("-q,--query", query_file, "Path to file with queries", false); app.add_option("-k", k, "The number of top results to return", true); - app.add_option("--terms", terms_file, "Overrides document lexicon from .ini (if defined)."); + app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); app.add_option("--documents", documents_file, - "Overrides document lexicon from .ini (if defined). Required otherwise."); + "Overrides document lexicon from .yml (if defined). Required otherwise."); app.add_flag("--benchmark", is_benchmark, "Run benchmark"); app.add_flag("--precomputed", precomputed, "Use precomputed scores"); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_ini(ini)); + auto meta = IndexMetadata::from_file(resolve_yml(yml)); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; if (meta.term_lexicon) { terms_file = meta.term_lexicon.value(); @@ -137,7 +137,7 @@ int main(int argc, char **argv) RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto &&index) { + run([&](auto&& index) { if (is_benchmark) { benchmark(queries, index, VoidScorer{}, k); } else { @@ -149,7 +149,7 @@ int main(int argc, char **argv) RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto &&index) { + run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { if (is_benchmark) { diff --git a/v1/score.cpp b/v1/score.cpp index 02ab5c8d1..310d00ad6 100644 --- a/v1/score.cpp +++ b/v1/score.cpp @@ -18,40 +18,40 @@ using pisa::v1::PostingFilePaths; using pisa::v1::ProgressStatus; using pisa::v1::RawReader; using pisa::v1::RawWriter; -using pisa::v1::resolve_ini; +using pisa::v1::resolve_yml; using pisa::v1::TermId; using pisa::v1::write_span; -int main(int argc, char **argv) +int main(int argc, char** argv) { - std::optional ini{}; + std::optional yml{}; int bytes_per_score = 1; std::size_t threads = std::thread::hardware_concurrency(); CLI::App app{"Scores v1 index."}; app.add_option("-i,--index", - ini, - "Path of .ini file of an index " + yml, + "Path of .yml file of an index " "(if not provided, it will be looked for in the current directory)", false); app.add_option("-j,--threads", threads, "Number of threads"); // TODO(michal): enable // app.add_option( - // "-b,--bytes-per-score", ini, "Quantize computed scores to this many bytes", true); + // "-b,--bytes-per-score", yml, "Quantize computed scores to this many bytes", true); CLI11_PARSE(app, argc, argv); - auto resolved_ini = resolve_ini(ini); - auto meta = IndexMetadata::from_file(resolved_ini); + auto resolved_yml = resolve_yml(yml); + auto meta = IndexMetadata::from_file(resolved_yml); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; auto run = index_runner(meta, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - auto index_basename = resolved_ini.substr(0, resolved_ini.size() - 4); + auto index_basename = resolved_yml.substr(0, resolved_yml.size() - 4); auto postings_path = fmt::format("{}.bm25", index_basename); auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); - run([&](auto &&index) { + run([&](auto&& index) { ProgressStatus calc_max_status(index.num_terms(), DefaultProgress("Calculating max partial score"), std::chrono::milliseconds(100)); @@ -66,7 +66,7 @@ int main(int argc, char **argv) boost::counting_iterator(first_term), boost::counting_iterator(end_term), [&](auto term) { - for_each(index.scoring_cursor(term, make_bm25(index)), [&](auto &cursor) { + for_each(index.scoring_cursor(term, make_bm25(index)), [&](auto& cursor) { max_scores[thread_id] = std::max(max_scores[thread_id], cursor.payload()); }); calc_max_status += 1; @@ -94,7 +94,7 @@ int main(int argc, char **argv) write_span(gsl::span(offsets), offsets_path); }); meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); - meta.write(resolved_ini); + meta.write(resolved_yml); return 0; } From 9c2442f02f19714d8a50f0a3b7e2e158ad29fe9a Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 15 Nov 2019 18:00:50 -0500 Subject: [PATCH 21/56] Union-lookup query (without precomptued scores and tool) --- include/pisa/payload_vector.hpp | 130 ++++++------ include/pisa/v1/blocked_cursor.hpp | 23 +-- include/pisa/v1/cursor/scoring_cursor.hpp | 6 +- include/pisa/v1/cursor_union.hpp | 131 ++++++++++-- include/pisa/v1/index.hpp | 234 ++++++++++++++++++---- include/pisa/v1/index_metadata.hpp | 34 ++++ include/pisa/v1/query.hpp | 129 +++++++++++- include/pisa/v1/zip_cursor.hpp | 10 +- src/v1/query.cpp | 13 ++ test/test_v1.cpp | 32 ++- test/test_v1_blocked_cursor.cpp | 57 +++--- test/test_v1_index.cpp | 4 +- test/test_v1_queries.cpp | 93 +++++++-- 13 files changed, 692 insertions(+), 204 deletions(-) create mode 100644 src/v1/query.cpp diff --git a/include/pisa/payload_vector.hpp b/include/pisa/payload_vector.hpp index 582f5ec06..58c3cd88b 100644 --- a/include/pisa/payload_vector.hpp +++ b/include/pisa/payload_vector.hpp @@ -11,8 +11,6 @@ #include #include -#include "payload_vector.hpp" - namespace pisa { namespace detail { @@ -28,78 +26,77 @@ namespace detail { typename gsl::span::iterator offset_iter; typename gsl::span::iterator payload_iter; - constexpr auto operator++() -> Payload_Vector_Iterator & + constexpr auto operator++() -> Payload_Vector_Iterator& { ++offset_iter; std::advance(payload_iter, *offset_iter - *std::prev(offset_iter)); return *this; } - [[nodiscard]] - constexpr auto operator++(int) -> Payload_Vector_Iterator + [[nodiscard]] constexpr auto operator++(int) -> Payload_Vector_Iterator { Payload_Vector_Iterator next_iter{offset_iter, payload_iter}; ++(*this); return next_iter; } - constexpr auto operator--() -> Payload_Vector_Iterator & { return *this -= 1; } + constexpr auto operator--() -> Payload_Vector_Iterator& { return *this -= 1; } - [[nodiscard]] - constexpr auto operator--(int) -> Payload_Vector_Iterator + [[nodiscard]] constexpr auto operator--(int) -> Payload_Vector_Iterator { Payload_Vector_Iterator next_iter{offset_iter, payload_iter}; --(*this); return next_iter; } - [[nodiscard]] - constexpr auto operator+(size_type n) const -> Payload_Vector_Iterator + [[nodiscard]] constexpr auto operator+(size_type n) const -> Payload_Vector_Iterator { return {std::next(offset_iter, n), std::next(payload_iter, *std::next(offset_iter, n) - *offset_iter)}; } - [[nodiscard]] - constexpr auto operator+=(size_type n) -> Payload_Vector_Iterator & + [[nodiscard]] constexpr auto operator+=(size_type n) -> Payload_Vector_Iterator& { std::advance(payload_iter, *std::next(offset_iter, n) - *offset_iter); std::advance(offset_iter, n); return *this; } - [[nodiscard]] - constexpr auto operator-(size_type n) const -> Payload_Vector_Iterator + [[nodiscard]] constexpr auto operator-(size_type n) const -> Payload_Vector_Iterator { return {std::prev(offset_iter, n), std::prev(payload_iter, *offset_iter - *std::prev(offset_iter, n))}; } - [[nodiscard]] - constexpr auto operator-=(size_type n) -> Payload_Vector_Iterator & + [[nodiscard]] constexpr auto operator-=(size_type n) -> Payload_Vector_Iterator& { return *this = *this - n; } - [[nodiscard]] - constexpr auto operator-(Payload_Vector_Iterator const& other) -> difference_type + [[nodiscard]] constexpr auto operator-(Payload_Vector_Iterator const& other) + -> difference_type { return offset_iter - other.offset_iter; } - [[nodiscard]] - constexpr auto operator*() -> value_type + [[nodiscard]] constexpr auto operator*() -> value_type { - return value_type(reinterpret_cast(&*payload_iter), - *std::next(offset_iter) - *offset_iter); + if constexpr (std::is_trivially_copyable_v) { + value_type value; + std::memcpy(&value, reinterpret_cast(&*payload_iter), sizeof(value)); + return value; + } else { + return value_type(reinterpret_cast(&*payload_iter), + *std::next(offset_iter) - *offset_iter); + } } - [[nodiscard]] constexpr auto operator==(Payload_Vector_Iterator const &other) const -> bool + [[nodiscard]] constexpr auto operator==(Payload_Vector_Iterator const& other) const -> bool { return offset_iter == other.offset_iter; } - [[nodiscard]] constexpr auto operator!=(Payload_Vector_Iterator const &other) const -> bool + [[nodiscard]] constexpr auto operator!=(Payload_Vector_Iterator const& other) const -> bool { return offset_iter != other.offset_iter; } @@ -138,12 +135,12 @@ namespace detail { }; template - [[nodiscard]] static constexpr auto unpack(std::byte const *ptr) -> std::tuple + [[nodiscard]] static constexpr auto unpack(std::byte const* ptr) -> std::tuple { - if constexpr (sizeof...(Ts) == 0u) { - return std::tuple(*reinterpret_cast(ptr)); + if constexpr (sizeof...(Ts) == 0U) { + return std::tuple(*reinterpret_cast(ptr)); } else { - return std::tuple_cat(std::tuple(*reinterpret_cast(ptr)), + return std::tuple_cat(std::tuple(*reinterpret_cast(ptr)), unpack(ptr + sizeof(T))); } } @@ -156,39 +153,39 @@ struct Payload_Vector_Buffer { std::vector const offsets; std::vector const payloads; - [[nodiscard]] static auto from_file(std::string const &filename) -> Payload_Vector_Buffer + [[nodiscard]] static auto from_file(std::string const& filename) -> Payload_Vector_Buffer { boost::system::error_code ec; auto file_size = boost::filesystem::file_size(boost::filesystem::path(filename)); std::ifstream is(filename); size_type len; - is.read(reinterpret_cast(&len), sizeof(size_type)); + is.read(reinterpret_cast(&len), sizeof(size_type)); auto offsets_bytes = (len + 1) * sizeof(size_type); std::vector offsets(len + 1); - is.read(reinterpret_cast(offsets.data()), offsets_bytes); + is.read(reinterpret_cast(offsets.data()), offsets_bytes); auto payloads_bytes = file_size - offsets_bytes - sizeof(size_type); std::vector payloads(payloads_bytes); - is.read(reinterpret_cast(payloads.data()), payloads_bytes); + is.read(reinterpret_cast(payloads.data()), payloads_bytes); return Payload_Vector_Buffer{std::move(offsets), std::move(payloads)}; } - void to_file(std::string const &filename) const + void to_file(std::string const& filename) const { std::ofstream is(filename); to_stream(is); } - void to_stream(std::ostream &is) const + void to_stream(std::ostream& is) const { - size_type length = offsets.size() - 1u; - is.write(reinterpret_cast(&length), sizeof(length)); - is.write(reinterpret_cast(offsets.data()), + size_type length = offsets.size() - 1U; + is.write(reinterpret_cast(&length), sizeof(length)); + is.write(reinterpret_cast(offsets.data()), offsets.size() * sizeof(offsets[0])); - is.write(reinterpret_cast(payloads.data()), payloads.size()); + is.write(reinterpret_cast(payloads.data()), payloads.size()); } template @@ -197,7 +194,7 @@ struct Payload_Vector_Buffer { PayloadEncodingFn encoding_fn) -> Payload_Vector_Buffer { std::vector offsets; - offsets.push_back(0u); + offsets.push_back(0U); std::vector payloads; for (; first != last; ++first) { encoding_fn(*first, std::back_inserter(payloads)); @@ -228,13 +225,14 @@ auto encode_payload_vector(InputIterator first, InputIterator last) }); } -auto encode_payload_vector(gsl::span values) +inline auto encode_payload_vector(gsl::span values) { return encode_payload_vector(values.begin(), values.end()); } template -constexpr auto unpack_head(gsl::span mem) -> std::tuple> +constexpr auto unpack_head(gsl::span mem) + -> std::tuple> { static_assert(detail::all_pod::value); auto offset = detail::sizeofs::value; @@ -247,7 +245,7 @@ constexpr auto unpack_head(gsl::span mem) -> std::tuple>(tail)); } -[[nodiscard]] auto split(gsl::span mem, std::size_t offset) +[[nodiscard]] inline auto split(gsl::span mem, std::size_t offset) { if (offset > mem.size()) { throw std::runtime_error( @@ -264,7 +262,7 @@ template throw std::runtime_error( fmt::format("Failed to cast byte-span to span of T of size {}", type_size)); } - return gsl::make_span(reinterpret_cast(mem.data()), mem.size() / type_size); + return gsl::make_span(reinterpret_cast(mem.data()), mem.size() / type_size); } template @@ -275,46 +273,45 @@ class Payload_Vector { using payload_type = Payload_View; using iterator = detail::Payload_Vector_Iterator; - Payload_Vector(Payload_Vector_Buffer const &container) + explicit Payload_Vector(Payload_Vector_Buffer const& container) : offsets_(container.offsets), payloads_(container.payloads) - {} + { + } Payload_Vector(gsl::span offsets, gsl::span payloads) - : offsets_(std::move(offsets)), - payloads_(std::move(payloads)) - {} + : offsets_(offsets), payloads_(payloads) + { + } template - [[nodiscard]] - constexpr static auto from(ContiguousContainer&& mem) -> Payload_Vector + [[nodiscard]] constexpr static auto from(ContiguousContainer&& mem) -> Payload_Vector { - return from(gsl::make_span(reinterpret_cast(mem.data()), mem.size())); + return from(gsl::make_span(reinterpret_cast(mem.data()), mem.size())); } - [[nodiscard]] - static auto from(gsl::span mem) -> Payload_Vector + [[nodiscard]] static auto from(gsl::span mem) -> Payload_Vector { size_type length; gsl::span tail; try { std::tie(length, tail) = unpack_head(mem); - } catch (std::runtime_error const &err) { - throw std::runtime_error(std::string("Failed to parse payload vector length: ") + - err.what()); + } catch (std::runtime_error const& err) { + throw std::runtime_error(std::string("Failed to parse payload vector length: ") + + err.what()); } - gsl::span offsets, payloads; + gsl::span offsets; + gsl::span payloads; try { - std::tie(offsets, payloads) = split(tail, (length + 1u) * sizeof(size_type)); + std::tie(offsets, payloads) = split(tail, (length + 1U) * sizeof(size_type)); } catch (std::runtime_error const& err) { - throw std::runtime_error(std::string("Failed to parse payload vector offset table: ") + - err.what()); + throw std::runtime_error(std::string("Failed to parse payload vector offset table: ") + + err.what()); } return Payload_Vector(cast_span(offsets), payloads); } - [[nodiscard]] - constexpr auto operator[](size_type idx) const -> payload_type + [[nodiscard]] constexpr auto operator[](size_type idx) const -> payload_type { if (idx >= offsets_.size()) { throw std::out_of_range(fmt::format( @@ -329,20 +326,17 @@ class Payload_Vector { return *(begin() + idx); } - [[nodiscard]] - constexpr auto begin() const -> iterator + [[nodiscard]] constexpr auto begin() const -> iterator { return {offsets_.begin(), payloads_.begin()}; } - [[nodiscard]] - constexpr auto end() const -> iterator + [[nodiscard]] constexpr auto end() const -> iterator { return {std::prev(offsets_.end()), payloads_.end()}; } [[nodiscard]] constexpr auto cbegin() const -> iterator { return begin(); } [[nodiscard]] constexpr auto cend() const -> iterator { return end(); } - [[nodiscard]] - constexpr auto size() const -> size_type + [[nodiscard]] constexpr auto size() const -> size_type { return offsets_.size() - size_type{1}; } diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index f1dd96479..df45926eb 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -116,12 +116,11 @@ struct BlockedCursor { /// Moves the cursor to the next value equal or greater than `value`. constexpr void advance_to_geq(value_type value) { - //static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); + // static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); // TODO(michal): This should be `static_assert` like above. But currently, // it would not compile. What needs to be done is separating document // and payload readers for the index runner. assert(DeltaEncoded); - Expects(value >= m_current_value || position() == 0); if (PISA_UNLIKELY(value > m_current_block.last_value)) { if (value > m_block_last_values.back()) { m_current_value = sentinel(); @@ -135,11 +134,7 @@ struct BlockedCursor { } while (m_current_value < value) { - if constexpr (DeltaEncoded) { - m_current_value += m_decoded_block[m_current_block.offset] + 1U; - } else { - m_current_value = m_decoded_block[m_current_block.offset] + 1U; - } + m_current_value += m_decoded_block[++m_current_block.offset] + 1U; Ensures(m_current_block.offset < m_current_block.length); } } @@ -174,8 +169,8 @@ struct BlockedCursor { { constexpr auto block_size = Codec::block_size; auto endpoint = block > 0U ? m_block_endpoints[block - 1] : static_cast(0U); - std::uint8_t const *block_data = - std::next(reinterpret_cast(m_encoded_blocks.data()), endpoint); + std::uint8_t const* block_data = + std::next(reinterpret_cast(m_encoded_blocks.data()), endpoint); m_current_block.length = ((block + 1) * block_size <= size()) ? block_size : (size() % block_size); @@ -229,7 +224,7 @@ struct BlockedReader { -> BlockedCursor { std::uint32_t length; - auto begin = reinterpret_cast(bytes.data()); + auto begin = reinterpret_cast(bytes.data()); auto after_length_ptr = pisa::TightVariableByte::decode(begin, &length, 1); auto length_byte_size = std::distance(begin, after_length_ptr); auto num_blocks = ceil_div(length, Codec::block_size); @@ -269,7 +264,7 @@ struct BlockedWriter { | encoding_traits::encoding_tag::encoding(); } - void push(value_type const &posting) + void push(value_type const& posting) { if constexpr (DeltaEncoded) { if (posting < m_last_value) { @@ -284,7 +279,7 @@ struct BlockedWriter { } template - [[nodiscard]] auto write(std::basic_ostream &os) const -> std::size_t + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t { std::vector buffer; std::uint32_t length = m_postings.size(); @@ -312,7 +307,7 @@ struct BlockedWriter { std::for_each(block_buffer.begin(), std::next(block_buffer.begin(), current_block_size), - [&](auto &&elem) { + [&](auto&& elem) { if constexpr (DeltaEncoded) { auto value = *iter++; elem = value - (last_value + 1); @@ -340,7 +335,7 @@ struct BlockedWriter { } block_base = last_value + 1; } - os.write(reinterpret_cast(buffer.data()), buffer.size()); + os.write(reinterpret_cast(buffer.data()), buffer.size()); return buffer.size(); } diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index 779b4b135..dbe84ea5c 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -11,8 +11,8 @@ namespace pisa::v1 { template struct ScoringCursor { using Document = decltype(*std::declval()); - using Payload = float; - static_assert(std::is_same_v); + using Payload = decltype((std::declval())(std::declval(), + std::declval().payload())); explicit constexpr ScoringCursor(BaseCursor base_cursor, TermScorer scorer) : m_base_cursor(std::move(base_cursor)), m_scorer(std::move(scorer)) @@ -24,7 +24,7 @@ struct ScoringCursor { { return m_base_cursor.value(); } - [[nodiscard]] constexpr auto payload() noexcept -> Payload + [[nodiscard]] constexpr auto payload() noexcept { return m_scorer(m_base_cursor.value(), m_base_cursor.payload()); } diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index beb88c936..1cf9596d5 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -26,19 +26,22 @@ struct CursorUnion { m_accumulate(std::move(accumulate)), m_size(std::nullopt) { - Expects(not m_cursors.empty()); - auto order = [](auto const &lhs, auto const &rhs) { return lhs.value() < rhs.value(); }; - m_next_docid = [&]() { - auto pos = std::min_element(m_cursors.begin(), m_cursors.end(), order); - return pos->value(); - }(); - m_sentinel = std::min_element(m_cursors.begin(), - m_cursors.end(), - [](auto const &lhs, auto const &rhs) { - return lhs.sentinel() < rhs.sentinel(); - }) - ->sentinel(); - advance(); + if (m_cursors.empty()) { + m_current_value = std::numeric_limits::max(); + } else { + auto order = [](auto const& lhs, auto const& rhs) { return lhs.value() < rhs.value(); }; + m_next_docid = [&]() { + auto pos = std::min_element(m_cursors.begin(), m_cursors.end(), order); + return pos->value(); + }(); + m_sentinel = std::min_element(m_cursors.begin(), + m_cursors.end(), + [](auto const& lhs, auto const& rhs) { + return lhs.sentinel() < rhs.sentinel(); + }) + ->sentinel(); + advance(); + } } [[nodiscard]] constexpr auto size() const noexcept -> std::size_t @@ -47,13 +50,13 @@ struct CursorUnion { m_size = std::accumulate(m_cursors.begin(), m_cursors.end(), std::size_t(0), - [](auto acc, auto const &elem) { return acc + elem.size(); }); + [](auto acc, auto const& elem) { return acc + elem.size(); }); } return *m_size; } [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } - [[nodiscard]] constexpr auto payload() const noexcept -> Payload const & + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& { return m_current_payload; } @@ -69,7 +72,7 @@ struct CursorUnion { m_current_value = m_next_docid; m_next_docid = m_sentinel; std::size_t cursor_idx = 0; - for (auto &cursor : m_cursors) { + for (auto& cursor : m_cursors) { if (cursor.value() == m_current_value) { m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); cursor.advance(); @@ -103,6 +106,93 @@ struct CursorUnion { std::uint32_t m_next_docid{}; }; +/// Transforms a list of cursors into one cursor by lazily merging them together. +template +struct VariadicCursorUnion { + using Value = std::decay_t(std::declval()))>; + + constexpr VariadicCursorUnion(Payload init, + CursorsTuple cursors, + std::tuple accumulate) + : m_cursors(std::move(cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_size(std::nullopt) + { + m_next_docid = std::numeric_limits::max(); + m_sentinel = std::numeric_limits::max(); + for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { + if (cursor.value() < m_next_docid) { + m_next_docid = cursor.value(); + } + }); + for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { + if (cursor.sentinel() < m_next_docid) { + m_next_docid = cursor.sentinel(); + } + }); + advance(); + } + + template + void for_each_cursor(Fn&& fn) + { + std::apply( + [&](auto&&... cursor) { + std::apply([&](auto&&... accumulate) { (fn(cursor, accumulate), ...); }, + m_accumulate); + }, + m_cursors); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { + m_current_value = m_sentinel; + m_current_payload = m_init; + } else { + m_current_payload = m_init; + m_current_value = m_next_docid; + m_next_docid = m_sentinel; + std::size_t cursor_idx = 0; + for_each_cursor([&](auto&& cursor, auto&& accumulate) { + if (cursor.value() == m_current_value) { + m_current_payload = accumulate(m_current_payload, cursor, cursor_idx); + cursor.advance(); + } + if (cursor.value() < m_next_docid) { + m_next_docid = cursor.value(); + } + ++cursor_idx; + }); + } + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorsTuple m_cursors; + Payload m_init; + std::tuple m_accumulate; + std::optional m_size; + + Value m_current_value{}; + Value m_sentinel{}; + Payload m_current_payload{}; + std::uint32_t m_next_docid{}; +}; + template [[nodiscard]] constexpr inline auto union_merge(CursorContainer cursors, Payload init, @@ -112,4 +202,13 @@ template std::move(cursors), std::move(init), std::move(accumulate)); } +template +[[nodiscard]] constexpr inline auto variadic_union_merge(Payload init, + std::tuple cursors, + std::tuple accumulate) +{ + return VariadicCursorUnion, AccumulateFn...>( + std::move(init), std::move(cursors), std::move(accumulate)); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 33cc4fccf..f355d8f23 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -17,6 +17,7 @@ #include #include "binary_freq_collection.hpp" +#include "payload_vector.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor/scoring_cursor.hpp" @@ -27,10 +28,14 @@ #include "v1/scorer/bm25.hpp" #include "v1/source.hpp" #include "v1/types.hpp" +#include "v1/zip_cursor.hpp" namespace pisa::v1 { -[[nodiscard]] inline auto calc_avg_length(gsl::span const &lengths) -> float +using OffsetSpan = gsl::span; +using BinarySpan = gsl::span; + +[[nodiscard]] inline auto calc_avg_length(gsl::span const& lengths) -> float { auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); return static_cast(sum) / lengths.size(); @@ -68,21 +73,31 @@ struct Index { PayloadReader payload_reader, gsl::span document_offsets, gsl::span payload_offsets, + tl::optional bigram_document_offsets, + tl::optional> bigram_frequency_offsets, gsl::span documents, gsl::span payloads, + tl::optional bigram_documents, + tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, + tl::optional<::pisa::Payload_Vector>> bigram_mapping, Source source) : m_document_reader(std::move(document_reader)), m_payload_reader(std::move(payload_reader)), m_document_offsets(document_offsets), m_payload_offsets(payload_offsets), + m_bigram_document_offsets(bigram_document_offsets), + m_bigram_frequency_offsets(bigram_frequency_offsets), m_documents(documents), m_payloads(payloads), + m_bigram_documents(bigram_documents), + m_bigram_frequencies(bigram_frequencies), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length.map_or_else( - [](auto &&self) { return self; }, + [](auto&& self) { return self; }, [&]() { return calc_avg_length(m_document_lengths); })), + m_bigram_mapping(bigram_mapping), m_source(std::move(source)) { } @@ -94,9 +109,28 @@ struct Index { payloads(term)); } + [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const + { + if (not m_bigram_mapping) { + throw std::logic_error("Bigrams are missing"); + } + if (auto pos = std::lower_bound(m_bigram_mapping->begin(), + m_bigram_mapping->end(), + std::array{left_term, right_term}); + pos != m_bigram_mapping->end()) { + auto bigram_id = std::distance(m_bigram_mapping->begin(), pos); + return DocumentPayloadCursor>( + m_document_reader.read(fetch_bigram_documents(bigram_id)), + zip(m_payload_reader.read(fetch_bigram_payloads<0>(bigram_id)), + m_payload_reader.read(fetch_bigram_payloads<1>(bigram_id)))); + } + throw std::invalid_argument( + fmt::format("Bigram for <{}, {}> not found.", left_term, right_term)); + } + /// Constructs a new document-score cursor. template - [[nodiscard]] auto scoring_cursor(TermId term, Scorer &&scorer) const + [[nodiscard]] auto scoring_cursor(TermId term, Scorer&& scorer) const { return ScoringCursor(cursor(term), std::forward(scorer).term_scorer(term)); } @@ -104,7 +138,7 @@ struct Index { /// This is equivalent to the `scoring_cursor` unless the scorer is of type `VoidScorer`, /// in which case index payloads are treated as scores. template - [[nodiscard]] auto scored_cursor(TermId term, Scorer &&scorer) const + [[nodiscard]] auto scored_cursor(TermId term, Scorer&& scorer) const { if constexpr (std::is_convertible_v) { return cursor(term); @@ -113,6 +147,34 @@ struct Index { } } + /// Constructs a new document-score cursor. + template + [[nodiscard]] auto scoring_bigram_cursor(TermId left_term, + TermId right_term, + Scorer&& scorer) const + { + return ScoringCursor( + bigram_cursor(left_term, right_term), + [scorers = + std::make_tuple(scorer.term_scorer(left_term), scorer.term_scorer(right_term))]( + auto&& docid, auto&& payload) { + return std::array{std::get<0>(scorers)(docid, std::get<0>(payload)), + std::get<1>(scorers)(docid, std::get<1>(payload))}; + }); + } + + template + [[nodiscard]] auto scored_bigram_cursor(TermId left_term, + TermId right_term, + Scorer&& scorer) const + { + if constexpr (std::is_convertible_v) { + return bigram_cursor(left_term, right_term); + } else { + return scoring_bigram_cursor(left_term, right_term, std::forward(scorer)); + } + } + /// Constructs a new document cursor. [[nodiscard]] auto documents(TermId term) const { @@ -161,27 +223,62 @@ struct Index { return m_payloads.subspan(m_payload_offsets[term], m_payload_offsets[term + 1] - m_payload_offsets[term]); } + [[nodiscard]] auto fetch_bigram_documents(TermId term) const -> gsl::span + { + if (not m_bigram_documents) { + throw std::logic_error("Bigrams are missing"); + } + Expects(term + 1 < m_bigram_document_offsets->size()); + return m_bigram_documents->subspan( + (*m_bigram_document_offsets)[term], + (*m_bigram_document_offsets)[term + 1] - (*m_bigram_document_offsets)[term]); + } + template + [[nodiscard]] auto fetch_bigram_payloads(TermId term) const -> gsl::span + { + if (not m_bigram_frequencies) { + throw std::logic_error("Bigrams are missing"); + } + Expects(term + 1 < std::get(*m_bigram_frequency_offsets).size()); + return std::get(*m_bigram_frequencies) + .subspan(std::get(*m_bigram_frequency_offsets)[term], + std::get(*m_bigram_frequency_offsets)[term + 1] + - std::get(*m_bigram_frequency_offsets)[term]); + } Reader m_document_reader; Reader m_payload_reader; - gsl::span m_document_offsets; - gsl::span m_payload_offsets; - gsl::span m_documents; - gsl::span m_payloads; + + OffsetSpan m_document_offsets; + OffsetSpan m_payload_offsets; + tl::optional m_bigram_document_offsets{}; + tl::optional> m_bigram_frequency_offsets{}; + + BinarySpan m_documents; + BinarySpan m_payloads; + tl::optional m_bigram_documents{}; + tl::optional> m_bigram_frequencies{}; + gsl::span m_document_lengths; float m_avg_document_length; + tl::optional<::pisa::Payload_Vector>> m_bigram_mapping; std::any m_source; }; template auto make_index(DocumentReader document_reader, PayloadReader payload_reader, - gsl::span document_offsets, - gsl::span payload_offsets, - gsl::span documents, - gsl::span payloads, + OffsetSpan document_offsets, + OffsetSpan payload_offsets, + tl::optional bigram_document_offsets, + tl::optional> bigram_frequency_offsets, + BinarySpan documents, + BinarySpan payloads, + tl::optional bigram_documents, + tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, + tl::optional<::pisa::Payload_Vector>> bigram_mapping, Source source) { using DocumentCursor = @@ -191,10 +288,15 @@ auto make_index(DocumentReader document_reader, std::move(payload_reader), document_offsets, payload_offsets, + bigram_document_offsets, + bigram_frequency_offsets, documents, payloads, + bigram_documents, + bigram_frequencies, document_lengths, avg_document_length, + bigram_mapping, std::move(source)); } @@ -204,19 +306,19 @@ template -auto score_index(Index const &index, - std::basic_ostream &os, +auto score_index(Index const& index, + std::basic_ostream& os, Writer writer, Scorer scorer, - Quantizer &&quantizer, - Callback &&callback) -> std::vector + Quantizer&& quantizer, + Callback&& callback) -> std::vector { PostingBuilder score_builder(writer); score_builder.write_header(os); std::for_each(boost::counting_iterator(0), boost::counting_iterator(index.num_terms()), [&](auto term) { - for_each(index.scoring_cursor(term, scorer), [&](auto &cursor) { + for_each(index.scoring_cursor(term, scorer), [&](auto& cursor) { score_builder.accumulate(quantizer(cursor.payload())); }); score_builder.flush_segment(os); @@ -226,7 +328,7 @@ auto score_index(Index const &index, } template -auto score_index(Index const &index, std::basic_ostream &os, Writer writer, Scorer scorer) +auto score_index(Index const& index, std::basic_ostream& os, Writer writer, Scorer scorer) -> std::vector { PostingBuilder score_builder(writer); @@ -235,14 +337,14 @@ auto score_index(Index const &index, std::basic_ostream &os, Writer write boost::counting_iterator(index.num_terms()), [&](auto term) { for_each(index.scoring_cursor(term, scorer), - [&](auto &cursor) { score_builder.accumulate(cursor.payload()); }); + [&](auto& cursor) { score_builder.accumulate(cursor.payload()); }); score_builder.flush_segment(os); }); return std::move(score_builder.offsets()); } /// Initializes a memory mapped source with a given file. -inline void open_source(mio::mmap_source &source, std::string const &filename) +inline void open_source(mio::mmap_source& source, std::string const& filename) { std::error_code error; source.map(filename, error); @@ -259,15 +361,15 @@ inline auto read_sizes(std::string_view basename) return std::vector(sequence.begin(), sequence.end()); } -[[nodiscard]] inline auto binary_collection_source(std::string const &basename) +[[nodiscard]] inline auto binary_collection_source(std::string const& basename) { using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; binary_freq_collection collection(basename.c_str()); VectorSource source{{{}, {}}, {{}, {}}, {read_sizes(basename)}}; - std::vector &docbuf = source.bytes[0]; - std::vector &freqbuf = source.bytes[1]; + std::vector& docbuf = source.bytes[0]; + std::vector& freqbuf = source.bytes[1]; PostingBuilder document_builder(Writer(RawWriter{})); PostingBuilder frequency_builder(Writer(RawWriter{})); @@ -291,7 +393,7 @@ inline auto read_sizes(std::string_view basename) return source; } -[[nodiscard]] inline auto binary_collection_index(std::string const &basename) +[[nodiscard]] inline auto binary_collection_index(std::string const& basename) { auto source = binary_collection_source(basename); auto documents = gsl::span(source.bytes[0]); @@ -303,14 +405,19 @@ inline auto read_sizes(std::string_view basename) RawReader{}, document_offsets, frequency_offsets, + {}, + {}, documents.subspan(8), frequencies.subspan(8), + {}, + {}, sizes, tl::nullopt, + tl::nullopt, std::move(source)); } -[[nodiscard]] inline auto binary_collection_scored_index(std::string const &basename) +[[nodiscard]] inline auto binary_collection_scored_index(std::string const& basename) { using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; @@ -325,10 +432,15 @@ inline auto read_sizes(std::string_view basename) RawReader{}, document_offsets, frequency_offsets, + {}, + {}, documents.subspan(8), frequencies.subspan(8), + {}, + {}, sizes, tl::nullopt, + tl::nullopt, false); source.offsets.push_back([&freq_index, &source]() { @@ -343,10 +455,15 @@ inline auto read_sizes(std::string_view basename) RawReader{}, document_offsets, score_offsets, + {}, + {}, documents.subspan(8), scores.subspan(8), + {}, + {}, sizes, tl::nullopt, + tl::nullopt, std::move(source)); } @@ -375,7 +492,7 @@ struct BigramIndex : public Index { /// Creates, on the fly, a bigram index with all pairs of adjecent terms. /// Disclaimer: for testing purposes. -[[nodiscard]] inline auto binary_collection_bigram_index(std::string const &basename) +[[nodiscard]] inline auto binary_collection_bigram_index(std::string const& basename) { using payload_type = std::array; using sink_type = boost::iostreams::back_insert_device>; @@ -403,7 +520,7 @@ struct BigramIndex : public Index { auto intersection = CursorIntersection( std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, payload_type{0, 0}, - [](payload_type &payload, auto &cursor, auto list_idx) { + [](payload_type& payload, auto& cursor, auto list_idx) { payload[list_idx] = cursor.payload(); return payload; }); @@ -412,7 +529,7 @@ struct BigramIndex : public Index { return; } pair_mapping.emplace_back(left, right); - for_each(intersection, [&](auto &cursor) { + for_each(intersection, [&](auto& cursor) { document_builder.accumulate(*cursor); frequency_builder.accumulate(cursor.payload()); }); @@ -434,10 +551,15 @@ struct BigramIndex : public Index { RawReader{}, document_offsets, frequency_offsets, + {}, + {}, document_span.subspan(8), payload_span.subspan(8), + {}, + {}, sizes, tl::nullopt, + tl::nullopt, std::move(source)); return BigramIndex(std::move(index), std::move(pair_mapping)); } @@ -447,18 +569,28 @@ struct IndexRunner { template IndexRunner(gsl::span document_offsets, gsl::span payload_offsets, + tl::optional bigram_document_offsets, + tl::optional> bigram_frequency_offsets, gsl::span documents, gsl::span payloads, + tl::optional bigram_documents, + tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, + tl::optional<::pisa::Payload_Vector>> bigram_mapping, Source source, Readers... readers) : m_document_offsets(document_offsets), m_payload_offsets(payload_offsets), + m_bigram_document_offsets(bigram_document_offsets), + m_bigram_frequency_offsets(bigram_frequency_offsets), m_documents(documents), m_payloads(payloads), + m_bigram_documents(bigram_documents), + m_bigram_frequencies(bigram_frequencies), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length), + m_bigram_mapping(bigram_mapping), m_source(std::move(source)), m_readers(readers...) { @@ -466,18 +598,28 @@ struct IndexRunner { template IndexRunner(gsl::span document_offsets, gsl::span payload_offsets, + tl::optional bigram_document_offsets, + tl::optional> bigram_frequency_offsets, gsl::span documents, gsl::span payloads, + tl::optional bigram_documents, + tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, + tl::optional<::pisa::Payload_Vector>> bigram_mapping, Source source, std::tuple readers) : m_document_offsets(document_offsets), m_payload_offsets(payload_offsets), + m_bigram_document_offsets(bigram_document_offsets), + m_bigram_frequency_offsets(bigram_frequency_offsets), m_documents(documents), m_payloads(payloads), + m_bigram_documents(bigram_documents), + m_bigram_frequencies(bigram_frequencies), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length), + m_bigram_mapping(bigram_mapping), m_source(std::move(source)), m_readers(std::move(readers)) { @@ -488,20 +630,29 @@ struct IndexRunner { { auto dheader = PostingFormatHeader::parse(m_documents.first(8)); auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); - auto run = [&](auto &&dreader, auto &&preader) { + auto run = [&](auto&& dreader, auto&& preader) { if (std::decay_t::encoding() == dheader.encoding && std::decay_t::encoding() == pheader.encoding && is_type::value_type>(dheader.type) && is_type::value_type>(pheader.type)) { - auto index = make_index(std::forward(dreader), - std::forward(preader), - m_document_offsets, - m_payload_offsets, - m_documents.subspan(8), - m_payloads.subspan(8), - m_document_lengths, - m_avg_document_length, - false); + auto index = make_index( + std::forward(dreader), + std::forward(preader), + m_document_offsets, + m_payload_offsets, + m_bigram_document_offsets, + m_bigram_frequency_offsets, + m_documents.subspan(8), + m_payloads.subspan(8), + m_bigram_documents.map([](auto&& bytes) { return bytes.subspan(8); }), + m_bigram_frequencies.map([](auto&& bytes) { + return std::array{std::get<0>(bytes).subspan(8), + std::get<1>(bytes).subspan(8)}; + }), + m_document_lengths, + m_avg_document_length, + m_bigram_mapping, + false); fn(index); return true; } @@ -525,10 +676,17 @@ struct IndexRunner { private: gsl::span m_document_offsets; gsl::span m_payload_offsets; + tl::optional m_bigram_document_offsets{}; + tl::optional> m_bigram_frequency_offsets{}; + gsl::span m_documents; gsl::span m_payloads; + tl::optional m_bigram_documents{}; + tl::optional> m_bigram_frequencies{}; + gsl::span m_document_lengths; tl::optional m_avg_document_length; + tl::optional<::pisa::Payload_Vector>> m_bigram_mapping; std::any m_source; std::tuple m_readers; }; diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 2568008d4..fadbd1c3e 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -91,12 +91,41 @@ template auto document_offsets = source_span(source, metadata.documents.offsets); auto frequency_offsets = source_span(source, metadata.frequencies.offsets); auto document_lengths = source_span(source, metadata.document_lengths_path); + tl::optional> bigram_document_offsets{}; + tl::optional, 2>> bigram_frequency_offsets{}; + tl::optional> bigram_documents{}; + tl::optional, 2>> bigram_frequencies{}; + tl::optional<::pisa::Payload_Vector>> bigram_mapping{}; + if (metadata.bigrams) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_frequency_offsets = { + source_span(source, metadata.bigrams->frequencies.first.offsets), + source_span(source, metadata.bigrams->frequencies.second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_frequencies = { + source_span(source, metadata.bigrams->frequencies.first.postings), + source_span(source, metadata.bigrams->frequencies.second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping).subspan(8); + auto num_offset_bytes = (metadata.bigrams->count + 1U) * 8U; + auto mapping_offsets = mapping_span.first(num_offset_bytes); + bigram_mapping = Payload_Vector>( + gsl::span( + reinterpret_cast(mapping_offsets.data()), + mapping_offsets.size() * sizeof(std::size_t)), + mapping_span.subspan(num_offset_bytes)); + } return IndexRunner(document_offsets, frequency_offsets, + bigram_document_offsets, + bigram_frequency_offsets, documents, frequencies, + bigram_documents, + bigram_frequencies, document_lengths, tl::make_optional(metadata.avg_document_length), + bigram_mapping, std::move(source), std::move(readers)); } @@ -119,10 +148,15 @@ template auto document_lengths = source_span(source, metadata.document_lengths_path); return IndexRunner(document_offsets, score_offsets, + {}, + {}, documents, scores, + {}, // TODO(michal): support scored bigrams + {}, document_lengths, tl::make_optional(metadata.avg_document_length), + {}, // TODO std::move(source), std::move(readers)); } diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index a21ad716b..84f72756f 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -3,6 +3,9 @@ #include #include +#include +#include + #include "topk_queue.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" @@ -13,14 +16,15 @@ namespace pisa::v1 { struct Query { std::vector terms; + std::vector> bigrams{}; }; template -using QueryProcessor = std::function; +using QueryProcessor = std::function; struct ExhaustiveConjunctiveProcessor { template - auto operator()(Index const &index, Query const &query, topk_queue que) -> topk_queue + auto operator()(Index const& index, Query const& query, topk_queue que) -> topk_queue { using Cursor = std::decay_t; std::vector cursors; @@ -31,7 +35,7 @@ struct ExhaustiveConjunctiveProcessor { auto intersection = intersect(std::move(cursors), 0.0F, - [](float score, auto &cursor, [[maybe_unused]] auto cursor_idx) { + [](float score, auto& cursor, [[maybe_unused]] auto cursor_idx) { return score + static_cast(cursor.payload()); }); while (not intersection.empty()) { @@ -42,7 +46,7 @@ struct ExhaustiveConjunctiveProcessor { }; template -auto daat_and(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +auto daat_and(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { std::vector cursors; std::transform(query.terms.begin(), @@ -50,16 +54,16 @@ auto daat_and(Query const &query, Index const &index, topk_queue topk, Scorer && std::back_inserter(cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); auto intersection = - v1::intersect(std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + v1::intersect(std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { score += cursor.payload(); return score; }); - v1::for_each(intersection, [&](auto &cursor) { topk.insert(cursor.payload(), *cursor); }); + v1::for_each(intersection, [&](auto& cursor) { topk.insert(cursor.payload(), *cursor); }); return topk; } template -auto daat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { std::vector cursors; std::transform(query.terms.begin(), @@ -67,21 +71,21 @@ auto daat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&s std::back_inserter(cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [](auto &score, auto &cursor, auto /* term_idx */) { + std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { score += cursor.payload(); return score; }); - v1::for_each(cunion, [&](auto &cursor) { topk.insert(cursor.payload(), cursor.value()); }); + v1::for_each(cunion, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); return topk; } template -auto taat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&scorer) +auto taat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { std::vector accumulator(index.num_documents(), 0.0F); for (auto term : query.terms) { v1::for_each(index.scored_cursor(term, scorer), - [&accumulator](auto &&cursor) { accumulator[*cursor] += cursor.payload(); }); + [&accumulator](auto&& cursor) { accumulator[*cursor] += cursor.payload(); }); } for (auto document = 0; document < accumulator.size(); document += 1) { topk.insert(accumulator[document], document); @@ -89,4 +93,107 @@ auto taat_or(Query const &query, Index const &index, topk_queue topk, Scorer &&s return topk; } +/// Returns only unique terms, in sorted order. +[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector; + +template +auto transform() +{ +} + +/// Performs a "union-lookup" query (name pending). +/// +/// \param query Full query, as received, possibly with duplicates. +/// \param index Inverted index, with access to both unigrams and bigrams. +/// \param topk Top-k heap. +/// \param scorer An object capable of constructing term scorers. +/// \param essential_unigrams A list of essential single-term posting lists. +/// Elements of this vector point to the index of the term +/// in the query. In other words, for each position `i` in this vector, +/// `query.terms[essential_unigrams[i]]` is an essential unigram. +/// \param essential_bigrams Similar to the above, but represents intersections between two +/// posting lists. These must exist in the index, or else this +/// algorithm will fail. +template +auto union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + std::vector essential_unigrams, + std::vector> essential_bigrams) +{ + ranges::sort(essential_unigrams); + ranges::actions::unique(essential_unigrams); + ranges::sort(essential_bigrams); + ranges::actions::unique(essential_bigrams); + + std::vector initial_payload(query.terms.size(), 0.0); + + std::vector essential_unigram_cursors; + std::transform(essential_unigrams.begin(), + essential_unigrams.end(), + std::back_inserter(essential_unigram_cursors), + [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { + acc[essential_unigrams[term_idx]] = cursor.payload(); + return acc; + }); + + std::vector essential_bigram_cursors; + std::transform(essential_bigrams.begin(), + essential_bigrams.end(), + std::back_inserter(essential_bigram_cursors), + [&](auto intersection) { + return index.scored_bigram_cursor(query.terms[intersection.first], + query.terms[intersection.second], + scorer); + }); + auto merged_bigrams = + v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto term_idx) { + auto payload = cursor.payload(); + acc[essential_bigrams[term_idx].first] = std::get<0>(payload); + acc[essential_bigrams[term_idx].second] = std::get<1>(payload); + return acc; + }); + + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + acc[idx] = payload[idx]; + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); + + std::vector lookup_cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(lookup_cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + + v1::for_each(merged, [&](auto& cursor) { + auto docid = cursor.value(); + auto partial_scores = cursor.payload(); + float score = 0.0F; + for (auto idx = 0; idx < partial_scores.size(); idx += 1) { + if (partial_scores[idx] > 0.0F) { + score += partial_scores[idx]; + } else { + lookup_cursors[idx].advance_to_geq(docid); + if (lookup_cursors[idx].value() == docid) { + score += lookup_cursors[idx].payload(); + } + } + } + topk.insert(score, docid); + }); + return topk; +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/zip_cursor.hpp b/include/pisa/v1/zip_cursor.hpp index e9e04b511..fdc66afaa 100644 --- a/include/pisa/v1/zip_cursor.hpp +++ b/include/pisa/v1/zip_cursor.hpp @@ -8,9 +8,9 @@ namespace pisa::v1 { template struct ZipCursor { - using Value = std::tuple())>; + using Value = std::tuple())...>; - explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors...)) {} + explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors)...) {} [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Value @@ -40,4 +40,10 @@ struct ZipCursor { std::tuple m_cursors; }; +template +auto zip(Cursors... cursors) +{ + return ZipCursor(cursors...); +} + } // namespace pisa::v1 diff --git a/src/v1/query.cpp b/src/v1/query.cpp new file mode 100644 index 000000000..a9c68f2c4 --- /dev/null +++ b/src/v1/query.cpp @@ -0,0 +1,13 @@ +#include "v1/query.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector +{ + auto terms = query.terms; + std::sort(terms.begin(), terms.end()); + terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); + return terms; +} + +} // namespace pisa::v1 diff --git a/test/test_v1.cpp b/test/test_v1.cpp index 754b2e4d1..22e9bb87d 100644 --- a/test/test_v1.cpp +++ b/test/test_v1.cpp @@ -38,7 +38,7 @@ using pisa::v1::UnalignedSpan; using pisa::v1::Writer; template -std::ostream &operator<<(std::ostream &os, tl::optional const &val) +std::ostream& operator<<(std::ostream& os, tl::optional const& val) { if (val.has_value()) { os << val.value(); @@ -81,15 +81,15 @@ TEST_CASE("Binary collection index", "[.][v1][unit]") REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) == collect(index.cursor(term_id))); REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) - == collect(index.cursor(term_id), [](auto &&cursor) { return cursor.payload(); })); + == collect(index.cursor(term_id), [](auto&& cursor) { return cursor.payload(); })); term_id += 1; } } TEST_CASE("Bigram collection index", "[.][v1][unit]") { - auto intersect = [](auto const &lhs, - auto const &rhs) -> std::vector> { + auto intersect = [](auto const& lhs, + auto const& rhs) -> std::vector> { std::vector> intersection; auto left = lhs.begin(); auto right = rhs.begin(); @@ -106,7 +106,7 @@ TEST_CASE("Bigram collection index", "[.][v1][unit]") } return intersection; }; - auto to_vec = [](auto const &seq) { + auto to_vec = [](auto const& seq) { std::vector> vec; std::transform(seq.docs.begin(), seq.docs.end(), @@ -131,7 +131,7 @@ TEST_CASE("Bigram collection index", "[.][v1][unit]") if (not intersection.empty()) { auto id = index.bigram_id(term_id - 1, term_id); REQUIRE(id.has_value()); - auto postings = collect(index.cursor(*id), [](auto &cursor) { + auto postings = collect(index.cursor(*id), [](auto& cursor) { auto freqs = cursor.payload(); return std::make_tuple(*cursor, freqs[0], freqs[1]); }); @@ -276,16 +276,21 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") { auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( - reinterpret_cast(source[0].data()), source[0].size()); + reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( - reinterpret_cast(source[1].data()), source[1].size()); + reinterpret_cast(source[1].data()), source[1].size()); IndexRunner runner(document_offsets, frequency_offsets, + {}, + {}, document_span, payload_span, + {}, + {}, document_sizes, tl::nullopt, + tl::nullopt, std::move(source), RawReader{}, RawReader{}); // Repeat to test that it only @@ -302,7 +307,7 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") REQUIRE( std::vector(sequence.freqs.begin(), sequence.freqs.end()) == collect(index.cursor(term_id), - [](auto &&cursor) { return cursor.payload(); })); + [](auto&& cursor) { return cursor.payload(); })); term_id += 1; } }); @@ -313,15 +318,20 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") { auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( - reinterpret_cast(source[0].data()), source[0].size()); + reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( - reinterpret_cast(source[1].data()), source[1].size()); + reinterpret_cast(source[1].data()), source[1].size()); IndexRunner runner(document_offsets, frequency_offsets, + {}, + {}, document_span, payload_span, + {}, + {}, document_sizes, tl::nullopt, + tl::nullopt, std::move(source), RawReader{}); // Correct encoding but not type! REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); diff --git a/test/test_v1_blocked_cursor.cpp b/test/test_v1_blocked_cursor.cpp index f4f8cd71a..fa40ef9d5 100644 --- a/test/test_v1_blocked_cursor.cpp +++ b/test/test_v1_blocked_cursor.cpp @@ -57,24 +57,11 @@ TEST_CASE("Build single-block blocked document file", "[v1][unit]") TEST_CASE("Build blocked document-frequency index", "[v1][unit]") { - // Temporary_Directory tmpdir; - - //{ - // std::vector document_data{1, 1, 4, 1, 3, 6, 11}; - // std::vector frequency_data{4, 5, 4, 3, 2}; - // std::ofstream dos((tmpdir.path() / "x.docs").string()); - // std::ofstream fos((tmpdir.path() / "x.freqs").string()); - // dos.write(reinterpret_cast(document_data.data()), document_data.size() * 4); - // fos.write(reinterpret_cast(frequency_data.data()), frequency_data.size() * - // 4); - //} - using sink_type = boost::iostreams::back_insert_device>; using vector_stream_type = boost::iostreams::stream; GIVEN("A test binary collection") { pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); - // pisa::binary_freq_collection collection((tmpdir.path() / "x").string().c_str()); WHEN("Built posting files for documents and frequencies") { std::vector docbuf; @@ -104,20 +91,13 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") auto documents = gsl::span(docbuf).subspan(8); auto frequencies = gsl::span(freqbuf).subspan(8); - //{ - // std::ofstream dos("/home/elshize/test.documents"); - // std::ofstream fos("/home/elshize/test.frequencies"); - // dos.write(reinterpret_cast(docbuf.data()), docbuf.size()); - // fos.write(reinterpret_cast(freqbuf.data()), freqbuf.size()); - //} - THEN("The values read back are euqual to the binary collection's") { CHECK(docbuf.size() == document_offsets.back() + 8); BlockedReader document_reader; BlockedReader frequency_reader; auto term = 0; - std::for_each(collection.begin(), collection.end(), [&](auto &&seq) { + std::for_each(collection.begin(), collection.end(), [&](auto&& seq) { std::vector expected_documents(seq.docs.begin(), seq.docs.end()); auto actual_documents = collect(document_reader.read( documents.subspan(document_offsets[term], @@ -139,16 +119,21 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") { auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( - reinterpret_cast(source[0].data()), source[0].size()); + reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( - reinterpret_cast(source[1].data()), source[1].size()); + reinterpret_cast(source[1].data()), source[1].size()); IndexRunner runner(document_offsets, frequency_offsets, + {}, + {}, document_span, payload_span, + {}, + {}, document_sizes, tl::nullopt, + tl::nullopt, std::move(source), BlockedReader{}, BlockedReader{}); @@ -164,9 +149,22 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") == collect(index.cursor(term_id))); REQUIRE( std::vector(sequence.freqs.begin(), sequence.freqs.end()) - //== collect(index.payloads(term_id))); == collect(index.cursor(term_id), - [](auto &&cursor) { return cursor.payload(); })); + [](auto&& cursor) { return cursor.payload(); })); + { + auto cursor = index.cursor(term_id); + for (auto doc : sequence.docs) { + cursor.advance_to_geq(doc); + REQUIRE(cursor.value() == doc); + } + } + { + auto cursor = index.cursor(term_id); + for (auto doc : sequence.docs) { + REQUIRE(cursor.value() == doc); + cursor.advance_to_geq(doc + 1); + } + } term_id += 1; } }); @@ -177,15 +175,20 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") { auto source = std::array, 2>{docbuf, freqbuf}; auto document_span = gsl::span( - reinterpret_cast(source[0].data()), source[0].size()); + reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( - reinterpret_cast(source[1].data()), source[1].size()); + reinterpret_cast(source[1].data()), source[1].size()); IndexRunner runner(document_offsets, frequency_offsets, + {}, + {}, document_span, payload_span, + {}, + {}, document_sizes, tl::nullopt, + tl::nullopt, std::move(source), RawReader{}); // Correct encoding but not type! REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); diff --git a/test/test_v1_index.cpp b/test/test_v1_index.cpp index 6903630d3..c0cdede87 100644 --- a/test/test_v1_index.cpp +++ b/test/test_v1_index.cpp @@ -40,7 +40,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") 8, make_writer(RawWriter{}), make_writer(RawWriter{})); - auto meta = IndexMetadata::from_file((tmpdir.path() / "index.ini").string()); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.yml").string()); REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); @@ -75,7 +75,7 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") 8, make_writer(BlockedWriter<::pisa::simdbp_block, true>{}), make_writer(BlockedWriter<::pisa::simdbp_block, false>{})); - auto meta = IndexMetadata::from_file((tmpdir.path() / "index.ini").string()); + auto meta = IndexMetadata::from_file((tmpdir.path() / "index.yml").string()); REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); diff --git a/test/test_v1_queries.cpp b/test/test_v1_queries.cpp index 3f15034ef..acfc74e4f 100644 --- a/test/test_v1_queries.cpp +++ b/test/test_v1_queries.cpp @@ -56,21 +56,21 @@ struct IndexFixture { auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", index_basename); REQUIRE(errors.empty()); - auto meta = v1::IndexMetadata::from_file(fmt::format("{}.ini", index_basename)); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); auto run = v1::index_runner(meta, document_reader(), frequency_reader()); auto postings_path = fmt::format("{}.bm25", index_basename); auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); - run([&](auto &&index) { + run([&](auto&& index) { std::ofstream score_file_stream(postings_path); auto offsets = score_index(index, score_file_stream, ScoreWriter{}, make_bm25(index)); v1::write_span(gsl::span(offsets), offsets_path); }); meta.scores.push_back( v1::PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); - meta.write(fmt::format("{}.ini", index_basename)); + meta.write(fmt::format("{}.yml", index_basename)); } - [[nodiscard]] auto const &tmpdir() const { return *m_tmpdir; } + [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } [[nodiscard]] auto document_reader() const { return m_document_reader; } [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } [[nodiscard]] auto score_reader() const { return m_score_reader; } @@ -86,7 +86,7 @@ struct IndexFixture { { std::vector queries; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -112,7 +112,7 @@ struct IndexData { { typename v0_Index::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( @@ -122,7 +122,7 @@ struct IndexData { term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -158,7 +158,7 @@ std::unique_ptr> IndexData::data = nullptr; TEMPLATE_TEST_CASE( - "DAAT OR2", + "DAAT OR", "[v1][integration]", (IndexFixture, v1::RawCursor, v1::RawCursor>), (IndexFixture, @@ -171,10 +171,10 @@ TEMPLATE_TEST_CASE( v1::Index, v1::RawCursor>>::get(); TestType fixture; auto index_basename = (fixture.tmpdir().path() / "inv").string(); - auto meta = v1::IndexMetadata::from_file(fmt::format("{}.ini", index_basename)); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); ranked_or_query or_q(10); int idx = 0; - for (auto const &q : test_queries()) { + for (auto const& q : test_queries()) { CAPTURE(q.terms); CAPTURE(idx++); @@ -186,7 +186,7 @@ TEMPLATE_TEST_CASE( auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; - run([&](auto &&index) { + run([&](auto&& index) { auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), make_bm25(index)); que.finalize(); results = que.topk(); @@ -199,7 +199,7 @@ TEMPLATE_TEST_CASE( auto run = v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); std::vector results; - run([&](auto &&index) { + run([&](auto&& index) { auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), v1::VoidScorer{}); que.finalize(); results = que.topk(); @@ -218,3 +218,72 @@ TEMPLATE_TEST_CASE( } } } + +TEMPLATE_TEST_CASE( + "UnionLookup", + "[v1][integration]", + (IndexFixture, v1::RawCursor, v1::RawCursor>), + (IndexFixture, + v1::BlockedCursor<::pisa::simdbp_block, false>, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + auto data = IndexData, v1::RawCursor>, + v1::Index, v1::RawCursor>>::get(); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + ranked_or_query or_q(10); + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.terms); + CAPTURE(idx++); + + or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); + auto expected = or_q.topk(); + std::sort(expected.begin(), expected.end(), std::greater{}); + + auto on_the_fly = [&]() { + auto run = + v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + std::vector results; + run([&](auto&& index) { + std::vector unigrams(q.terms.size()); + std::iota(unigrams.begin(), unigrams.end(), 0); + auto que = union_lookup(v1::Query{q.terms}, + index, + topk_queue(10), + make_bm25(index), + std::move(unigrams), + {}); + que.finalize(); + results = que.topk(); + std::sort(results.begin(), results.end(), std::greater{}); + }); + return results; + }(); + + // auto precomputed = [&]() { + // auto run = + // v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); + // std::vector results; + // run([&](auto&& index) { + // auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), v1::VoidScorer{}); + // que.finalize(); + // results = que.topk(); + // std::sort(results.begin(), results.end(), std::greater{}); + // }); + // return results; + //}(); + + REQUIRE(expected.size() == on_the_fly.size()); + // REQUIRE(expected.size() == precomputed.size()); + for (size_t i = 0; i < on_the_fly.size(); ++i) { + REQUIRE(on_the_fly[i].second == expected[i].second); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + // REQUIRE(precomputed[i].second == expected[i].second); + // REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + } + } +} From 024d01743bd2473b9ed9e0ccfa4b71647d9e47ed Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 20 Nov 2019 14:58:38 -0500 Subject: [PATCH 22/56] Max scores + maxscore + union-lookup --- .clang-tidy | 16 + CMakeLists.txt | 5 +- include/pisa/payload_vector.hpp | 8 +- include/pisa/v1/algorithm.hpp | 27 ++ include/pisa/v1/cursor/scoring_cursor.hpp | 32 ++ include/pisa/v1/cursor_accumulator.hpp | 16 + include/pisa/v1/cursor_union.hpp | 14 +- include/pisa/v1/index.hpp | 274 ++++++----------- include/pisa/v1/index_builder.hpp | 1 - include/pisa/v1/index_metadata.hpp | 38 ++- include/pisa/v1/maxscore.hpp | 278 ++++++++++++++++++ include/pisa/v1/query.hpp | 121 -------- include/pisa/v1/score_index.hpp | 9 + include/pisa/v1/union_lookup.hpp | 151 ++++++++++ src/v1/index_metadata.cpp | 40 ++- src/v1/score_index.cpp | 93 ++++++ test/v1/CMakeLists.txt | 27 ++ test/v1/index_fixture.hpp | 64 ++++ test/{ => v1}/test_v1.cpp | 91 +----- test/{ => v1}/test_v1_blocked_cursor.cpp | 6 +- .../test_v1_document_payload_cursor.cpp | 4 +- test/{ => v1}/test_v1_index.cpp | 45 +-- test/v1/test_v1_maxscore_join.cpp | 132 +++++++++ test/{ => v1}/test_v1_queries.cpp | 212 ++++++------- test/v1/test_v1_score_index.cpp | 84 ++++++ v1/CMakeLists.txt | 7 +- v1/app.hpp | 37 +++ v1/bigram_index.cpp | 10 +- v1/compress.cpp | 5 +- v1/query.cpp | 123 +++++--- v1/score.cpp | 81 +---- v1/union_lookup.cpp | 227 ++++++++++++++ 32 files changed, 1601 insertions(+), 677 deletions(-) create mode 100644 .clang-tidy create mode 100644 include/pisa/v1/algorithm.hpp create mode 100644 include/pisa/v1/cursor_accumulator.hpp create mode 100644 include/pisa/v1/maxscore.hpp create mode 100644 include/pisa/v1/score_index.hpp create mode 100644 include/pisa/v1/union_lookup.hpp create mode 100644 src/v1/score_index.cpp create mode 100644 test/v1/CMakeLists.txt create mode 100644 test/v1/index_fixture.hpp rename test/{ => v1}/test_v1.cpp (77%) rename test/{ => v1}/test_v1_blocked_cursor.cpp (97%) rename test/{ => v1}/test_v1_document_payload_cursor.cpp (97%) rename test/{ => v1}/test_v1_index.cpp (70%) create mode 100644 test/v1/test_v1_maxscore_join.cpp rename test/{ => v1}/test_v1_queries.cpp (57%) create mode 100644 test/v1/test_v1_score_index.cpp create mode 100644 v1/app.hpp create mode 100644 v1/union_lookup.cpp diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 000000000..df6d7beae --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,16 @@ +HeaderFilterRegex: '.*include/pisa.*\.hpp' +Checks: | + *, + -clang-diagnostic-c++17-extensions, + -llvm-header-guard, + -cppcoreguidelines-pro-type-reinterpret-cast, + -google-runtime-references, + -fuchsia-*, + -google-readability-namespace-comments, + -llvm-namespace-comment, + -clang-diagnostic-error, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-avoid-magic-numbers, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -modernize-use-trailing-return-type, + -misc-non-private-member-variables-in-classes diff --git a/CMakeLists.txt b/CMakeLists.txt index 9faabf7de..eba106995 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") # Extensive warnings - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces -Wfatal-errors") if (USE_SANITIZERS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") @@ -108,7 +108,8 @@ add_subdirectory(v1) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() - add_subdirectory(test) + #add_subdirectory(test) + add_subdirectory(test/v1) endif() if (PISA_ENABLE_BENCHMARKING) diff --git a/include/pisa/payload_vector.hpp b/include/pisa/payload_vector.hpp index 58c3cd88b..0e8d8f593 100644 --- a/include/pisa/payload_vector.hpp +++ b/include/pisa/payload_vector.hpp @@ -81,13 +81,13 @@ namespace detail { [[nodiscard]] constexpr auto operator*() -> value_type { - if constexpr (std::is_trivially_copyable_v) { + if constexpr (std::is_same_v) { + return value_type(reinterpret_cast(&*payload_iter), + *std::next(offset_iter) - *offset_iter); + } else { value_type value; std::memcpy(&value, reinterpret_cast(&*payload_iter), sizeof(value)); return value; - } else { - return value_type(reinterpret_cast(&*payload_iter), - *std::next(offset_iter) - *offset_iter); } } diff --git a/include/pisa/v1/algorithm.hpp b/include/pisa/v1/algorithm.hpp new file mode 100644 index 000000000..19a3f93be --- /dev/null +++ b/include/pisa/v1/algorithm.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace pisa::v1 { + +template +[[nodiscard]] auto min_value(Container&& cursors) +{ + auto pos = + std::min_element(cursors.begin(), cursors.end(), [](auto const& lhs, auto const& rhs) { + return lhs.value() < rhs.value(); + }); + return pos->value(); +} + +template +[[nodiscard]] auto min_sentinel(Container&& cursors) +{ + auto pos = + std::min_element(cursors.begin(), cursors.end(), [](auto const& lhs, auto const& rhs) { + return lhs.sentinel() < rhs.sentinel(); + }); + return pos->sentinel(); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index dbe84ea5c..bfd9516d1 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -44,4 +44,36 @@ struct ScoringCursor { TermScorer m_scorer; }; +template +struct MaxScoreCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(std::declval().payload()); + + constexpr MaxScoreCursor(BaseCursor base_cursor, ScoreT max_score) + : m_base_cursor(std::move(base_cursor)), m_max_score(max_score) + { + } + [[nodiscard]] constexpr auto operator*() const -> Document { return m_base_cursor.value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() noexcept { return m_base_cursor.payload(); } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + [[nodiscard]] constexpr auto max_score() const -> ScoreT { return m_max_score; } + + private: + BaseCursor m_base_cursor; + float m_max_score; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_accumulator.hpp b/include/pisa/v1/cursor_accumulator.hpp new file mode 100644 index 000000000..5eb1df5a4 --- /dev/null +++ b/include/pisa/v1/cursor_accumulator.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace pisa::v1::accumulate { + +struct Add { + template + auto operator()(Score&& score, Cursor&& cursor, std::size_t /* term_idx */) + { + score += cursor.payload(); + return score; + } +}; + +} // namespace pisa::v1::accumulate diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index 1cf9596d5..cd30d090e 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -7,6 +7,7 @@ #include #include "util/likely.hpp" +#include "v1/algorithm.hpp" namespace pisa::v1 { @@ -29,17 +30,8 @@ struct CursorUnion { if (m_cursors.empty()) { m_current_value = std::numeric_limits::max(); } else { - auto order = [](auto const& lhs, auto const& rhs) { return lhs.value() < rhs.value(); }; - m_next_docid = [&]() { - auto pos = std::min_element(m_cursors.begin(), m_cursors.end(), order); - return pos->value(); - }(); - m_sentinel = std::min_element(m_cursors.begin(), - m_cursors.end(), - [](auto const& lhs, auto const& rhs) { - return lhs.sentinel() < rhs.sentinel(); - }) - ->sentinel(); + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); advance(); } } diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index f355d8f23..a5a59d494 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -41,6 +41,18 @@ using BinarySpan = gsl::span; return static_cast(sum) / lengths.size(); } +[[nodiscard]] inline auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool +{ + if (std::get<0>(lhs) < std::get<0>(rhs)) { + return true; + } + if (std::get<0>(lhs) > std::get<0>(rhs)) { + return false; + } + return std::get<1>(lhs) < std::get<1>(rhs); +} + /// A generic type for an inverted index. /// /// \tparam DocumentReader Type of an object that reads document posting lists from bytes @@ -81,7 +93,9 @@ struct Index { tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, - tl::optional<::pisa::Payload_Vector>> bigram_mapping, + std::unordered_map> max_scores, + gsl::span quantized_max_scores, + tl::optional const>> bigram_mapping, Source source) : m_document_reader(std::move(document_reader)), m_payload_reader(std::move(payload_reader)), @@ -97,6 +111,8 @@ struct Index { m_avg_document_length(avg_document_length.map_or_else( [](auto&& self) { return self; }, [&]() { return calc_avg_length(m_document_lengths); })), + m_max_scores(std::move(max_scores)), + m_max_quantized_scores(quantized_max_scores), m_bigram_mapping(bigram_mapping), m_source(std::move(source)) { @@ -109,6 +125,15 @@ struct Index { payloads(term)); } + [[nodiscard]] auto cursors(gsl::span terms) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return cursor(term); + }); + return cursors; + } + [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const { if (not m_bigram_mapping) { @@ -116,7 +141,8 @@ struct Index { } if (auto pos = std::lower_bound(m_bigram_mapping->begin(), m_bigram_mapping->end(), - std::array{left_term, right_term}); + std::array{left_term, right_term}, + compare_arrays); pos != m_bigram_mapping->end()) { auto bigram_id = std::distance(m_bigram_mapping->begin(), pos); return DocumentPayloadCursor>( @@ -135,6 +161,16 @@ struct Index { return ScoringCursor(cursor(term), std::forward(scorer).term_scorer(term)); } + template + [[nodiscard]] auto scoring_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return scoring_cursor(term, scorer); + }); + return cursors; + } + /// This is equivalent to the `scoring_cursor` unless the scorer is of type `VoidScorer`, /// in which case index payloads are treated as scores. template @@ -147,6 +183,41 @@ struct Index { } } + template + [[nodiscard]] auto scored_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return scored_cursor(term, scorer); + }); + return cursors; + } + + template + [[nodiscard]] auto max_scored_cursor(TermId term, Scorer&& scorer) const + { + using cursor_type = + std::decay_t(scorer)))>; + if constexpr (std::is_convertible_v) { + return MaxScoreCursor( + scored_cursor(term, std::forward(scorer)), m_max_quantized_scores[term]); + } else { + return MaxScoreCursor( + scored_cursor(term, std::forward(scorer)), + m_max_scores.at(std::hash>{}(scorer))[term]); + } + } + + template + [[nodiscard]] auto max_scored_cursors(gsl::span terms, Scorer&& scorer) const + { + std::vector cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return max_scored_cursor(term, scorer); + }); + return cursors; + } + /// Constructs a new document-score cursor. template [[nodiscard]] auto scoring_bigram_cursor(TermId left_term, @@ -261,7 +332,9 @@ struct Index { gsl::span m_document_lengths; float m_avg_document_length; - tl::optional<::pisa::Payload_Vector>> m_bigram_mapping; + std::unordered_map> m_max_scores; + gsl::span m_max_quantized_scores; + tl::optional const>> m_bigram_mapping; std::any m_source; }; @@ -278,7 +351,9 @@ auto make_index(DocumentReader document_reader, tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, - tl::optional<::pisa::Payload_Vector>> bigram_mapping, + std::unordered_map> max_scores, + gsl::span quantized_max_scores, + tl::optional const>> bigram_mapping, Source source) { using DocumentCursor = @@ -296,6 +371,8 @@ auto make_index(DocumentReader document_reader, bigram_frequencies, document_lengths, avg_document_length, + max_scores, + quantized_max_scores, bigram_mapping, std::move(source)); } @@ -393,177 +470,6 @@ inline auto read_sizes(std::string_view basename) return source; } -[[nodiscard]] inline auto binary_collection_index(std::string const& basename) -{ - auto source = binary_collection_source(basename); - auto documents = gsl::span(source.bytes[0]); - auto frequencies = gsl::span(source.bytes[1]); - auto document_offsets = gsl::span(source.offsets[0]); - auto frequency_offsets = gsl::span(source.offsets[1]); - auto sizes = gsl::span(source.sizes[0]); - return Index, RawCursor>(RawReader{}, - RawReader{}, - document_offsets, - frequency_offsets, - {}, - {}, - documents.subspan(8), - frequencies.subspan(8), - {}, - {}, - sizes, - tl::nullopt, - tl::nullopt, - std::move(source)); -} - -[[nodiscard]] inline auto binary_collection_scored_index(std::string const& basename) -{ - using sink_type = boost::iostreams::back_insert_device>; - using vector_stream_type = boost::iostreams::stream; - - auto source = binary_collection_source(basename); - auto documents = gsl::span(source.bytes[0]); - auto frequencies = gsl::span(source.bytes[1]); - auto sizes = gsl::span(source.sizes[0]); - auto document_offsets = gsl::span(source.offsets[0]); - auto frequency_offsets = gsl::span(source.offsets[1]); - auto freq_index = Index, RawCursor>(RawReader{}, - RawReader{}, - document_offsets, - frequency_offsets, - {}, - {}, - documents.subspan(8), - frequencies.subspan(8), - {}, - {}, - sizes, - tl::nullopt, - tl::nullopt, - false); - - source.offsets.push_back([&freq_index, &source]() { - vector_stream_type score_stream{sink_type{source.bytes.emplace_back()}}; - return score_index(freq_index, score_stream, RawWriter{}, make_bm25(freq_index)); - }()); - auto scores = gsl::span(source.bytes.back()); - - document_offsets = gsl::span(source.offsets[0]); - auto score_offsets = gsl::span(source.offsets[2]); - return Index, RawCursor>(RawReader{}, - RawReader{}, - document_offsets, - score_offsets, - {}, - {}, - documents.subspan(8), - scores.subspan(8), - {}, - {}, - sizes, - tl::nullopt, - tl::nullopt, - std::move(source)); -} - -template -struct BigramIndex : public Index { - using PairMapping = std::vector>; - - BigramIndex(Index index, PairMapping pair_mapping) - : Index(std::move(index)), m_pair_mapping(std::move(pair_mapping)) - { - } - - [[nodiscard]] auto bigram_id(TermId left, TermId right) -> tl::optional - { - auto pos = - std::find(m_pair_mapping.begin(), m_pair_mapping.end(), std::make_pair(left, right)); - if (pos != m_pair_mapping.end()) { - return tl::make_optional(std::distance(m_pair_mapping.begin(), pos)); - } - return tl::nullopt; - } - - private: - PairMapping m_pair_mapping; -}; - -/// Creates, on the fly, a bigram index with all pairs of adjecent terms. -/// Disclaimer: for testing purposes. -[[nodiscard]] inline auto binary_collection_bigram_index(std::string const& basename) -{ - using payload_type = std::array; - using sink_type = boost::iostreams::back_insert_device>; - using vector_stream_type = boost::iostreams::stream; - - auto unigram_index = binary_collection_index(basename); - - std::vector> pair_mapping; - std::vector docbuf; - std::vector freqbuf; - - PostingBuilder document_builder(RawWriter{}); - PostingBuilder frequency_builder(RawWriter{}); - { - vector_stream_type docstream{sink_type{docbuf}}; - vector_stream_type freqstream{sink_type{freqbuf}}; - - document_builder.write_header(docstream); - frequency_builder.write_header(freqstream); - - std::for_each(boost::counting_iterator(0), - boost::counting_iterator(unigram_index.num_terms() - 1), - [&](auto left) { - auto right = left + 1; - auto intersection = CursorIntersection( - std::vector{unigram_index.cursor(left), unigram_index.cursor(right)}, - payload_type{0, 0}, - [](payload_type& payload, auto& cursor, auto list_idx) { - payload[list_idx] = cursor.payload(); - return payload; - }); - if (intersection.empty()) { - // Include only non-empty intersections. - return; - } - pair_mapping.emplace_back(left, right); - for_each(intersection, [&](auto& cursor) { - document_builder.accumulate(*cursor); - frequency_builder.accumulate(cursor.payload()); - }); - document_builder.flush_segment(docstream); - frequency_builder.flush_segment(freqstream); - }); - } - - VectorSource source{ - {std::move(docbuf), std::move(freqbuf)}, - {std::move(document_builder.offsets()), std::move(frequency_builder.offsets())}, - {read_sizes(basename)}}; - auto document_span = gsl::span(source.bytes[0]); - auto payload_span = gsl::span(source.bytes[1]); - auto document_offsets = gsl::span(source.offsets[0]); - auto frequency_offsets = gsl::span(source.offsets[1]); - auto sizes = gsl::span(source.sizes[0]); - auto index = Index, RawCursor>(RawReader{}, - RawReader{}, - document_offsets, - frequency_offsets, - {}, - {}, - document_span.subspan(8), - payload_span.subspan(8), - {}, - {}, - sizes, - tl::nullopt, - tl::nullopt, - std::move(source)); - return BigramIndex(std::move(index), std::move(pair_mapping)); -} - template struct IndexRunner { template @@ -577,7 +483,9 @@ struct IndexRunner { tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, - tl::optional<::pisa::Payload_Vector>> bigram_mapping, + std::unordered_map> max_scores, + gsl::span quantized_max_scores, + tl::optional> const> bigram_mapping, Source source, Readers... readers) : m_document_offsets(document_offsets), @@ -590,6 +498,8 @@ struct IndexRunner { m_bigram_frequencies(bigram_frequencies), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length), + m_max_scores(std::move(max_scores)), + m_max_quantized_scores(quantized_max_scores), m_bigram_mapping(bigram_mapping), m_source(std::move(source)), m_readers(readers...) @@ -606,7 +516,9 @@ struct IndexRunner { tl::optional> bigram_frequencies, gsl::span document_lengths, tl::optional avg_document_length, - tl::optional<::pisa::Payload_Vector>> bigram_mapping, + std::unordered_map> max_scores, + gsl::span quantized_max_scores, + tl::optional const>> bigram_mapping, Source source, std::tuple readers) : m_document_offsets(document_offsets), @@ -619,6 +531,8 @@ struct IndexRunner { m_bigram_frequencies(bigram_frequencies), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length), + m_max_scores(std::move(max_scores)), + m_max_quantized_scores(quantized_max_scores), m_bigram_mapping(bigram_mapping), m_source(std::move(source)), m_readers(std::move(readers)) @@ -651,6 +565,8 @@ struct IndexRunner { }), m_document_lengths, m_avg_document_length, + m_max_scores, + m_max_quantized_scores, m_bigram_mapping, false); fn(index); @@ -686,7 +602,9 @@ struct IndexRunner { gsl::span m_document_lengths; tl::optional m_avg_document_length; - tl::optional<::pisa::Payload_Vector>> m_bigram_mapping; + std::unordered_map> m_max_scores; + gsl::span m_max_quantized_scores; + tl::optional const>> m_bigram_mapping; std::any m_source; std::tuple m_readers; }; diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 4821d9e29..9ee6d3ffc 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -5,7 +5,6 @@ #include #include -#include #include #include diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index fadbd1c3e..59364936c 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -57,6 +57,8 @@ struct IndexMetadata { tl::optional document_lexicon{}; tl::optional stemmer{}; tl::optional bigrams{}; + std::map max_scores{}; + std::map quantized_max_scores{}; void write(std::string const& file); [[nodiscard]] static auto from_file(std::string const& file) -> IndexMetadata; @@ -95,7 +97,7 @@ template tl::optional, 2>> bigram_frequency_offsets{}; tl::optional> bigram_documents{}; tl::optional, 2>> bigram_frequencies{}; - tl::optional<::pisa::Payload_Vector>> bigram_mapping{}; + tl::optional const>> bigram_mapping{}; if (metadata.bigrams) { bigram_document_offsets = source_span(source, metadata.bigrams->documents.offsets); @@ -106,14 +108,18 @@ template bigram_frequencies = { source_span(source, metadata.bigrams->frequencies.first.postings), source_span(source, metadata.bigrams->frequencies.second.postings)}; - auto mapping_span = source_span(source, metadata.bigrams->mapping).subspan(8); - auto num_offset_bytes = (metadata.bigrams->count + 1U) * 8U; - auto mapping_offsets = mapping_span.first(num_offset_bytes); - bigram_mapping = Payload_Vector>( - gsl::span( - reinterpret_cast(mapping_offsets.data()), - mapping_offsets.size() * sizeof(std::size_t)), - mapping_span.subspan(num_offset_bytes)); + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + } + std::unordered_map> max_scores; + if (not metadata.max_scores.empty()) { + for (auto [name, file] : metadata.max_scores) { + auto bytes = source_span(source, file); + max_scores[std::hash{}(name)] = gsl::span( + reinterpret_cast(bytes.data()), bytes.size() / (sizeof(float))); + } } return IndexRunner(document_offsets, frequency_offsets, @@ -125,6 +131,8 @@ template bigram_frequencies, document_lengths, tl::make_optional(metadata.avg_document_length), + std::move(max_scores), + {}, bigram_mapping, std::move(source), std::move(readers)); @@ -142,10 +150,18 @@ template { MMapSource source; auto documents = source_span(source, metadata.documents.postings); + // TODO(michal): support many precomputed scores auto scores = source_span(source, metadata.scores.front().postings); auto document_offsets = source_span(source, metadata.documents.offsets); auto score_offsets = source_span(source, metadata.scores.front().offsets); auto document_lengths = source_span(source, metadata.document_lengths_path); + gsl::span quantized_max_scores; + if (not metadata.quantized_max_scores.empty()) { + // TODO(michal): support many precomputed scores + for (auto [name, file] : metadata.quantized_max_scores) { + quantized_max_scores = source_span(source, file); + } + } return IndexRunner(document_offsets, score_offsets, {}, @@ -156,7 +172,9 @@ template {}, document_lengths, tl::make_optional(metadata.avg_document_length), - {}, // TODO + {}, + quantized_max_scores, + {}, std::move(source), std::move(readers)); } diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp new file mode 100644 index 000000000..485800925 --- /dev/null +++ b/include/pisa/v1/maxscore.hpp @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +struct MaxScoreJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr MaxScoreJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_sorted_cursors(m_cursors.size()), + m_cursor_idx(m_cursors.size()), + m_upper_bounds(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + std::transform(m_cursors.begin(), + m_cursors.end(), + m_sorted_cursors.begin(), + [](auto&& cursor) { return &cursor; }); + std::sort(m_sorted_cursors.begin(), m_sorted_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs->max_score() < rhs->max_score(); + }); + std::iota(m_cursor_idx.begin(), m_cursor_idx.end(), 0); + std::sort(m_cursor_idx.begin(), m_cursor_idx.end(), [this](auto&& lhs, auto&& rhs) { + return m_cursors[lhs].max_score() < m_cursors[rhs].max_score(); + }); + + m_upper_bounds[0] = m_sorted_cursors[0]->max_score(); + for (size_t i = 1; i < m_sorted_cursors.size(); ++i) { + m_upper_bounds[i] = m_upper_bounds[i - 1] + m_sorted_cursors[i]->max_score(); + } + + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() + || m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + for (auto sorted_position = m_non_essential_count; + sorted_position < m_sorted_cursors.size(); + sorted_position += 1) { + + auto& cursor = m_sorted_cursors[sorted_position]; + if (cursor->value() == m_current_value) { + m_current_payload = + m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); + cursor->advance(); + } + if (auto docid = cursor->value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; + sorted_position -= 1) { + if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { + exit = false; + break; + } + auto& cursor = m_sorted_cursors[sorted_position]; + cursor->advance_to_geq(m_current_value); + if (cursor->value() == m_current_value) { + m_current_payload = + m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); + } + } + } + + while (m_non_essential_count < m_cursors.size() + && not m_above_threshold(m_upper_bounds[m_non_essential_count])) { + m_non_essential_count += 1; + } + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorContainer m_cursors; + std::vector m_sorted_cursors; + std::vector m_cursor_idx; + std::vector m_upper_bounds; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + std::size_t m_non_essential_count = 0; + payload_type m_previous_threshold{}; +}; + +template +auto join_maxscore(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return MaxScoreJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto maxscore(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + decltype(index.max_scored_cursor(0, scorer).max_score()) initial_threshold) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using score_type = decltype(index.max_scored_cursor(0, scorer).max_score()); + using value_type = decltype(index.max_scored_cursor(0, scorer).value()); + + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.max_scored_cursor(term, scorer); }); + + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + [](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }, + [&](auto score) { return topk.would_enter(score) && score > initial_threshold; }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +template +auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using score_type = decltype(index.max_scored_cursor(0, scorer).max_score()); + using value_type = decltype(index.max_scored_cursor(0, scorer).value()); + + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return index.max_scored_cursor(term, scorer); }); + + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + [](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }, + [&](auto score) { return topk.would_enter(score); }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; + // template MaxScoreJoin { + + // std::vector sorted_cursors; + // std::transform(cursors.begin(), + // cursors.end(), + // std::back_inserter(sorted_cursors), + // [](auto&& cursor) { return &cursor; }); + // std::sort(sorted_cursors.begin(), sorted_cursors.end(), [](auto&& lhs, auto&& rhs) { + // return lhs->max_score() < rhs->max_score(); + //}); + + // std::vector upper_bounds(sorted_cursors.size()); + //// upper_bounds.push_back(cursors); + //// for (auto* cursor : sorted_cursors) { + //// + ////} + // upper_bounds[0] = sorted_cursors[0]->max_score(); + // for (size_t i = 1; i < sorted_cursors.size(); ++i) { + // upper_bounds[i] = upper_bounds[i - 1] + sorted_cursors[i]->max_score(); + //} + //// std::partial_sum(sorted_cursors.begin(), + //// sorted_cursors.end(), + //// upper_bounds.begin(), + //// [](auto sum, auto* cursor) { return sum + cursor->max_score(); }); + // auto essential = gsl::make_span(sorted_cursors); + // auto non_essential = essential.subspan(essential.size()); // Empty + + // DocId current_docid = + // std::min_element(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { + // return *lhs < *rhs; + // })->value(); + + // while (not essential.empty() && current_docid < std::numeric_limits::max()) { + // score_type score{}; + // auto next_docid = std::numeric_limits::max(); + + // for (auto cursor : essential) { + // if (cursor->value() == current_docid) { + // score += cursor->payload(); + // cursor->advance(); + // } + // if (auto docid = cursor->value(); docid < next_docid) { + // next_docid = docid; + // } + // } + + // for (auto idx = non_essential.size() - 1; idx + 1 > 0; idx -= 1) { + // if (not topk.would_enter(score + upper_bounds[idx])) { + // break; + // } + // sorted_cursors[idx]->advance_to_geq(current_docid); + // if (sorted_cursors[idx]->value() == current_docid) { + // score += sorted_cursors[idx]->payload(); + // } + // } + // if (topk.insert(score, current_docid)) { + // //// update non-essential lists + // while (not essential.empty() + // && not topk.would_enter(upper_bounds[non_essential.size()])) { + // essential = essential.first(essential.size() - 1); + // non_essential = gsl::make_span(sorted_cursors).subspan(essential.size()); + // } + // } + // current_docid = next_docid; + //} + + // return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 84f72756f..14d92c5cf 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -19,32 +19,6 @@ struct Query { std::vector> bigrams{}; }; -template -using QueryProcessor = std::function; - -struct ExhaustiveConjunctiveProcessor { - template - auto operator()(Index const& index, Query const& query, topk_queue que) -> topk_queue - { - using Cursor = std::decay_t; - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&index](auto term_id) { return index.cursor(term_id); }); - auto intersection = - intersect(std::move(cursors), - 0.0F, - [](float score, auto& cursor, [[maybe_unused]] auto cursor_idx) { - return score + static_cast(cursor.payload()); - }); - while (not intersection.empty()) { - que.insert(intersection.payload(), *intersection); - } - return que; - } -}; - template auto daat_and(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { @@ -101,99 +75,4 @@ auto transform() { } -/// Performs a "union-lookup" query (name pending). -/// -/// \param query Full query, as received, possibly with duplicates. -/// \param index Inverted index, with access to both unigrams and bigrams. -/// \param topk Top-k heap. -/// \param scorer An object capable of constructing term scorers. -/// \param essential_unigrams A list of essential single-term posting lists. -/// Elements of this vector point to the index of the term -/// in the query. In other words, for each position `i` in this vector, -/// `query.terms[essential_unigrams[i]]` is an essential unigram. -/// \param essential_bigrams Similar to the above, but represents intersections between two -/// posting lists. These must exist in the index, or else this -/// algorithm will fail. -template -auto union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - std::vector essential_unigrams, - std::vector> essential_bigrams) -{ - ranges::sort(essential_unigrams); - ranges::actions::unique(essential_unigrams); - ranges::sort(essential_bigrams); - ranges::actions::unique(essential_bigrams); - - std::vector initial_payload(query.terms.size(), 0.0); - - std::vector essential_unigram_cursors; - std::transform(essential_unigrams.begin(), - essential_unigrams.end(), - std::back_inserter(essential_unigram_cursors), - [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); - auto merged_unigrams = v1::union_merge( - essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { - acc[essential_unigrams[term_idx]] = cursor.payload(); - return acc; - }); - - std::vector essential_bigram_cursors; - std::transform(essential_bigrams.begin(), - essential_bigrams.end(), - std::back_inserter(essential_bigram_cursors), - [&](auto intersection) { - return index.scored_bigram_cursor(query.terms[intersection.first], - query.terms[intersection.second], - scorer); - }); - auto merged_bigrams = - v1::union_merge(std::move(essential_bigram_cursors), - initial_payload, - [&](auto& acc, auto& cursor, auto term_idx) { - auto payload = cursor.payload(); - acc[essential_bigrams[term_idx].first] = std::get<0>(payload); - acc[essential_bigrams[term_idx].second] = std::get<1>(payload); - return acc; - }); - - auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { - auto payload = cursor.payload(); - for (auto idx = 0; idx < acc.size(); idx += 1) { - acc[idx] = payload[idx]; - } - return acc; - }; - auto merged = v1::variadic_union_merge( - initial_payload, - std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), - std::make_tuple(accumulate, accumulate)); - - std::vector lookup_cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(lookup_cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); - - v1::for_each(merged, [&](auto& cursor) { - auto docid = cursor.value(); - auto partial_scores = cursor.payload(); - float score = 0.0F; - for (auto idx = 0; idx < partial_scores.size(); idx += 1) { - if (partial_scores[idx] > 0.0F) { - score += partial_scores[idx]; - } else { - lookup_cursors[idx].advance_to_geq(docid); - if (lookup_cursors[idx].value() == docid) { - score += lookup_cursors[idx].payload(); - } - } - } - topk.insert(score, docid); - }); - return topk; -} - } // namespace pisa::v1 diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp new file mode 100644 index 000000000..bf1eeace6 --- /dev/null +++ b/include/pisa/v1/score_index.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace pisa::v1 { + +void score_index(std::string const& yml, std::size_t threads); + +} // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp new file mode 100644 index 000000000..1c6468bb7 --- /dev/null +++ b/include/pisa/v1/union_lookup.hpp @@ -0,0 +1,151 @@ +#pragma once + +#include "v1/query.hpp" + +namespace pisa::v1 { + +/// Performs a "union-lookup" query (name pending). +/// +/// \param query Full query, as received, possibly with duplicates. +/// \param index Inverted index, with access to both unigrams and bigrams. +/// \param topk Top-k heap. +/// \param scorer An object capable of constructing term scorers. +/// \param essential_unigrams A list of essential single-term posting lists. +/// Elements of this vector point to the index of the term +/// in the query. In other words, for each position `i` in this vector, +/// `query.terms[essential_unigrams[i]]` is an essential unigram. +/// \param essential_bigrams Similar to the above, but represents intersections between two +/// posting lists. These must exist in the index, or else this +/// algorithm will fail. +template +auto union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + std::vector essential_unigrams, + std::vector> essential_bigrams) +{ + ranges::sort(essential_unigrams); + ranges::actions::unique(essential_unigrams); + ranges::sort(essential_bigrams); + ranges::actions::unique(essential_bigrams); + + std::vector is_essential(query.terms.size(), false); + // std::cerr << "essential: "; + for (auto idx : essential_unigrams) { + // std::cerr << idx << ' '; + is_essential[idx] = true; + } + // std::cerr << '\n'; + + // std::vector initial_payload(query.terms.size(), 0.0); + + // std::vector essential_unigram_cursors; + // std::transform(essential_unigrams.begin(), + // essential_unigrams.end(), + // std::back_inserter(essential_unigram_cursors), + // [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); + // auto merged_unigrams = v1::union_merge( + // essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { + // acc[essential_unigrams[term_idx]] = cursor.payload(); + // return acc; + // }); + + std::vector essential_unigram_cursors; + std::transform(essential_unigrams.begin(), + essential_unigrams.end(), + std::back_inserter(essential_unigram_cursors), + [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); + // std::cerr << "No. essential: " << essential_unigram_cursors.size() << '\n'; + + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, 0.0F, [&](auto acc, auto& cursor, auto /*term_idx*/) { + // acc[essential_unigrams[term_idx]] = cursor.payload(); + return acc + cursor.payload(); + }); + + // std::vector essential_bigram_cursors; + // std::transform(essential_bigrams.begin(), + // essential_bigrams.end(), + // std::back_inserter(essential_bigram_cursors), + // [&](auto intersection) { + // return index.scored_bigram_cursor(query.terms[intersection.first], + // query.terms[intersection.second], + // scorer); + // }); + // auto merged_bigrams = + // v1::union_merge(std::move(essential_bigram_cursors), + // initial_payload, + // [&](auto& acc, auto& cursor, auto term_idx) { + // auto payload = cursor.payload(); + // acc[essential_bigrams[term_idx].first] = std::get<0>(payload); + // acc[essential_bigrams[term_idx].second] = std::get<1>(payload); + // return acc; + // }); + + // auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + // auto payload = cursor.payload(); + // for (auto idx = 0; idx < acc.size(); idx += 1) { + // acc[idx] = payload[idx]; + // } + // return acc; + //}; + // auto merged = v1::variadic_union_merge( + // initial_payload, + // std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + // std::make_tuple(accumulate, accumulate)); + + std::vector lookup_cursors; + for (auto idx = 0; idx < query.terms.size(); idx += 1) { + if (not is_essential[idx]) { + lookup_cursors.push_back(index.max_scored_cursor(query.terms[idx], scorer)); + } + } + // std::transform(query.terms.begin(), + // query.terms.end(), + // std::back_inserter(lookup_cursors), + // [&](auto term) { return index.scored_cursor(term, scorer); }); + + // v1::for_each(merged, [&](auto& cursor) { + v1::for_each(merged_unigrams, [&](auto& cursor) { + auto docid = cursor.value(); + auto score = cursor.payload(); + auto score_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), score, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + if (not topk.would_enter(score_bound)) { + return; + } + for (auto lookup_cursor : lookup_cursors) { + // lookup_cursor.advance(); + lookup_cursor.advance_to_geq(docid); + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + score += lookup_cursor.payload(); + } + if (not topk.would_enter(score - lookup_cursor.max_score())) { + return; + } + } + topk.insert(score, docid); + // auto docid = cursor.value(); + // auto partial_scores = cursor.payload(); + // float score = 0.0F; + // for (auto idx = 0; idx < partial_scores.size(); idx += 1) { + // score += partial_scores[idx]; + // // if (partial_scores[idx] > 0.0F) { + // // score += partial_scores[idx]; + // //} + // // else if (not is_essential[idx]) { + // // lookup_cursors[idx].advance_to_geq(docid); + // // if (lookup_cursors[idx].value() == docid) { + // // score += lookup_cursors[idx].payload(); + // // } + // //} + //} + // topk.insert(score, docid); + }); + return topk; +} + +} // namespace pisa::v1 diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 754bc9b04..c835c2f0d 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -17,6 +17,8 @@ constexpr char const* STATS = "stats"; constexpr char const* LEXICON = "lexicon"; constexpr char const* TERMS = "terms"; constexpr char const* BIGRAM = "bigram"; +constexpr char const* MAX_SCORES = "max_scores"; +constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; [[nodiscard]] auto resolve_yml(std::optional const& arg) -> std::string { @@ -64,18 +66,32 @@ constexpr char const* BIGRAM = "bigram"; .bigrams = [&]() -> tl::optional { if (config[BIGRAM]) { return BigramMetadata{ - .documents = {.postings = config[DOCUMENTS][POSTINGS].as(), - .offsets = config[DOCUMENTS][OFFSETS].as()}, + .documents = {.postings = config[BIGRAM][DOCUMENTS][POSTINGS].as(), + .offsets = config[BIGRAM][DOCUMENTS][OFFSETS].as()}, .frequencies = - {{.postings = config["frequencies_0"][POSTINGS].as(), - .offsets = config["frequencies_0"][OFFSETS].as()}, - {.postings = config["frequencies_1"][POSTINGS].as(), - .offsets = config["frequencies_1"][OFFSETS].as()}}, + {{.postings = config[BIGRAM]["frequencies_0"][POSTINGS].as(), + .offsets = config[BIGRAM]["frequencies_0"][OFFSETS].as()}, + {.postings = config[BIGRAM]["frequencies_1"][POSTINGS].as(), + .offsets = config[BIGRAM]["frequencies_1"][OFFSETS].as()}}, .mapping = config[BIGRAM]["mapping"].as(), .count = config[BIGRAM]["count"].as()}; } return tl::nullopt; - }()}; + }(), + .max_scores = + [&]() { + if (config[MAX_SCORES]) { + return config[MAX_SCORES].as>(); + } + return std::map{}; + }(), + .quantized_max_scores = + [&]() { + if (config[QUANTIZED_MAX_SCORES]) { + return config[QUANTIZED_MAX_SCORES].as>(); + } + return std::map{}; + }()}; } void IndexMetadata::write(std::string const& file) @@ -108,6 +124,16 @@ void IndexMetadata::write(std::string const& file) root[BIGRAM]["mapping"] = bigrams->mapping; root[BIGRAM]["count"] = bigrams->count; } + if (not max_scores.empty()) { + for (auto [key, value] : max_scores) { + root[MAX_SCORES][key] = value; + } + } + if (not quantized_max_scores.empty()) { + for (auto [key, value] : quantized_max_scores) { + root[QUANTIZED_MAX_SCORES][key] = value; + } + } std::ofstream fout(file); fout << root; } diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp new file mode 100644 index 000000000..aee096191 --- /dev/null +++ b/src/v1/score_index.cpp @@ -0,0 +1,93 @@ +#include + +#include "codec/simdbp.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/score_index.hpp" + +using pisa::v1::BlockedReader; +using pisa::v1::DefaultProgress; +using pisa::v1::IndexMetadata; +using pisa::v1::PostingFilePaths; +using pisa::v1::ProgressStatus; +using pisa::v1::RawReader; +using pisa::v1::RawWriter; +using pisa::v1::TermId; +using pisa::v1::write_span; + +namespace pisa::v1 { + +void score_index(std::string const& yml, std::size_t threads) +{ + auto meta = IndexMetadata::from_file(yml); + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + auto index_basename = yml.substr(0, yml.size() - 4); + auto postings_path = fmt::format("{}.bm25", index_basename); + auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); + auto max_scores_path = fmt::format("{}.bm25.maxf", index_basename); + auto quantized_max_scores_path = fmt::format("{}.bm25.maxq", index_basename); + run([&](auto&& index) { + ProgressStatus calc_max_status(index.num_terms(), + DefaultProgress("Calculating max partial score"), + std::chrono::milliseconds(100)); + std::vector max_scores(index.num_terms(), 0.0F); + tbb::task_group group; + auto batch_size = index.num_terms() / threads; + for (auto thread_id = 0; thread_id < threads; thread_id += 1) { + auto first_term = thread_id * batch_size; + auto end_term = + thread_id < threads - 1 ? (thread_id + 1) * batch_size : index.num_terms(); + std::for_each(boost::counting_iterator(first_term), + boost::counting_iterator(end_term), + [&](auto term) { + for_each( + index.scoring_cursor(term, make_bm25(index)), [&](auto& cursor) { + if (auto score = cursor.payload(); max_scores[term] < score) { + max_scores[term] = score; + } + }); + calc_max_status += 1; + }); + } + group.wait(); + auto max_score = *std::max_element(max_scores.begin(), max_scores.end()); + std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.", + max_score, + std::numeric_limits::max()); + + auto quantizer = [&](float score) { + return static_cast(score * std::numeric_limits::max() + / max_score); + }; + std::vector quantized_max_scores; + std::transform(max_scores.begin(), + max_scores.end(), + std::back_inserter(quantized_max_scores), + quantizer); + + ProgressStatus status( + index.num_terms(), DefaultProgress("Scoring"), std::chrono::milliseconds(100)); + std::ofstream score_file_stream(postings_path); + auto offsets = score_index(index, + score_file_stream, + RawWriter{}, + make_bm25(index), + quantizer, + [&]() { status += 1; }); + write_span(gsl::span(offsets), offsets_path); + write_span(gsl::span(max_scores), max_scores_path); + write_span(gsl::span(quantized_max_scores), quantized_max_scores_path); + }); + meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); + meta.max_scores["bm25"] = max_scores_path; + meta.quantized_max_scores["bm25"] = quantized_max_scores_path; + meta.write(yml); +} + +} // namespace pisa::v1 diff --git a/test/v1/CMakeLists.txt b/test/v1/CMakeLists.txt new file mode 100644 index 000000000..663367684 --- /dev/null +++ b/test/v1/CMakeLists.txt @@ -0,0 +1,27 @@ +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test) + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../external/Catch2/contrib) +include(CTest) +include(Catch) + +file(GLOB TEST_SOURCES test_*.cpp) +foreach(TEST_SRC ${TEST_SOURCES}) + get_filename_component (TEST_SRC_NAME ${TEST_SRC} NAME_WE) + add_executable(${TEST_SRC_NAME} ${TEST_SRC}) + target_link_libraries(${TEST_SRC_NAME} + pisa + Catch2 + ) + catch_discover_tests(${TEST_SRC_NAME} TEST_PREFIX "${TEST_SRC_NAME}:") + + # enable code coverage + add_coverage(${TEST_SRC_NAME}) +endforeach(TEST_SRC) + +target_link_libraries(test_v1 + pisa + Catch2 + optional +) + +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../test_data DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp new file mode 100644 index 000000000..89125c2eb --- /dev/null +++ b/test/v1/index_fixture.hpp @@ -0,0 +1,64 @@ +#pragma once + +#include "../temporary_directory.hpp" +#include "pisa_config.hpp" +#include "query/queries.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/query.hpp" +#include "v1/score_index.hpp" + +namespace v1 = pisa::v1; + +[[nodiscard]] inline auto test_queries() -> std::vector +{ + std::vector queries; + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + auto push_query = [&](std::string const& query_line) { + queries.push_back(pisa::parse_query_ids(query_line)); + }; + pisa::io::for_each_line(qfile, push_query); + return queries; +} + +template +struct IndexFixture { + using DocumentWriter = typename v1::CursorTraits::Writer; + using FrequencyWriter = typename v1::CursorTraits::Writer; + using ScoreWriter = typename v1::CursorTraits::Writer; + + using DocumentReader = typename v1::CursorTraits::Reader; + using FrequencyReader = typename v1::CursorTraits::Reader; + using ScoreReader = typename v1::CursorTraits::Reader; + + IndexFixture() : m_tmpdir(std::make_unique()) + { + auto index_basename = (tmpdir().path() / "inv").string(); + v1::compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", + PISA_SOURCE_DIR "/test/test_data/test_collection.fwd", + index_basename, + 2, + v1::make_writer(), + v1::make_writer()); + auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", + index_basename); + REQUIRE(errors.empty()); + v1::score_index(fmt::format("{}.yml", index_basename), 1); + } + + [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } + [[nodiscard]] auto document_reader() const { return m_document_reader; } + [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } + [[nodiscard]] auto score_reader() const { return m_score_reader; } + [[nodiscard]] auto meta() const + { + auto index_basename = (tmpdir().path() / "inv").string(); + return v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + } + + private: + std::unique_ptr m_tmpdir; + DocumentReader m_document_reader{}; + FrequencyReader m_frequency_reader{}; + ScoreReader m_score_reader{}; +}; diff --git a/test/test_v1.cpp b/test/v1/test_v1.cpp similarity index 77% rename from test/test_v1.cpp rename to test/v1/test_v1.cpp index 22e9bb87d..40c89643a 100644 --- a/test/test_v1.cpp +++ b/test/v1/test_v1.cpp @@ -61,93 +61,6 @@ TEST_CASE("RawReader", "[v1][unit]") REQUIRE(next(cursor) == tl::nullopt); } -TEST_CASE("Binary collection index", "[.][v1][unit]") -{ - pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); - auto index = - pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); - auto term_id = 0; - for (auto sequence : collection) { - CAPTURE(term_id); - REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) - == collect(index.documents(term_id))); - REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) - == collect(index.payloads(term_id))); - term_id += 1; - } - term_id = 0; - for (auto sequence : collection) { - CAPTURE(term_id); - REQUIRE(std::vector(sequence.docs.begin(), sequence.docs.end()) - == collect(index.cursor(term_id))); - REQUIRE(std::vector(sequence.freqs.begin(), sequence.freqs.end()) - == collect(index.cursor(term_id), [](auto&& cursor) { return cursor.payload(); })); - term_id += 1; - } -} - -TEST_CASE("Bigram collection index", "[.][v1][unit]") -{ - auto intersect = [](auto const& lhs, - auto const& rhs) -> std::vector> { - std::vector> intersection; - auto left = lhs.begin(); - auto right = rhs.begin(); - while (left != lhs.end() && right != rhs.end()) { - if (left->first == right->first) { - intersection.emplace_back(left->first, left->second, right->second); - ++right; - ++left; - } else if (left->first < right->first) { - ++left; - } else { - ++right; - } - } - return intersection; - }; - auto to_vec = [](auto const& seq) { - std::vector> vec; - std::transform(seq.docs.begin(), - seq.docs.end(), - seq.freqs.begin(), - std::back_inserter(vec), - [](auto doc, auto freq) { return std::make_pair(doc, freq); }); - return vec; - }; - - pisa::binary_freq_collection collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); - auto index = - pisa::v1::binary_collection_bigram_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); - - auto pos = collection.begin(); - auto prev = to_vec(*pos); - ++pos; - TermId term_id = 1; - for (; pos != collection.end(); ++pos, ++term_id) { - CAPTURE(term_id); - auto current = to_vec(*pos); - auto intersection = intersect(prev, current); - if (not intersection.empty()) { - auto id = index.bigram_id(term_id - 1, term_id); - REQUIRE(id.has_value()); - auto postings = collect(index.cursor(*id), [](auto& cursor) { - auto freqs = cursor.payload(); - return std::make_tuple(*cursor, freqs[0], freqs[1]); - }); - for (auto idx = 0; idx < 10; idx++) { - std::cout << std::get<1>(postings[idx]) << " " << std::get<1>(intersection[idx]) - << '\n'; - std::cout << std::get<2>(postings[idx]) << " " << std::get<2>(intersection[idx]) - << "\n---\n"; - } - REQUIRE(postings == intersection); - } - std::swap(prev, current); - break; - } -} - TEST_CASE("Test read header", "[v1][unit]") { { @@ -290,6 +203,8 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") {}, document_sizes, tl::nullopt, + {}, + {}, tl::nullopt, std::move(source), RawReader{}, @@ -331,6 +246,8 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") {}, document_sizes, tl::nullopt, + {}, + {}, tl::nullopt, std::move(source), RawReader{}); // Correct encoding but not type! diff --git a/test/test_v1_blocked_cursor.cpp b/test/v1/test_v1_blocked_cursor.cpp similarity index 97% rename from test/test_v1_blocked_cursor.cpp rename to test/v1/test_v1_blocked_cursor.cpp index fa40ef9d5..0dca9ca64 100644 --- a/test/test_v1_blocked_cursor.cpp +++ b/test/v1/test_v1_blocked_cursor.cpp @@ -7,9 +7,9 @@ #include #include +#include "../temporary_directory.hpp" #include "codec/simdbp.hpp" #include "pisa_config.hpp" -#include "temporary_directory.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" @@ -133,6 +133,8 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") {}, document_sizes, tl::nullopt, + {}, + {}, tl::nullopt, std::move(source), BlockedReader{}, @@ -188,6 +190,8 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") {}, document_sizes, tl::nullopt, + {}, + {}, tl::nullopt, std::move(source), RawReader{}); // Correct encoding but not type! diff --git a/test/test_v1_document_payload_cursor.cpp b/test/v1/test_v1_document_payload_cursor.cpp similarity index 97% rename from test/test_v1_document_payload_cursor.cpp rename to test/v1/test_v1_document_payload_cursor.cpp index cc71baae3..dcd4d8d9a 100644 --- a/test/test_v1_document_payload_cursor.cpp +++ b/test/v1/test_v1_document_payload_cursor.cpp @@ -7,9 +7,9 @@ #include #include +#include "../temporary_directory.hpp" #include "codec/simdbp.hpp" #include "pisa_config.hpp" -#include "temporary_directory.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" @@ -49,7 +49,7 @@ TEST_CASE("Document-payload cursor", "[v1][unit]") { std::vector collected_documents; std::vector collected_frequencies; - for_each(cursor, [&](auto &&cursor) { + for_each(cursor, [&](auto&& cursor) { collected_documents.push_back(cursor.value()); collected_frequencies.push_back(cursor.payload()); }); diff --git a/test/test_v1_index.cpp b/test/v1/test_v1_index.cpp similarity index 70% rename from test/test_v1_index.cpp rename to test/v1/test_v1_index.cpp index c0cdede87..cedb6cd1c 100644 --- a/test/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -7,9 +7,10 @@ #include #include +#include "../temporary_directory.hpp" +#include "binary_collection.hpp" #include "codec/simdbp.hpp" #include "pisa_config.hpp" -#include "temporary_directory.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" @@ -17,7 +18,7 @@ #include "v1/index_metadata.hpp" #include "v1/types.hpp" -using pisa::v1::binary_collection_index; +using pisa::binary_freq_collection; using pisa::v1::BlockedReader; using pisa::v1::BlockedWriter; using pisa::v1::compress_binary_collection; @@ -33,7 +34,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") { tbb::task_scheduler_init init(8); Temporary_Directory tmpdir; - auto bci = binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", (tmpdir.path() / "fwd").string(), (tmpdir.path() / "index").string(), @@ -51,15 +52,15 @@ TEST_CASE("Binary collection index", "[v1][unit]") BlockedReader{}, BlockedReader{}); run([&](auto index) { - REQUIRE(bci.num_documents() == index.num_documents()); - REQUIRE(bci.num_terms() == index.num_terms()); - REQUIRE(bci.avg_document_length() == index.avg_document_length()); - for (auto doc = 0; doc < bci.num_documents(); doc += 1) { - REQUIRE(bci.document_length(doc) == index.document_length(doc)); - } - for (auto term = 0; term < bci.num_terms(); term += 1) { - REQUIRE(collect(bci.documents(term)) == collect(index.documents(term))); - REQUIRE(collect(bci.payloads(term)) == collect(index.payloads(term))); + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; } }); } @@ -68,7 +69,7 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") { tbb::task_scheduler_init init(8); Temporary_Directory tmpdir; - auto bci = binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection"); + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", (tmpdir.path() / "fwd").string(), (tmpdir.path() / "index").string(), @@ -86,15 +87,15 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") BlockedReader{}, BlockedReader{}); run([&](auto index) { - REQUIRE(bci.num_documents() == index.num_documents()); - REQUIRE(bci.num_terms() == index.num_terms()); - REQUIRE(bci.avg_document_length() == index.avg_document_length()); - for (auto doc = 0; doc < bci.num_documents(); doc += 1) { - REQUIRE(bci.document_length(doc) == index.document_length(doc)); - } - for (auto term = 0; term < bci.num_terms(); term += 1) { - REQUIRE(collect(bci.documents(term)) == collect(index.documents(term))); - REQUIRE(collect(bci.payloads(term)) == collect(index.payloads(term))); + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; } }); } diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp new file mode 100644 index 000000000..072d23b7e --- /dev/null +++ b/test/v1/test_v1_maxscore_join.cpp @@ -0,0 +1,132 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "index_fixture.hpp" +#include "pisa_config.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" +#include "v1/maxscore.hpp" +#include "v1/posting_builder.hpp" +#include "v1/types.hpp" + +using pisa::v1::BlockedReader; +using pisa::v1::BlockedWriter; +using pisa::v1::collect; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::IndexRunner; +using pisa::v1::join_maxscore; +using pisa::v1::PostingBuilder; +using pisa::v1::RawReader; +using pisa::v1::read_sizes; +using pisa::v1::TermId; +using pisa::v1::accumulate::Add; + +TEMPLATE_TEST_CASE("", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>)) +//(IndexFixture, +// v1::BlockedCursor<::pisa::simdbp_block, false>, +// v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + + SECTION("Zero threshold -- equivalent to union") + { + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + // auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { + // if (name == "daat_or") { + // return daat_or(query, index, topk_queue(10), scorer); + // } + // if (name == "maxscore") { + // return maxscore(query, index, topk_queue(10), scorer); + // } + // std::abort(); + //}; + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.terms); + CAPTURE(idx++); + + auto run = + v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + run([&](auto&& index) { + auto union_results = collect(v1::union_merge( + index.scored_cursors(gsl::make_span(q.terms), make_bm25(index)), 0.0F, Add{})); + auto maxscore_results = collect(v1::join_maxscore( + index.max_scored_cursors(gsl::make_span(q.terms), make_bm25(index)), + 0.0F, + Add{}, + [](auto /* score */) { return true; })); + REQUIRE(union_results == maxscore_results); + }); + + run([&](auto&& index) { + auto union_results = collect_with_payload(v1::union_merge( + index.scored_cursors(gsl::make_span(q.terms), make_bm25(index)), 0.0F, Add{})); + union_results.erase(std::remove_if(union_results.begin(), + union_results.end(), + [](auto score) { return score.second <= 5.0F; }), + union_results.end()); + auto maxscore_results = collect_with_payload(v1::join_maxscore( + index.max_scored_cursors(gsl::make_span(q.terms), make_bm25(index)), + 0.0F, + Add{}, + [](auto score) { return score > 5.0F; })); + REQUIRE(union_results.size() == maxscore_results.size()); + for (size_t i = 0; i < union_results.size(); ++i) { + CAPTURE(i); + REQUIRE(union_results[i].first == union_results[i].first); + REQUIRE(union_results[i].second + == Approx(union_results[i].second).epsilon(0.01)); + // REQUIRE(precomputed[i].second == expected[i].second); + // REQUIRE(precomputed[i].first == + // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + } + }); + + // // auto precomputed = [&]() { + // // auto run = + // // v1::scored_index_runner(meta, fixture.document_reader(), + // // fixture.score_reader()); + // // std::vector results; + // // run([&](auto&& index) { + // // // auto que = run_query(algorithm, v1::Query{q.terms}, index, + // // v1::VoidScorer{}); auto que = daat_or(v1::Query{q.terms}, index, + // // topk_queue(10),v1::VoidScorer{}); que.finalize(); results = que.topk(); + // // std::sort(results.begin(), results.end(), std::greater{}); + // // }); + // // return results; + // // }(); + + // REQUIRE(expected.size() == on_the_fly.size()); + // // REQUIRE(expected.size() == precomputed.size()); + // for (size_t i = 0; i < on_the_fly.size(); ++i) { + // REQUIRE(on_the_fly[i].second == expected[i].second); + // REQUIRE(on_the_fly[i].first == + // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + // // REQUIRE(precomputed[i].second == expected[i].second); + // // REQUIRE(precomputed[i].first == + // // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + // } + } + } +} diff --git a/test/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp similarity index 57% rename from test/test_v1_queries.cpp rename to test/v1/test_v1_queries.cpp index acfc74e4f..de84365a3 100644 --- a/test/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -7,15 +7,16 @@ #include #include +#include "../temporary_directory.hpp" #include "accumulator/lazy_accumulator.hpp" #include "cursor/block_max_scored_cursor.hpp" #include "cursor/max_scored_cursor.hpp" #include "cursor/scored_cursor.hpp" +#include "index_fixture.hpp" #include "index_types.hpp" #include "io.hpp" #include "pisa_config.hpp" #include "query/queries.hpp" -#include "temporary_directory.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/cursor_intersection.hpp" @@ -23,76 +24,20 @@ #include "v1/cursor_union.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" +#include "v1/maxscore.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/query.hpp" +#include "v1/score_index.hpp" #include "v1/scorer/bm25.hpp" #include "v1/types.hpp" +#include "v1/union_lookup.hpp" namespace v1 = pisa::v1; using namespace pisa; static constexpr auto RELATIVE_ERROR = 0.1F; -template -struct IndexFixture { - using DocumentWriter = typename v1::CursorTraits::Writer; - using FrequencyWriter = typename v1::CursorTraits::Writer; - using ScoreWriter = typename v1::CursorTraits::Writer; - - using DocumentReader = typename v1::CursorTraits::Reader; - using FrequencyReader = typename v1::CursorTraits::Reader; - using ScoreReader = typename v1::CursorTraits::Reader; - - IndexFixture() : m_tmpdir(std::make_unique()) - { - auto index_basename = (tmpdir().path() / "inv").string(); - v1::compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", - PISA_SOURCE_DIR "/test/test_data/test_collection.fwd", - index_basename, - 2, - v1::make_writer(), - v1::make_writer()); - auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", - index_basename); - REQUIRE(errors.empty()); - auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); - auto run = v1::index_runner(meta, document_reader(), frequency_reader()); - auto postings_path = fmt::format("{}.bm25", index_basename); - auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); - run([&](auto&& index) { - std::ofstream score_file_stream(postings_path); - auto offsets = score_index(index, score_file_stream, ScoreWriter{}, make_bm25(index)); - v1::write_span(gsl::span(offsets), offsets_path); - }); - meta.scores.push_back( - v1::PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); - meta.write(fmt::format("{}.yml", index_basename)); - } - - [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } - [[nodiscard]] auto document_reader() const { return m_document_reader; } - [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } - [[nodiscard]] auto score_reader() const { return m_score_reader; } - - private: - std::unique_ptr m_tmpdir; - DocumentReader m_document_reader{}; - FrequencyReader m_frequency_reader{}; - ScoreReader m_score_reader{}; -}; - -[[nodiscard]] auto test_queries() -> std::vector -{ - std::vector queries; - std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); - }; - io::for_each_line(qfile, push_query); - return queries; -} - template struct IndexData { @@ -101,10 +46,6 @@ struct IndexData { IndexData() : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), - v1_index( - pisa::v1::binary_collection_index(PISA_SOURCE_DIR "/test/test_data/test_collection")), - scored_index(pisa::v1::binary_collection_scored_index(PISA_SOURCE_DIR - "/test/test_data/test_collection")), wdata(document_sizes.begin()->begin(), collection.num_docs(), collection, @@ -146,86 +87,159 @@ struct IndexData { binary_freq_collection collection; binary_collection document_sizes; v0_Index v0_index; - v1_Index v1_index; - ScoredIndex scored_index; std::vector queries; std::vector thresholds; wand_data wdata; }; +/// Inefficient, do not use in production code. +[[nodiscard]] auto approximate_order(std::pair lhs, + std::pair rhs) -> bool +{ + return std::make_pair(fmt::format("{:0.4f}", lhs.first), lhs.second) + < std::make_pair(fmt::format("{:0.4f}", rhs.first), rhs.second); +} + +/// Inefficient, do not use in production code. +[[nodiscard]] auto approximate_order_f(float lhs, float rhs) -> bool +{ + return fmt::format("{:0.4f}", lhs) < fmt::format("{:0.4f}", rhs); +} + template std::unique_ptr> IndexData::data = nullptr; -TEMPLATE_TEST_CASE( - "DAAT OR", - "[v1][integration]", - (IndexFixture, v1::RawCursor, v1::RawCursor>), - (IndexFixture, - v1::BlockedCursor<::pisa::simdbp_block, false>, - v1::RawCursor>)) +TEMPLATE_TEST_CASE("Query", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::BlockedCursor<::pisa::simdbp_block, false>, + v1::RawCursor>)) { tbb::task_scheduler_init init(1); auto data = IndexData, v1::RawCursor>, v1::Index, v1::RawCursor>>::get(); TestType fixture; + auto algorithm = GENERATE(std::string("daat_or"), std::string("maxscore")); + CAPTURE(algorithm); auto index_basename = (fixture.tmpdir().path() / "inv").string(); auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); ranked_or_query or_q(10); + auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { + if (name == "daat_or") { + return daat_or(query, index, topk_queue(10), scorer); + } + if (name == "maxscore") { + return maxscore(query, index, topk_queue(10), scorer); + } + std::abort(); + }; int idx = 0; - for (auto const& q : test_queries()) { + for (auto& q : test_queries()) { + std::sort(q.terms.begin(), q.terms.end()); + q.terms.erase(std::unique(q.terms.begin(), q.terms.end()), q.terms.end()); CAPTURE(q.terms); CAPTURE(idx++); or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); auto expected = or_q.topk(); - std::sort(expected.begin(), expected.end(), std::greater{}); + std::sort(expected.begin(), expected.end(), approximate_order); auto on_the_fly = [&]() { auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { - auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), make_bm25(index)); + auto que = run_query(algorithm, v1::Query{q.terms}, index, make_bm25(index)); que.finalize(); results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); + std::sort(results.begin(), results.end(), approximate_order); }); return results; }(); + REQUIRE(expected.size() == on_the_fly.size()); + for (size_t i = 0; i < on_the_fly.size(); ++i) { + REQUIRE(on_the_fly[i].second == expected[i].second); + REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + } + auto precomputed = [&]() { auto run = v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); std::vector results; run([&](auto&& index) { - auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), v1::VoidScorer{}); + auto que = run_query(algorithm, v1::Query{q.terms}, index, v1::VoidScorer{}); que.finalize(); results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); }); + // Remove the tail that might be different due to quantization error. + // Note that `precomputed` will have summed quantized score, while the + // vector we compare to will have quantized sum---that's why whe remove anything + // that's withing 2 of the last result. + // auto last_score = results.back().first; + // results.erase(std::remove_if( + // results.begin(), + // results.end(), + // [last_score](auto&& entry) { return entry.first <= last_score + 3; + // }), + // results.end()); + // results.resize(5); + // std::sort(results.begin(), results.end(), [](auto&& lhs, auto&& rhs) { + // return lhs.second < rhs.second; + //}); return results; }(); - REQUIRE(expected.size() == on_the_fly.size()); - REQUIRE(expected.size() == precomputed.size()); - for (size_t i = 0; i < on_the_fly.size(); ++i) { - REQUIRE(on_the_fly[i].second == expected[i].second); - REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - REQUIRE(precomputed[i].second == expected[i].second); - REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); + constexpr float max_partial_score = 16.5724F; + auto quantizer = [&](float score) { + return static_cast(score * std::numeric_limits::max() + / max_partial_score); + }; + + auto expected_quantized = expected; + std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& rhs) { + return lhs.first > rhs.first; + }); + for (auto& v : expected_quantized) { + v.first = quantizer(v.first); } + + // TODO(michal): test the quantized results + + // expected_quantized.resize(precomputed.size()); + // std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& + // rhs) { + // return lhs.second < rhs.second; + //}); + + // for (size_t i = 0; i < precomputed.size(); ++i) { + // std::cerr << fmt::format("{}, {:f} -- {}, {:f}\n", + // precomputed[i].second, + // precomputed[i].first, + // expected_quantized[i].second, + // expected_quantized[i].first); + //} + + // for (size_t i = 0; i < precomputed.size(); ++i) { + // REQUIRE(std::abs(precomputed[i].first - expected_quantized[i].first) + // <= static_cast(q.terms.size())); + //} } } -TEMPLATE_TEST_CASE( - "UnionLookup", - "[v1][integration]", - (IndexFixture, v1::RawCursor, v1::RawCursor>), - (IndexFixture, - v1::BlockedCursor<::pisa::simdbp_block, false>, - v1::RawCursor>)) +TEMPLATE_TEST_CASE("UnionLookup", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::BlockedCursor<::pisa::simdbp_block, false>, + v1::RawCursor>)) { tbb::task_scheduler_init init(1); auto data = IndexData unigrams(q.terms.size()); std::iota(unigrams.begin(), unigrams.end(), 0); - auto que = union_lookup(v1::Query{q.terms}, - index, - topk_queue(10), - make_bm25(index), - std::move(unigrams), - {}); + auto que = v1::union_lookup(v1::Query{q.terms}, + index, + topk_queue(10), + make_bm25(index), + std::move(unigrams), + {}); que.finalize(); results = que.topk(); std::sort(results.begin(), results.end(), std::greater{}); diff --git a/test/v1/test_v1_score_index.cpp b/test/v1/test_v1_score_index.cpp new file mode 100644 index 000000000..0ba952fbd --- /dev/null +++ b/test/v1/test_v1_score_index.cpp @@ -0,0 +1,84 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "codec/simdbp.hpp" +#include "index_fixture.hpp" +#include "pisa_config.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/types.hpp" + +using pisa::v1::BlockedCursor; +using pisa::v1::BlockedReader; +using pisa::v1::BlockedWriter; +using pisa::v1::compress_binary_collection; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::make_bm25; +using pisa::v1::RawCursor; +using pisa::v1::RawReader; +using pisa::v1::RawWriter; +using pisa::v1::TermId; + +TEMPLATE_TEST_CASE("DAAT OR", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + BlockedCursor<::pisa::simdbp_block, false>, + RawCursor>)) +{ + tbb::task_scheduler_init init(1); + GIVEN("Index fixture (built and scored index)") + { + TestType fixture; + THEN("Float max scores are correct") + { + auto run = v1::index_runner( + fixture.meta(), fixture.document_reader(), fixture.frequency_reader()); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.max_scored_cursor(term, make_bm25(index)); + auto precomputed_max = cursor.max_score(); + float calculated_max = 0.0F; + ::pisa::v1::for_each(cursor, [&](auto&& cursor) { + calculated_max = std::max(cursor.payload(), calculated_max); + }); + REQUIRE(precomputed_max == calculated_max); + } + }); + } + THEN("Quantized max scores are correct") + { + auto run = v1::scored_index_runner( + fixture.meta(), fixture.document_reader(), fixture.score_reader()); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.max_scored_cursor(term, pisa::v1::VoidScorer{}); + auto precomputed_max = cursor.max_score(); + std::uint8_t calculated_max = 0; + ::pisa::v1::for_each(cursor, [&](auto&& cursor) { + if (cursor.payload() > calculated_max) { + calculated_max = cursor.payload(); + } + }); + REQUIRE(precomputed_max == calculated_max); + } + }); + } + } +} diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index 4723e9cf3..b4b62944f 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -4,11 +4,14 @@ target_link_libraries(compress pisa CLI11) add_executable(query query.cpp) target_link_libraries(query pisa CLI11) +add_executable(union-lookup union_lookup.cpp) +target_link_libraries(union-lookup pisa CLI11) + add_executable(postings postings.cpp) target_link_libraries(postings pisa CLI11) add_executable(score score.cpp) target_link_libraries(score pisa CLI11) -add_executable(bigram_index bigram_index.cpp) -target_link_libraries(bigram_index pisa CLI11) +add_executable(bigram-index bigram_index.cpp) +target_link_libraries(bigram-index pisa CLI11) diff --git a/v1/app.hpp b/v1/app.hpp new file mode 100644 index 000000000..7fcddd14e --- /dev/null +++ b/v1/app.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include + +namespace pisa { + +struct QueryApp : public CLI::App { + explicit QueryApp(std::string description) : CLI::App(std::move(description)) + { + add_option("-i,--index", + yml, + "Path of .yml file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + add_option("-q,--query", query_file, "Path to file with queries", false); + add_option("-k", k, "The number of top results to return", true); + add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); + add_option("--documents", + documents_file, + "Overrides document lexicon from .yml (if defined). Required otherwise."); + add_flag("--benchmark", is_benchmark, "Run benchmark"); + add_flag("--precomputed", precomputed, "Use precomputed scores"); + } + + std::optional yml{}; + std::optional query_file{}; + std::optional terms_file{}; + std::optional documents_file{}; + int k = 1'000; + bool is_benchmark = false; + bool precomputed = false; +}; + +} // namespace pisa diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index beffda3c7..8affa592b 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -6,7 +6,6 @@ #include #include "io.hpp" -#include "payload_vector.hpp" #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" @@ -20,7 +19,6 @@ #include "v1/scorer/runner.hpp" #include "v1/types.hpp" -using pisa::Payload_Vector_Buffer; using pisa::v1::BigramMetadata; using pisa::v1::BlockedReader; using pisa::v1::CursorTraits; @@ -159,13 +157,7 @@ int main(int argc, char** argv) .count = pair_mapping.size()}; meta.write(resolved_yml); std::cerr << " Done.\nWriting bigram mapping..."; - Payload_Vector_Buffer::make(pair_mapping.begin(), - pair_mapping.end(), - [](auto&& terms, auto out) { - auto bytes = gsl::as_bytes(gsl::make_span(terms)); - std::copy(bytes.begin(), bytes.end(), out); - }) - .to_file(meta.bigrams->mapping); + write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); std::cerr << " Done.\n"; return 0; } diff --git a/v1/compress.cpp b/v1/compress.cpp index 4f7a6531d..aa63efda6 100644 --- a/v1/compress.cpp +++ b/v1/compress.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "binary_freq_collection.hpp" #include "v1/blocked_cursor.hpp" @@ -43,7 +44,7 @@ auto frequency_encoding(std::string_view name) -> std::uint32_t std::exit(1); } -int main(int argc, char **argv) +int main(int argc, char** argv) { std::string input; std::string fwd; @@ -80,7 +81,7 @@ int main(int argc, char **argv) std::cerr << "Detected more than 10 errors, printing head:\n"; errors.resize(10); } - for (auto const &error : errors) { + for (auto const& error : errors) { std::cerr << error << '\n'; } return 1; diff --git a/v1/query.cpp b/v1/query.cpp index 03dd9f87c..6c9cd46f1 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -5,12 +5,14 @@ #include #include +#include "app.hpp" #include "io.hpp" #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index_metadata.hpp" +#include "v1/maxscore.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" #include "v1/scorer/bm25.hpp" @@ -27,16 +29,44 @@ using pisa::v1::RawReader; using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; +using RetrievalAlgorithm = + std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue, tl::optional)>; + template +auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& scorer) + -> RetrievalAlgorithm +{ + if (name == "daat_or") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, + ::pisa::topk_queue topk, + [[maybe_unused]] tl::optional threshold) { + return pisa::v1::daat_or(query, index, std::move(topk), std::forward(scorer)); + }); + } + if (name == "maxscore") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, + ::pisa::topk_queue topk, + tl::optional threshold) { + return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); + }); + } + spdlog::error("Unknown algorithm: {}", name); + std::exit(1); +} + +template void evaluate(std::vector const& queries, Index&& index, Scorer&& scorer, int k, - pisa::Payload_Vector<> const& docmap) + pisa::Payload_Vector<> const& docmap, + Algorithm&& retrieve, + tl::optional> thresholds) { auto query_idx = 0; for (auto const& query : queries) { - auto que = daat_or(pisa::v1::Query{query.terms}, index, pisa::topk_queue(k), scorer); + auto threshold = thresholds.map([query_idx](auto&& vec) { return vec[query_idx]; }); + auto que = retrieve(pisa::v1::Query{query.terms}, pisa::topk_queue(k), threshold); que.finalize(); auto rank = 0; for (auto result : que.topk()) { @@ -53,16 +83,23 @@ void evaluate(std::vector const& queries, } } -template -void benchmark(std::vector const& queries, Index&& index, Scorer&& scorer, int k) +template +void benchmark(std::vector const& queries, + Index&& index, + Scorer&& scorer, + int k, + Algorithm&& retrieve, + tl::optional> thresholds) { std::vector times(queries.size(), std::numeric_limits::max()); for (auto run = 0; run < 5; run += 1) { for (auto query = 0; query < queries.size(); query += 1) { + float threshold = + thresholds.map([query](auto&& vec) { return vec[query]; }).value_or(0.0F); auto usecs = ::pisa::run_with_timer([&]() { - auto que = daat_or( - pisa::v1::Query{queries[query].terms}, index, pisa::topk_queue(k), scorer); + auto que = + retrieve(pisa::v1::Query{queries[query].terms}, pisa::topk_queue(k), threshold); que.finalize(); do_not_optimize_away(que); }); @@ -82,66 +119,65 @@ void benchmark(std::vector const& queries, Index&& index, Scorer&& int main(int argc, char** argv) { - std::optional yml{}; - std::optional query_file{}; - std::optional terms_file{}; - std::optional documents_file{}; - int k = 1'000; - bool is_benchmark = false; - bool precomputed = false; - - CLI::App app{"Queries a v1 index."}; - app.add_option("-i,--index", - yml, - "Path of .yml file of an index " - "(if not provided, it will be looked for in the current directory)", - false); - app.add_option("-q,--query", query_file, "Path to file with queries", false); - app.add_option("-k", k, "The number of top results to return", true); - app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); - app.add_option("--documents", - documents_file, - "Overrides document lexicon from .yml (if defined). Required otherwise."); - app.add_flag("--benchmark", is_benchmark, "Run benchmark"); - app.add_flag("--precomputed", precomputed, "Use precomputed scores"); + std::string algorithm = "daat_or"; + tl::optional threshold_file; + pisa::QueryApp app("Queries a v1 index."); + app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); + app.add_option("--thredsholds", algorithm, "File with (estimated) thresholds.", false); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_yml(yml)); + auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; if (meta.term_lexicon) { - terms_file = meta.term_lexicon.value(); + app.terms_file = meta.term_lexicon.value(); } if (meta.document_lexicon) { - documents_file = meta.document_lexicon.value(); + app.documents_file = meta.document_lexicon.value(); } std::vector queries; - auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); - if (query_file) { - std::ifstream is(*query_file); + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); pisa::io::for_each_line(is, parse_query); } else { pisa::io::for_each_line(std::cin, parse_query); } - if (not documents_file) { + if (not app.documents_file) { spdlog::error("Document lexicon not defined"); std::exit(1); } - auto source = std::make_shared(documents_file.value().c_str()); + auto source = std::make_shared(app.documents_file.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); - if (precomputed) { + auto thresholds = [&threshold_file, &queries]() { + if (threshold_file) { + std::vector thresholds; + std::ifstream is(*threshold_file); + pisa::io::for_each_line( + is, [&thresholds](auto&& line) { thresholds.push_back(std::stof(line)); }); + if (thresholds.size() != queries.size()) { + spdlog::error("Number of thresholds not equal to number of queries"); + std::exit(1); + } + return tl::make_optional(thresholds); + } + return tl::optional>{}; + }(); + + if (app.precomputed) { auto run = scored_index_runner(meta, RawReader{}, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto&& index) { - if (is_benchmark) { - benchmark(queries, index, VoidScorer{}, k); + auto retrieve = resolve_algorithm(algorithm, index, VoidScorer{}); + if (app.is_benchmark) { + benchmark(queries, index, VoidScorer{}, app.k, retrieve, thresholds); } else { - evaluate(queries, index, VoidScorer{}, k, docmap); + evaluate(queries, index, VoidScorer{}, app.k, docmap, retrieve, thresholds); } }); } else { @@ -152,10 +188,11 @@ int main(int argc, char** argv) run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { - if (is_benchmark) { - benchmark(queries, index, scorer, k); + auto retrieve = resolve_algorithm(algorithm, index, scorer); + if (app.is_benchmark) { + benchmark(queries, index, scorer, app.k, retrieve, thresholds); } else { - evaluate(queries, index, scorer, k, docmap); + evaluate(queries, index, scorer, app.k, docmap, retrieve, thresholds); } }); }); diff --git a/v1/score.cpp b/v1/score.cpp index 310d00ad6..971308684 100644 --- a/v1/score.cpp +++ b/v1/score.cpp @@ -1,26 +1,11 @@ -#include +#include +#include +#include #include -#include -#include "binary_freq_collection.hpp" -#include "v1/blocked_cursor.hpp" -#include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" -#include "v1/progress_status.hpp" -#include "v1/raw_cursor.hpp" -#include "v1/types.hpp" - -using pisa::v1::BlockedReader; -using pisa::v1::DefaultProgress; -using pisa::v1::IndexMetadata; -using pisa::v1::PostingFilePaths; -using pisa::v1::ProgressStatus; -using pisa::v1::RawReader; -using pisa::v1::RawWriter; -using pisa::v1::resolve_yml; -using pisa::v1::TermId; -using pisa::v1::write_span; +#include "v1/score_index.hpp" int main(int argc, char** argv) { @@ -39,62 +24,6 @@ int main(int argc, char** argv) // app.add_option( // "-b,--bytes-per-score", yml, "Quantize computed scores to this many bytes", true); CLI11_PARSE(app, argc, argv); - - auto resolved_yml = resolve_yml(yml); - auto meta = IndexMetadata::from_file(resolved_yml); - auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - auto index_basename = resolved_yml.substr(0, resolved_yml.size() - 4); - auto postings_path = fmt::format("{}.bm25", index_basename); - auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); - run([&](auto&& index) { - ProgressStatus calc_max_status(index.num_terms(), - DefaultProgress("Calculating max partial score"), - std::chrono::milliseconds(100)); - std::vector max_scores(threads, 0.0F); - tbb::task_group group; - auto batch_size = index.num_terms() / threads; - for (auto thread_id = 0; thread_id < threads; thread_id += 1) { - auto first_term = thread_id * batch_size; - auto end_term = - thread_id < threads - 1 ? (thread_id + 1) * batch_size : index.num_terms(); - std::for_each( - boost::counting_iterator(first_term), - boost::counting_iterator(end_term), - [&](auto term) { - for_each(index.scoring_cursor(term, make_bm25(index)), [&](auto& cursor) { - max_scores[thread_id] = std::max(max_scores[thread_id], cursor.payload()); - }); - calc_max_status += 1; - }); - } - group.wait(); - auto max_score = *std::max_element(max_scores.begin(), max_scores.end()); - std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.", - max_score, - std::numeric_limits::max()); - - ProgressStatus status( - index.num_terms(), DefaultProgress("Scoring"), std::chrono::milliseconds(100)); - std::ofstream score_file_stream(postings_path); - auto offsets = score_index( - index, - score_file_stream, - RawWriter{}, - make_bm25(index), - [&](float score) { - return static_cast(score * std::numeric_limits::max() - / max_score); - }, - [&]() { status += 1; }); - write_span(gsl::span(offsets), offsets_path); - }); - meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); - meta.write(resolved_yml); - + pisa::v1::score_index(pisa::v1::resolve_yml(yml), threads); return 0; } diff --git a/v1/union_lookup.cpp b/v1/union_lookup.cpp new file mode 100644 index 000000000..6c303d6a1 --- /dev/null +++ b/v1/union_lookup.cpp @@ -0,0 +1,227 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "query/queries.hpp" +#include "timer.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" +#include "v1/union_lookup.hpp" + +using pisa::Query; +using pisa::resolve_query_parser; +using pisa::v1::BlockedReader; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::RawReader; +using pisa::v1::resolve_yml; +using pisa::v1::union_lookup; +using pisa::v1::VoidScorer; + +template +void evaluate(std::vector const& queries, + Index&& index, + Scorer&& scorer, + int k, + pisa::Payload_Vector<> const& docmap, + std::vector> essential_unigrams, + std::vector>> essential_bigrams) +{ + auto query_idx = 0; + for (auto const& query : queries) { + std::vector uni(query.terms.size()); + std::iota(uni.begin(), uni.end(), 0); + auto que = union_lookup(pisa::v1::Query{query.terms}, + index, + pisa::topk_queue(k), + scorer, + // uni, {}); + essential_unigrams[query_idx], + essential_bigrams[query_idx]); + que.finalize(); + auto rank = 0; + for (auto result : que.topk()) { + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", + query.id.value_or(std::to_string(query_idx)), + "Q0", + docmap[result.second], + rank, + result.first, + "R0"); + rank += 1; + } + query_idx += 1; + } +} + +template +void benchmark(std::vector const& queries, + Index&& index, + Scorer&& scorer, + int k, + std::vector> essential_unigrams, + std::vector>> essential_bigrams) + +{ + std::vector times(queries.size(), std::numeric_limits::max()); + for (auto run = 0; run < 5; run += 1) { + for (auto query = 0; query < queries.size(); query += 1) { + std::vector uni(queries[query].terms.size()); + std::iota(uni.begin(), uni.end(), 0); + auto usecs = ::pisa::run_with_timer([&]() { + auto que = union_lookup(pisa::v1::Query{queries[query].terms}, + index, + pisa::topk_queue(k), + scorer, + // uni, + //{}); + essential_unigrams[query], + essential_bigrams[query]); + que.finalize(); + do_not_optimize_away(que); + }); + times[query] = std::min(times[query], static_cast(usecs.count())); + } + } + std::sort(times.begin(), times.end()); + double avg = std::accumulate(times.begin(), times.end(), double()) / times.size(); + double q50 = times[times.size() / 2]; + double q90 = times[90 * times.size() / 100]; + double q95 = times[95 * times.size() / 100]; + spdlog::info("Mean: {}", avg); + spdlog::info("50% quantile: {}", q50); + spdlog::info("90% quantile: {}", q90); + spdlog::info("95% quantile: {}", q95); +} + +int main(int argc, char** argv) +{ + std::string inter_filename; + pisa::QueryApp app("Queries a v1 index."); + app.add_option("--intersections", inter_filename, "Intersections filename")->required(); + CLI11_PARSE(app, argc, argv); + + auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + app.terms_file = meta.term_lexicon.value(); + } + if (meta.document_lexicon) { + app.documents_file = meta.document_lexicon.value(); + } + + std::vector queries; + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + + auto intersections = [&]() { + std::vector>> intersections; + std::ifstream is(inter_filename); + pisa::io::for_each_line(is, [&](auto const& query_line) { + intersections.emplace_back(); + std::istringstream iss(query_line); + std::transform( + std::istream_iterator(iss), + std::istream_iterator(), + std::back_inserter(intersections.back()), + [&](auto const& n) { + auto bits = std::bitset<64>(std::stoul(n)); + if (bits.count() > 2) { + spdlog::error("Intersections of more than 2 terms not supported yet!"); + std::exit(1); + } + return bits; + }); + }); + return intersections; + }(); + auto bitset_to_vec = [](auto bits) { + std::vector vec; + for (auto idx = 0; idx < bits.size(); idx += 1) { + if (bits.test(idx)) { + vec.push_back(idx); + } + } + return vec; + }; + auto is_n_gram = [](auto n) { return [n](auto bits) { return bits.count() == n; }; }; + std::vector> unigrams = + intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(1)) + | ranges::views::transform([&](auto bits) { return bitset_to_vec(bits)[0]; }) + | ranges::to_vector; + }) + | ranges::to_vector; + std::vector>> bigrams = + intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(2)) + | ranges::views::transform([&](auto bits) { + auto vec = bitset_to_vec(bits); + return std::make_pair(vec[0], vec[0]); + }) + | ranges::to_vector; + }) + | ranges::to_vector; + + if (intersections.size() != queries.size()) { + spdlog::error("Number of intersections is not equal to number of queries"); + std::exit(1); + } + + if (not app.documents_file) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(app.documents_file.value().c_str()); + auto docmap = pisa::Payload_Vector<>::from(*source); + + if (app.precomputed) { + std::abort(); + // auto run = scored_index_runner(meta, + // RawReader{}, + // RawReader{}, + // BlockedReader<::pisa::simdbp_block, true>{}, + // BlockedReader<::pisa::simdbp_block, false>{}); + // run([&](auto&& index) { + // if (app.is_benchmark) { + // benchmark(queries, index, VoidScorer{}, app.k, unigrams, bigrams); + // } else { + // evaluate(queries, index, VoidScorer{}, app.k, docmap, unigrams, bigrams); + // } + //}); + } else { + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + if (app.is_benchmark) { + benchmark(queries, index, scorer, app.k, unigrams, bigrams); + } else { + evaluate(queries, index, scorer, app.k, docmap, unigrams, bigrams); + } + }); + }); + } + return 0; +} From 008efb7b9b1a1609d10ced71cc2daa5905d0d730 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sat, 23 Nov 2019 15:32:43 -0500 Subject: [PATCH 23/56] Add rapidcheck --- .gitmodules | 3 + CMakeLists.txt | 2 +- external/rapidcheck | 1 + include/pisa/topk_queue.hpp | 43 ++- include/pisa/v1/algorithm.hpp | 26 ++ include/pisa/v1/blocked_cursor.hpp | 6 + include/pisa/v1/cursor/scoring_cursor.hpp | 11 + include/pisa/v1/document_payload_cursor.hpp | 5 + include/pisa/v1/index.hpp | 6 + include/pisa/v1/maxscore.hpp | 247 +++++++------ include/pisa/v1/query.hpp | 68 +++- include/pisa/v1/raw_cursor.hpp | 19 +- include/pisa/v1/scorer/bm25.hpp | 30 +- include/pisa/v1/unaligned_span.hpp | 44 ++- include/pisa/v1/union_lookup.hpp | 382 +++++++++++++++----- test/v1/test_v1.cpp | 43 +++ test/v1/test_v1_index.cpp | 8 +- test/v1/test_v1_queries.cpp | 114 +++--- v1/query.cpp | 154 +++++--- v1/union_lookup.cpp | 59 ++- 20 files changed, 919 insertions(+), 352 deletions(-) create mode 160000 external/rapidcheck diff --git a/.gitmodules b/.gitmodules index 14cc7eac9..c9f749b0a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -76,3 +76,6 @@ [submodule "external/yaml-cpp"] path = external/yaml-cpp url = https://github.com/jbeder/yaml-cpp.git +[submodule "external/rapidcheck"] + path = external/rapidcheck + url = https://github.com/emil-e/rapidcheck.git diff --git a/CMakeLists.txt b/CMakeLists.txt index eba106995..8d6cd6c77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") # Extensive warnings - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces -Wfatal-errors") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") if (USE_SANITIZERS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") diff --git a/external/rapidcheck b/external/rapidcheck new file mode 160000 index 000000000..4df02602a --- /dev/null +++ b/external/rapidcheck @@ -0,0 +1 @@ +Subproject commit 4df02602aae74ff1711611b64630d3fd8ae40571 diff --git a/include/pisa/topk_queue.hpp b/include/pisa/topk_queue.hpp index 70dba7b56..89a754125 100644 --- a/include/pisa/topk_queue.hpp +++ b/include/pisa/topk_queue.hpp @@ -1,33 +1,38 @@ #pragma once -#include -#include "util/util.hpp" #include "util/likely.hpp" +#include "util/util.hpp" +#include namespace pisa { struct topk_queue { using entry_type = std::pair; - explicit topk_queue(uint64_t k) : m_threshold(0), m_k(k) { m_q.reserve(m_k + 1); } - topk_queue(topk_queue const &q) = default; - topk_queue &operator=(topk_queue const &q) = default; + explicit topk_queue(uint64_t k) : m_threshold(std::numeric_limits::lowest()), m_k(k) + { + m_q.reserve(m_k + 1); + } + topk_queue(topk_queue const& q) = default; + topk_queue& operator=(topk_queue const& q) = default; - [[nodiscard]] constexpr static auto min_heap_order(entry_type const &lhs, - entry_type const &rhs) noexcept -> bool { + [[nodiscard]] constexpr static auto min_heap_order(entry_type const& lhs, + entry_type const& rhs) noexcept -> bool + { return lhs.first > rhs.first; } bool insert(float score) { return insert(score, 0); } - bool insert(float score, uint64_t docid) { + bool insert(float score, uint64_t docid) + { if (PISA_UNLIKELY(score <= m_threshold)) { return false; } m_q.emplace_back(score, docid); if (PISA_UNLIKELY(m_q.size() <= m_k)) { std::push_heap(m_q.begin(), m_q.end(), min_heap_order); - if(PISA_UNLIKELY(m_q.size() == m_k)) { + if (PISA_UNLIKELY(m_q.size() == m_k)) { m_threshold = m_q.front().first; } } else { @@ -38,31 +43,35 @@ struct topk_queue { return true; } - bool would_enter(float score) const { return m_q.size() < m_k || score > m_threshold; } + [[nodiscard]] bool would_enter(float score) const { return score > m_threshold; } - void finalize() { + void finalize() + { std::sort_heap(m_q.begin(), m_q.end(), min_heap_order); size_t size = std::lower_bound(m_q.begin(), m_q.end(), 0, - [](std::pair l, float r) { return l.first > r; }) - - m_q.begin(); + [](std::pair l, float r) { return l.first > r; }) + - m_q.begin(); m_q.resize(size); } - [[nodiscard]] std::vector const &topk() const noexcept { return m_q; } + [[nodiscard]] std::vector const& topk() const noexcept { return m_q; } - void clear() noexcept { + void clear() noexcept + { m_q.clear(); m_threshold = 0; } [[nodiscard]] uint64_t size() const noexcept { return m_k; } + void set_threshold(float threshold) noexcept { m_threshold = threshold; } + private: - float m_threshold; - uint64_t m_k; + float m_threshold; + uint64_t m_k; std::vector m_q; }; diff --git a/include/pisa/v1/algorithm.hpp b/include/pisa/v1/algorithm.hpp index 19a3f93be..08cd10941 100644 --- a/include/pisa/v1/algorithm.hpp +++ b/include/pisa/v1/algorithm.hpp @@ -2,6 +2,8 @@ #include +#include + namespace pisa::v1 { template @@ -24,4 +26,28 @@ template return pos->sentinel(); } +template +void partition_by_index(gsl::span range, gsl::span right_indices) +{ + if (right_indices.empty()) { + return; + } + std::sort(right_indices.begin(), right_indices.end()); + if (right_indices[right_indices.size() - 1] >= range.size()) { + throw std::logic_error("Essential index too large"); + } + auto left = 0; + auto right = range.size() - 1; + auto eidx = 0; + while (left < right && eidx < right_indices.size()) { + if (left < right_indices[eidx]) { + left += 1; + } else { + std::swap(range[left], range[right]); + right -= 1; + eidx += 1; + } + } +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index df45926eb..0419d5f2c 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -69,6 +69,12 @@ struct BlockedCursor { reset(); } + constexpr BlockedCursor(BlockedCursor const&) = default; + constexpr BlockedCursor(BlockedCursor&&) noexcept = default; + constexpr BlockedCursor& operator=(BlockedCursor const&) = default; + constexpr BlockedCursor& operator=(BlockedCursor&&) noexcept = default; + ~BlockedCursor() = default; + void reset() { decode_and_update_block(0); } /// Dereferences the current value. diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index bfd9516d1..9e992c9b4 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -18,6 +18,11 @@ struct ScoringCursor { : m_base_cursor(std::move(base_cursor)), m_scorer(std::move(scorer)) { } + constexpr ScoringCursor(ScoringCursor const&) = default; + constexpr ScoringCursor(ScoringCursor&&) noexcept = default; + constexpr ScoringCursor& operator=(ScoringCursor const&) = default; + constexpr ScoringCursor& operator=(ScoringCursor&&) noexcept = default; + ~ScoringCursor() = default; [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Document @@ -53,6 +58,12 @@ struct MaxScoreCursor { : m_base_cursor(std::move(base_cursor)), m_max_score(max_score) { } + constexpr MaxScoreCursor(MaxScoreCursor const&) = default; + constexpr MaxScoreCursor(MaxScoreCursor&&) noexcept = default; + constexpr MaxScoreCursor& operator=(MaxScoreCursor const&) = default; + constexpr MaxScoreCursor& operator=(MaxScoreCursor&&) noexcept = default; + ~MaxScoreCursor() = default; + [[nodiscard]] constexpr auto operator*() const -> Document { return m_base_cursor.value(); } [[nodiscard]] constexpr auto value() const noexcept -> Document { diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index 8c075497b..ec06e885c 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -16,6 +16,11 @@ struct DocumentPayloadCursor { : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) { } + constexpr DocumentPayloadCursor(DocumentPayloadCursor const&) = default; + constexpr DocumentPayloadCursor(DocumentPayloadCursor&&) noexcept = default; + constexpr DocumentPayloadCursor& operator=(DocumentPayloadCursor const&) = default; + constexpr DocumentPayloadCursor& operator=(DocumentPayloadCursor&&) noexcept = default; + ~DocumentPayloadCursor() = default; [[nodiscard]] constexpr auto operator*() const -> Document { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Document { return m_key_cursor.value(); } diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index a5a59d494..9204d8dae 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -199,9 +199,15 @@ struct Index { using cursor_type = std::decay_t(scorer)))>; if constexpr (std::is_convertible_v) { + if (m_max_quantized_scores.empty()) { + throw std::logic_error("Missing quantized max scores."); + } return MaxScoreCursor( scored_cursor(term, std::forward(scorer)), m_max_quantized_scores[term]); } else { + if (m_max_scores.empty()) { + throw std::logic_error("Missing max scores."); + } return MaxScoreCursor( scored_cursor(term, std::forward(scorer)), m_max_scores.at(std::hash>{}(scorer))[term]); diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index 485800925..b84365fe5 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -12,7 +12,11 @@ namespace pisa::v1 { -template +template struct MaxScoreJoin { using cursor_type = typename CursorContainer::value_type; using payload_type = Payload; @@ -35,6 +39,29 @@ struct MaxScoreJoin { m_accumulate(std::move(accumulate)), m_above_threshold(std::move(above_threshold)), m_size(std::nullopt) + { + initialize(); + } + + constexpr MaxScoreJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Analyzer* analyzer) + : m_cursors(std::move(cursors)), + m_sorted_cursors(m_cursors.size()), + m_cursor_idx(m_cursors.size()), + m_upper_bounds(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_analyzer(analyzer) + { + initialize(); + } + + void initialize() { std::transform(m_cursors.begin(), m_cursors.end(), @@ -82,12 +109,19 @@ struct MaxScoreJoin { m_current_payload = m_init; m_current_value = std::exchange(m_next_docid, sentinel()); + if constexpr (not std::is_void_v) { + m_analyzer->document(); + } + for (auto sorted_position = m_non_essential_count; sorted_position < m_sorted_cursors.size(); sorted_position += 1) { auto& cursor = m_sorted_cursors[sorted_position]; if (cursor->value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_analyzer->posting(); + } m_current_payload = m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); cursor->advance(); @@ -106,6 +140,9 @@ struct MaxScoreJoin { } auto& cursor = m_sorted_cursors[sorted_position]; cursor->advance_to_geq(m_current_value); + if constexpr (not std::is_void_v) { + m_analyzer->lookup(); + } if (cursor->value() == m_current_value) { m_current_payload = m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); @@ -141,6 +178,8 @@ struct MaxScoreJoin { std::uint32_t m_next_docid{}; std::size_t m_non_essential_count = 0; payload_type m_previous_threshold{}; + + Analyzer* m_analyzer; }; template @@ -149,44 +188,29 @@ auto join_maxscore(CursorContainer cursors, AccumulateFn accumulate, ThresholdFn threshold) { - return MaxScoreJoin( + return MaxScoreJoin( std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); } -template -auto maxscore(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - decltype(index.max_scored_cursor(0, scorer).max_score()) initial_threshold) +template +auto join_maxscore(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Analyzer* analyzer) { - using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using score_type = decltype(index.max_scored_cursor(0, scorer).max_score()); - using value_type = decltype(index.max_scored_cursor(0, scorer).value()); - - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.max_scored_cursor(term, scorer); }); - - auto joined = join_maxscore( - std::move(cursors), - 0.0F, - [](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }, - [&](auto score) { return topk.would_enter(score) && score > initial_threshold; }); - v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); - return topk; + return MaxScoreJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), analyzer); } template auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using score_type = decltype(index.max_scored_cursor(0, scorer).max_score()); using value_type = decltype(index.max_scored_cursor(0, scorer).value()); std::vector cursors; @@ -195,84 +219,99 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& std::back_inserter(cursors), [&](auto term) { return index.max_scored_cursor(term, scorer); }); + auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }; + if (query.threshold) { + topk.set_threshold(*query.threshold); + } auto joined = join_maxscore( - std::move(cursors), - 0.0F, - [](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }, - [&](auto score) { return topk.would_enter(score); }); + std::move(cursors), 0.0F, accumulate, [&](auto score) { return topk.would_enter(score); }); v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); return topk; - // template MaxScoreJoin { - - // std::vector sorted_cursors; - // std::transform(cursors.begin(), - // cursors.end(), - // std::back_inserter(sorted_cursors), - // [](auto&& cursor) { return &cursor; }); - // std::sort(sorted_cursors.begin(), sorted_cursors.end(), [](auto&& lhs, auto&& rhs) { - // return lhs->max_score() < rhs->max_score(); - //}); - - // std::vector upper_bounds(sorted_cursors.size()); - //// upper_bounds.push_back(cursors); - //// for (auto* cursor : sorted_cursors) { - //// - ////} - // upper_bounds[0] = sorted_cursors[0]->max_score(); - // for (size_t i = 1; i < sorted_cursors.size(); ++i) { - // upper_bounds[i] = upper_bounds[i - 1] + sorted_cursors[i]->max_score(); - //} - //// std::partial_sum(sorted_cursors.begin(), - //// sorted_cursors.end(), - //// upper_bounds.begin(), - //// [](auto sum, auto* cursor) { return sum + cursor->max_score(); }); - // auto essential = gsl::make_span(sorted_cursors); - // auto non_essential = essential.subspan(essential.size()); // Empty - - // DocId current_docid = - // std::min_element(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { - // return *lhs < *rhs; - // })->value(); - - // while (not essential.empty() && current_docid < std::numeric_limits::max()) { - // score_type score{}; - // auto next_docid = std::numeric_limits::max(); - - // for (auto cursor : essential) { - // if (cursor->value() == current_docid) { - // score += cursor->payload(); - // cursor->advance(); - // } - // if (auto docid = cursor->value(); docid < next_docid) { - // next_docid = docid; - // } - // } - - // for (auto idx = non_essential.size() - 1; idx + 1 > 0; idx -= 1) { - // if (not topk.would_enter(score + upper_bounds[idx])) { - // break; - // } - // sorted_cursors[idx]->advance_to_geq(current_docid); - // if (sorted_cursors[idx]->value() == current_docid) { - // score += sorted_cursors[idx]->payload(); - // } - // } - // if (topk.insert(score, current_docid)) { - // //// update non-essential lists - // while (not essential.empty() - // && not topk.would_enter(upper_bounds[non_essential.size()])) { - // essential = essential.first(essential.size() - 1); - // non_essential = gsl::make_span(sorted_cursors).subspan(essential.size()); - // } - // } - // current_docid = next_docid; - //} - - // return topk; } +template +struct MaxscoreAnalyzer { + MaxscoreAnalyzer(Index const& index, Scorer scorer) + : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); + } + + void operator()(Query const& query) + { + using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); + using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); + + m_current_documents = 0; + m_current_postings = 0; + m_current_lookups = 0; + + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return m_index.max_scored_cursor(term, m_scorer); }); + + std::size_t inserts = 0; + topk_queue topk(query.k); + auto initial_threshold = query.threshold.value_or(-1.0); + topk.set_threshold(initial_threshold); + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + [&](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }, + [&](auto score) { return topk.would_enter(score); }, + this); + v1::for_each(joined, [&](auto& cursor) { + if (topk.insert(cursor.payload(), cursor.value())) { + inserts += 1; + }; + }); + std::cout << fmt::format("{}\t{}\t{}\t{}\n", + m_current_documents, + m_current_postings, + inserts, + m_current_lookups); + m_documents += m_current_documents; + m_postings += m_current_postings; + m_lookups += m_current_lookups; + m_inserts += inserts; + m_count += 1; + } + + void summarize() && + { + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" + "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", + static_cast(m_documents) / m_count, + static_cast(m_postings) / m_count, + static_cast(m_inserts) / m_count, + static_cast(m_lookups) / m_count); + } + + void document() { m_current_documents += 1; } + void posting() { m_current_postings += 1; } + void lookup() { m_current_lookups += 1; } + + private: + std::size_t m_current_documents = 0; + std::size_t m_current_postings = 0; + std::size_t m_current_lookups = 0; + + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_lookups = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 14d92c5cf..52ebb616f 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -2,11 +2,14 @@ #include #include +#include #include #include +#include #include "topk_queue.hpp" +#include "v1/analyze_query.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/cursor_union.hpp" @@ -14,9 +17,17 @@ namespace pisa::v1 { +struct ListSelection { + std::vector unigrams{}; + std::vector> bigrams{}; +}; + struct Query { std::vector terms; - std::vector> bigrams{}; + tl::optional list_selection{}; + tl::optional threshold{}; + tl::optional id{}; + int k{}; }; template @@ -53,6 +64,61 @@ auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& s return topk; } +template +struct DaatOrAnalyzer { + DaatOrAnalyzer(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format("documents\tpostings\n"); + } + + void operator()(Query const& query) + { + std::vector cursors; + std::transform(query.terms.begin(), + query.terms.end(), + std::back_inserter(cursors), + [&](auto term) { return m_index.scored_cursor(term, m_scorer); }); + std::size_t postings = 0; + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { + postings += 1; + score += cursor.payload(); + return score; + }); + std::size_t documents = 0; + std::size_t inserts = 0; + topk_queue topk(query.k); + v1::for_each(cunion, [&](auto& cursor) { + if (topk.insert(cursor.payload(), cursor.value())) { + inserts += 1; + }; + documents += 1; + }); + std::cout << fmt::format("{}\t{}\t{}\n", documents, postings, inserts); + m_documents += documents; + m_postings += postings; + m_inserts += inserts; + m_count += 1; + } + + void summarize() && + { + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n- postings:\t{}\n- inserts:\t{}\n", + static_cast(m_documents) / m_count, + static_cast(m_postings) / m_count, + static_cast(m_inserts) / m_count); + } + + private: + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + template auto taat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 53b1cdd02..b23229ccd 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -18,7 +18,7 @@ namespace pisa::v1 { template -[[nodiscard]] auto next(Cursor &&cursor) -> tl::optional::value_type> +[[nodiscard]] auto next(Cursor&& cursor) -> tl::optional::value_type> { cursor.advance(); if (cursor.empty()) { @@ -28,7 +28,7 @@ template } template -inline void contract(bool condition, std::string const &message, Args &&... args) +inline void contract(bool condition, std::string const& message, Args&&... args) { if (not condition) { throw std::logic_error(fmt::format(message, std::forward(args)...)); @@ -50,6 +50,11 @@ struct RawCursor { m_bytes.size()); contract(not m_bytes.empty(), "Raw cursor memory must not be empty"); } + constexpr RawCursor(RawCursor const&) = default; + constexpr RawCursor(RawCursor&&) noexcept = default; + constexpr RawCursor& operator=(RawCursor const&) = default; + constexpr RawCursor& operator=(RawCursor&&) noexcept = default; + ~RawCursor() = default; /// Dereferences the current value. [[nodiscard]] constexpr auto operator*() const -> T @@ -120,17 +125,17 @@ struct RawWriter { constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } - void push(T const &posting) { m_postings.push_back(posting); } - void push(T &&posting) { m_postings.push_back(posting); } + void push(T const& posting) { m_postings.push_back(posting); } + void push(T&& posting) { m_postings.push_back(posting); } template - [[nodiscard]] auto write(std::basic_ostream &os) const -> std::size_t + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t { assert(!m_postings.empty()); std::uint32_t length = m_postings.size(); - os.write(reinterpret_cast(&length), sizeof(length)); + os.write(reinterpret_cast(&length), sizeof(length)); auto memory = gsl::as_bytes(gsl::make_span(m_postings.data(), m_postings.size())); - os.write(reinterpret_cast(memory.data()), memory.size()); + os.write(reinterpret_cast(memory.data()), memory.size()); return sizeof(length) + memory.size(); } diff --git a/include/pisa/v1/scorer/bm25.hpp b/include/pisa/v1/scorer/bm25.hpp index 1da1b049c..31bfe7921 100644 --- a/include/pisa/v1/scorer/bm25.hpp +++ b/include/pisa/v1/scorer/bm25.hpp @@ -16,7 +16,24 @@ struct BM25 { static constexpr float b = 0.4; static constexpr float k1 = 0.9; - explicit BM25(Index const &index) : m_index(index) {} + explicit BM25(Index const& index) : m_index(index) {} + + struct TermScorer { + TermScorer(Index const& index, float term_weight) + : m_index(&index), m_term_weight(term_weight) + { + } + + auto operator()(std::uint32_t docid, std::uint32_t frequency) const + { + return m_term_weight + * doc_term_weight(frequency, m_index->normalized_document_length(docid)); + } + + private: + Index const* m_index; + float m_term_weight; + }; [[nodiscard]] static float doc_term_weight(uint64_t freq, float norm_len) { @@ -36,18 +53,15 @@ struct BM25 { { auto term_weight = query_term_weight(m_index.term_posting_count(term_id), m_index.num_documents()); - return [this, term_weight](uint32_t doc, uint32_t freq) { - return term_weight - * doc_term_weight(freq, this->m_index.normalized_document_length(doc)); - }; + return TermScorer(m_index, term_weight); } private: - Index const &m_index; + Index const& m_index; }; template -auto make_bm25(Index const &index) +auto make_bm25(Index const& index) { return BM25(index); } @@ -57,7 +71,7 @@ auto make_bm25(Index const &index) namespace std { template struct hash<::pisa::v1::BM25> { - std::size_t operator()(::pisa::v1::BM25 const & /* bm25 */) const noexcept + std::size_t operator()(::pisa::v1::BM25 const& /* bm25 */) const noexcept { return std::hash{}("bm25"); } diff --git a/include/pisa/v1/unaligned_span.hpp b/include/pisa/v1/unaligned_span.hpp index 8ea845823..d7f200d84 100644 --- a/include/pisa/v1/unaligned_span.hpp +++ b/include/pisa/v1/unaligned_span.hpp @@ -15,25 +15,25 @@ struct UnalignedSpan; template struct UnalignedSpanIterator { - UnalignedSpanIterator(std::uint32_t index, UnalignedSpan const &span) + UnalignedSpanIterator(std::uint32_t index, UnalignedSpan const& span) : m_index(index), m_span(span) { } - UnalignedSpanIterator(UnalignedSpanIterator const &) = default; - UnalignedSpanIterator(UnalignedSpanIterator &&) noexcept = default; - UnalignedSpanIterator &operator=(UnalignedSpanIterator const &) = default; - UnalignedSpanIterator &operator=(UnalignedSpanIterator &&) noexcept = default; + UnalignedSpanIterator(UnalignedSpanIterator const&) = default; + UnalignedSpanIterator(UnalignedSpanIterator&&) noexcept = default; + UnalignedSpanIterator& operator=(UnalignedSpanIterator const&) = default; + UnalignedSpanIterator& operator=(UnalignedSpanIterator&&) noexcept = default; ~UnalignedSpanIterator() = default; - [[nodiscard]] auto operator==(UnalignedSpanIterator const &other) const + [[nodiscard]] auto operator==(UnalignedSpanIterator const& other) const { return m_span.bytes().data() == other.m_span.bytes().data() && m_index == other.m_index; } - [[nodiscard]] auto operator!=(UnalignedSpanIterator const &other) const + [[nodiscard]] auto operator!=(UnalignedSpanIterator const& other) const { return m_index != other.m_index || m_span.bytes().data() != other.m_span.bytes().data(); } [[nodiscard]] auto operator*() const { return m_span[m_index]; } - auto operator++() -> UnalignedSpanIterator & + auto operator++() -> UnalignedSpanIterator& { m_index++; return *this; @@ -44,7 +44,7 @@ struct UnalignedSpanIterator { m_index++; return copy; } - [[nodiscard]] auto operator+=(std::uint32_t n) -> UnalignedSpanIterator & + [[nodiscard]] auto operator+=(std::uint32_t n) -> UnalignedSpanIterator& { m_index += n; return *this; @@ -53,7 +53,7 @@ struct UnalignedSpanIterator { { return UnalignedSpanIterator(m_index + n, m_span); } - auto operator--() -> UnalignedSpanIterator & + auto operator--() -> UnalignedSpanIterator& { m_index--; return *this; @@ -64,7 +64,7 @@ struct UnalignedSpanIterator { m_index--; return copy; } - [[nodiscard]] auto operator-=(std::uint32_t n) -> UnalignedSpanIterator & + [[nodiscard]] auto operator-=(std::uint32_t n) -> UnalignedSpanIterator& { m_index -= n; return *this; @@ -73,30 +73,30 @@ struct UnalignedSpanIterator { { return UnalignedSpanIterator(m_index - n, m_span); } - [[nodiscard]] auto operator-(UnalignedSpanIterator const &other) const -> std::int32_t + [[nodiscard]] auto operator-(UnalignedSpanIterator const& other) const -> std::int32_t { return static_cast(m_index) - static_cast(other.m_index); } - [[nodiscard]] auto operator<(UnalignedSpanIterator const &other) const -> bool + [[nodiscard]] auto operator<(UnalignedSpanIterator const& other) const -> bool { return m_index < other.m_index; } - [[nodiscard]] auto operator<=(UnalignedSpanIterator const &other) const -> bool + [[nodiscard]] auto operator<=(UnalignedSpanIterator const& other) const -> bool { return m_index <= other.m_index; } - [[nodiscard]] auto operator>(UnalignedSpanIterator const &other) const -> bool + [[nodiscard]] auto operator>(UnalignedSpanIterator const& other) const -> bool { return m_index > other.m_index; } - [[nodiscard]] auto operator>=(UnalignedSpanIterator const &other) const -> bool + [[nodiscard]] auto operator>=(UnalignedSpanIterator const& other) const -> bool { return m_index >= other.m_index; } private: std::uint32_t m_index; - UnalignedSpan const &m_span; + UnalignedSpan const& m_span; }; template @@ -111,6 +111,12 @@ struct UnalignedSpan { throw std::logic_error("Number of bytes must be a multiplier of type size"); } } + constexpr UnalignedSpan(UnalignedSpan const&) = default; + constexpr UnalignedSpan(UnalignedSpan&&) noexcept = default; + constexpr UnalignedSpan& operator=(UnalignedSpan const&) = default; + constexpr UnalignedSpan& operator=(UnalignedSpan&&) noexcept = default; + ~UnalignedSpan() = default; + using iterator = UnalignedSpanIterator; [[nodiscard]] auto operator[](std::uint32_t index) const -> value_type @@ -150,8 +156,8 @@ struct iterator_traits<::pisa::v1::UnalignedSpanIterator> { using size_type = std::uint32_t; using difference_type = std::make_signed_t; using value_type = T; - using pointer = T const *; - using reference = T const &; + using pointer = T const*; + using reference = T const&; using iterator_category = std::random_access_iterator_tag; }; diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 1c6468bb7..e7e99a86c 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -1,9 +1,253 @@ #pragma once +#include + +#include + #include "v1/query.hpp" namespace pisa::v1 { +template +void partition_by_essential(gsl::span cursors, gsl::span essential_indices) +{ + if (essential_indices.empty()) { + return; + } + std::sort(essential_indices.begin(), essential_indices.end()); + if (essential_indices[essential_indices.size() - 1] >= cursors.size()) { + throw std::logic_error("Essential index too large"); + } + auto left = 0; + auto right = cursors.size() - 1; + auto eidx = 0; + while (left < right && eidx < essential_indices.size()) { + if (left < essential_indices[eidx]) { + left += 1; + } else { + std::swap(cursors[left], cursors[right]); + right -= 1; + eidx += 1; + } + } +} + +template +auto unigram_union_lookup( + Query query, Index const& index, topk_queue topk, Scorer&& scorer, Analyzer* analyzer = nullptr) +{ + if (not query.threshold) { + throw std::invalid_argument("Must provide threshold to the query"); + } + if (not query.list_selection) { + throw std::invalid_argument("Must provide essential list selection"); + } + if (not query.list_selection->bigrams.empty()) { + throw std::invalid_argument("This algorithm only supports unigrams"); + } + + topk.set_threshold(*query.threshold); + + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto cursors = index.max_scored_cursors(gsl::make_span(query.terms), scorer); + partition_by_essential(gsl::make_span(cursors), gsl::make_span(query.list_selection->unigrams)); + auto non_essential_count = cursors.size() - query.list_selection->unigrams.size(); + std::sort(cursors.begin(), + std::next(cursors.begin(), non_essential_count), + [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); + + std::vector upper_bounds(non_essential_count); + upper_bounds[0] = cursors[0].max_score(); + for (size_t idx = 1; idx < non_essential_count; idx += 1) { + upper_bounds[idx] = upper_bounds[idx - 1] + cursors[idx].max_score(); + } + + auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), + 0.0F, + [&](auto acc, auto& cursor, auto /*term_idx*/) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + return acc + cursor.payload(); + }); + + auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); + v1::for_each(merged_essential, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + analyzer->document(); + } + auto docid = cursor.value(); + auto score = cursor.payload(); + for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; + lookup_cursor_idx -= 1) { + if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { + return; + } + cursors[lookup_cursor_idx].advance_to_geq(docid); + if constexpr (not std::is_void_v) { + analyzer->lookup(); + } + if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { + score += cursors[lookup_cursor_idx].payload(); + } + } + if constexpr (not std::is_void_v) { + analyzer->insert(); + } + topk.insert(score, docid); + }); + return topk; +} + +template +auto maxscore_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Analyzer* analyzer = nullptr) +{ + if (not query.threshold) { + throw std::invalid_argument("Must provide threshold to the query"); + } + + topk.set_threshold(*query.threshold); + + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto cursors = index.max_scored_cursors(gsl::make_span(query.terms), scorer); + std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() < rhs.max_score(); + }); + + std::vector upper_bounds(cursors.size()); + upper_bounds[0] = cursors[0].max_score(); + for (size_t i = 1; i < cursors.size(); ++i) { + upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); + } + std::size_t non_essential_count = 0; + while (non_essential_count < cursors.size() + && upper_bounds[non_essential_count] < *query.threshold) { + non_essential_count += 1; + } + + std::vector unigrams(cursors.size() - non_essential_count); + std::iota(unigrams.begin(), unigrams.end(), non_essential_count); + Query query_with_selections = query; + query_with_selections.list_selection = + tl::make_optional(ListSelection{.unigrams = {}, .bigrams = {}}); + return unigram_union_lookup(std::move(query_with_selections), + index, + std::move(topk), + std::forward(scorer), + analyzer); + + // auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), + // 0.0F, + // [&](auto acc, auto& cursor, auto /*term_idx*/) { + // if constexpr (not std::is_void_v) { + // analyzer->posting(); + // } + // return acc + cursor.payload(); + // }); + + // auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); + // v1::for_each(merged_essential, [&](auto& cursor) { + // if constexpr (not std::is_void_v) { + // analyzer->document(); + // } + // auto docid = cursor.value(); + // auto score = cursor.payload(); + // for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; + // lookup_cursor_idx -= 1) { + // if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { + // return; + // } + // cursors[lookup_cursor_idx].advance_to_geq(docid); + // if constexpr (not std::is_void_v) { + // analyzer->lookup(); + // } + // if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { + // score += cursors[lookup_cursor_idx].payload(); + // } + // } + // if constexpr (not std::is_void_v) { + // analyzer->insert(); + // } + // topk.insert(score, docid); + //}); + // return topk; +} + +template +struct MaxscoreUnionLookupAnalyzer { + MaxscoreUnionLookupAnalyzer(Index const& index, Scorer scorer) + : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); + } + + void reset_current() + { + m_current_documents = 0; + m_current_postings = 0; + m_current_lookups = 0; + m_current_inserts = 0; + } + + void operator()(Query const& query) + { + using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); + using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); + + reset_current(); + topk_queue topk(query.k); + maxscore_union_lookup(query, m_index, topk, m_scorer, this); + std::cout << fmt::format("{}\t{}\t{}\t{}\n", + m_current_documents, + m_current_postings, + m_current_inserts, + m_current_lookups); + m_documents += m_current_documents; + m_postings += m_current_postings; + m_lookups += m_current_lookups; + m_inserts += m_current_inserts; + m_count += 1; + } + + void summarize() && + { + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" + "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", + static_cast(m_documents) / m_count, + static_cast(m_postings) / m_count, + static_cast(m_inserts) / m_count, + static_cast(m_lookups) / m_count); + } + + void document() { m_current_documents += 1; } + void posting() { m_current_postings += 1; } + void lookup() { m_current_lookups += 1; } + void insert() { m_current_inserts += 1; } + + private: + std::size_t m_current_documents = 0; + std::size_t m_current_postings = 0; + std::size_t m_current_lookups = 0; + std::size_t m_current_inserts = 0; + + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_lookups = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + /// Performs a "union-lookup" query (name pending). /// /// \param query Full query, as received, possibly with duplicates. @@ -30,70 +274,57 @@ auto union_lookup(Query const& query, ranges::sort(essential_bigrams); ranges::actions::unique(essential_bigrams); + topk.set_threshold(*query.threshold); + std::vector is_essential(query.terms.size(), false); - // std::cerr << "essential: "; for (auto idx : essential_unigrams) { - // std::cerr << idx << ' '; is_essential[idx] = true; } - // std::cerr << '\n'; - - // std::vector initial_payload(query.terms.size(), 0.0); - // std::vector essential_unigram_cursors; - // std::transform(essential_unigrams.begin(), - // essential_unigrams.end(), - // std::back_inserter(essential_unigram_cursors), - // [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); - // auto merged_unigrams = v1::union_merge( - // essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { - // acc[essential_unigrams[term_idx]] = cursor.payload(); - // return acc; - // }); + std::vector initial_payload(query.terms.size(), 0.0); std::vector essential_unigram_cursors; std::transform(essential_unigrams.begin(), essential_unigrams.end(), std::back_inserter(essential_unigram_cursors), [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); - // std::cerr << "No. essential: " << essential_unigram_cursors.size() << '\n'; auto merged_unigrams = v1::union_merge( - essential_unigram_cursors, 0.0F, [&](auto acc, auto& cursor, auto /*term_idx*/) { - // acc[essential_unigrams[term_idx]] = cursor.payload(); - return acc + cursor.payload(); + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { + acc[essential_unigrams[term_idx]] = cursor.payload(); + return acc; }); - // std::vector essential_bigram_cursors; - // std::transform(essential_bigrams.begin(), - // essential_bigrams.end(), - // std::back_inserter(essential_bigram_cursors), - // [&](auto intersection) { - // return index.scored_bigram_cursor(query.terms[intersection.first], - // query.terms[intersection.second], - // scorer); - // }); - // auto merged_bigrams = - // v1::union_merge(std::move(essential_bigram_cursors), - // initial_payload, - // [&](auto& acc, auto& cursor, auto term_idx) { - // auto payload = cursor.payload(); - // acc[essential_bigrams[term_idx].first] = std::get<0>(payload); - // acc[essential_bigrams[term_idx].second] = std::get<1>(payload); - // return acc; - // }); - - // auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { - // auto payload = cursor.payload(); - // for (auto idx = 0; idx < acc.size(); idx += 1) { - // acc[idx] = payload[idx]; - // } - // return acc; - //}; - // auto merged = v1::variadic_union_merge( - // initial_payload, - // std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), - // std::make_tuple(accumulate, accumulate)); + std::vector essential_bigram_cursors; + std::transform(essential_bigrams.begin(), + essential_bigrams.end(), + std::back_inserter(essential_bigram_cursors), + [&](auto intersection) { + return index.scored_bigram_cursor(query.terms[intersection.first], + query.terms[intersection.second], + scorer); + }); + auto merged_bigrams = + v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto term_idx) { + auto payload = cursor.payload(); + acc[essential_bigrams[term_idx].first] = std::get<0>(payload); + acc[essential_bigrams[term_idx].second] = std::get<1>(payload); + return acc; + }); + + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + acc[idx] = payload[idx]; + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); std::vector lookup_cursors; for (auto idx = 0; idx < query.terms.size(); idx += 1) { @@ -101,49 +332,32 @@ auto union_lookup(Query const& query, lookup_cursors.push_back(index.max_scored_cursor(query.terms[idx], scorer)); } } - // std::transform(query.terms.begin(), - // query.terms.end(), - // std::back_inserter(lookup_cursors), - // [&](auto term) { return index.scored_cursor(term, scorer); }); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() > rhs.max_score(); + }); + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); - // v1::for_each(merged, [&](auto& cursor) { - v1::for_each(merged_unigrams, [&](auto& cursor) { + v1::for_each(merged, [&](auto& cursor) { auto docid = cursor.value(); - auto score = cursor.payload(); - auto score_bound = std::accumulate( - lookup_cursors.begin(), lookup_cursors.end(), score, [](auto acc, auto&& cursor) { - return acc + cursor.max_score(); - }); - if (not topk.would_enter(score_bound)) { - return; - } + auto scores = cursor.payload(); + auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + auto upper_bound = score + lookup_cursors_upper_bound; for (auto lookup_cursor : lookup_cursors) { - // lookup_cursor.advance(); + if (not topk.would_enter(upper_bound)) { + return; + } lookup_cursor.advance_to_geq(docid); if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - score += lookup_cursor.payload(); - } - if (not topk.would_enter(score - lookup_cursor.max_score())) { - return; + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; } + upper_bound -= lookup_cursor.max_score(); } topk.insert(score, docid); - // auto docid = cursor.value(); - // auto partial_scores = cursor.payload(); - // float score = 0.0F; - // for (auto idx = 0; idx < partial_scores.size(); idx += 1) { - // score += partial_scores[idx]; - // // if (partial_scores[idx] > 0.0F) { - // // score += partial_scores[idx]; - // //} - // // else if (not is_essential[idx]) { - // // lookup_cursors[idx].advance_to_geq(docid); - // // if (lookup_cursors[idx].value() == docid) { - // // score += lookup_cursors[idx].payload(); - // // } - // //} - //} - // topk.insert(score, docid); }); return topk; } diff --git a/test/v1/test_v1.cpp b/test/v1/test_v1.cpp index 40c89643a..4b23ed14e 100644 --- a/test/v1/test_v1.cpp +++ b/test/v1/test_v1.cpp @@ -5,10 +5,12 @@ #include #include +#include #include #include "io.hpp" #include "pisa_config.hpp" +#include "v1/algorithm.hpp" #include "v1/cursor/collect.hpp" #include "v1/index.hpp" #include "v1/io.hpp" @@ -48,6 +50,47 @@ std::ostream& operator<<(std::ostream& os, tl::optional const& val) return os; } +TEST_CASE("partition_by_index", "[v1][unit]") +{ + std::vector values{5, 0, 1, 2, 3, 4}; + // auto input_data = GENERATE(table, std::vector>( + // {{{}, {5, 0, 1, 2, 3, 4}}, + // {{0, 1, 2}, {2, 3, 4, 5, 0, 1}}, + // {{3, 4, 5}, {}}, + // {{0, 4, 5}, {}}})); + // auto input_data = GENERATE(table, std::vector>( + // {{{}, {5, 0, 1, 2, 3, 4}}, + // {{0, 1, 2}, {2, 3, 4, 5, 0, 1}}, + // {{3, 4, 5}, {}}, + // {{0, 4, 5}, {}}})); + + auto expected = [](auto input_vec, auto right_indices) { + std::vector essential; + std::sort(right_indices.begin(), right_indices.end(), std::greater{}); + for (auto idx : right_indices) { + essential.push_back(input_vec[idx]); + input_vec.erase(std::next(input_vec.begin(), idx)); + } + std::sort(essential.begin(), essential.end()); + std::sort(input_vec.begin(), input_vec.end()); + input_vec.insert(input_vec.end(), essential.begin(), essential.end()); + return input_vec; + }; + + std::vector input{5, 0, 1, 2, 3, 4}; + std::vector right_indices{}; + auto expected_output = expected(input, right_indices); + pisa::v1::partition_by_index(gsl::make_span(input), gsl::make_span(right_indices)); + std::sort(input.begin(), std::next(input.begin(), input.size() - right_indices.size())); + std::sort(std::next(input.begin(), right_indices.size()), input.end()); + REQUIRE(input == expected_output); + + // auto [right_indices, expected] = GENERATE( + // table, std::vector>({{}, {5, 0, 1, 2, 3, 4}})); + // pisa::v1::partition_by_index(gsl::make_span(values), gsl::make_span(right_indices)); + // REQUIRE(values == expected); +} + TEST_CASE("RawReader", "[v1][unit]") { std::vector const mem{5, 0, 1, 2, 3, 4}; diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp index cedb6cd1c..a6f4a4f26 100644 --- a/test/v1/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -56,9 +56,9 @@ TEST_CASE("Binary collection index", "[v1][unit]") REQUIRE(bci.size() == index.num_terms()); auto bci_iter = bci.begin(); for (auto term = 0; term < 1'000; term += 1) { - REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) == collect(index.documents(term))); - REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) == collect(index.payloads(term))); ++bci_iter; } @@ -91,9 +91,9 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") REQUIRE(bci.size() == index.num_terms()); auto bci_iter = bci.begin(); for (auto term = 0; term < 1'000; term += 1) { - REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) == collect(index.documents(term))); - REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) == collect(index.payloads(term))); ++bci_iter; } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index de84365a3..8c388c1c8 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -2,6 +2,7 @@ #include "catch2/catch.hpp" #include +#include #include #include @@ -124,8 +125,14 @@ TEMPLATE_TEST_CASE("Query", v1::Index, v1::RawCursor>, v1::Index, v1::RawCursor>>::get(); TestType fixture; - auto algorithm = GENERATE(std::string("daat_or"), std::string("maxscore")); + auto input_data = GENERATE(table({{"daat_or", false}, + {"maxscore", false}, + {"maxscore", true}, + {"maxscore_union_lookup", true}})); + std::string algorithm = std::get<0>(input_data); + bool with_threshold = std::get<1>(input_data); CAPTURE(algorithm); + CAPTURE(with_threshold); auto index_basename = (fixture.tmpdir().path() / "inv").string(); auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); ranked_or_query or_q(10); @@ -136,6 +143,9 @@ TEMPLATE_TEST_CASE("Query", if (name == "maxscore") { return maxscore(query, index, topk_queue(10), scorer); } + if (name == "maxscore_union_lookup") { + return maxscore_union_lookup(query, index, topk_queue(10), scorer); + } std::abort(); }; int idx = 0; @@ -147,67 +157,87 @@ TEMPLATE_TEST_CASE("Query", or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); auto expected = or_q.topk(); - std::sort(expected.begin(), expected.end(), approximate_order); + auto threshold = [&]() { + if (with_threshold) { + return tl::make_optional(expected.back().first - 1.0F); + } + return tl::optional{}; + }(); auto on_the_fly = [&]() { auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { - auto que = run_query(algorithm, v1::Query{q.terms}, index, make_bm25(index)); + auto que = run_query(algorithm, + v1::Query{.terms = q.terms, + .list_selection = {}, + .threshold = threshold, + .id = {}, + .k = 10}, + index, + make_bm25(index)); que.finalize(); results = que.topk(); + results.erase(std::remove_if(results.begin(), + results.end(), + [last_score = results.back().first](auto&& entry) { + return entry.first <= last_score; + }), + results.end()); std::sort(results.begin(), results.end(), approximate_order); }); return results; }(); + expected.resize(on_the_fly.size()); + std::sort(expected.begin(), expected.end(), approximate_order); - REQUIRE(expected.size() == on_the_fly.size()); for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); } - auto precomputed = [&]() { - auto run = - v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); - std::vector results; - run([&](auto&& index) { - auto que = run_query(algorithm, v1::Query{q.terms}, index, v1::VoidScorer{}); - que.finalize(); - results = que.topk(); - }); - // Remove the tail that might be different due to quantization error. - // Note that `precomputed` will have summed quantized score, while the - // vector we compare to will have quantized sum---that's why whe remove anything - // that's withing 2 of the last result. - // auto last_score = results.back().first; - // results.erase(std::remove_if( - // results.begin(), - // results.end(), - // [last_score](auto&& entry) { return entry.first <= last_score + 3; - // }), - // results.end()); - // results.resize(5); - // std::sort(results.begin(), results.end(), [](auto&& lhs, auto&& rhs) { - // return lhs.second < rhs.second; - //}); - return results; - }(); + // auto precomputed = [&]() { + // auto run = + // v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); + // std::vector results; + // run([&](auto&& index) { + // auto que = run_query(algorithm, v1::Query{q.terms}, index, v1::VoidScorer{}); + // que.finalize(); + // results = que.topk(); + // }); + // // Remove the tail that might be different due to quantization error. + // // Note that `precomputed` will have summed quantized score, while the + // // vector we compare to will have quantized sum---that's why whe remove anything + // // that's withing 2 of the last result. + // // auto last_score = results.back().first; + // // results.erase(std::remove_if( + // // results.begin(), + // // results.end(), + // // [last_score](auto&& entry) { return entry.first <= last_score + 3; + // // }), + // // results.end()); + // // results.resize(5); + // // std::sort(results.begin(), results.end(), [](auto&& lhs, auto&& rhs) { + // // return lhs.second < rhs.second; + // //}); + // return results; + //}(); - constexpr float max_partial_score = 16.5724F; - auto quantizer = [&](float score) { - return static_cast(score * std::numeric_limits::max() - / max_partial_score); - }; + // constexpr float max_partial_score = 16.5724F; + // auto quantizer = [&](float score) { + // return static_cast(score * std::numeric_limits::max() + // / max_partial_score); + //}; - auto expected_quantized = expected; - std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& rhs) { - return lhs.first > rhs.first; - }); - for (auto& v : expected_quantized) { - v.first = quantizer(v.first); - } + // auto expected_quantized = expected; + // std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& + // rhs) { + // return lhs.first > rhs.first; + //}); + // for (auto& v : expected_quantized) { + // v.first = quantizer(v.first); + //} // TODO(michal): test the quantized results diff --git a/v1/query.cpp b/v1/query.cpp index 6c9cd46f1..dd8571f79 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -10,6 +10,7 @@ #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" +#include "v1/analyze_query.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index_metadata.hpp" #include "v1/maxscore.hpp" @@ -18,55 +19,76 @@ #include "v1/scorer/bm25.hpp" #include "v1/scorer/runner.hpp" #include "v1/types.hpp" +#include "v1/union_lookup.hpp" -using pisa::Query; using pisa::resolve_query_parser; using pisa::v1::BlockedReader; using pisa::v1::daat_or; +using pisa::v1::DaatOrAnalyzer; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::maxscore_union_lookup; +using pisa::v1::MaxscoreAnalyzer; +using pisa::v1::MaxscoreUnionLookupAnalyzer; +using pisa::v1::Query; +using pisa::v1::QueryAnalyzer; using pisa::v1::RawReader; using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; -using RetrievalAlgorithm = - std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue, tl::optional)>; +using RetrievalAlgorithm = std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)>; template auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& scorer) -> RetrievalAlgorithm { if (name == "daat_or") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, - ::pisa::topk_queue topk, - [[maybe_unused]] tl::optional threshold) { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { return pisa::v1::daat_or(query, index, std::move(topk), std::forward(scorer)); }); } if (name == "maxscore") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, - ::pisa::topk_queue topk, - tl::optional threshold) { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold) { + topk.set_threshold(*query.threshold); + } return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); }); } + if (name == "maxscore-union-lookup") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::maxscore_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }); + } + spdlog::error("Unknown algorithm: {}", name); + std::exit(1); +} + +template +auto resolve_analyze(std::string const& name, Index const& index, Scorer&& scorer) -> QueryAnalyzer +{ + if (name == "daat_or") { + return QueryAnalyzer(DaatOrAnalyzer(index, std::forward(scorer))); + } + if (name == "maxscore") { + return QueryAnalyzer(MaxscoreAnalyzer(index, std::forward(scorer))); + } + if (name == "maxscore-union-lookup") { + return QueryAnalyzer(MaxscoreUnionLookupAnalyzer(index, std::forward(scorer))); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } -template -void evaluate(std::vector const& queries, - Index&& index, - Scorer&& scorer, +void evaluate(std::vector const& queries, int k, pisa::Payload_Vector<> const& docmap, - Algorithm&& retrieve, - tl::optional> thresholds) + RetrievalAlgorithm const& retrieve) { auto query_idx = 0; for (auto const& query : queries) { - auto threshold = thresholds.map([query_idx](auto&& vec) { return vec[query_idx]; }); - auto que = retrieve(pisa::v1::Query{query.terms}, pisa::topk_queue(k), threshold); + auto que = retrieve(query, pisa::topk_queue(k)); que.finalize(); auto rank = 0; for (auto result : que.topk()) { @@ -83,23 +105,14 @@ void evaluate(std::vector const& queries, } } -template -void benchmark(std::vector const& queries, - Index&& index, - Scorer&& scorer, - int k, - Algorithm&& retrieve, - tl::optional> thresholds) +void benchmark(std::vector const& queries, int k, RetrievalAlgorithm retrieve) { std::vector times(queries.size(), std::numeric_limits::max()); for (auto run = 0; run < 5; run += 1) { for (auto query = 0; query < queries.size(); query += 1) { - float threshold = - thresholds.map([query](auto&& vec) { return vec[query]; }).value_or(0.0F); auto usecs = ::pisa::run_with_timer([&]() { - auto que = - retrieve(pisa::v1::Query{queries[query].terms}, pisa::topk_queue(k), threshold); + auto que = retrieve(queries[query], pisa::topk_queue(k)); que.finalize(); do_not_optimize_away(que); }); @@ -117,13 +130,27 @@ void benchmark(std::vector const& queries, spdlog::info("95% quantile: {}", q95); } +void analyze_queries(std::vector const& queries, QueryAnalyzer analyzer) +{ + std::vector times(queries.size(), std::numeric_limits::max()); + for (auto run = 0; run < 5; run += 1) { + for (auto query = 0; query < queries.size(); query += 1) { + analyzer(queries[query]); + } + } + std::move(analyzer).summarize(); +} + int main(int argc, char** argv) { std::string algorithm = "daat_or"; tl::optional threshold_file; + bool analyze = false; + pisa::QueryApp app("Queries a v1 index."); app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); - app.add_option("--thredsholds", algorithm, "File with (estimated) thresholds.", false); + app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.", false); + app.add_flag("--analyze", analyze, "Analyze query execution and stats"); CLI11_PARSE(app, argc, argv); auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); @@ -135,14 +162,31 @@ int main(int argc, char** argv) app.documents_file = meta.document_lexicon.value(); } - std::vector queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } + auto queries = [&]() { + std::vector<::pisa::Query> queries; + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + std::vector v1_queries(queries.size()); + std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& query) { + return Query{.terms = query.terms, + .list_selection = {}, + .threshold = {}, + .id = + [&]() { + if (query.id) { + return tl::make_optional(*query.id); + } + return tl::optional{}; + }(), + .k = app.k}; + }); + return v1_queries; + }(); if (not app.documents_file) { spdlog::error("Document lexicon not defined"); @@ -151,20 +195,22 @@ int main(int argc, char** argv) auto source = std::make_shared(app.documents_file.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); - auto thresholds = [&threshold_file, &queries]() { - if (threshold_file) { - std::vector thresholds; - std::ifstream is(*threshold_file); - pisa::io::for_each_line( - is, [&thresholds](auto&& line) { thresholds.push_back(std::stof(line)); }); - if (thresholds.size() != queries.size()) { + if (threshold_file) { + std::ifstream is(*threshold_file); + auto queries_iter = queries.begin(); + pisa::io::for_each_line(is, [&](auto&& line) { + if (queries_iter == queries.end()) { spdlog::error("Number of thresholds not equal to number of queries"); std::exit(1); } - return tl::make_optional(thresholds); + queries_iter->threshold = tl::make_optional(std::stof(line)); + ++queries_iter; + }); + if (queries_iter != queries.end()) { + spdlog::error("Number of thresholds not equal to number of queries"); + std::exit(1); } - return tl::optional>{}; - }(); + } if (app.precomputed) { auto run = scored_index_runner(meta, @@ -173,11 +219,12 @@ int main(int argc, char** argv) BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto&& index) { - auto retrieve = resolve_algorithm(algorithm, index, VoidScorer{}); if (app.is_benchmark) { - benchmark(queries, index, VoidScorer{}, app.k, retrieve, thresholds); + benchmark(queries, app.k, resolve_algorithm(algorithm, index, VoidScorer{})); + } else if (analyze) { + analyze_queries(queries, resolve_analyze(algorithm, index, VoidScorer{})); } else { - evaluate(queries, index, VoidScorer{}, app.k, docmap, retrieve, thresholds); + evaluate(queries, app.k, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); } }); } else { @@ -188,11 +235,12 @@ int main(int argc, char** argv) run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { - auto retrieve = resolve_algorithm(algorithm, index, scorer); if (app.is_benchmark) { - benchmark(queries, index, scorer, app.k, retrieve, thresholds); + benchmark(queries, app.k, resolve_algorithm(algorithm, index, scorer)); + } else if (analyze) { + analyze_queries(queries, resolve_analyze(algorithm, index, scorer)); } else { - evaluate(queries, index, scorer, app.k, docmap, retrieve, thresholds); + evaluate(queries, app.k, docmap, resolve_algorithm(algorithm, index, scorer)); } }); }); diff --git a/v1/union_lookup.cpp b/v1/union_lookup.cpp index 6c303d6a1..e8a507f7d 100644 --- a/v1/union_lookup.cpp +++ b/v1/union_lookup.cpp @@ -21,18 +21,18 @@ #include "v1/types.hpp" #include "v1/union_lookup.hpp" -using pisa::Query; using pisa::resolve_query_parser; using pisa::v1::BlockedReader; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::Query; using pisa::v1::RawReader; using pisa::v1::resolve_yml; using pisa::v1::union_lookup; using pisa::v1::VoidScorer; template -void evaluate(std::vector const& queries, +void evaluate(std::vector const& queries, Index&& index, Scorer&& scorer, int k, @@ -44,7 +44,7 @@ void evaluate(std::vector const& queries, for (auto const& query : queries) { std::vector uni(query.terms.size()); std::iota(uni.begin(), uni.end(), 0); - auto que = union_lookup(pisa::v1::Query{query.terms}, + auto que = union_lookup(query, index, pisa::topk_queue(k), scorer, @@ -68,7 +68,7 @@ void evaluate(std::vector const& queries, } template -void benchmark(std::vector const& queries, +void benchmark(std::vector const& queries, Index&& index, Scorer&& scorer, int k, @@ -82,7 +82,7 @@ void benchmark(std::vector const& queries, std::vector uni(queries[query].terms.size()); std::iota(uni.begin(), uni.end(), 0); auto usecs = ::pisa::run_with_timer([&]() { - auto que = union_lookup(pisa::v1::Query{queries[query].terms}, + auto que = union_lookup(queries[query], index, pisa::topk_queue(k), scorer, @@ -110,8 +110,11 @@ void benchmark(std::vector const& queries, int main(int argc, char** argv) { std::string inter_filename; + std::string threshold_file; + pisa::QueryApp app("Queries a v1 index."); app.add_option("--intersections", inter_filename, "Intersections filename")->required(); + app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.")->required(); CLI11_PARSE(app, argc, argv); auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); @@ -123,13 +126,45 @@ int main(int argc, char** argv) app.documents_file = meta.document_lexicon.value(); } - std::vector queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); + auto queries = [&]() { + std::vector<::pisa::Query> queries; + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + std::vector v1_queries(queries.size()); + std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& query) { + return Query{.terms = query.terms, + .list_selection = {}, + .threshold = {}, + .id = + [&]() { + if (query.id) { + return tl::make_optional(*query.id); + } + return tl::optional{}; + }(), + .k = app.k}; + }); + return v1_queries; + }(); + + std::ifstream is(threshold_file); + auto queries_iter = queries.begin(); + pisa::io::for_each_line(is, [&](auto&& line) { + if (queries_iter == queries.end()) { + spdlog::error("Number of thresholds not equal to number of queries"); + std::exit(1); + } + queries_iter->threshold = tl::make_optional(std::stof(line)); + ++queries_iter; + }); + if (queries_iter != queries.end()) { + spdlog::error("Number of thresholds not equal to number of queries"); + std::exit(1); } auto intersections = [&]() { From 945b64423dda87e299c1e386354051fe336289cd Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Thu, 28 Nov 2019 08:16:48 -0500 Subject: [PATCH 24/56] Union-lookup with bigrams --- external/CMakeLists.txt | 8 +- include/pisa/v1/algorithm.hpp | 23 +- include/pisa/v1/analyze_query.hpp | 69 +++++ include/pisa/v1/cursor_union.hpp | 88 +++++- include/pisa/v1/index.hpp | 78 +++-- include/pisa/v1/index_builder.hpp | 8 + include/pisa/v1/intersection.hpp | 50 +++ include/pisa/v1/maxscore.hpp | 10 + include/pisa/v1/query.hpp | 12 +- include/pisa/v1/union_lookup.hpp | 412 ++++++++++++++---------- include/pisa/v1/zip_cursor.hpp | 25 +- src/CMakeLists.txt | 214 ++++++------- src/compute_intersection.cpp | 30 +- src/v1/index_builder.cpp | 106 +++++++ src/v1/intersection.cpp | 77 +++++ src/v1/query.cpp | 39 +++ test/test_data/top10_selections | 500 ++++++++++++++++++++++++++++++ test/test_data/top10_thresholds | 500 ++++++++++++++++++++++++++++++ test/v1/CMakeLists.txt | 1 + test/v1/index_fixture.hpp | 35 ++- test/v1/test_v1.cpp | 40 +-- test/v1/test_v1_bigram_index.cpp | 104 +++++++ test/v1/test_v1_queries.cpp | 135 +++----- v1/CMakeLists.txt | 3 - v1/bigram_index.cpp | 146 +++------ v1/query.cpp | 59 +++- 26 files changed, 2210 insertions(+), 562 deletions(-) create mode 100644 include/pisa/v1/analyze_query.hpp create mode 100644 include/pisa/v1/intersection.hpp create mode 100644 src/v1/intersection.cpp create mode 100644 test/test_data/top10_selections create mode 100644 test/test_data/top10_thresholds create mode 100644 test/v1/test_v1_bigram_index.cpp diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 865ec1a94..879a73d45 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -126,8 +126,6 @@ set(YAML_CPP_BUILD_TOOLS OFF CACHE BOOL "skip building YAML tools") set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "skip building YAML tests") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/yaml-cpp EXCLUDE_FROM_ALL) -# Add tl::expected -#set(EXPECTED_BUILD_TESTS OFF CACHE BOOL "skip tl::expected testing") -#set(EXPECTED_BUILD_PACKAGE OFF CACHE BOOL "skip tl::expected package") -#set(EXPECTED_BUILD_PACKAGE_DEB OFF CACHE BOOL "skip tl::expected package deb") -#add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/expected EXCLUDE_FROM_ALL) +# Add RapidCheck +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/rapidcheck) +target_compile_options(rapidcheck PRIVATE -Wno-error=all) diff --git a/include/pisa/v1/algorithm.hpp b/include/pisa/v1/algorithm.hpp index 08cd10941..7d8f5481a 100644 --- a/include/pisa/v1/algorithm.hpp +++ b/include/pisa/v1/algorithm.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -36,18 +37,26 @@ void partition_by_index(gsl::span range, gsl::span right_indices if (right_indices[right_indices.size() - 1] >= range.size()) { throw std::logic_error("Essential index too large"); } - auto left = 0; - auto right = range.size() - 1; + std::vector essential; + essential.reserve(right_indices.size()); + std::vector non_essential; + non_essential.reserve(range.size() - right_indices.size()); + + auto cidx = 0; auto eidx = 0; - while (left < right && eidx < right_indices.size()) { - if (left < right_indices[eidx]) { - left += 1; + while (eidx < right_indices.size()) { + if (cidx < right_indices[eidx]) { + non_essential.push_back(std::move(range[cidx])); + cidx += 1; } else { - std::swap(range[left], range[right]); - right -= 1; + essential.push_back(std::move(range[cidx])); eidx += 1; + cidx += 1; } } + std::move(std::next(range.begin(), cidx), range.end(), std::back_inserter(non_essential)); + auto pos = std::move(non_essential.begin(), non_essential.end(), range.begin()); + std::move(essential.begin(), essential.end(), pos); } } // namespace pisa::v1 diff --git a/include/pisa/v1/analyze_query.hpp b/include/pisa/v1/analyze_query.hpp new file mode 100644 index 000000000..77f30ee3c --- /dev/null +++ b/include/pisa/v1/analyze_query.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include + +#include +#include + +namespace pisa::v1 { + +struct Query; + +struct QueryAnalyzer { + + template + explicit constexpr QueryAnalyzer(R writer) + : m_internal_analyzer(std::make_unique>(writer)) + { + } + QueryAnalyzer() = default; + QueryAnalyzer(QueryAnalyzer const& other) + : m_internal_analyzer(other.m_internal_analyzer->clone()) + { + } + QueryAnalyzer(QueryAnalyzer&& other) noexcept = default; + QueryAnalyzer& operator=(QueryAnalyzer const& other) = delete; + QueryAnalyzer& operator=(QueryAnalyzer&& other) noexcept = default; + ~QueryAnalyzer() = default; + + void operator()(Query const& query) { m_internal_analyzer->operator()(query); } + void summarize() && { std::move(*m_internal_analyzer).summarize(); } + + struct AnalyzerInterface { + AnalyzerInterface() = default; + AnalyzerInterface(AnalyzerInterface const&) = default; + AnalyzerInterface(AnalyzerInterface&&) noexcept = default; + AnalyzerInterface& operator=(AnalyzerInterface const&) = default; + AnalyzerInterface& operator=(AnalyzerInterface&&) noexcept = default; + virtual ~AnalyzerInterface() = default; + virtual void operator()(Query const& query) = 0; + virtual void summarize() && = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct AnalyzerImpl : AnalyzerInterface { + explicit AnalyzerImpl(R analyzer) : m_analyzer(std::move(analyzer)) {} + AnalyzerImpl() = default; + AnalyzerImpl(AnalyzerImpl const&) = default; + AnalyzerImpl(AnalyzerImpl&&) noexcept = default; + AnalyzerImpl& operator=(AnalyzerImpl const&) = default; + AnalyzerImpl& operator=(AnalyzerImpl&&) noexcept = default; + ~AnalyzerImpl() override = default; + void operator()(Query const& query) override { m_analyzer(query); } + void summarize() && override { std::move(m_analyzer).summarize(); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_analyzer; + }; + + private: + std::unique_ptr m_internal_analyzer; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index cd30d090e..adeb28097 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -11,6 +12,18 @@ namespace pisa::v1 { +template +void init_payload(T& payload, T const& initial_value) +{ + payload = initial_value; +} + +template <> +inline void init_payload(std::vector& payload, std::vector const& initial_value) +{ + std::copy(initial_value.begin(), initial_value.end(), payload.begin()); +} + /// Transforms a list of cursors into one cursor by lazily merging them together. template struct CursorUnion { @@ -27,6 +40,7 @@ struct CursorUnion { m_accumulate(std::move(accumulate)), m_size(std::nullopt) { + m_current_payload = m_init; if (m_cursors.empty()) { m_current_value = std::numeric_limits::max(); } else { @@ -58,9 +72,9 @@ struct CursorUnion { { if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { m_current_value = m_sentinel; - m_current_payload = m_init; + ::pisa::v1::init_payload(m_current_payload, m_init); } else { - m_current_payload = m_init; + ::pisa::v1::init_payload(m_current_payload, m_init); m_current_value = m_next_docid; m_next_docid = m_sentinel; std::size_t cursor_idx = 0; @@ -98,6 +112,76 @@ struct CursorUnion { std::uint32_t m_next_docid{}; }; +/// Transforms a list of cursors into one cursor by lazily merging them together. +//template +//struct CursorFlatUnion { +// using Cursor = typename CursorContainer::value_type; +// using iterator_category = +// typename std::iterator_traits::iterator_category; +// static_assert(std::is_base_of(), +// "cursors must be stored in a random access container"); +// using Value = std::decay_t())>; +// +// explicit constexpr CursorFlatUnion(CursorContainer cursors) : m_cursors(std::move(cursors)) +// { +// m_current_payload = m_init; +// if (m_cursors.empty()) { +// m_current_value = std::numeric_limits::max(); +// } else { +// m_next_docid = min_value(m_cursors); +// m_sentinel = min_sentinel(m_cursors); +// advance(); +// } +// } +// +// [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& +// { +// return m_current_payload; +// } +// [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } +// +// constexpr void advance() +// { +// //if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { +// // m_current_value = m_sentinel; +// //} else { +// // m_current_value = m_next_docid; +// // m_next_docid = m_sentinel; +// // std::size_t cursor_idx = 0; +// // for (auto& cursor : m_cursors) { +// // if (cursor.value() == m_current_value) { +// // m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); +// // cursor.advance(); +// // } +// // if (cursor.value() < m_next_docid) { +// // m_next_docid = cursor.value(); +// // } +// // ++cursor_idx; +// // } +// //} +// } +// +// constexpr void advance_to_position(std::size_t pos); // TODO(michal) +// constexpr void advance_to_geq(Value value); // TODO(michal) +// [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) +// +// [[nodiscard]] constexpr auto empty() const noexcept -> bool +// { +// return m_current_value >= sentinel(); +// } +// +// private: +// CursorContainer m_cursors; +// +// std::size_t m_cursor_idx = 0; +// Value m_current_value{}; +// Value m_sentinel{}; +// Payload m_current_payload{}; +// std::uint32_t m_next_docid{}; +//}; + /// Transforms a list of cursors into one cursor by lazily merging them together. template struct VariadicCursorUnion { diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 9204d8dae..26dff12e9 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -133,25 +133,50 @@ struct Index { }); return cursors; } - - [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const + [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional { if (not m_bigram_mapping) { throw std::logic_error("Bigrams are missing"); } - if (auto pos = std::lower_bound(m_bigram_mapping->begin(), - m_bigram_mapping->end(), - std::array{left_term, right_term}, - compare_arrays); + if (right_term == left_term) { + throw std::logic_error("Requested bigram of two identical terms"); + } + auto bigram = std::array{left_term, right_term}; + if (right_term < left_term) { + std::swap(bigram[0], bigram[1]); + } + if (auto pos = std::lower_bound( + m_bigram_mapping->begin(), m_bigram_mapping->end(), bigram, compare_arrays); pos != m_bigram_mapping->end()) { - auto bigram_id = std::distance(m_bigram_mapping->begin(), pos); - return DocumentPayloadCursor>( - m_document_reader.read(fetch_bigram_documents(bigram_id)), - zip(m_payload_reader.read(fetch_bigram_payloads<0>(bigram_id)), - m_payload_reader.read(fetch_bigram_payloads<1>(bigram_id)))); + if (*pos == bigram) { + return tl::make_optional(std::distance(m_bigram_mapping->begin(), pos)); + } } - throw std::invalid_argument( - fmt::format("Bigram for <{}, {}> not found.", left_term, right_term)); + return tl::nullopt; + } + + [[nodiscard]] auto bigram_payloads_0(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return m_payload_reader.read(fetch_bigram_payloads<0>(bid)); + }); + } + + [[nodiscard]] auto bigram_payloads_1(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return m_payload_reader.read(fetch_bigram_payloads<1>(bid)); + }); + } + + [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const + { + return bigram_id(left_term, right_term).map([this](auto bid) { + return DocumentPayloadCursor>( + m_document_reader.read(fetch_bigram_documents(bid)), + zip(m_payload_reader.read(fetch_bigram_payloads<0>(bid)), + m_payload_reader.read(fetch_bigram_payloads<1>(bid)))); + }); } /// Constructs a new document-score cursor. @@ -230,13 +255,17 @@ struct Index { TermId right_term, Scorer&& scorer) const { - return ScoringCursor( - bigram_cursor(left_term, right_term), - [scorers = - std::make_tuple(scorer.term_scorer(left_term), scorer.term_scorer(right_term))]( - auto&& docid, auto&& payload) { - return std::array{std::get<0>(scorers)(docid, std::get<0>(payload)), - std::get<1>(scorers)(docid, std::get<1>(payload))}; + return bigram_cursor(left_term, right_term) + .take() + .map([scorer = std::forward(scorer), left_term, right_term](auto cursor) { + return ScoringCursor(cursor, + [scorers = std::make_tuple(scorer.term_scorer(left_term), + scorer.term_scorer(right_term))]( + auto&& docid, auto&& payload) { + return std::array{ + std::get<0>(scorers)(docid, std::get<0>(payload)), + std::get<1>(scorers)(docid, std::get<1>(payload))}; + }); }); } @@ -255,12 +284,14 @@ struct Index { /// Constructs a new document cursor. [[nodiscard]] auto documents(TermId term) const { + assert_term_in_bounds(term); return m_document_reader.read(fetch_documents(term)); } /// Constructs a new payload cursor. [[nodiscard]] auto payloads(TermId term) const { + assert_term_in_bounds(term); return m_payload_reader.read(fetch_payloads(term)); } @@ -288,6 +319,13 @@ struct Index { } private: + void assert_term_in_bounds(TermId term) const + { + if (term >= num_terms()) { + std::invalid_argument( + fmt::format("Requested term ID out of bounds [0-{}): {}", num_terms(), term)); + } + } [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span { Expects(term + 1 < m_document_offsets.size()); diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 9ee6d3ffc..373ac6635 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -11,6 +11,7 @@ #include "v1/index.hpp" #include "v1/index_metadata.hpp" #include "v1/progress_status.hpp" +#include "v1/query.hpp" namespace pisa::v1 { @@ -224,4 +225,11 @@ inline void compress_binary_collection(std::string const& input, auto verify_compressed_index(std::string const& input, std::string_view output) -> std::vector; +auto collect_unique_bigrams(std::vector const& queries, + std::function const& callback) + -> std::vector>; + +void build_bigram_index(std::string const& yml, + std::vector> const& bigrams); + } // namespace pisa::v1 diff --git a/include/pisa/v1/intersection.hpp b/include/pisa/v1/intersection.hpp new file mode 100644 index 000000000..c44a57989 --- /dev/null +++ b/include/pisa/v1/intersection.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include + +namespace pisa::v1 { + +/// Read a list of intersections. +/// +/// Each line in the format relates to one query, and each space-separated value +/// is an integer intersection representation. These numbers are converted to +/// bitsets, and each 1 at position `i` means that the `i`-th term in the query +/// is present in the intersection. +/// +/// # Example +/// +/// Let `q = a b c d e` be our query. The following line: +/// ``` +/// 1 2 5 16 +/// ``` +/// can be represented as bitsets: +/// ``` +/// 00001 00010 00101 10000 +/// ``` +/// which in turn represent four intersection: a, b, ac, e. +[[nodiscard]] auto read_intersections(std::string const& filename) + -> std::vector>>; +[[nodiscard]] auto read_intersections(std::istream& is) + -> std::vector>>; + +/// Converts a bitset to a vector of positions set to 1. +[[nodiscard]] auto to_vector(std::bitset<64> const& bits) -> std::vector; + +/// Returns a lambda taking a bitset and returning `true` if it has `n` set bits. +[[nodiscard]] inline auto is_n_gram(std::size_t n) +{ + return [n](std::bitset<64> const& bits) -> bool { return bits.count() == n; }; +} + +/// Returns only positions of terms in unigrams. +[[nodiscard]] auto filter_unigrams(std::vector>> const& intersections) + -> std::vector>; + +/// Returns only positions of terms in bigrams. +[[nodiscard]] auto filter_bigrams(std::vector>> const& intersections) + -> std::vector>>; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index b84365fe5..cd541749f 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -63,6 +63,10 @@ struct MaxScoreJoin { void initialize() { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } std::transform(m_cursors.begin(), m_cursors.end(), m_sorted_cursors.begin(), @@ -210,6 +214,9 @@ auto join_maxscore(CursorContainer cursors, template auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { + if (query.terms.empty()) { + return topk; + } using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using value_type = decltype(index.max_scored_cursor(0, scorer).value()); @@ -242,6 +249,9 @@ struct MaxscoreAnalyzer { void operator()(Query const& query) { + if (query.terms.empty()) { + return; + } using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 52ebb616f..a5507e39d 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -13,13 +14,14 @@ #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/cursor_union.hpp" +#include "v1/intersection.hpp" #include "v1/types.hpp" namespace pisa::v1 { struct ListSelection { - std::vector unigrams{}; - std::vector> bigrams{}; + std::vector unigrams{}; + std::vector> bigrams{}; }; struct Query { @@ -28,6 +30,12 @@ struct Query { tl::optional threshold{}; tl::optional id{}; int k{}; + + void add_selections(gsl::span const> selections); + void remove_duplicates(); + + private: + auto resolve_term(std::size_t pos) -> TermId; }; template diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index e7e99a86c..986c76db9 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -2,40 +2,72 @@ #include +#include +#include #include +#include +#include "v1/algorithm.hpp" #include "v1/query.hpp" namespace pisa::v1 { -template -void partition_by_essential(gsl::span cursors, gsl::span essential_indices) -{ - if (essential_indices.empty()) { - return; - } - std::sort(essential_indices.begin(), essential_indices.end()); - if (essential_indices[essential_indices.size() - 1] >= cursors.size()) { - throw std::logic_error("Essential index too large"); - } - auto left = 0; - auto right = cursors.size() - 1; - auto eidx = 0; - while (left < right && eidx < essential_indices.size()) { - if (left < essential_indices[eidx]) { - left += 1; - } else { - std::swap(cursors[left], cursors[right]); - right -= 1; - eidx += 1; - } +namespace detail { + template + auto unigram_union_lookup(Cursors cursors, + UpperBounds upper_bounds, + std::size_t non_essential_count, + topk_queue topk, + Analyzer* analyzer = nullptr) + { + using payload_type = decltype(std::declval().payload()); + + auto merged_essential = + v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), + 0.0F, + [&](auto acc, auto& cursor, auto /*term_idx*/) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + return acc + cursor.payload(); + }); + + auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); + v1::for_each(merged_essential, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + analyzer->document(); + } + auto docid = cursor.value(); + auto score = cursor.payload(); + for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; + lookup_cursor_idx -= 1) { + if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { + return; + } + cursors[lookup_cursor_idx].advance_to_geq(docid); + if constexpr (not std::is_void_v) { + analyzer->lookup(); + } + if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { + score += cursors[lookup_cursor_idx].payload(); + } + } + if constexpr (not std::is_void_v) { + analyzer->insert(); + } + topk.insert(score, docid); + }); + return topk; } -} +} // namespace detail template auto unigram_union_lookup( Query query, Index const& index, topk_queue topk, Scorer&& scorer, Analyzer* analyzer = nullptr) { + if (query.terms.empty()) { + return topk; + } if (not query.threshold) { throw std::invalid_argument("Must provide threshold to the query"); } @@ -51,54 +83,35 @@ auto unigram_union_lookup( using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using payload_type = decltype(std::declval().payload()); - auto cursors = index.max_scored_cursors(gsl::make_span(query.terms), scorer); - partition_by_essential(gsl::make_span(cursors), gsl::make_span(query.list_selection->unigrams)); - auto non_essential_count = cursors.size() - query.list_selection->unigrams.size(); - std::sort(cursors.begin(), - std::next(cursors.begin(), non_essential_count), - [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); + ranges::sort(query.terms); + ranges::sort(query.list_selection->unigrams); - std::vector upper_bounds(non_essential_count); - upper_bounds[0] = cursors[0].max_score(); - for (size_t idx = 1; idx < non_essential_count; idx += 1) { - upper_bounds[idx] = upper_bounds[idx - 1] + cursors[idx].max_score(); - } + auto non_essential_terms = + ranges::views::set_difference(query.terms, query.list_selection->unigrams) + | ranges::to_vector; - auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), - 0.0F, - [&](auto acc, auto& cursor, auto /*term_idx*/) { - if constexpr (not std::is_void_v) { - analyzer->posting(); - } - return acc + cursor.payload(); - }); - - auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); - v1::for_each(merged_essential, [&](auto& cursor) { - if constexpr (not std::is_void_v) { - analyzer->document(); - } - auto docid = cursor.value(); - auto score = cursor.payload(); - for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; - lookup_cursor_idx -= 1) { - if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { - return; - } - cursors[lookup_cursor_idx].advance_to_geq(docid); - if constexpr (not std::is_void_v) { - analyzer->lookup(); - } - if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { - score += cursors[lookup_cursor_idx].payload(); - } - } - if constexpr (not std::is_void_v) { - analyzer->insert(); - } - topk.insert(score, docid); + std::vector cursors; + for (auto non_essential_term : non_essential_terms) { + cursors.push_back(index.max_scored_cursor(non_essential_term, scorer)); + } + auto non_essential_count = cursors.size(); + std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() < rhs.max_score(); }); - return topk; + for (auto essential_term : query.list_selection->unigrams) { + cursors.push_back(index.max_scored_cursor(essential_term, scorer)); + } + + std::vector upper_bounds(cursors.size()); + upper_bounds[0] = cursors[0].max_score(); + for (size_t i = 1; i < cursors.size(); ++i) { + upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); + } + return detail::unigram_union_lookup(std::move(cursors), + std::move(upper_bounds), + non_essential_count, + std::move(topk), + analyzer); } template @@ -108,6 +121,9 @@ auto maxscore_union_lookup(Query const& query, Scorer&& scorer, Analyzer* analyzer = nullptr) { + if (query.terms.empty()) { + return topk; + } if (not query.threshold) { throw std::invalid_argument("Must provide threshold to the query"); } @@ -132,58 +148,16 @@ auto maxscore_union_lookup(Query const& query, && upper_bounds[non_essential_count] < *query.threshold) { non_essential_count += 1; } - - std::vector unigrams(cursors.size() - non_essential_count); - std::iota(unigrams.begin(), unigrams.end(), non_essential_count); - Query query_with_selections = query; - query_with_selections.list_selection = - tl::make_optional(ListSelection{.unigrams = {}, .bigrams = {}}); - return unigram_union_lookup(std::move(query_with_selections), - index, - std::move(topk), - std::forward(scorer), - analyzer); - - // auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), - // 0.0F, - // [&](auto acc, auto& cursor, auto /*term_idx*/) { - // if constexpr (not std::is_void_v) { - // analyzer->posting(); - // } - // return acc + cursor.payload(); - // }); - - // auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); - // v1::for_each(merged_essential, [&](auto& cursor) { - // if constexpr (not std::is_void_v) { - // analyzer->document(); - // } - // auto docid = cursor.value(); - // auto score = cursor.payload(); - // for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; - // lookup_cursor_idx -= 1) { - // if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { - // return; - // } - // cursors[lookup_cursor_idx].advance_to_geq(docid); - // if constexpr (not std::is_void_v) { - // analyzer->lookup(); - // } - // if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { - // score += cursors[lookup_cursor_idx].payload(); - // } - // } - // if constexpr (not std::is_void_v) { - // analyzer->insert(); - // } - // topk.insert(score, docid); - //}); - // return topk; + return detail::unigram_union_lookup(std::move(cursors), + std::move(upper_bounds), + non_essential_count, + std::move(topk), + analyzer); } template -struct MaxscoreUnionLookupAnalyzer { - MaxscoreUnionLookupAnalyzer(Index const& index, Scorer scorer) +struct BaseUnionLookupAnalyzer { + BaseUnionLookupAnalyzer(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) { std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); @@ -197,14 +171,18 @@ struct MaxscoreUnionLookupAnalyzer { m_current_inserts = 0; } + virtual void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) = 0; + void operator()(Query const& query) { + if (query.terms.empty()) { + return; + } using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); reset_current(); - topk_queue topk(query.k); - maxscore_union_lookup(query, m_index, topk, m_scorer, this); + run(query, m_index, m_scorer, topk_queue(query.k)); std::cout << fmt::format("{}\t{}\t{}\t{}\n", m_current_documents, m_current_postings, @@ -248,6 +226,42 @@ struct MaxscoreUnionLookupAnalyzer { Scorer m_scorer; }; +template +struct MaxscoreUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { + MaxscoreUnionLookupAnalyzer(Index const& index, Scorer scorer) + : BaseUnionLookupAnalyzer(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +template +struct UnigramUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { + UnigramUnionLookupAnalyzer(Index const& index, Scorer scorer) + : BaseUnionLookupAnalyzer(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +template +struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { + UnionLookupAnalyzer(Index const& index, Scorer scorer) + : BaseUnionLookupAnalyzer(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + union_lookup(query, index, std::move(topk), scorer, this); + } +}; + /// Performs a "union-lookup" query (name pending). /// /// \param query Full query, as received, possibly with duplicates. @@ -261,63 +275,114 @@ struct MaxscoreUnionLookupAnalyzer { /// \param essential_bigrams Similar to the above, but represents intersections between two /// posting lists. These must exist in the index, or else this /// algorithm will fail. -template -auto union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - std::vector essential_unigrams, - std::vector> essential_bigrams) +template +auto union_lookup( + Query query, Index const& index, topk_queue topk, Scorer&& scorer, Analyzer* analyzer = nullptr) { + if (query.terms.empty()) { + return topk; + } + if (query.terms.size() > 8) { + throw std::invalid_argument( + "Generic version of union-Lookup supported only for queries of length <= 8"); + } + if (not query.threshold) { + throw std::invalid_argument("Must provide threshold to the query"); + } + if (not query.list_selection) { + throw std::invalid_argument("Must provide essential list selection"); + } + + using bigram_cursor_type = std::decay_t; + + auto& essential_unigrams = query.list_selection->unigrams; + auto& essential_bigrams = query.list_selection->bigrams; + ranges::sort(essential_unigrams); ranges::actions::unique(essential_unigrams); ranges::sort(essential_bigrams); ranges::actions::unique(essential_bigrams); + ranges::sort(query.terms); - topk.set_threshold(*query.threshold); + auto non_essential_terms = + ranges::views::set_difference(query.terms, essential_unigrams) | ranges::to_vector; - std::vector is_essential(query.terms.size(), false); - for (auto idx : essential_unigrams) { - is_essential[idx] = true; - } + topk.set_threshold(*query.threshold); - std::vector initial_payload(query.terms.size(), 0.0); + std::array initial_payload{ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; std::vector essential_unigram_cursors; std::transform(essential_unigrams.begin(), essential_unigrams.end(), std::back_inserter(essential_unigram_cursors), - [&](auto idx) { return index.scored_cursor(query.terms[idx], scorer); }); + [&](auto term) { return index.scored_cursor(term, scorer); }); + + std::vector essential_unigram_positions; + [&]() { + auto pos = query.terms.begin(); + for (auto term : essential_unigrams) { + pos = std::find(pos, query.terms.end(), term); + assert(pos != query.terms.end()); + essential_unigram_positions.push_back(std::distance(query.terms.begin(), pos)); + } + }(); auto merged_unigrams = v1::union_merge( essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { - acc[essential_unigrams[term_idx]] = cursor.payload(); + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + acc[essential_unigram_positions[term_idx]] = cursor.payload(); return acc; }); - std::vector essential_bigram_cursors; - std::transform(essential_bigrams.begin(), - essential_bigrams.end(), - std::back_inserter(essential_bigram_cursors), - [&](auto intersection) { - return index.scored_bigram_cursor(query.terms[intersection.first], - query.terms[intersection.second], - scorer); - }); - auto merged_bigrams = - v1::union_merge(std::move(essential_bigram_cursors), - initial_payload, - [&](auto& acc, auto& cursor, auto term_idx) { - auto payload = cursor.payload(); - acc[essential_bigrams[term_idx].first] = std::get<0>(payload); - acc[essential_bigrams[term_idx].second] = std::get<1>(payload); - return acc; - }); + std::vector essential_bigram_cursors; + std::vector> essential_bigram_positions; + for (auto [left, right] : essential_bigrams) { + if (left > right) { + std::swap(left, right); + } + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back(cursor.take().value()); + std::pair bigram_positions; + if (auto pos = std::lower_bound(query.terms.begin(), query.terms.end(), left); + pos != query.terms.end()) { + bigram_positions.first = std::distance(query.terms.begin(), pos); + } else { + throw std::logic_error("Term from selected intersection not part of query"); + } + if (auto pos = std::lower_bound(query.terms.begin(), query.terms.end(), right); + pos != query.terms.end()) { + bigram_positions.second = std::distance(query.terms.begin(), pos); + } else { + throw std::logic_error("Term from selected intersection not part of query"); + } + essential_bigram_positions.push_back(bigram_positions); + } + + auto merged_bigrams = v1::union_merge( + std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto bigram_idx) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + auto payload = cursor.payload(); + acc[essential_bigram_positions[bigram_idx].first] = std::get<0>(payload); + acc[essential_bigram_positions[bigram_idx].second] = std::get<1>(payload); + return acc; + }); auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { auto payload = cursor.payload(); for (auto idx = 0; idx < acc.size(); idx += 1) { - acc[idx] = payload[idx]; + if (acc[idx] == 0.0F) { + acc[idx] = payload[idx]; + } } return acc; }; @@ -326,38 +391,55 @@ auto union_lookup(Query const& query, std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), std::make_tuple(accumulate, accumulate)); - std::vector lookup_cursors; - for (auto idx = 0; idx < query.terms.size(); idx += 1) { - if (not is_essential[idx]) { - lookup_cursors.push_back(index.max_scored_cursor(query.terms[idx], scorer)); + auto lookup_cursors = [&]() { + std::vector> + lookup_cursors; + auto pos = query.terms.begin(); + for (auto non_essential_term : non_essential_terms) { + pos = std::find(pos, query.terms.end(), non_essential_term); + assert(pos != query.terms.end()); + auto idx = std::distance(query.terms.begin(), pos); + lookup_cursors.emplace_back(idx, index.max_scored_cursor(non_essential_term, scorer)); } - } + return lookup_cursors; + }(); std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs.max_score() > rhs.max_score(); + return lhs.second.max_score() > rhs.second.max_score(); }); auto lookup_cursors_upper_bound = std::accumulate( lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { - return acc + cursor.max_score(); + return acc + cursor.second.max_score(); }); v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + analyzer->document(); + } auto docid = cursor.value(); auto scores = cursor.payload(); auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); auto upper_bound = score + lookup_cursors_upper_bound; - for (auto lookup_cursor : lookup_cursors) { + for (auto& [idx, lookup_cursor] : lookup_cursors) { if (not topk.would_enter(upper_bound)) { return; } - lookup_cursor.advance_to_geq(docid); - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - auto partial_score = lookup_cursor.payload(); - score += partial_score; - upper_bound += partial_score; + if (scores[idx] == 0) { + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + analyzer->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } } upper_bound -= lookup_cursor.max_score(); } topk.insert(score, docid); + if constexpr (not std::is_void_v) { + analyzer->insert(); + } }); return topk; } diff --git a/include/pisa/v1/zip_cursor.hpp b/include/pisa/v1/zip_cursor.hpp index fdc66afaa..2fe4e5241 100644 --- a/include/pisa/v1/zip_cursor.hpp +++ b/include/pisa/v1/zip_cursor.hpp @@ -10,7 +10,11 @@ template struct ZipCursor { using Value = std::tuple())...>; - explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors)...) {} + explicit constexpr ZipCursor(Cursors... cursors) : m_cursors(std::move(cursors)...) + { + static_assert(std::tuple_size_v> == 2, + "Zip of more than two lists is currently not supported"); + } [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Value @@ -20,15 +24,24 @@ struct ZipCursor { } constexpr void advance() { - auto advance_all = [](auto... cursors) { (cursors.advance(), ...); }; - std::apply(advance_all, m_cursors); + // TODO: Why generic doesn't work? + // auto advance_all = [](auto... cursors) { (cursors.advance(), ...); }; + // std::apply(advance_all, m_cursors); + std::get<0>(m_cursors).advance(); + std::get<1>(m_cursors).advance(); } constexpr void advance_to_position(std::size_t pos) { - auto advance_all = [pos](auto... cursors) { (cursors.advance_to_position(pos), ...); }; - std::apply(advance_all, m_cursors); + // TODO: Why generic doesn't work? + // auto advance_all = [pos](auto... cursors) { (cursors.advance_to_position(pos), ...); }; + // std::apply(advance_all, m_cursors); + std::get<0>(m_cursors).advance(); + std::get<1>(m_cursors).advance(); + } + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return std::get<0>(m_cursors).empty() || std::get<1>(m_cursors).empty(); } - //[[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_key_cursor.empty(); } [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return std::get<0>(m_cursors).position(); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5ea8e7e49..81ab1d268 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,93 +1,93 @@ -add_executable(create_freq_index create_freq_index.cpp) -target_link_libraries(create_freq_index - pisa - CLI11 -) - -add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) -target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) -target_link_libraries(optimal_hybrid_index - ${STXXL_LIBRARIES} - pisa -) -set_target_properties(optimal_hybrid_index PROPERTIES - CXX_STANDARD 14 -) - -add_executable(create_wand_data create_wand_data.cpp) -target_link_libraries(create_wand_data - pisa - CLI11 -) - -add_executable(queries queries.cpp) -target_link_libraries(queries - pisa - CLI11 -) - -add_executable(evaluate_queries evaluate_queries.cpp) -target_link_libraries(evaluate_queries - pisa - CLI11 -) - -add_executable(thresholds thresholds.cpp) -target_link_libraries(thresholds - pisa - CLI11 -) - -add_executable(profile_queries profile_queries.cpp) -target_link_libraries(profile_queries - pisa -) - -add_executable(profile_decoding profile_decoding.cpp) -target_link_libraries(profile_decoding - pisa -) - -add_executable(shuffle_docids shuffle_docids.cpp) -target_link_libraries(shuffle_docids - pisa -) - -add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) -target_link_libraries(recursive_graph_bisection - pisa - CLI11 -) - -add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) -target_link_libraries(evaluate_collection_ordering - pisa - ) - -add_executable(parse_collection parse_collection.cpp) -target_link_libraries(parse_collection - pisa - CLI11 - wapopp -) - -add_executable(invert invert.cpp) -target_link_libraries(invert - CLI11 - pisa -) - -add_executable(read_collection read_collection.cpp) -target_link_libraries(read_collection - pisa - CLI11 -) - -add_executable(partition_fwd_index partition_fwd_index.cpp) -target_link_libraries(partition_fwd_index - pisa - CLI11 -) +#add_executable(create_freq_index create_freq_index.cpp) +#target_link_libraries(create_freq_index +# pisa +# CLI11 +#) +# +#add_executable(optimal_hybrid_index optimal_hybrid_index.cpp) +#target_include_directories(optimal_hybrid_index PRIVATE ${STXXL_INCLUDE_DIRS}) +#target_link_libraries(optimal_hybrid_index +# ${STXXL_LIBRARIES} +# pisa +#) +#set_target_properties(optimal_hybrid_index PROPERTIES +# CXX_STANDARD 14 +#) +# +#add_executable(create_wand_data create_wand_data.cpp) +#target_link_libraries(create_wand_data +# pisa +# CLI11 +#) +# +#add_executable(queries queries.cpp) +#target_link_libraries(queries +# pisa +# CLI11 +#) +# +#add_executable(evaluate_queries evaluate_queries.cpp) +#target_link_libraries(evaluate_queries +# pisa +# CLI11 +#) +# +#add_executable(thresholds thresholds.cpp) +#target_link_libraries(thresholds +# pisa +# CLI11 +#) +# +#add_executable(profile_queries profile_queries.cpp) +#target_link_libraries(profile_queries +# pisa +#) +# +#add_executable(profile_decoding profile_decoding.cpp) +#target_link_libraries(profile_decoding +# pisa +#) +# +#add_executable(shuffle_docids shuffle_docids.cpp) +#target_link_libraries(shuffle_docids +# pisa +#) +# +#add_executable(recursive_graph_bisection recursive_graph_bisection.cpp) +#target_link_libraries(recursive_graph_bisection +# pisa +# CLI11 +#) +# +#add_executable(evaluate_collection_ordering evaluate_collection_ordering.cpp) +#target_link_libraries(evaluate_collection_ordering +# pisa +# ) +# +#add_executable(parse_collection parse_collection.cpp) +#target_link_libraries(parse_collection +# pisa +# CLI11 +# wapopp +#) +# +#add_executable(invert invert.cpp) +#target_link_libraries(invert +# CLI11 +# pisa +#) +# +#add_executable(read_collection read_collection.cpp) +#target_link_libraries(read_collection +# pisa +# CLI11 +#) +# +#add_executable(partition_fwd_index partition_fwd_index.cpp) +#target_link_libraries(partition_fwd_index +# pisa +# CLI11 +#) add_executable(compute_intersection compute_intersection.cpp) target_link_libraries(compute_intersection @@ -95,20 +95,20 @@ target_link_libraries(compute_intersection CLI11 ) -add_executable(lexicon lexicon.cpp) -target_link_libraries(lexicon - pisa - CLI11 -) - -add_executable(extract_topics extract_topics.cpp) -target_link_libraries(extract_topics - pisa - CLI11 -) - -add_executable(sample_inverted_index sample_inverted_index.cpp) -target_link_libraries(sample_inverted_index - pisa - CLI11 -) +#add_executable(lexicon lexicon.cpp) +#target_link_libraries(lexicon +# pisa +# CLI11 +#) +# +#add_executable(extract_topics extract_topics.cpp) +#target_link_libraries(extract_topics +# pisa +# CLI11 +#) +# +#add_executable(sample_inverted_index sample_inverted_index.cpp) +#target_link_libraries(sample_inverted_index +# pisa +# CLI11 +#) diff --git a/src/compute_intersection.cpp b/src/compute_intersection.cpp index 2f7a23501..d137184bc 100644 --- a/src/compute_intersection.cpp +++ b/src/compute_intersection.cpp @@ -18,10 +18,10 @@ using pisa::intersection::IntersectionType; using pisa::intersection::Mask; template -void intersect(std::string const &index_filename, - std::optional const &wand_data_filename, - std::vector const &queries, - std::string const &type, +void intersect(std::string const& index_filename, + std::optional const& wand_data_filename, + std::vector const& queries, + std::string const& type, IntersectionType intersection_type, std::optional max_term_count = std::nullopt) { @@ -44,16 +44,18 @@ void intersect(std::string const &index_filename, std::size_t qid = 0u; - auto print_intersection = [&](auto const &query, auto const &mask) { + auto print_intersection = [&](auto const& query, auto const& mask) { auto intersection = Intersection::compute(index, wdata, query, mask); - std::cout << fmt::format("{}\t{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), - mask.to_ulong(), - intersection.length, - intersection.max_score); + if (intersection.length > 0) { + std::cout << fmt::format("{}\t{}\t{}\t{}\n", + query.id ? *query.id : std::to_string(qid), + mask.to_ulong(), + intersection.length, + intersection.max_score); + } }; - for (auto const &query : queries) { + for (auto const& query : queries) { if (intersection_type == IntersectionType::Combinations) { for_all_subsets(query, max_term_count, print_intersection); } else { @@ -70,7 +72,7 @@ void intersect(std::string const &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { std::string type; std::string index_filename; @@ -92,9 +94,9 @@ int main(int argc, const char **argv) app.add_option("-w,--wand", wand_data_filename, "Wand data filename"); app.add_option("-q,--query", query_filename, "Queries filename"); app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); - auto *terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); + auto* terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); - auto *combinations_flag = app.add_flag( + auto* combinations_flag = app.add_flag( "--combinations", combinations, "Compute intersections for combinations of terms in query"); app.add_option("--max-term-count,--mtc", max_term_count, diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index 3bb1d051c..757062560 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -1,10 +1,34 @@ #include "v1/index_builder.hpp" #include "codec/simdbp.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/query.hpp" #include "v1/raw_cursor.hpp" namespace pisa::v1 { +auto collect_unique_bigrams(std::vector const& queries, + std::function const& callback) + -> std::vector> +{ + std::vector> bigrams; + auto idx = 0; + for (auto query : queries) { + if (query.terms.empty()) { + continue; + } + callback(); + std::sort(query.terms.begin(), query.terms.end()); + for (auto left = 0; left < query.terms.size() - 1; left += 1) { + for (auto right = left + 1; right < query.terms.size(); right += 1) { + bigrams.emplace_back(query.terms[left], query.terms[right]); + } + } + } + std::sort(bigrams.begin(), bigrams.end()); + bigrams.erase(std::unique(bigrams.begin(), bigrams.end()), bigrams.end()); + return bigrams; +} + auto verify_compressed_index(std::string const& input, std::string_view output) -> std::vector { @@ -54,4 +78,86 @@ auto verify_compressed_index(std::string const& input, std::string_view output) return errors; } +void build_bigram_index(std::string const& yml, + std::vector> const& bigrams) +{ + Expects(not bigrams.empty()); + auto index_basename = yml.substr(0, yml.size() - 4); + auto meta = IndexMetadata::from_file(yml); + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + + std::vector> pair_mapping; + auto documents_file = fmt::format("{}.bigram_documents", index_basename); + auto frequencies_file_0 = fmt::format("{}.bigram_frequencies_0", index_basename); + auto frequencies_file_1 = fmt::format("{}.bigram_frequencies_1", index_basename); + auto document_offsets_file = fmt::format("{}.bigram_document_offsets", index_basename); + auto frequency_offsets_file_0 = fmt::format("{}.bigram_frequency_offsets_0", index_basename); + auto frequency_offsets_file_1 = fmt::format("{}.bigram_frequency_offsets_1", index_basename); + std::ofstream document_out(documents_file); + std::ofstream frequency_out_0(frequencies_file_0); + std::ofstream frequency_out_1(frequencies_file_1); + + run([&](auto&& index) { + ProgressStatus status(bigrams.size(), + DefaultProgress("Building bigram index"), + std::chrono::milliseconds(100)); + using index_type = std::decay_t; + using document_writer_type = + typename CursorTraits::Writer; + using frequency_writer_type = + typename CursorTraits::Writer; + + PostingBuilder document_builder(document_writer_type{}); + PostingBuilder frequency_builder_0(frequency_writer_type{}); + PostingBuilder frequency_builder_1(frequency_writer_type{}); + + document_builder.write_header(document_out); + frequency_builder_0.write_header(frequency_out_0); + frequency_builder_1.write_header(frequency_out_1); + + for (auto [left_term, right_term] : bigrams) { + auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + payload[list_idx] = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + pair_mapping.push_back({left_term, right_term}); + for_each(intersection, [&](auto& cursor) { + document_builder.accumulate(*cursor); + auto payload = cursor.payload(); + frequency_builder_0.accumulate(std::get<0>(payload)); + frequency_builder_1.accumulate(std::get<1>(payload)); + }); + document_builder.flush_segment(document_out); + frequency_builder_0.flush_segment(frequency_out_0); + frequency_builder_1.flush_segment(frequency_out_1); + status += 1; + } + std::cerr << "Writing offsets..."; + write_span(gsl::make_span(document_builder.offsets()), document_offsets_file); + write_span(gsl::make_span(frequency_builder_0.offsets()), frequency_offsets_file_0); + write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); + std::cerr << " Done.\n"; + }); + std::cerr << "Writing metadata..."; + meta.bigrams = BigramMetadata{ + .documents = {.postings = documents_file, .offsets = document_offsets_file}, + .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, + {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, + .mapping = fmt::format("{}.bigram_mapping", index_basename), + .count = pair_mapping.size()}; + meta.write(yml); + std::cerr << " Done.\nWriting bigram mapping..."; + write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); + std::cerr << " Done.\n"; +} + } // namespace pisa::v1 diff --git a/src/v1/intersection.cpp b/src/v1/intersection.cpp new file mode 100644 index 000000000..d3c3b5d38 --- /dev/null +++ b/src/v1/intersection.cpp @@ -0,0 +1,77 @@ +#include +#include + +#include +#include +#include + +#include "io.hpp" +#include "v1/intersection.hpp" + +namespace pisa::v1 { + +auto read_intersections(std::string const& filename) -> std::vector>> +{ + std::ifstream is(filename); + return read_intersections(is); +} + +auto read_intersections(std::istream& is) -> std::vector>> +{ + std::vector>> intersections; + ::pisa::io::for_each_line(is, [&](auto const& query_line) { + intersections.emplace_back(); + std::istringstream iss(query_line); + std::transform( + std::istream_iterator(iss), + std::istream_iterator(), + std::back_inserter(intersections.back()), + [&](auto const& n) { + auto bits = std::bitset<64>(std::stoul(n)); + if (bits.count() > 2) { + spdlog::error("Intersections of more than 2 terms not supported yet!"); + std::exit(1); + } + return bits; + }); + }); + return intersections; +} + +[[nodiscard]] auto to_vector(std::bitset<64> const& bits) -> std::vector +{ + std::vector vec; + for (auto idx = 0; idx < bits.size(); idx += 1) { + if (bits.test(idx)) { + vec.push_back(idx); + } + } + return vec; +} + +[[nodiscard]] auto filter_unigrams(std::vector>> const& intersections) + -> std::vector> +{ + return intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(1)) + | ranges::views::transform([&](auto bits) { return to_vector(bits)[0]; }) + | ranges::to_vector; + }) + | ranges::to_vector; +} + +[[nodiscard]] auto filter_bigrams(std::vector>> const& intersections) + -> std::vector>> +{ + return intersections | ranges::views::transform([&](auto&& query_intersections) { + return query_intersections | ranges::views::filter(is_n_gram(2)) + | ranges::views::transform([&](auto bits) { + auto vec = to_vector(bits); + return std::make_pair(vec[0], vec[1]); + }) + | ranges::to_vector; + }) + | ranges::to_vector; +} + +} // namespace pisa::v1 diff --git a/src/v1/query.cpp b/src/v1/query.cpp index a9c68f2c4..ec85b0fa0 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -1,3 +1,5 @@ +#include + #include "v1/query.hpp" namespace pisa::v1 { @@ -10,4 +12,41 @@ namespace pisa::v1 { return terms; } +void Query::add_selections(gsl::span const> selections) +{ + list_selection = ListSelection{}; + for (auto intersection : selections) { + if (intersection.count() > 2) { + throw std::invalid_argument("Intersections of more than 2 terms not supported yet!"); + } + auto positions = to_vector(intersection); + if (positions.size() == 1) { + list_selection->unigrams.push_back(resolve_term(positions.front())); + } else { + list_selection->bigrams.emplace_back(resolve_term(positions[0]), + resolve_term(positions[1])); + } + } +} + +void Query::remove_duplicates() +{ + ranges::sort(terms); + ranges::actions::unique(terms); + if (list_selection) { + ranges::sort(list_selection->unigrams); + ranges::actions::unique(list_selection->unigrams); + ranges::sort(list_selection->bigrams); + ranges::actions::unique(list_selection->bigrams); + } +} + +auto Query::resolve_term(std::size_t pos) -> TermId +{ + if (pos >= terms.size()) { + throw std::out_of_range("Invalid intersections: term position out of bounds"); + } + return terms[pos]; +} + } // namespace pisa::v1 diff --git a/test/test_data/top10_selections b/test/test_data/top10_selections new file mode 100644 index 000000000..2c9284012 --- /dev/null +++ b/test/test_data/top10_selections @@ -0,0 +1,500 @@ +1 +1 +10 18 24 +1 2 4 +1 +1 +2 +16 32 192 256 +1 6 +2 +1 +1 2 +3 +1 6 +4 8 +1 +1 2 +1 +2 +4 +1 +1 +1 +1 +3 +1 2 +9 +2 +1 2 +1 12 +3 9 12 +5 9 12 33 40 +6 +1 6 +2 +1 +1 2 +1 +2 +2 9 40 72 96 136 160 192 +12 +3 5 6 +1 +1 2 +3 5 6 +2 5 +1 2 4 +1 2 4 +1 8 +1 6 +1 2 +2 +1 +1 10 66 72 +1 +4 +1 6 +1 +3 5 6 17 18 20 +2 +5 8 17 20 +1 2 +3 5 16 64 129 132 640 +1 2 +1 2 +1 2 4 +9 16 33 65 72 96 129 136 160 192 +5 17 20 32 +2 4 +5 6 9 +3 5 6 +5 9 12 +1 2 +1 2 +1 2 +1 +9 64 +2 5 +1 2 +8 32 +1 2 +2 4 +2 +1 4 +1 2 +1 32 64 +1 +1 4 +4 9 17 18 24 +8 +5 9 12 17 20 24 +1 2 +3 4 9 10 18 +1 10 34 40 +2 4 +10 +1 4 +1 6 +1 +3 6 +2 +1 2 +12 32 68 72 +2 4 +4 +1 +1 2 +36 64 +2 +1 8 +1 2 +3 +1 6 +1 2 +1 2 +5 +3 9 10 17 18 24 129 130 136 144 +4 10 +1 2 4 +3 5 12 +1 2 +1 10 18 24 +3 6 9 +1 2 4 +1 2 +1 2 +2 4 8 +4 64 136 +1 6 +1 +1 2 +2 9 17 +1 6 +4 +3 4 +2 4 33 +3 9 17 18 +2 +2 5 +2 +1 4 +2 9 +3 4 9 10 +9 +3 4 +4 8 +1 2 +3 6 +1 +1 2 +2 9 +1 +3 +1 12 +1 2 +3 5 +6 10 12 +1 2 +3 5 6 +5 6 +1 +1 4 +1 2 +5 6 8 +1 64 128 +1 +3 6 +1 2 12 +1 4 +2 8 +1 +1 2 +1 8 16 +1 +1 2 +1 2 +2 +5 +3 +2 +1 +1 +1 2 +4 16 +16 32 +1 12 +4 9 10 18 24 34 40 48 +1 12 +1 2 +3 +1 2 +1 8 +1 2 +1 +20 +1 +1 2 +1 2 +1 +1 2 +1 2 4 +1 2 +2 5 +3 5 6 +1 4 +1 6 +2 9 20 24 +1 10 18 24 34 40 48 +3 4 +1 2 +1 4 32 +1 2 +1 2 +2 +3 5 +10 12 +1 +1 2 +1 2 +3 +1 +2 4 40 +1 6 +1 6 +1 2 +3 9 10 +3 4 9 +3 5 6 16 +4 +3 4 +6 +1 2 +1 +3 5 6 +1 +3 5 6 33 34 36 +1 2 +1 +6 +3 4 9 10 33 34 40 +1 +1 +4 9 10 +1 +2 +1 2 +1 2 4 +2 +6 10 12 18 +1 +5 9 12 +2 +1 +2 +1 +1 +1 +5 +2 +2 4 +1 2 4 +1 +1 2 +48 +1 6 +3 +1 +1 2 +1 +1 2 +1 2 4 +1 +5 16 +2 +3 +3 5 6 8 +2 +1 2 +3 +1 +1 +3 +20 +1 8 20 36 +3 9 10 17 24 +3 4 +4 8 +1 2 4 +3 +1 8 +3 9 10 16 +4 +1 2 +1 2 +1 2 +3 +5 9 +3 8 +1 4 +1 2 12 +1 2 4 24 +2 +2 +1 2 +1 4 +2 9 12 20 33 36 40 +1 2 +3 5 6 +1 +4 9 17 24 +1 +3 +1 6 +6 8 +2 +1 +5 9 12 +3 5 6 +1 2 4 +1 8 18 +3 4 +1 2 +1 2 +6 10 12 18 20 24 +3 +3 +3 17 18 +2 +2 +1 +1 12 +1 2 +4 +9 16 65 72 129 132 136 192 +3 +3 5 +1 2 4 +1 6 8 18 20 +3 +2 +1 6 +2 4 +2 +1 2 4 +1 2 4 +1 2 +1 2 +1 +3 +1 +3 +1 2 +5 +1 +1 2 +5 +5 +1 2 +1 +1 2 +33 48 +3 9 10 +1 4 +1 2 +1 +3 8 17 18 +2 +1 2 +1 2 +1 +1 +1 +1 2 +3 5 6 +1 2 +1 +2 +1 +1 2 +1 +1 8 18 +6 10 12 +5 +3 4 +3 4 +1 +3 5 6 +5 9 12 17 20 24 +2 12 +1 20 64 +2 +2 20 24 +4 10 18 24 +1 2 +1 2 8 +4 10 18 24 +1 +3 9 10 +1 +12 20 24 36 40 48 68 72 80 96 +2 +1 2 4 +3 17 18 65 66 80 +1 +2 +1 12 20 24 68 72 80 +1 +1 +3 +1 +1 6 +2 +17 20 32 +4 9 33 65 72 96 +1 +1 2 +2 +3 +1 2 +1 +1 6 8 +2 +4 9 10 +5 8 65 68 +1 6 +5 17 18 20 33 36 48 +1 12 +1 4 +2 +1 +1 2 +4 +1 2 +2 4 8 +1 40 +1 2 +1 +1 2 +2 +3 5 6 +24 36 40 48 260 264 272 288 +6 +4 8 +5 9 12 16 +1 2 +1 24 +1 2 +1 2 +2 12 16 +1 2 +1 +1 2 12 40 48 +8 +3 9 10 +1 2 12 +1 +5 8 +3 4 9 10 +2 4 9 33 40 +1 2 +3 +1 4 +1 2 4 +4 9 +4 129 +1 20 +1 2 +1 6 +1 2 +1 2 +3 5 +3 5 6 +1 +1 2 +3 5 +2 +1 +8 16 32 +2 +1 +5 9 10 17 18 24 +1 2 +6 +5 +1 2 +3 +1 2 +1 +5 +1 6 +1 2 +1 2 +1 +1 +3 +1 +1 +6 10 12 +1 +1 diff --git a/test/test_data/top10_thresholds b/test/test_data/top10_thresholds new file mode 100644 index 000000000..e90c41493 --- /dev/null +++ b/test/test_data/top10_thresholds @@ -0,0 +1,500 @@ +8.1191 +6.40702 +8.53867 +7.33842 +4.62827 +7.43321 +6.07675 +9.76084 +8.25249 +6.33509 +5.94863 +8.06593 +6.77912 +8.00327 +8.06362 +6.78048 +7.16929 +6.22611 +4.86989 +8.08515 +5.00578 +5.72889 +3.3509 +6.03476 +7.0362 +7.55356 +10.4887 +7.63555 +5.8102 +9.09428 +9.71736 +10.1705 +7.12889 +6.97783 +6.79471 +5.62071 +0 +0 +4.61359 +16.0051 +7.7722 +10.6181 +7.65763 +2.58132 +10.6563 +9.43986 +9.1898 +6.66688 +6.36398 +5.66858 +8.04628 +7.27758 +8.27132 +16.886 +6.54159 +6.16754 +8.786 +0 +9.97472 +6.3738 +9.62928 +5.61858 +15.1814 +7.34612 +7.48248 +1.88119e-06 +9.75801 +7.72713 +8.63993 +9.11045 +11.1941 +12.6965 +9.2991 +6.23525 +6.18397 +7.16673 +7.3081 +9.34341 +7.56361 +8.35106 +6.32008 +9.17452 +7.44254 +3.29669 +4.81384 +13.2702 +2.66406 +8.50921 +9.88172 +7.06583 +11.3807 +2.72187 +11.1195 +8.54663 +6.3199 +9.95984 +2.11432 +9.89071 +0 +9.52282 +8.62435 +6.07234 +9.51326 +8.03819 +8.19029 +4.81047 +7.06047 +15.864 +6.73753 +6.07235 +7.07044 +9.949 +7.07569 +1.3461 +6.15922 +7.38229 +9.95886 +6.96017 +7.56859 +8.89754 +7.31538 +10.0964 +10.802 +5.87512 +4.65612 +7.31016 +8.06353 +14.7184 +5.10633 +6.54296 +7.29809 +8.5369 +8.21982 +6.63908 +7.93862 +8.48126 +9.1978 +6.47504 +9.09038 +6.65207 +7.87977 +8.91329 +8.21564 +7.90121 +7.21129 +14.6396 +7.7758 +11.0978 +6.53356 +5.61717 +8.16117 +7.21347 +9.12 +8.98664 +6.14115 +9.136 +8.8644 +4.25504 +9.91241 +8.2076 +8.59041 +6.81703 +6.4299 +9.45913 +11.1704 +7.23268 +7.89493 +8.27872 +6.45686 +10.1443 +7.79637 +5.45131 +6.81683 +3.88536 +7.02361 +7.18851 +7.42748 +8.95018 +11.2902 +8.81273 +6.72542 +8.52207 +6.91805 +7.30957 +7.42604 +8.68603 +12.3029 +9.5102 +6.58983 +9.3897 +0 +7.50808 +6.99158 +0 +6.10725 +8.43093 +6.20851 +0 +6.88064 +5.59466 +8.55461 +3.07427 +6.58394 +8.60744 +3.13404 +7.16225 +10.3509 +9.13109 +8.54642 +7.0724 +7.90035 +6.53838 +7.0606 +6.20906 +8.03361 +8.33321 +6.2458 +5.85291 +4.86942 +6.39161 +8.01369 +9.68012 +8.3957 +9.56297 +5.30648 +7.9272 +8.41178 +9.78377 +5.95837 +7.02013 +7.29275 +7.72767 +7.13853 +7.19798 +6.81257 +9.9947 +6.52238 +1.84112 +8.02057 +9.75506 +4.24206 +8.10603 +7.62294 +6.64202 +8.31228 +6.39894 +7.64641 +7.18985 +10.9307 +0 +9.35913 +7.14301 +2.62868 +7.82687 +7.78378 +4.47242 +5.99799 +8.01446 +7.29016 +6.96249 +7.31208 +7.26219 +1.02171 +7.90062 +8.52454 +7.79633 +0 +6.39334 +3.65948 +6.13462 +8.53574 +2.42495 +8.34661 +6.5509 +9.19673 +8.72243 +0.975942 +2.97355 +17.1217 +8.16323 +5.56541 +6.57861 +5.56972 +9.25555 +8.9732 +9.58803 +7.21451 +7.85729 +7.96213 +9.15531 +8.90842 +7.63755 +4.98901 +6.89196 +8.80033 +3.0826 +8.63325 +8.71573 +2.71681 +9.0253 +7.51434 +8.20914 +7.16011 +5.08073 +7.0777 +12.4818 +5.96502 +10.091 +0 +8.57939 +7.44738 +8.60938 +6.30986 +15.5856 +7.38011 +0 +10.2552 +8.81642 +2.74022 +7.3057 +8.29425 +5.7883 +6.793 +12.642 +7.0631 +7.46704 +13.6006 +8.07774 +6.36285 +0 +7.83312 +6.49256 +10.4431 +11.8368 +7.7008 +7.39802 +5.43342 +8.84415 +9.69069 +3.62001 +7.9487 +6.91921 +9.79507 +6.22026 +6.91993 +3.88733 +2.97617 +8.70682 +5.30657 +6.33269 +10.1033 +6.96088 +4.79799 +5.73728 +7.03829 +8.05473 +7.29462 +5.77349 +2.82017 +6.793 +14.9411 +9.10034 +6.73388 +6.46703 +7.13749 +9.87858 +7.78961 +6.41109 +6.69017 +0 +8.75548 +7.01149 +7.10874 +8.80585 +7.39948 +6.35082 +5.5849 +7.18396 +0.990798 +5.72014 +8.44618 +8.44117 +4.89387 +9.79468 +7.04739 +8.89953 +8.64338 +12.7184 +10.9665 +6.86001 +8.71301 +6.50483 +10.4117 +8.92076 +7.14701 +9.26171 +7.06198 +7.50738 +7.10056 +9.75535 +7.04357 +7.95306 +10.1646 +0 +5.94793 +7.94278 +6.07347 +6.6129 +5.74331 +6.27284 +7.83404 +5.58244 +8.4209 +8.93761 +0 +5.68759 +8.45663 +7.6864 +7.32927 +6.34879 +7.65884 +6.00451 +8.86703 +9.56031 +8.25177 +17.2111 +7.36363 +7.45831 +7.88611 +0 +7.44071 +5.64288 +5.34061 +6.47387 +8.15098 +7.30174 +5.3787 +7.22029 +6.87507 +6.78256 +13.5872 +9.50327 +7.79518 +12.4835 +2.97089 +7.6197 +6.54403 +8.7382 +8.07254 +8.09486 +6.58483 +10.2843 +6.67195 +9.55084 +6.63688 +0 +12.5787 +9.6835 +9.46256 +4.68514 +9.70017 +7.34846 +6.35295 +7.57669 +16.6089 +9.65398 +7.08425 +8.39853 +0 +7.6837 +8.62668 +10.1491 +0 +5.10292 +8.00893 +4.33105 +0 +7.55501 +8.08855 +5.86112 +11.9459 +6.54911 +9.99107 +8.97999 +7.07643 +7.32234 +5.08513 +6.66545 +11.1706 +8.37998 +7.17167 +6.5358 +5.87431 +8.89896 +6.3356 +7.97252 +6.52537 +9.99582 +6.36973 +0 diff --git a/test/v1/CMakeLists.txt b/test/v1/CMakeLists.txt index 663367684..0ae89c35d 100644 --- a/test/v1/CMakeLists.txt +++ b/test/v1/CMakeLists.txt @@ -11,6 +11,7 @@ foreach(TEST_SRC ${TEST_SOURCES}) target_link_libraries(${TEST_SRC_NAME} pisa Catch2 + rapidcheck ) catch_discover_tests(${TEST_SRC_NAME} TEST_PREFIX "${TEST_SRC_NAME}:") diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 89125c2eb..efc6cdf40 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -1,26 +1,41 @@ #pragma once +#include +#include + #include "../temporary_directory.hpp" #include "pisa_config.hpp" #include "query/queries.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" +#include "v1/intersection.hpp" #include "v1/query.hpp" #include "v1/score_index.hpp" namespace v1 = pisa::v1; -[[nodiscard]] inline auto test_queries() -> std::vector +[[nodiscard]] inline auto test_queries() -> std::vector { - std::vector queries; + std::vector queries; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { - queries.push_back(pisa::parse_query_ids(query_line)); + auto q = pisa::parse_query_ids(query_line); + queries.push_back( + v1::Query{.terms = q.terms, .list_selection = {}, .threshold = {}, .id = {}, .k = 10}); }; pisa::io::for_each_line(qfile, push_query); return queries; } +[[nodiscard]] inline auto test_intersection_selections() +{ + auto intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + auto unigrams = pisa::v1::filter_unigrams(intersections); + auto bigrams = pisa::v1::filter_bigrams(intersections); + return std::make_pair(unigrams, bigrams); +} + template struct IndexFixture { using DocumentWriter = typename v1::CursorTraits::Writer; @@ -43,7 +58,19 @@ struct IndexFixture { auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", index_basename); REQUIRE(errors.empty()); - v1::score_index(fmt::format("{}.yml", index_basename), 1); + auto yml = fmt::format("{}.yml", index_basename); + + auto queries = [&]() { + std::vector queries; + auto qs = test_queries(); + int idx = 0; + std::transform(qs.begin(), qs.end(), std::back_inserter(queries), [&](auto q) { + return v1::Query{q.terms}; + }); + return queries; + }(); + v1::build_bigram_index(yml, collect_unique_bigrams(queries, []() {})); + v1::score_index(yml, 1); } [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } diff --git a/test/v1/test_v1.cpp b/test/v1/test_v1.cpp index 4b23ed14e..6728dd0a3 100644 --- a/test/v1/test_v1.cpp +++ b/test/v1/test_v1.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "io.hpp" @@ -52,18 +53,6 @@ std::ostream& operator<<(std::ostream& os, tl::optional const& val) TEST_CASE("partition_by_index", "[v1][unit]") { - std::vector values{5, 0, 1, 2, 3, 4}; - // auto input_data = GENERATE(table, std::vector>( - // {{{}, {5, 0, 1, 2, 3, 4}}, - // {{0, 1, 2}, {2, 3, 4, 5, 0, 1}}, - // {{3, 4, 5}, {}}, - // {{0, 4, 5}, {}}})); - // auto input_data = GENERATE(table, std::vector>( - // {{{}, {5, 0, 1, 2, 3, 4}}, - // {{0, 1, 2}, {2, 3, 4, 5, 0, 1}}, - // {{3, 4, 5}, {}}, - // {{0, 4, 5}, {}}})); - auto expected = [](auto input_vec, auto right_indices) { std::vector essential; std::sort(right_indices.begin(), right_indices.end(), std::greater{}); @@ -77,18 +66,21 @@ TEST_CASE("partition_by_index", "[v1][unit]") return input_vec; }; - std::vector input{5, 0, 1, 2, 3, 4}; - std::vector right_indices{}; - auto expected_output = expected(input, right_indices); - pisa::v1::partition_by_index(gsl::make_span(input), gsl::make_span(right_indices)); - std::sort(input.begin(), std::next(input.begin(), input.size() - right_indices.size())); - std::sort(std::next(input.begin(), right_indices.size()), input.end()); - REQUIRE(input == expected_output); - - // auto [right_indices, expected] = GENERATE( - // table, std::vector>({{}, {5, 0, 1, 2, 3, 4}})); - // pisa::v1::partition_by_index(gsl::make_span(values), gsl::make_span(right_indices)); - // REQUIRE(values == expected); + rc::check([&](std::vector input) { + CAPTURE(input); + std::vector all_indices(input.size()); + std::iota(all_indices.begin(), all_indices.end(), 0); + auto right_indices = + *rc::gen::unique>(rc::gen::elementOf(all_indices)); + CAPTURE(right_indices); + auto expected_output = expected(input, right_indices); + auto essential_count = right_indices.size(); + auto non_essential_count = input.size() - essential_count; + pisa::v1::partition_by_index(gsl::make_span(input), gsl::make_span(right_indices)); + std::sort(input.begin(), std::next(input.begin(), non_essential_count)); + std::sort(std::next(input.begin(), non_essential_count), input.end()); + REQUIRE(input == expected_output); + }); } TEST_CASE("RawReader", "[v1][unit]") diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp new file mode 100644 index 000000000..e2028bb43 --- /dev/null +++ b/test/v1/test_v1_bigram_index.cpp @@ -0,0 +1,104 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include +#include + +#include +#include +#include + +#include "../temporary_directory.hpp" +#include "accumulator/lazy_accumulator.hpp" +#include "cursor/block_max_scored_cursor.hpp" +#include "cursor/max_scored_cursor.hpp" +#include "cursor/scored_cursor.hpp" +#include "index_fixture.hpp" +#include "index_types.hpp" +#include "io.hpp" +#include "pisa_config.hpp" +#include "query/queries.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/cursor_union.hpp" +#include "v1/index.hpp" +#include "v1/index_builder.hpp" +#include "v1/maxscore.hpp" +#include "v1/posting_builder.hpp" +#include "v1/posting_format_header.hpp" +#include "v1/query.hpp" +#include "v1/score_index.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/types.hpp" +#include "v1/union_lookup.hpp" + +namespace v1 = pisa::v1; + +static constexpr auto RELATIVE_ERROR = 0.1F; + +TEMPLATE_TEST_CASE("Bigram v intersection", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::BlockedCursor<::pisa::simdbp_block, false>, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + int idx = 0; + for (auto& q : test_queries()) { + std::sort(q.terms.begin(), q.terms.end()); + q.terms.erase(std::unique(q.terms.begin(), q.terms.end()), q.terms.end()); + CAPTURE(q.terms); + CAPTURE(idx++); + + auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + std::vector results; + run([&](auto&& index) { + auto scorer = make_bm25(index); + for (auto left = 0; left < q.terms.size(); left += 1) { + for (auto right = left + 1; right < q.terms.size(); right += 1) { + auto left_cursor = index.cursor(q.terms[left]); + auto right_cursor = index.cursor(q.terms[right]); + auto intersection = v1::intersect({left_cursor, right_cursor}, + std::array{0, 0}, + [](auto& acc, auto&& cursor, auto idx) { + acc[idx] = cursor.payload(); + return acc; + }); + if (not intersection.empty()) { + auto bigram_cursor = *index.bigram_cursor(q.terms[left], q.terms[right]); + std::vector bigram_documents; + std::vector bigram_frequencies_0; + std::vector bigram_frequencies_1; + v1::for_each(bigram_cursor, [&](auto&& cursor) { + bigram_documents.push_back(*cursor); + auto payload = cursor.payload(); + bigram_frequencies_0.push_back(std::get<0>(payload)); + bigram_frequencies_1.push_back(std::get<1>(payload)); + }); + std::vector intersection_documents; + std::vector intersection_frequencies_0; + std::vector intersection_frequencies_1; + v1::for_each(intersection, [&](auto&& cursor) { + intersection_documents.push_back(*cursor); + auto payload = cursor.payload(); + intersection_frequencies_0.push_back(std::get<0>(payload)); + intersection_frequencies_1.push_back(std::get<1>(payload)); + }); + CHECK(bigram_documents == intersection_documents); + CHECK(bigram_frequencies_0 == intersection_frequencies_0); + REQUIRE(bigram_frequencies_1 == intersection_frequencies_1); + } + } + } + }); + } +} diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 8c388c1c8..ccacf0a78 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -128,7 +128,9 @@ TEMPLATE_TEST_CASE("Query", auto input_data = GENERATE(table({{"daat_or", false}, {"maxscore", false}, {"maxscore", true}, - {"maxscore_union_lookup", true}})); + {"maxscore_union_lookup", true}, + {"unigram_union_lookup", true}, + {"union_lookup", true}})); std::string algorithm = std::get<0>(input_data); bool with_threshold = std::get<1>(input_data); CAPTURE(algorithm); @@ -146,37 +148,45 @@ TEMPLATE_TEST_CASE("Query", if (name == "maxscore_union_lookup") { return maxscore_union_lookup(query, index, topk_queue(10), scorer); } + if (name == "unigram_union_lookup") { + query.list_selection = + tl::make_optional(v1::ListSelection{.unigrams = query.terms, .bigrams = {}}); + return unigram_union_lookup(query, index, topk_queue(10), scorer); + } + if (name == "union_lookup") { + if (query.terms.size() > 8) { + return maxscore_union_lookup(query, index, topk_queue(10), scorer); + } + return union_lookup(query, index, topk_queue(10), scorer); + } std::abort(); }; int idx = 0; - for (auto& q : test_queries()) { - std::sort(q.terms.begin(), q.terms.end()); - q.terms.erase(std::unique(q.terms.begin(), q.terms.end()), q.terms.end()); - CAPTURE(q.terms); - CAPTURE(idx++); + auto const intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + for (auto& query : test_queries()) { + if (algorithm == "union_lookup") { + query.add_selections(gsl::make_span(intersections[idx])); + } + query.remove_duplicates(); - or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); + CAPTURE(query.terms); + CAPTURE(idx); + CAPTURE(intersections[idx]); + + or_q(make_scored_cursors(data->v0_index, data->wdata, ::pisa::Query{{}, query.terms, {}}), + data->v0_index.num_docs()); auto expected = or_q.topk(); - auto threshold = [&]() { - if (with_threshold) { - return tl::make_optional(expected.back().first - 1.0F); - } - return tl::optional{}; - }(); + if (with_threshold) { + query.threshold = expected.back().first - 1.0F; + } auto on_the_fly = [&]() { auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { - auto que = run_query(algorithm, - v1::Query{.terms = q.terms, - .list_selection = {}, - .threshold = threshold, - .id = {}, - .k = 10}, - index, - make_bm25(index)); + auto que = run_query(algorithm, query, index, make_bm25(index)); que.finalize(); results = que.topk(); results.erase(std::remove_if(results.begin(), @@ -192,11 +202,24 @@ TEMPLATE_TEST_CASE("Query", expected.resize(on_the_fly.size()); std::sort(expected.begin(), expected.end(), approximate_order); + // if (algorithm == "union_lookup") { + // for (size_t i = 0; i < on_the_fly.size(); ++i) { + // std::cerr << fmt::format("{}, {:f} -- {}, {:f}\n", + // on_the_fly[i].second, + // on_the_fly[i].first, + // expected[i].second, + // expected[i].first); + // } + //} + // std::cerr << '\n'; + for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); } + idx += 1; + // auto precomputed = [&]() { // auto run = // v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); @@ -261,73 +284,3 @@ TEMPLATE_TEST_CASE("Query", //} } } - -TEMPLATE_TEST_CASE("UnionLookup", - "[v1][integration]", - (IndexFixture, - v1::RawCursor, - v1::RawCursor>), - (IndexFixture, - v1::BlockedCursor<::pisa::simdbp_block, false>, - v1::RawCursor>)) -{ - tbb::task_scheduler_init init(1); - auto data = IndexData, v1::RawCursor>, - v1::Index, v1::RawCursor>>::get(); - TestType fixture; - auto index_basename = (fixture.tmpdir().path() / "inv").string(); - auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); - ranked_or_query or_q(10); - int idx = 0; - for (auto& q : test_queries()) { - CAPTURE(q.terms); - CAPTURE(idx++); - - or_q(make_scored_cursors(data->v0_index, data->wdata, q), data->v0_index.num_docs()); - auto expected = or_q.topk(); - std::sort(expected.begin(), expected.end(), std::greater{}); - - auto on_the_fly = [&]() { - auto run = - v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); - std::vector results; - run([&](auto&& index) { - std::vector unigrams(q.terms.size()); - std::iota(unigrams.begin(), unigrams.end(), 0); - auto que = v1::union_lookup(v1::Query{q.terms}, - index, - topk_queue(10), - make_bm25(index), - std::move(unigrams), - {}); - que.finalize(); - results = que.topk(); - std::sort(results.begin(), results.end(), std::greater{}); - }); - return results; - }(); - - // auto precomputed = [&]() { - // auto run = - // v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); - // std::vector results; - // run([&](auto&& index) { - // auto que = daat_or(v1::Query{q.terms}, index, topk_queue(10), v1::VoidScorer{}); - // que.finalize(); - // results = que.topk(); - // std::sort(results.begin(), results.end(), std::greater{}); - // }); - // return results; - //}(); - - REQUIRE(expected.size() == on_the_fly.size()); - // REQUIRE(expected.size() == precomputed.size()); - for (size_t i = 0; i < on_the_fly.size(); ++i) { - REQUIRE(on_the_fly[i].second == expected[i].second); - REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - // REQUIRE(precomputed[i].second == expected[i].second); - // REQUIRE(precomputed[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - } - } -} diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index b4b62944f..3fefab289 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -4,9 +4,6 @@ target_link_libraries(compress pisa CLI11) add_executable(query query.cpp) target_link_libraries(query pisa CLI11) -add_executable(union-lookup union_lookup.cpp) -target_link_libraries(union-lookup pisa CLI11) - add_executable(postings postings.cpp) target_link_libraries(postings pisa CLI11) diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index 8affa592b..7e4c0aa7a 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -13,42 +14,23 @@ #include "v1/cursor_intersection.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" #include "v1/scorer/bm25.hpp" #include "v1/scorer/runner.hpp" #include "v1/types.hpp" -using pisa::v1::BigramMetadata; -using pisa::v1::BlockedReader; -using pisa::v1::CursorTraits; +using pisa::v1::build_bigram_index; +using pisa::v1::collect_unique_bigrams; using pisa::v1::DefaultProgress; using pisa::v1::DocId; using pisa::v1::Frequency; -using pisa::v1::index_runner; using pisa::v1::IndexMetadata; -using pisa::v1::intersect; -using pisa::v1::PostingBuilder; using pisa::v1::ProgressStatus; -using pisa::v1::RawReader; +using pisa::v1::Query; using pisa::v1::resolve_yml; using pisa::v1::TermId; -using pisa::v1::write_span; - -auto collect_unique_bigrams(std::vector const& queries) - -> std::vector> -{ - std::vector> bigrams; - for (auto&& query : queries) { - for (auto left = 0; left < query.terms.size(); left += 1) { - auto right = left + 1; - bigrams.emplace_back(query.terms[left], query.terms[right]); - } - } - std::sort(bigrams.begin(), bigrams.end()); - bigrams.erase(std::unique(bigrams.begin(), bigrams.end()), bigrams.end()); - return bigrams; -} int main(int argc, char** argv) { @@ -73,91 +55,41 @@ int main(int argc, char** argv) terms_file = meta.term_lexicon.value(); } - std::vector queries; - auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); - if (query_file) { - std::ifstream is(*query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } - - auto bigrams = collect_unique_bigrams(queries); - - auto index_basename = resolved_yml.substr(0, resolved_yml.size() - 4); - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - - std::vector> pair_mapping; - auto documents_file = fmt::format("{}.bigram_documents", index_basename); - auto frequencies_file_0 = fmt::format("{}.bigram_frequencies_0", index_basename); - auto frequencies_file_1 = fmt::format("{}.bigram_frequencies_1", index_basename); - auto document_offsets_file = fmt::format("{}.bigram_document_offsets", index_basename); - auto frequency_offsets_file_0 = fmt::format("{}.bigram_frequency_offsets_0", index_basename); - auto frequency_offsets_file_1 = fmt::format("{}.bigram_frequency_offsets_1", index_basename); - std::ofstream document_out(documents_file); - std::ofstream frequency_out_0(frequencies_file_0); - std::ofstream frequency_out_1(frequencies_file_1); - - run([&](auto&& index) { - ProgressStatus status(bigrams.size(), - DefaultProgress("Building bigram index"), - std::chrono::milliseconds(100)); - using index_type = std::decay_t; - using document_writer_type = - typename CursorTraits::Writer; - using frequency_writer_type = - typename CursorTraits::Writer; - - PostingBuilder document_builder(document_writer_type{}); - PostingBuilder frequency_builder_0(frequency_writer_type{}); - PostingBuilder frequency_builder_1(frequency_writer_type{}); - - document_builder.write_header(document_out); - frequency_builder_0.write_header(frequency_out_0); - frequency_builder_1.write_header(frequency_out_1); - - for (auto [left_term, right_term] : bigrams) { - auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, - std::array{0, 0}, - [](auto& payload, auto& cursor, auto list_idx) { - payload[list_idx] = cursor.payload(); - return payload; - }); - if (intersection.empty()) { - status += 1; - continue; + spdlog::info("Collecting queries..."); + auto queries = [&]() { + std::vector<::pisa::Query> queries; + auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); + if (query_file) { + std::ifstream is(*query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + std::vector v1_queries; + v1_queries.reserve(queries.size()); + for (auto query : queries) { + if (not query.terms.empty()) { + v1_queries.push_back(Query{.terms = query.terms, + .list_selection = {}, + .threshold = {}, + .id = + [&]() { + if (query.id) { + return tl::make_optional(*query.id); + } + return tl::optional{}; + }(), + .k = 10}); } - pair_mapping.push_back({left_term, right_term}); - for_each(intersection, [&](auto& cursor) { - document_builder.accumulate(*cursor); - auto payload = cursor.payload(); - frequency_builder_0.accumulate(std::get<0>(payload)); - frequency_builder_1.accumulate(std::get<1>(payload)); - }); - document_builder.flush_segment(document_out); - frequency_builder_0.flush_segment(frequency_out_0); - frequency_builder_1.flush_segment(frequency_out_1); - status += 1; } - std::cerr << "Writing offsets..."; - write_span(gsl::make_span(document_builder.offsets()), document_offsets_file); - write_span(gsl::make_span(frequency_builder_0.offsets()), frequency_offsets_file_0); - write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); - std::cerr << " Done.\n"; - }); - std::cerr << "Writing metadata..."; - meta.bigrams = BigramMetadata{ - .documents = {.postings = documents_file, .offsets = document_offsets_file}, - .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, - {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, - .mapping = fmt::format("{}.bigram_mapping", index_basename), - .count = pair_mapping.size()}; - meta.write(resolved_yml); - std::cerr << " Done.\nWriting bigram mapping..."; - write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); - std::cerr << " Done.\n"; + return v1_queries; + }(); + + spdlog::info("Collected {} queries", queries.size()); + spdlog::info("Collecting bigrams..."); + ProgressStatus status(queries.size(), DefaultProgress{}, std::chrono::milliseconds(1000)); + auto bigrams = collect_unique_bigrams(queries, [&]() { status += 1; }); + spdlog::info("Collected {} bigrams", bigrams.size()); + build_bigram_index(resolved_yml, bigrams); return 0; } diff --git a/v1/query.cpp b/v1/query.cpp index dd8571f79..5f949482e 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include "app.hpp" @@ -13,6 +15,7 @@ #include "v1/analyze_query.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index_metadata.hpp" +#include "v1/intersection.hpp" #include "v1/maxscore.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -27,6 +30,7 @@ using pisa::v1::daat_or; using pisa::v1::DaatOrAnalyzer; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::ListSelection; using pisa::v1::maxscore_union_lookup; using pisa::v1::MaxscoreAnalyzer; using pisa::v1::MaxscoreUnionLookupAnalyzer; @@ -34,6 +38,10 @@ using pisa::v1::Query; using pisa::v1::QueryAnalyzer; using pisa::v1::RawReader; using pisa::v1::resolve_yml; +using pisa::v1::unigram_union_lookup; +using pisa::v1::UnigramUnionLookupAnalyzer; +using pisa::v1::union_lookup; +using pisa::v1::UnionLookupAnalyzer; using pisa::v1::VoidScorer; using RetrievalAlgorithm = std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)>; @@ -61,6 +69,22 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco query, index, std::move(topk), std::forward(scorer)); }); } + if (name == "unigram-union-lookup") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }); + } + if (name == "union-lookup") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.list_selection && query.list_selection->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } @@ -75,7 +99,15 @@ auto resolve_analyze(std::string const& name, Index const& index, Scorer&& score return QueryAnalyzer(MaxscoreAnalyzer(index, std::forward(scorer))); } if (name == "maxscore-union-lookup") { - return QueryAnalyzer(MaxscoreUnionLookupAnalyzer(index, std::forward(scorer))); + return QueryAnalyzer( + MaxscoreUnionLookupAnalyzer>(index, scorer)); + } + if (name == "unigram-union-lookup") { + return QueryAnalyzer( + UnigramUnionLookupAnalyzer>(index, scorer)); + } + if (name == "union-lookup") { + return QueryAnalyzer(UnionLookupAnalyzer>(index, scorer)); } spdlog::error("Unknown algorithm: {}", name); std::exit(1); @@ -124,10 +156,10 @@ void benchmark(std::vector const& queries, int k, RetrievalAlgorithm retr double q50 = times[times.size() / 2]; double q90 = times[90 * times.size() / 100]; double q95 = times[95 * times.size() / 100]; - spdlog::info("Mean: {}", avg); - spdlog::info("50% quantile: {}", q50); - spdlog::info("90% quantile: {}", q90); - spdlog::info("95% quantile: {}", q95); + spdlog::info("Mean: {} microsec.", avg); + spdlog::info("50% quantile: {} microsec.", q50); + spdlog::info("90% quantile: {} microsec.", q90); + spdlog::info("95% quantile: {} microsec.", q95); } void analyze_queries(std::vector const& queries, QueryAnalyzer analyzer) @@ -145,11 +177,13 @@ int main(int argc, char** argv) { std::string algorithm = "daat_or"; tl::optional threshold_file; + tl::optional inter_filename; bool analyze = false; pisa::QueryApp app("Queries a v1 index."); app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.", false); + app.add_option("--intersections", inter_filename, "Intersections filename"); app.add_flag("--analyze", analyze, "Analyze query execution and stats"); CLI11_PARSE(app, argc, argv); @@ -212,6 +246,21 @@ int main(int argc, char** argv) } } + if (inter_filename) { + auto const intersections = pisa::v1::read_intersections(*inter_filename); + if (intersections.size() != queries.size()) { + spdlog::error("Number of intersections is not equal to number of queries"); + std::exit(1); + } + /* auto unigrams = pisa::v1::filter_unigrams(intersections); */ + /* auto bigrams = pisa::v1::filter_bigrams(intersections); */ + + for (auto query_idx = 0; query_idx < queries.size(); query_idx += 1) { + queries[query_idx].add_selections(gsl::make_span(intersections[query_idx])); + // ListSelection{std::move(unigrams[query_idx]), std::move(bigrams[query_idx])}; + } + } + if (app.precomputed) { auto run = scored_index_runner(meta, RawReader{}, From 319895cf794b20d5468882bfdfe530f3ea15d7f4 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Thu, 28 Nov 2019 08:46:57 -0500 Subject: [PATCH 25/56] Add json library --- .gitmodules | 3 +++ external/json | 1 + 2 files changed, 4 insertions(+) create mode 160000 external/json diff --git a/.gitmodules b/.gitmodules index c9f749b0a..97ddee4b6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -79,3 +79,6 @@ [submodule "external/rapidcheck"] path = external/rapidcheck url = https://github.com/emil-e/rapidcheck.git +[submodule "external/json"] + path = external/json + url = https://github.com/nlohmann/json.git diff --git a/external/json b/external/json new file mode 160000 index 000000000..e7b3b40b5 --- /dev/null +++ b/external/json @@ -0,0 +1 @@ +Subproject commit e7b3b40b5a95bc74b9a7f662830a27c49ffc01b4 From ba7f62c992204f5fb4c80c91b53dfa5aaf3d6bb5 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 29 Nov 2019 09:58:41 -0500 Subject: [PATCH 26/56] Refactoring --- CMakeLists.txt | 5 +- external/CMakeLists.txt | 6 + include/pisa/v1/daat_and.hpp | 26 ++ include/pisa/v1/daat_or.hpp | 83 ++++ include/pisa/v1/maxscore.hpp | 25 +- include/pisa/v1/query.hpp | 205 +++++----- include/pisa/v1/taat_or.hpp | 25 ++ include/pisa/v1/union_lookup.hpp | 132 +++---- src/compute_intersection.cpp | 9 +- src/v1/index_builder.cpp | 10 +- src/v1/query.cpp | 159 +++++++- test/test_data/top10_selections_unigram | 500 ++++++++++++++++++++++++ test/v1/index_fixture.hpp | 17 +- test/v1/test_v1_bigram_index.cpp | 15 +- test/v1/test_v1_maxscore_join.cpp | 14 +- test/v1/test_v1_queries.cpp | 18 +- v1/CMakeLists.txt | 3 + v1/bigram_index.cpp | 22 +- v1/filter_queries.cpp | 74 ++++ v1/query.cpp | 36 +- 20 files changed, 1103 insertions(+), 281 deletions(-) create mode 100644 include/pisa/v1/daat_and.hpp create mode 100644 include/pisa/v1/daat_or.hpp create mode 100644 include/pisa/v1/taat_or.hpp create mode 100644 test/test_data/top10_selections_unigram create mode 100644 v1/filter_queries.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d6cd6c77..c3cc78cda 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,8 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb") # Add debug info anyway + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + endif() find_package(OpenMP) @@ -100,11 +102,12 @@ target_link_libraries(pisa PUBLIC range-v3 optional yaml-cpp + nlohmann_json ) target_include_directories(pisa PUBLIC external) add_subdirectory(v1) -#add_subdirectory(src) +add_subdirectory(src) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 879a73d45..cb705fe45 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -129,3 +129,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/yaml-cpp EXCLUDE_FROM_ALL) # Add RapidCheck add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/rapidcheck) target_compile_options(rapidcheck PRIVATE -Wno-error=all) + +# Add json +# TODO(michal): I had to comment this out because `wapocpp` already adds this target. +# How should we deal with it? +# set(JSON_BuildTests OFF CACHE BOOL "skip building JSON tests") +# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/json) diff --git a/include/pisa/v1/daat_and.hpp b/include/pisa/v1/daat_and.hpp new file mode 100644 index 000000000..0f3bc0343 --- /dev/null +++ b/include/pisa/v1/daat_and.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto daat_and(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + std::vector cursors; + std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(cursors), [&](auto term) { + return index.scored_cursor(term, scorer); + }); + auto intersection = + v1::intersect(std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(intersection, [&](auto& cursor) { topk.insert(cursor.payload(), *cursor); }); + return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/daat_or.hpp b/include/pisa/v1/daat_or.hpp new file mode 100644 index 000000000..f9b1f9cd7 --- /dev/null +++ b/include/pisa/v1/daat_or.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + std::vector cursors; + std::transform(query.get_term_ids().begin(), + query.get_term_ids().end(), + std::back_inserter(cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { + score += cursor.payload(); + return score; + }); + v1::for_each(cunion, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +template +struct DaatOrAnalyzer { + DaatOrAnalyzer(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format("documents\tpostings\n"); + } + + void operator()(Query const& query) + { + std::vector cursors; + std::transform(query.get_term_ids().begin(), + query.get_term_ids().end(), + std::back_inserter(cursors), + [&](auto term) { return m_index.scored_cursor(term, m_scorer); }); + std::size_t postings = 0; + auto cunion = v1::union_merge( + std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { + postings += 1; + score += cursor.payload(); + return score; + }); + std::size_t documents = 0; + std::size_t inserts = 0; + topk_queue topk(query.k()); + v1::for_each(cunion, [&](auto& cursor) { + if (topk.insert(cursor.payload(), cursor.value())) { + inserts += 1; + }; + documents += 1; + }); + std::cout << fmt::format("{}\t{}\t{}\n", documents, postings, inserts); + m_documents += documents; + m_postings += postings; + m_inserts += inserts; + m_count += 1; + } + + void summarize() && + { + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n- postings:\t{}\n- inserts:\t{}\n", + static_cast(m_documents) / m_count, + static_cast(m_postings) / m_count, + static_cast(m_inserts) / m_count); + } + + private: + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index cd541749f..ecd86f363 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -214,24 +214,24 @@ auto join_maxscore(CursorContainer cursors, template auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return topk; } using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using value_type = decltype(index.max_scored_cursor(0, scorer).value()); std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.max_scored_cursor(term, scorer); }); + std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(cursors), [&](auto term) { + return index.max_scored_cursor(term, scorer); + }); auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { score += cursor.payload(); return score; }; - if (query.threshold) { - topk.set_threshold(*query.threshold); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); } auto joined = join_maxscore( std::move(cursors), 0.0F, accumulate, [&](auto score) { return topk.would_enter(score); }); @@ -249,7 +249,8 @@ struct MaxscoreAnalyzer { void operator()(Query const& query) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return; } using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); @@ -260,14 +261,14 @@ struct MaxscoreAnalyzer { m_current_lookups = 0; std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), + std::transform(term_ids.begin(), + term_ids.end(), std::back_inserter(cursors), [&](auto term) { return m_index.max_scored_cursor(term, m_scorer); }); std::size_t inserts = 0; - topk_queue topk(query.k); - auto initial_threshold = query.threshold.value_or(-1.0); + topk_queue topk(query.k()); + auto initial_threshold = query.threshold().value_or(-1.0); topk.set_threshold(initial_threshold); auto joined = join_maxscore( std::move(cursors), diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index a5507e39d..7c0a071b5 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -24,129 +24,136 @@ struct ListSelection { std::vector> bigrams{}; }; -struct Query { - std::vector terms; - tl::optional list_selection{}; - tl::optional threshold{}; - tl::optional id{}; - int k{}; - - void add_selections(gsl::span const> selections); - void remove_duplicates(); - - private: - auto resolve_term(std::size_t pos) -> TermId; -}; +template +std::ostream& operator<<(std::ostream& os, std::pair const& p) +{ + return os << '(' << p.first << ", " << p.second << ')'; +} -template -auto daat_and(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +template +std::ostream& operator<<(std::ostream& os, std::vector const& vec) { - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); - auto intersection = - v1::intersect(std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { - score += cursor.payload(); - return score; - }); - v1::for_each(intersection, [&](auto& cursor) { topk.insert(cursor.payload(), *cursor); }); - return topk; + auto pos = vec.begin(); + os << '['; + if (pos != vec.end()) { + os << *pos++; + } + for (; pos != vec.end(); ++pos) { + os << ' ' << *pos; + } + os << ']'; + return os; } -template -auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +inline std::ostream& operator<<(std::ostream& os, ListSelection const& selection) { - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); - auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { - score += cursor.payload(); - return score; - }); - v1::for_each(cunion, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); - return topk; + return os << "ListSelection { unigrams: " << selection.unigrams + << ", bigrams: " << selection.bigrams << " }"; } -template -struct DaatOrAnalyzer { - DaatOrAnalyzer(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) +struct TermIdSet { + explicit TermIdSet(std::vector terms) : m_term_list(std::move(terms)) { - std::cout << fmt::format("documents\tpostings\n"); + m_term_set = m_term_list; + ranges::sort(m_term_set); + ranges::actions::unique(m_term_set); + std::size_t pos = 0; + for (auto term_id : m_term_set) { + m_sorted_positions[term_id] = pos++; + } } - void operator()(Query const& query) + [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t { - std::vector cursors; - std::transform(query.terms.begin(), - query.terms.end(), - std::back_inserter(cursors), - [&](auto term) { return m_index.scored_cursor(term, m_scorer); }); - std::size_t postings = 0; - auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { - postings += 1; - score += cursor.payload(); - return score; - }); - std::size_t documents = 0; - std::size_t inserts = 0; - topk_queue topk(query.k); - v1::for_each(cunion, [&](auto& cursor) { - if (topk.insert(cursor.payload(), cursor.value())) { - inserts += 1; - }; - documents += 1; - }); - std::cout << fmt::format("{}\t{}\t{}\n", documents, postings, inserts); - m_documents += documents; - m_postings += postings; - m_inserts += inserts; - m_count += 1; + return m_sorted_positions.at(term); } - void summarize() && + [[nodiscard]] auto term_at_pos(std::size_t pos) const -> TermId { - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n- postings:\t{}\n- inserts:\t{}\n", - static_cast(m_documents) / m_count, - static_cast(m_postings) / m_count, - static_cast(m_inserts) / m_count); + if (pos >= m_term_list.size()) { + throw std::out_of_range("Invalid intersections: term position out of bounds"); + } + return m_term_list[pos]; } + [[nodiscard]] auto get() const -> std::vector const& { return m_term_set; } + private: - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - Index const& m_index; - Scorer m_scorer; + friend std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids); + std::vector m_term_list; + std::vector m_term_set{}; + std::unordered_map m_sorted_positions{}; }; -template -auto taat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +inline std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids) { - std::vector accumulator(index.num_documents(), 0.0F); - for (auto term : query.terms) { - v1::for_each(index.scored_cursor(term, scorer), - [&accumulator](auto&& cursor) { accumulator[*cursor] += cursor.payload(); }); - } - for (auto document = 0; document < accumulator.size(); document += 1) { - topk.insert(accumulator[document], document); - } - return topk; + return os << "TermIdSet { original: " << term_ids.m_term_list + << ", unique: " << term_ids.m_term_set << " }"; } -/// Returns only unique terms, in sorted order. -[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector; +struct Query { + Query() = default; + + explicit Query(std::string query, tl::optional id = tl::nullopt); + explicit Query(std::vector term_ids, tl::optional id = tl::nullopt); + + /// Setters for optional values (or ones with default value). + auto term_ids(std::vector term_ids) -> Query&; + auto id(std::string) -> Query&; + auto k(int k) -> Query&; + auto selections(gsl::span const> selections) -> Query&; + auto selections(ListSelection selections) -> Query&; + auto threshold(float threshold) -> Query&; + + /// Non-throwing getters + [[nodiscard]] auto term_ids() const -> tl::optional const&>; + [[nodiscard]] auto id() const -> tl::optional const&; + [[nodiscard]] auto k() const -> int; + [[nodiscard]] auto selections() const -> tl::optional; + [[nodiscard]] auto threshold() const -> tl::optional; + + /// Throwing getters + [[nodiscard]] auto get_term_ids() const -> std::vector const&; + [[nodiscard]] auto get_id() const -> std::string const&; + [[nodiscard]] auto get_selections() const -> ListSelection const&; + [[nodiscard]] auto get_threshold() const -> float; + + [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t; + [[nodiscard]] auto term_at_pos(std::size_t pos) const -> TermId; + + void add_selections(gsl::span const> selections); + [[nodiscard]] static auto from_json(std::string_view) -> Query; + + private: + friend std::ostream& operator<<(std::ostream& os, Query const& query); + auto resolve_term(std::size_t pos) -> TermId; + + tl::optional m_term_ids{}; + tl::optional m_selections{}; + tl::optional m_threshold{}; + tl::optional m_id{}; + tl::optional m_raw_string; + int m_k = 1000; +}; -template -auto transform() +template +std::ostream& operator<<(std::ostream& os, tl::optional const& value) { + if (not value) { + os << "None"; + } else { + os << "Some(" << *value << ")"; + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, Query const& query) +{ + return os << "Query { term_ids: " << query.m_term_ids << ", selections: " << query.m_selections + << " }"; } +/// Returns only unique terms, in sorted order. +[[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector; + } // namespace pisa::v1 diff --git a/include/pisa/v1/taat_or.hpp b/include/pisa/v1/taat_or.hpp new file mode 100644 index 000000000..d1e7e1f46 --- /dev/null +++ b/include/pisa/v1/taat_or.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "topk_queue.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +auto taat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + std::vector accumulator(index.num_documents(), 0.0F); + for (auto term : query.get_term_ids()) { + v1::for_each(index.scored_cursor(term, scorer), + [&accumulator](auto&& cursor) { accumulator[*cursor] += cursor.payload(); }); + } + for (auto document = 0; document < accumulator.size(); document += 1) { + topk.insert(accumulator[document], document); + } + return topk; +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 986c76db9..8c7245f3c 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -62,33 +62,34 @@ namespace detail { } // namespace detail template -auto unigram_union_lookup( - Query query, Index const& index, topk_queue topk, Scorer&& scorer, Analyzer* analyzer = nullptr) +auto unigram_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Analyzer* analyzer = nullptr) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return topk; } - if (not query.threshold) { + if (not query.threshold()) { throw std::invalid_argument("Must provide threshold to the query"); } - if (not query.list_selection) { + if (not query.selections()) { throw std::invalid_argument("Must provide essential list selection"); } - if (not query.list_selection->bigrams.empty()) { + if (not query.selections()->bigrams.empty()) { throw std::invalid_argument("This algorithm only supports unigrams"); } + auto const& selections = query.get_selections(); - topk.set_threshold(*query.threshold); + topk.set_threshold(*query.threshold()); using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using payload_type = decltype(std::declval().payload()); - ranges::sort(query.terms); - ranges::sort(query.list_selection->unigrams); - auto non_essential_terms = - ranges::views::set_difference(query.terms, query.list_selection->unigrams) - | ranges::to_vector; + ranges::views::set_difference(term_ids, selections.unigrams) | ranges::to_vector; std::vector cursors; for (auto non_essential_term : non_essential_terms) { @@ -98,7 +99,7 @@ auto unigram_union_lookup( std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); - for (auto essential_term : query.list_selection->unigrams) { + for (auto essential_term : selections.unigrams) { cursors.push_back(index.max_scored_cursor(essential_term, scorer)); } @@ -121,19 +122,17 @@ auto maxscore_union_lookup(Query const& query, Scorer&& scorer, Analyzer* analyzer = nullptr) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return topk; } - if (not query.threshold) { - throw std::invalid_argument("Must provide threshold to the query"); - } - - topk.set_threshold(*query.threshold); + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using payload_type = decltype(std::declval().payload()); - auto cursors = index.max_scored_cursors(gsl::make_span(query.terms), scorer); + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); @@ -144,8 +143,7 @@ auto maxscore_union_lookup(Query const& query, upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); } std::size_t non_essential_count = 0; - while (non_essential_count < cursors.size() - && upper_bounds[non_essential_count] < *query.threshold) { + while (non_essential_count < cursors.size() && upper_bounds[non_essential_count] < threshold) { non_essential_count += 1; } return detail::unigram_union_lookup(std::move(cursors), @@ -175,14 +173,15 @@ struct BaseUnionLookupAnalyzer { void operator()(Query const& query) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return; } using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); reset_current(); - run(query, m_index, m_scorer, topk_queue(query.k)); + run(query, m_index, m_scorer, topk_queue(query.k())); std::cout << fmt::format("{}\t{}\t{}\t{}\n", m_current_documents, m_current_postings, @@ -258,7 +257,11 @@ struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { } void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override { - union_lookup(query, index, std::move(topk), scorer, this); + if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + union_lookup(query, index, std::move(topk), scorer, this); + } } }; @@ -276,38 +279,33 @@ struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { /// posting lists. These must exist in the index, or else this /// algorithm will fail. template -auto union_lookup( - Query query, Index const& index, topk_queue topk, Scorer&& scorer, Analyzer* analyzer = nullptr) +auto union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Analyzer* analyzer = nullptr) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { return topk; } - if (query.terms.size() > 8) { + if (term_ids.size() > 8) { throw std::invalid_argument( "Generic version of union-Lookup supported only for queries of length <= 8"); } - if (not query.threshold) { - throw std::invalid_argument("Must provide threshold to the query"); - } - if (not query.list_selection) { - throw std::invalid_argument("Must provide essential list selection"); - } - using bigram_cursor_type = std::decay_t; + auto threshold = query.get_threshold(); + auto const& selections = query.get_selections(); - auto& essential_unigrams = query.list_selection->unigrams; - auto& essential_bigrams = query.list_selection->bigrams; + using bigram_cursor_type = std::decay_t; - ranges::sort(essential_unigrams); - ranges::actions::unique(essential_unigrams); - ranges::sort(essential_bigrams); - ranges::actions::unique(essential_bigrams); - ranges::sort(query.terms); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; auto non_essential_terms = - ranges::views::set_difference(query.terms, essential_unigrams) | ranges::to_vector; + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; - topk.set_threshold(*query.threshold); + topk.set_threshold(threshold); std::array initial_payload{ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; @@ -318,50 +316,22 @@ auto union_lookup( std::back_inserter(essential_unigram_cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); - std::vector essential_unigram_positions; - [&]() { - auto pos = query.terms.begin(); - for (auto term : essential_unigrams) { - pos = std::find(pos, query.terms.end(), term); - assert(pos != query.terms.end()); - essential_unigram_positions.push_back(std::distance(query.terms.begin(), pos)); - } - }(); - auto merged_unigrams = v1::union_merge( essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { if constexpr (not std::is_void_v) { analyzer->posting(); } - acc[essential_unigram_positions[term_idx]] = cursor.payload(); + acc[query.sorted_position(essential_unigrams[term_idx])] = cursor.payload(); return acc; }); std::vector essential_bigram_cursors; - std::vector> essential_bigram_positions; for (auto [left, right] : essential_bigrams) { - if (left > right) { - std::swap(left, right); - } auto cursor = index.scored_bigram_cursor(left, right, scorer); if (not cursor) { throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); } essential_bigram_cursors.push_back(cursor.take().value()); - std::pair bigram_positions; - if (auto pos = std::lower_bound(query.terms.begin(), query.terms.end(), left); - pos != query.terms.end()) { - bigram_positions.first = std::distance(query.terms.begin(), pos); - } else { - throw std::logic_error("Term from selected intersection not part of query"); - } - if (auto pos = std::lower_bound(query.terms.begin(), query.terms.end(), right); - pos != query.terms.end()) { - bigram_positions.second = std::distance(query.terms.begin(), pos); - } else { - throw std::logic_error("Term from selected intersection not part of query"); - } - essential_bigram_positions.push_back(bigram_positions); } auto merged_bigrams = v1::union_merge( @@ -372,8 +342,8 @@ auto union_lookup( analyzer->posting(); } auto payload = cursor.payload(); - acc[essential_bigram_positions[bigram_idx].first] = std::get<0>(payload); - acc[essential_bigram_positions[bigram_idx].second] = std::get<1>(payload); + acc[query.sorted_position(essential_bigrams[bigram_idx].first)] = std::get<0>(payload); + acc[query.sorted_position(essential_bigrams[bigram_idx].second)] = std::get<1>(payload); return acc; }); @@ -394,11 +364,11 @@ auto union_lookup( auto lookup_cursors = [&]() { std::vector> lookup_cursors; - auto pos = query.terms.begin(); + auto pos = term_ids.begin(); for (auto non_essential_term : non_essential_terms) { - pos = std::find(pos, query.terms.end(), non_essential_term); - assert(pos != query.terms.end()); - auto idx = std::distance(query.terms.begin(), pos); + pos = std::find(pos, term_ids.end(), non_essential_term); + assert(pos != term_ids.end()); + auto idx = std::distance(term_ids.begin(), pos); lookup_cursors.emplace_back(idx, index.max_scored_cursor(non_essential_term, scorer)); } return lookup_cursors; @@ -411,7 +381,7 @@ auto union_lookup( return acc + cursor.second.max_score(); }); - v1::for_each(merged, [&](auto& cursor) { + v1::for_each(merged_unigrams, [&](auto& cursor) { if constexpr (not std::is_void_v) { analyzer->document(); } diff --git a/src/compute_intersection.cpp b/src/compute_intersection.cpp index d137184bc..74065084c 100644 --- a/src/compute_intersection.cpp +++ b/src/compute_intersection.cpp @@ -42,9 +42,16 @@ void intersect(std::string const& index_filename, mapper::map(wdata, md, mapper::map_flags::warmup); } - std::size_t qid = 0u; + std::size_t qid = 0U; auto print_intersection = [&](auto const& query, auto const& mask) { + // FIXME(michal): Quick workaround to not compute intersection (a, a) + auto filtered = intersection::filter(query, mask); + std::sort(filtered.terms.begin(), filtered.terms.end()); + if (std::unique(filtered.terms.begin(), filtered.terms.end()) != filtered.terms.end()) { + // Do not compute: contains duplicates + return; + } auto intersection = Intersection::compute(index, wdata, query, mask); if (intersection.length > 0) { std::cout << fmt::format("{}\t{}\t{}\t{}\n", diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index 757062560..8ba9469e0 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -13,14 +13,14 @@ auto collect_unique_bigrams(std::vector const& queries, std::vector> bigrams; auto idx = 0; for (auto query : queries) { - if (query.terms.empty()) { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { continue; } callback(); - std::sort(query.terms.begin(), query.terms.end()); - for (auto left = 0; left < query.terms.size() - 1; left += 1) { - for (auto right = left + 1; right < query.terms.size(); right += 1) { - bigrams.emplace_back(query.terms[left], query.terms[right]); + for (auto left = 0; left < term_ids.size() - 1; left += 1) { + for (auto right = left + 1; right < term_ids.size(); right += 1) { + bigrams.emplace_back(term_ids[left], term_ids[right]); } } } diff --git a/src/v1/query.cpp b/src/v1/query.cpp index ec85b0fa0..170574270 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -1,12 +1,15 @@ +#include #include #include "v1/query.hpp" namespace pisa::v1 { +using json = nlohmann::json; + [[nodiscard]] auto filter_unique_terms(Query const& query) -> std::vector { - auto terms = query.terms; + auto terms = query.get_term_ids(); std::sort(terms.begin(), terms.end()); terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); return terms; @@ -14,39 +17,161 @@ namespace pisa::v1 { void Query::add_selections(gsl::span const> selections) { - list_selection = ListSelection{}; + m_selections = ListSelection{}; for (auto intersection : selections) { if (intersection.count() > 2) { throw std::invalid_argument("Intersections of more than 2 terms not supported yet!"); } auto positions = to_vector(intersection); if (positions.size() == 1) { - list_selection->unigrams.push_back(resolve_term(positions.front())); + m_selections->unigrams.push_back(resolve_term(positions.front())); } else { - list_selection->bigrams.emplace_back(resolve_term(positions[0]), - resolve_term(positions[1])); + m_selections->bigrams.emplace_back(resolve_term(positions[0]), + resolve_term(positions[1])); + auto& added = m_selections->bigrams.back(); + if (added.first > added.second) { + std::swap(added.first, added.second); + } } } } -void Query::remove_duplicates() +auto Query::resolve_term(std::size_t pos) -> TermId { - ranges::sort(terms); - ranges::actions::unique(terms); - if (list_selection) { - ranges::sort(list_selection->unigrams); - ranges::actions::unique(list_selection->unigrams); - ranges::sort(list_selection->bigrams); - ranges::actions::unique(list_selection->bigrams); + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); } + return m_term_ids->term_at_pos(pos); } -auto Query::resolve_term(std::size_t pos) -> TermId +template +[[nodiscard]] auto get(json const& node, std::string_view field) -> tl::optional +{ + if (auto pos = node.find(field); pos != node.end()) { + return tl::make_optional(pos->get()); + } + return tl::optional{}; +} + +[[nodiscard]] auto Query::from_json(std::string_view json_string) -> Query +{ + // auto query_json = json::parse(json_string); + // auto terms = get>(query_json, "terms"); + // auto term_ids = get>(query_json, "term_ids"); + Query query; + // query.m_raw_string = get(query_json, "query"); + return query; +} + +Query::Query(std::vector term_ids, tl::optional id) + : m_term_ids(std::move(term_ids)), m_id(std::move(id)) +{ +} + +Query::Query(std::string query, tl::optional id) + : m_raw_string(std::move(query)), m_id(std::move(id)) +{ +} + +auto Query::term_ids(std::vector term_ids) -> Query& +{ + m_term_ids = TermIdSet(std::move(term_ids)); + return *this; +} + +auto Query::id(std::string id) -> Query& +{ + m_id = std::move(id); + return *this; +} + +auto Query::k(int k) -> Query& +{ + m_k = k; + return *this; +} + +auto Query::selections(gsl::span const> selections) -> Query& +{ + add_selections(selections); + return *this; +} + +auto Query::selections(ListSelection selections) -> Query& +{ + m_selections = std::move(selections); + return *this; +} + +auto Query::threshold(float threshold) -> Query& +{ + m_threshold = threshold; + return *this; +} + +auto Query::term_ids() const -> tl::optional const&> +{ + return m_term_ids.map( + [](auto const& terms) -> std::vector const& { return terms.get(); }); +} +auto Query::id() const -> tl::optional const& { return m_id; } +auto Query::k() const -> int { return m_k; } +auto Query::selections() const -> tl::optional +{ + if (m_selections) { + return *m_selections; + } + return tl::nullopt; +} +auto Query::threshold() const -> tl::optional { return m_threshold; } + +/// Throwing getters +auto Query::get_term_ids() const -> std::vector const& +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->get(); +} + +auto Query::get_id() const -> std::string const& +{ + if (not m_id) { + throw std::runtime_error("ID is not set"); + } + return *m_id; +} + +auto Query::get_selections() const -> ListSelection const& +{ + if (not m_selections) { + throw std::runtime_error("Selections are not set"); + } + return *m_selections; +} + +auto Query::get_threshold() const -> float +{ + if (not m_threshold) { + throw std::runtime_error("Threshold is not set"); + } + return *m_threshold; +} + +auto Query::sorted_position(TermId term) const -> std::size_t +{ + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); + } + return m_term_ids->sorted_position(term); +} + +auto Query::term_at_pos(std::size_t pos) const -> TermId { - if (pos >= terms.size()) { - throw std::out_of_range("Invalid intersections: term position out of bounds"); + if (not m_term_ids) { + throw std::runtime_error("Term IDs are not set"); } - return terms[pos]; + return m_term_ids->term_at_pos(pos); } } // namespace pisa::v1 diff --git a/test/test_data/top10_selections_unigram b/test/test_data/top10_selections_unigram new file mode 100644 index 000000000..b6551bb17 --- /dev/null +++ b/test/test_data/top10_selections_unigram @@ -0,0 +1,500 @@ +1 +1 +8 16 +1 2 4 +1 +1 +2 +16 32 64 256 +1 2 +2 +1 +1 2 +2 +1 4 +4 8 +1 +1 2 +1 +2 +4 +1 +1 +1 +1 +1 +1 2 +8 +2 +1 2 +1 4 +2 8 +1 8 +2 +1 2 +2 4 +1 +1 2 +1 +2 +2 8 32 128 +4 +1 4 +1 +1 2 +2 4 +2 4 +1 2 4 +1 2 4 +1 4 8 +1 4 +1 2 +2 +1 +1 8 64 +1 +4 +1 4 +1 +1 2 4 +2 +4 8 16 +1 2 +1 2 16 64 +1 2 +1 2 +1 2 4 +1 16 64 128 +1 16 32 +2 4 +4 8 +2 4 +4 8 +1 2 +1 2 +1 2 +1 +8 64 +1 2 +1 2 +8 32 +1 2 +2 4 +2 +1 4 +1 2 +1 32 64 +1 +1 4 +4 8 16 +8 +1 4 16 +1 2 +1 2 4 16 +1 2 8 +2 4 +2 +1 4 +1 2 +1 +1 2 +2 +1 2 +4 32 64 +2 4 +4 +1 +1 2 +4 64 +2 +1 8 +1 2 +1 +1 4 +1 2 +1 2 +1 +1 2 16 128 +4 8 +1 2 4 +2 4 +1 2 +1 2 8 +2 8 +1 2 4 +1 2 +1 2 +2 4 8 +4 64 128 +1 2 +1 +1 2 +1 2 +1 2 +4 +1 4 +1 2 4 +1 2 +2 +1 2 +2 +1 4 +2 8 +1 4 8 +1 +2 4 +4 8 +1 2 +1 4 +1 +1 2 +2 8 +1 +2 4 +1 4 +1 2 +1 2 +4 8 +1 2 +2 4 +2 4 +1 4 +1 4 +1 2 +4 8 +1 64 128 +1 +2 4 +1 2 4 +1 4 +2 8 +1 +1 2 +1 8 16 +1 +1 2 +1 2 +2 +1 2 +1 +2 +1 +1 +1 2 +4 16 +16 32 +1 4 +4 8 16 32 +1 8 +1 2 +1 +1 2 +1 8 +1 2 +1 +4 +1 +1 2 +1 2 +1 +1 2 +1 2 4 +1 2 +1 2 +1 2 +1 4 +1 4 +2 8 16 +1 2 8 16 +1 4 +1 2 +1 4 32 +1 2 +1 2 +2 +1 4 +8 +1 +1 2 +1 2 +1 +1 +2 4 8 +1 2 +1 2 +1 2 +1 8 +1 2 4 +1 2 16 +4 +1 4 +2 +1 2 +1 +1 2 +1 +1 4 32 +1 2 +1 +4 +2 4 8 32 +1 +1 +1 4 8 +1 +2 +1 2 +1 2 4 +2 +2 8 +1 +1 4 +2 +1 +2 +1 +1 +1 +4 +2 +2 4 +1 2 4 +1 +1 2 +16 +1 4 +2 +1 +1 2 +1 +1 2 +1 2 4 +1 +1 16 +2 +1 2 +1 4 8 +2 +1 2 +16 +1 +1 +2 +16 +1 4 8 +1 8 +1 4 +4 8 +1 2 4 +1 +1 8 +1 8 16 +4 +1 2 +1 2 +1 2 +1 +1 +1 2 8 +1 4 +1 2 8 +1 2 4 16 +2 +2 +1 2 +1 4 +1 2 4 8 +1 2 +1 4 +1 +1 4 16 +1 +1 +1 4 +2 8 +1 2 +1 +1 8 +2 4 +1 2 4 +1 2 8 +1 4 +1 2 +1 2 +2 4 16 +1 +1 +1 2 +2 +2 +1 +1 4 +1 2 +4 +8 16 64 128 +2 +1 4 +1 2 4 +1 2 8 +2 +2 +1 4 +2 4 +2 +1 2 4 +1 2 4 +1 2 +1 2 +1 +1 +1 +2 +1 2 +1 +1 +1 2 +1 +1 +1 2 +1 +1 2 +16 32 +1 8 +1 4 +1 2 +1 +2 8 16 +2 +1 2 +1 2 +1 +1 +1 +1 2 +1 2 +1 2 +1 +2 +1 +1 2 +1 +1 8 16 +2 8 +1 +1 4 +1 4 +1 +1 4 +4 8 16 +2 4 +1 4 64 +2 +2 16 +2 4 16 +1 2 +1 2 8 +2 4 8 +1 +1 2 +1 +4 8 16 64 +2 +1 2 4 +2 16 64 +1 +2 +1 4 16 64 +1 +1 +1 +1 +1 2 +2 +16 32 +1 4 64 +1 +1 2 +2 4 +2 +1 2 +1 +1 4 8 +2 +1 2 4 +4 8 64 +1 2 +1 16 32 +1 8 +1 4 +2 +1 +1 2 +4 +1 2 +2 4 8 +1 32 +1 2 +1 +1 2 +2 +2 4 +8 32 256 +1 +4 8 +4 8 16 +1 2 +1 16 +1 2 +1 2 +2 4 16 +1 2 +1 +1 2 4 16 32 +8 +2 8 +1 2 4 +1 +1 8 +4 8 +1 2 4 32 +1 2 +2 +1 4 +1 2 4 +4 8 +1 4 +1 4 +1 2 +1 2 +1 2 +1 2 +1 2 +2 4 +1 +1 2 +1 +2 +1 +8 16 32 +2 +1 +1 2 8 +1 2 +2 +4 +1 2 +2 +1 2 +1 +4 +1 4 +1 2 +1 2 +1 +1 +2 +1 +1 +1 2 8 +1 +1 diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index efc6cdf40..753014c98 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -20,8 +20,9 @@ namespace v1 = pisa::v1; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { auto q = pisa::parse_query_ids(query_line); - queries.push_back( - v1::Query{.terms = q.terms, .list_selection = {}, .threshold = {}, .id = {}, .k = 10}); + v1::Query query(q.terms); + query.k(1000); + queries.push_back(std::move(query)); }; pisa::io::for_each_line(qfile, push_query); return queries; @@ -59,17 +60,7 @@ struct IndexFixture { index_basename); REQUIRE(errors.empty()); auto yml = fmt::format("{}.yml", index_basename); - - auto queries = [&]() { - std::vector queries; - auto qs = test_queries(); - int idx = 0; - std::transform(qs.begin(), qs.end(), std::back_inserter(queries), [&](auto q) { - return v1::Query{q.terms}; - }); - return queries; - }(); - v1::build_bigram_index(yml, collect_unique_bigrams(queries, []() {})); + v1::build_bigram_index(yml, collect_unique_bigrams(test_queries(), []() {})); v1::score_index(yml, 1); } diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp index e2028bb43..6b982b173 100644 --- a/test/v1/test_v1_bigram_index.cpp +++ b/test/v1/test_v1_bigram_index.cpp @@ -54,19 +54,17 @@ TEMPLATE_TEST_CASE("Bigram v intersection", auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); int idx = 0; for (auto& q : test_queries()) { - std::sort(q.terms.begin(), q.terms.end()); - q.terms.erase(std::unique(q.terms.begin(), q.terms.end()), q.terms.end()); - CAPTURE(q.terms); + CAPTURE(q.get_term_ids()); CAPTURE(idx++); auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { auto scorer = make_bm25(index); - for (auto left = 0; left < q.terms.size(); left += 1) { - for (auto right = left + 1; right < q.terms.size(); right += 1) { - auto left_cursor = index.cursor(q.terms[left]); - auto right_cursor = index.cursor(q.terms[right]); + for (auto left = 0; left < q.get_term_ids().size(); left += 1) { + for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { + auto left_cursor = index.cursor(q.get_term_ids()[left]); + auto right_cursor = index.cursor(q.get_term_ids()[right]); auto intersection = v1::intersect({left_cursor, right_cursor}, std::array{0, 0}, [](auto& acc, auto&& cursor, auto idx) { @@ -74,7 +72,8 @@ TEMPLATE_TEST_CASE("Bigram v intersection", return acc; }); if (not intersection.empty()) { - auto bigram_cursor = *index.bigram_cursor(q.terms[left], q.terms[right]); + auto bigram_cursor = + *index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]); std::vector bigram_documents; std::vector bigram_frequencies_0; std::vector bigram_frequencies_1; diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp index 072d23b7e..cba225954 100644 --- a/test/v1/test_v1_maxscore_join.cpp +++ b/test/v1/test_v1_maxscore_join.cpp @@ -63,16 +63,18 @@ TEMPLATE_TEST_CASE("", //}; int idx = 0; for (auto& q : test_queries()) { - CAPTURE(q.terms); + CAPTURE(q.get_term_ids()); CAPTURE(idx++); auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); run([&](auto&& index) { auto union_results = collect(v1::union_merge( - index.scored_cursors(gsl::make_span(q.terms), make_bm25(index)), 0.0F, Add{})); + index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), + 0.0F, + Add{})); auto maxscore_results = collect(v1::join_maxscore( - index.max_scored_cursors(gsl::make_span(q.terms), make_bm25(index)), + index.max_scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), 0.0F, Add{}, [](auto /* score */) { return true; })); @@ -81,13 +83,15 @@ TEMPLATE_TEST_CASE("", run([&](auto&& index) { auto union_results = collect_with_payload(v1::union_merge( - index.scored_cursors(gsl::make_span(q.terms), make_bm25(index)), 0.0F, Add{})); + index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), + 0.0F, + Add{})); union_results.erase(std::remove_if(union_results.begin(), union_results.end(), [](auto score) { return score.second <= 5.0F; }), union_results.end()); auto maxscore_results = collect_with_payload(v1::join_maxscore( - index.max_scored_cursors(gsl::make_span(q.terms), make_bm25(index)), + index.max_scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), 0.0F, Add{}, [](auto score) { return score > 5.0F; })); diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index ccacf0a78..264dbf09f 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -23,6 +23,8 @@ #include "v1/cursor_intersection.hpp" #include "v1/cursor_traits.hpp" #include "v1/cursor_union.hpp" +#include "v1/daat_and.hpp" +#include "v1/daat_or.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/maxscore.hpp" @@ -31,6 +33,7 @@ #include "v1/query.hpp" #include "v1/score_index.hpp" #include "v1/scorer/bm25.hpp" +#include "v1/taat_or.hpp" #include "v1/types.hpp" #include "v1/union_lookup.hpp" @@ -149,12 +152,11 @@ TEMPLATE_TEST_CASE("Query", return maxscore_union_lookup(query, index, topk_queue(10), scorer); } if (name == "unigram_union_lookup") { - query.list_selection = - tl::make_optional(v1::ListSelection{.unigrams = query.terms, .bigrams = {}}); + query.selections(v1::ListSelection{.unigrams = query.get_term_ids(), .bigrams = {}}); return unigram_union_lookup(query, index, topk_queue(10), scorer); } if (name == "union_lookup") { - if (query.terms.size() > 8) { + if (query.get_term_ids().size() > 8) { return maxscore_union_lookup(query, index, topk_queue(10), scorer); } return union_lookup(query, index, topk_queue(10), scorer); @@ -166,19 +168,19 @@ TEMPLATE_TEST_CASE("Query", pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); for (auto& query : test_queries()) { if (algorithm == "union_lookup") { - query.add_selections(gsl::make_span(intersections[idx])); + query.selections(gsl::make_span(intersections[idx])); } - query.remove_duplicates(); - CAPTURE(query.terms); + CAPTURE(query); CAPTURE(idx); CAPTURE(intersections[idx]); - or_q(make_scored_cursors(data->v0_index, data->wdata, ::pisa::Query{{}, query.terms, {}}), + or_q(make_scored_cursors( + data->v0_index, data->wdata, ::pisa::Query{{}, query.get_term_ids(), {}}), data->v0_index.num_docs()); auto expected = or_q.topk(); if (with_threshold) { - query.threshold = expected.back().first - 1.0F; + query.threshold(expected.back().first - 1.0F); } auto on_the_fly = [&]() { diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index 3fefab289..ebac433fe 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -12,3 +12,6 @@ target_link_libraries(score pisa CLI11) add_executable(bigram-index bigram_index.cpp) target_link_libraries(bigram-index pisa CLI11) + +add_executable(filter-queries filter_queries.cpp) +target_link_libraries(filter-queries pisa CLI11) diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index 7e4c0aa7a..168e9a3ce 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -67,19 +67,15 @@ int main(int argc, char** argv) } std::vector v1_queries; v1_queries.reserve(queries.size()); - for (auto query : queries) { - if (not query.terms.empty()) { - v1_queries.push_back(Query{.terms = query.terms, - .list_selection = {}, - .threshold = {}, - .id = - [&]() { - if (query.id) { - return tl::make_optional(*query.id); - } - return tl::optional{}; - }(), - .k = 10}); + for (auto q : queries) { + if (not q.terms.empty()) { + Query query(q.terms, [&]() { + if (q.id) { + return tl::make_optional(*q.id); + } + return tl::optional{}; + }()); + v1_queries.push_back(query); } } return v1_queries; diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp new file mode 100644 index 000000000..fedb44194 --- /dev/null +++ b/v1/filter_queries.cpp @@ -0,0 +1,74 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "query/queries.hpp" +#include "timer.hpp" +#include "topk_queue.hpp" +#include "v1/analyze_query.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/daat_or.hpp" +#include "v1/index_metadata.hpp" +#include "v1/intersection.hpp" +#include "v1/maxscore.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" +#include "v1/union_lookup.hpp" + +using pisa::resolve_query_parser; +using pisa::TermProcessor; +using pisa::v1::BlockedReader; +using pisa::v1::daat_or; +using pisa::v1::DaatOrAnalyzer; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::ListSelection; +using pisa::v1::maxscore_union_lookup; +using pisa::v1::MaxscoreAnalyzer; +using pisa::v1::MaxscoreUnionLookupAnalyzer; +using pisa::v1::Query; +using pisa::v1::QueryAnalyzer; +using pisa::v1::RawReader; +using pisa::v1::resolve_yml; +using pisa::v1::unigram_union_lookup; +using pisa::v1::UnigramUnionLookupAnalyzer; +using pisa::v1::union_lookup; +using pisa::v1::UnionLookupAnalyzer; +using pisa::v1::VoidScorer; + +int main(int argc, char** argv) +{ + pisa::QueryApp app("Filters out empty queries against a v1 index."); + CLI11_PARSE(app, argc, argv); + + auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); + auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + app.terms_file = meta.term_lexicon.value(); + } + + auto term_processor = TermProcessor(app.terms_file, {}, stemmer); + auto filter = [&](auto&& line) { + auto query = parse_query_terms(line, term_processor); + if (not query.terms.empty()) { + std::cout << line << '\n'; + } + }; + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, filter); + } else { + pisa::io::for_each_line(std::cin, filter); + } + return 0; +} diff --git a/v1/query.cpp b/v1/query.cpp index 5f949482e..d2079101d 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -14,6 +14,7 @@ #include "topk_queue.hpp" #include "v1/analyze_query.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/daat_or.hpp" #include "v1/index_metadata.hpp" #include "v1/intersection.hpp" #include "v1/maxscore.hpp" @@ -57,8 +58,8 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco } if (name == "maxscore") { return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.threshold) { - topk.set_threshold(*query.threshold); + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); } return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); }); @@ -77,10 +78,14 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco } if (name == "union-lookup") { return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.list_selection && query.list_selection->bigrams.empty()) { - return pisa::v1::unigram_union_lookup( + if (query.get_term_ids().size() > 8) { + return pisa::v1::maxscore_union_lookup( query, index, std::move(topk), std::forward(scorer)); } + // if (query.selections() && query.selections()->bigrams.empty()) { + // return pisa::v1::unigram_union_lookup( + // query, index, std::move(topk), std::forward(scorer)); + //} return pisa::v1::union_lookup( query, index, std::move(topk), std::forward(scorer)); }); @@ -125,7 +130,7 @@ void evaluate(std::vector const& queries, auto rank = 0; for (auto result : que.topk()) { std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", - query.id.value_or(std::to_string(query_idx)), + query.id().value_or(std::to_string(query_idx)), "Q0", docmap[result.second], rank, @@ -206,18 +211,13 @@ int main(int argc, char** argv) pisa::io::for_each_line(std::cin, parse_query); } std::vector v1_queries(queries.size()); - std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& query) { - return Query{.terms = query.terms, - .list_selection = {}, - .threshold = {}, - .id = - [&]() { - if (query.id) { - return tl::make_optional(*query.id); - } - return tl::optional{}; - }(), - .k = app.k}; + std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { + Query query(parsed.terms); + if (parsed.id) { + query.id(*parsed.id); + } + query.k(app.k); + return query; }); return v1_queries; }(); @@ -237,7 +237,7 @@ int main(int argc, char** argv) spdlog::error("Number of thresholds not equal to number of queries"); std::exit(1); } - queries_iter->threshold = tl::make_optional(std::stof(line)); + queries_iter->threshold(std::stof(line)); ++queries_iter; }); if (queries_iter != queries.end()) { From db2862764b4bd94a9565e4c7b21aedd45ea7e767 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 29 Nov 2019 12:05:43 -0500 Subject: [PATCH 27/56] Add scripts --- script/cw09b.sh | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ script/robust.sh | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 script/cw09b.sh create mode 100644 script/robust.sh diff --git a/script/cw09b.sh b/script/cw09b.sh new file mode 100644 index 000000000..6c76c2d5f --- /dev/null +++ b/script/cw09b.sh @@ -0,0 +1,65 @@ +PISA_BIN="/home/michal/work/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/mnt/michal/work/cw09b/inv" +FWD="/mnt/michal/work/cw09b/fwd" +INV="/mnt/michal/work/cw09b/inv" +BASENAME="/mnt/michal/work/v1/cw09b/cw09b" +THREADS=16 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +K=1000 +OUTPUT_DIR=`pwd` +FILTERED_QUERIES="${OUTPUT_DIR}/filtered_queries" + +set -x +set -e + +## Compress an inverted index in `binary_freq_collection` format. +#./bin/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} +# +## This will produce both quantized scores and max scores (both quantized and not). +#./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} +# +## This will produce both quantized scores and max scores (both quantized and not). +#./bin/score -i "${BASENAME}.yml" -j ${THREADS} + +# Filter out queries witout existing terms. +${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ + > ${FILTERED_QUERIES} + +# Extract thresholds (TODO: estimates) +${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/thresholds +cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv + +# Extract intersections +${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.tsv + +# Select unigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 1 > ${OUTPUT_DIR}/selections.1 + +# Select unigrams and bigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 > ${OUTPUT_DIR}/selections.2 + +# Run benchmarks +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ +# --thresholds ${OUTPUT_DIR}/thresholds +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ +# --thresholds ${OUTPUT_DIR}/thresholds +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ +# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm two-phase-union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 diff --git a/script/robust.sh b/script/robust.sh new file mode 100644 index 000000000..4b7cd3ee1 --- /dev/null +++ b/script/robust.sh @@ -0,0 +1,63 @@ +PISA_BIN="/home/michal/work/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/mnt/michal/work/robust/inv" +FWD="/mnt/michal/work/robust/fwd" +INV="/mnt/michal/work/robust/inv" +BASENAME="/mnt/michal/work/v1/robust/robust" +THREADS=16 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +K=1000 +OUTPUT_DIR=`pwd` + +set -x +set -e + +# Compress an inverted index in `binary_freq_collection` format. +${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +## This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} +# +## This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# Filter out queries witout existing terms. +FILTERED_QUERIES="${OUTPUT_DIR}/filtered_queries" +${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ + > ${FILTERED_QUERIES} + +# Extract thresholds (TODO: estimates) +${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/thresholds +cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv + +# Extract intersections +${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.tsv + +# Select unigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 1 > ${OUTPUT_DIR}/selections.1 + +# Select unigrams and bigrams +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 > ${OUTPUT_DIR}/selections.2 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + --thresholds ${OUTPUT_DIR}/thresholds +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 From a49dc0de4f9c12062cc54e8970ed14f4e35552f0 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 29 Nov 2019 12:11:33 -0500 Subject: [PATCH 28/56] Two-phase union-lookup --- include/pisa/v1/cursor_union.hpp | 105 ++++++++++++++++++- include/pisa/v1/union_lookup.hpp | 169 +++++++++++++++++++++++++++++++ v1/query.cpp | 15 +++ 3 files changed, 284 insertions(+), 5 deletions(-) diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index adeb28097..e490a19c2 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -113,8 +113,8 @@ struct CursorUnion { }; /// Transforms a list of cursors into one cursor by lazily merging them together. -//template -//struct CursorFlatUnion { +// template +// struct CursorFlatUnion { // using Cursor = typename CursorContainer::value_type; // using iterator_category = // typename std::iterator_traits::iterator_category; @@ -196,15 +196,15 @@ struct VariadicCursorUnion { m_size(std::nullopt) { m_next_docid = std::numeric_limits::max(); - m_sentinel = std::numeric_limits::max(); + m_sentinel = std::numeric_limits::min(); for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { if (cursor.value() < m_next_docid) { m_next_docid = cursor.value(); } }); for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { - if (cursor.sentinel() < m_next_docid) { - m_next_docid = cursor.sentinel(); + if (cursor.sentinel() > m_sentinel) { + m_sentinel = cursor.sentinel(); } }); advance(); @@ -269,6 +269,101 @@ struct VariadicCursorUnion { std::uint32_t m_next_docid{}; }; +///// Transforms a list of cursors into one cursor by lazily merging them together. +// template +// struct PairCursorUnion { +// using Value = std::pair())>, +// std::decay_t())>>; +// +// constexpr PairCursorUnion(Payload init, +// Cursor1 cursor1, +// Cursor2 cursor2, +// AccumulateFn1 accumulate1, +// AccumulateFn2 accumulate2) +// : m_cursors(std::move(cursors)), +// m_init(std::move(init)), +// m_accumulate(std::move(accumulate)), +// m_size(std::nullopt) +// { +// m_next_docid = std::numeric_limits::max(); +// m_sentinel = std::numeric_limits::max(); +// for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { +// if (cursor.value() < m_next_docid) { +// m_next_docid = cursor.value(); +// } +// }); +// for_each_cursor([&](auto&& cursor, [[maybe_unused]] auto&& fn) { +// std::cerr << fmt::format("Sentinel: {} v. current {}\n", cursor.sentinel(), +// m_sentinel); if (cursor.sentinel() < m_sentinel) { +// m_sentinel = cursor.sentinel(); +// } +// }); +// advance(); +// } +// +// template +// void for_each_cursor(Fn&& fn) +// { +// std::apply( +// [&](auto&&... cursor) { +// std::apply([&](auto&&... accumulate) { (fn(cursor, accumulate), ...); }, +// m_accumulate); +// }, +// m_cursors); +// } +// +// [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } +// [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& +// { +// return m_current_payload; +// } +// [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } +// +// constexpr void advance() +// { +// if (PISA_UNLIKELY(m_next_docid == m_sentinel)) { +// m_current_value = m_sentinel; +// m_current_payload = m_init; +// } else { +// m_current_payload = m_init; +// m_current_value = m_next_docid; +// m_next_docid = m_sentinel; +// std::size_t cursor_idx = 0; +// for_each_cursor([&](auto&& cursor, auto&& accumulate) { +// if (cursor.value() == m_current_value) { +// m_current_payload = accumulate(m_current_payload, cursor, cursor_idx); +// cursor.advance(); +// } +// if (cursor.value() < m_next_docid) { +// m_next_docid = cursor.value(); +// } +// ++cursor_idx; +// }); +// } +// } +// +// [[nodiscard]] constexpr auto empty() const noexcept -> bool +// { +// return m_current_value >= sentinel(); +// } +// +// private: +// CursorsTuple m_cursors; +// Payload m_init; +// std::tuple m_accumulate; +// std::optional m_size; +// +// Value m_current_value{}; +// Value m_sentinel{}; +// Payload m_current_payload{}; +// std::uint32_t m_next_docid{}; +//}; + template [[nodiscard]] constexpr inline auto union_merge(CursorContainer cursors, Payload init, diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 8c7245f3c..d5ae14825 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -265,6 +265,22 @@ struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { } }; +template +struct TwoPhaseUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { + TwoPhaseUnionLookupAnalyzer(Index const& index, Scorer scorer) + : BaseUnionLookupAnalyzer(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + two_phase_union_lookup(query, index, std::move(topk), scorer, this); + } + } +}; + /// Performs a "union-lookup" query (name pending). /// /// \param query Full query, as received, possibly with duplicates. @@ -381,6 +397,128 @@ auto union_lookup(Query const& query, return acc + cursor.second.max_score(); }); + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + analyzer->document(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + auto upper_bound = score + lookup_cursors_upper_bound; + for (auto& [idx, lookup_cursor] : lookup_cursors) { + if (not topk.would_enter(upper_bound)) { + return; + } + if (scores[idx] == 0) { + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + analyzer->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + } + upper_bound -= lookup_cursor.max_score(); + } + topk.insert(score, docid); + if constexpr (not std::is_void_v) { + analyzer->insert(); + } + }); + return topk; +} + +template +auto two_phase_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Analyzer* analyzer = nullptr) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + if (term_ids.size() > 8) { + throw std::invalid_argument( + "Generic version of union-Lookup supported only for queries of length <= 8"); + } + + auto threshold = query.get_threshold(); + auto const& selections = query.get_selections(); + + using bigram_cursor_type = std::decay_t; + + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + topk.set_threshold(threshold); + + std::array initial_payload{ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + std::vector essential_unigram_cursors; + std::transform(essential_unigrams.begin(), + essential_unigrams.end(), + std::back_inserter(essential_unigram_cursors), + [&](auto term) { return index.scored_cursor(term, scorer); }); + + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + acc[query.sorted_position(essential_unigrams[term_idx])] = cursor.payload(); + return acc; + }); + + std::vector essential_bigram_cursors; + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back(cursor.take().value()); + } + + auto merged_bigrams = v1::union_merge( + std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto bigram_idx) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + auto payload = cursor.payload(); + acc[query.sorted_position(essential_bigrams[bigram_idx].first)] = std::get<0>(payload); + acc[query.sorted_position(essential_bigrams[bigram_idx].second)] = std::get<1>(payload); + return acc; + }); + + auto lookup_cursors = [&]() { + std::vector> + lookup_cursors; + auto pos = term_ids.begin(); + for (auto non_essential_term : non_essential_terms) { + pos = std::find(pos, term_ids.end(), non_essential_term); + assert(pos != term_ids.end()); + auto idx = std::distance(term_ids.begin(), pos); + lookup_cursors.emplace_back(idx, index.max_scored_cursor(non_essential_term, scorer)); + } + return lookup_cursors; + }(); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.second.max_score() > rhs.second.max_score(); + }); + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.second.max_score(); + }); + v1::for_each(merged_unigrams, [&](auto& cursor) { if constexpr (not std::is_void_v) { analyzer->document(); @@ -411,6 +549,37 @@ auto union_lookup(Query const& query, analyzer->insert(); } }); + + v1::for_each(merged_bigrams, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + analyzer->document(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + auto upper_bound = score + lookup_cursors_upper_bound; + for (auto& [idx, lookup_cursor] : lookup_cursors) { + if (not topk.would_enter(upper_bound)) { + return; + } + if (scores[idx] == 0) { + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + analyzer->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + } + upper_bound -= lookup_cursor.max_score(); + } + topk.insert(score, docid); + if constexpr (not std::is_void_v) { + analyzer->insert(); + } + }); return topk; } diff --git a/v1/query.cpp b/v1/query.cpp index d2079101d..3a8457233 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -39,6 +39,7 @@ using pisa::v1::Query; using pisa::v1::QueryAnalyzer; using pisa::v1::RawReader; using pisa::v1::resolve_yml; +using pisa::v1::TwoPhaseUnionLookupAnalyzer; using pisa::v1::unigram_union_lookup; using pisa::v1::UnigramUnionLookupAnalyzer; using pisa::v1::union_lookup; @@ -90,6 +91,16 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco query, index, std::move(topk), std::forward(scorer)); }); } + if (name == "two-phase-union-lookup") { + return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.get_term_ids().size() > 8) { + return pisa::v1::maxscore_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::two_phase_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } @@ -114,6 +125,10 @@ auto resolve_analyze(std::string const& name, Index const& index, Scorer&& score if (name == "union-lookup") { return QueryAnalyzer(UnionLookupAnalyzer>(index, scorer)); } + if (name == "two-phase-union-lookup") { + return QueryAnalyzer( + TwoPhaseUnionLookupAnalyzer>(index, scorer)); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } From 37edd687ee620faab177ab3fdc817315c12cdd96 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 2 Dec 2019 19:11:12 +0000 Subject: [PATCH 29/56] Precomputed scores for bigram index --- include/pisa/query/queries.hpp | 4 +- include/pisa/v1/cursor/for_each.hpp | 2 +- include/pisa/v1/cursor_union.hpp | 14 +- include/pisa/v1/default_index_runner.hpp | 12 + include/pisa/v1/index.hpp | 50 +- include/pisa/v1/index_metadata.hpp | 31 +- include/pisa/v1/maxscore.hpp | 65 +-- include/pisa/v1/scorer/runner.hpp | 14 +- include/pisa/v1/union_lookup.hpp | 619 +++++++++++++++++++---- script/cw09b.sh | 144 ++++-- src/CMakeLists.txt | 28 +- src/thresholds.cpp | 51 +- src/v1/index_builder.cpp | 68 ++- src/v1/index_metadata.cpp | 11 + src/v1/score_index.cpp | 2 +- test/v1/test_v1_bigram_index.cpp | 1 - test/v1/test_v1_queries.cpp | 28 +- v1/query.cpp | 198 ++++---- 18 files changed, 959 insertions(+), 383 deletions(-) create mode 100644 include/pisa/v1/default_index_runner.hpp diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index a0e22582e..ed580d524 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -51,10 +51,10 @@ namespace pisa { if (!term_processor.is_stopword(*term)) { parsed_query.push_back(std::move(*term)); } else { - spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); + //spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); } } else { - spdlog::warn("Term `{}` not found and will be ignored", raw_term); + //spdlog::warn("Term `{}` not found and will be ignored", raw_term); } } return {std::move(id), std::move(parsed_query), {}}; diff --git a/include/pisa/v1/cursor/for_each.hpp b/include/pisa/v1/cursor/for_each.hpp index b05dec9f1..170b9c2f0 100644 --- a/include/pisa/v1/cursor/for_each.hpp +++ b/include/pisa/v1/cursor/for_each.hpp @@ -8,7 +8,7 @@ template void for_each(Cursor &&cursor, UnaryOp op) { while (not cursor.empty()) { - op(cursor); + op(std::forward(cursor)); cursor.advance(); } } diff --git a/include/pisa/v1/cursor_union.hpp b/include/pisa/v1/cursor_union.hpp index e490a19c2..81fd9d303 100644 --- a/include/pisa/v1/cursor_union.hpp +++ b/include/pisa/v1/cursor_union.hpp @@ -50,16 +50,6 @@ struct CursorUnion { } } - [[nodiscard]] constexpr auto size() const noexcept -> std::size_t - { - if (!m_size) { - m_size = std::accumulate(m_cursors.begin(), - m_cursors.end(), - std::size_t(0), - [](auto acc, auto const& elem) { return acc + elem.size(); }); - } - return *m_size; - } [[nodiscard]] constexpr auto operator*() const noexcept -> Value { return m_current_value; } [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_current_value; } [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& @@ -83,8 +73,8 @@ struct CursorUnion { m_current_payload = m_accumulate(m_current_payload, cursor, cursor_idx); cursor.advance(); } - if (cursor.value() < m_next_docid) { - m_next_docid = cursor.value(); + if (auto value = cursor.value(); value < m_next_docid) { + m_next_docid = value; } ++cursor_idx; } diff --git a/include/pisa/v1/default_index_runner.hpp b/include/pisa/v1/default_index_runner.hpp new file mode 100644 index 000000000..eaefaf027 --- /dev/null +++ b/include/pisa/v1/default_index_runner.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "v1/index.hpp" + +namespace pisa::v1 { + +using DefaultIndexRunner = IndexRunner{}, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}>; +} diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 26dff12e9..23eadc3cc 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -114,12 +114,10 @@ struct Index { m_max_scores(std::move(max_scores)), m_max_quantized_scores(quantized_max_scores), m_bigram_mapping(bigram_mapping), - m_source(std::move(source)) - { - } + m_source(std::move(source)){} - /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). - [[nodiscard]] auto cursor(TermId term) const + /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). + [[nodiscard]] auto cursor(TermId term) const { return DocumentPayloadCursor(documents(term), payloads(term)); @@ -439,7 +437,7 @@ auto score_index(Index const& index, std::for_each(boost::counting_iterator(0), boost::counting_iterator(index.num_terms()), [&](auto term) { - for_each(index.scoring_cursor(term, scorer), [&](auto& cursor) { + for_each(index.scoring_cursor(term, scorer), [&](auto&& cursor) { score_builder.accumulate(quantizer(cursor.payload())); }); score_builder.flush_segment(os); @@ -593,27 +591,25 @@ struct IndexRunner { && std::decay_t::encoding() == pheader.encoding && is_type::value_type>(dheader.type) && is_type::value_type>(pheader.type)) { - auto index = make_index( - std::forward(dreader), - std::forward(preader), - m_document_offsets, - m_payload_offsets, - m_bigram_document_offsets, - m_bigram_frequency_offsets, - m_documents.subspan(8), - m_payloads.subspan(8), - m_bigram_documents.map([](auto&& bytes) { return bytes.subspan(8); }), - m_bigram_frequencies.map([](auto&& bytes) { - return std::array{std::get<0>(bytes).subspan(8), - std::get<1>(bytes).subspan(8)}; - }), - m_document_lengths, - m_avg_document_length, - m_max_scores, - m_max_quantized_scores, - m_bigram_mapping, - false); - fn(index); + fn(make_index(std::forward(dreader), + std::forward(preader), + m_document_offsets, + m_payload_offsets, + m_bigram_document_offsets, + m_bigram_frequency_offsets, + m_documents.subspan(8), + m_payloads.subspan(8), + m_bigram_documents.map([](auto&& bytes) { return bytes.subspan(8); }), + m_bigram_frequencies.map([](auto&& bytes) { + return std::array{std::get<0>(bytes).subspan(8), + std::get<1>(bytes).subspan(8)}; + }), + m_document_lengths, + m_avg_document_length, + m_max_scores, + m_max_quantized_scores, + m_bigram_mapping, + false)); return true; } return false; diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 59364936c..f7d0623d2 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -43,6 +43,7 @@ struct PostingFilePaths { struct BigramMetadata { PostingFilePaths documents; std::pair frequencies; + std::vector> scores{}; std::string mapping; std::size_t count; }; @@ -155,6 +156,26 @@ template auto document_offsets = source_span(source, metadata.documents.offsets); auto score_offsets = source_span(source, metadata.scores.front().offsets); auto document_lengths = source_span(source, metadata.document_lengths_path); + tl::optional> bigram_document_offsets{}; + tl::optional, 2>> bigram_score_offsets{}; + tl::optional> bigram_documents{}; + tl::optional, 2>> bigram_scores{}; + tl::optional const>> bigram_mapping{}; + if (metadata.bigrams) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_score_offsets = { + source_span(source, metadata.bigrams->scores[0].first.offsets), + source_span(source, metadata.bigrams->scores[0].second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_scores = { + source_span(source, metadata.bigrams->scores[0].first.postings), + source_span(source, metadata.bigrams->scores[0].second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + } gsl::span quantized_max_scores; if (not metadata.quantized_max_scores.empty()) { // TODO(michal): support many precomputed scores @@ -164,17 +185,17 @@ template } return IndexRunner(document_offsets, score_offsets, - {}, - {}, + bigram_document_offsets, + bigram_score_offsets, documents, scores, - {}, // TODO(michal): support scored bigrams - {}, + bigram_documents, + bigram_scores, document_lengths, tl::make_optional(metadata.avg_document_length), {}, quantized_max_scores, - {}, + bigram_mapping, std::move(source), std::move(readers)); } diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index ecd86f363..b35b17219 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -32,7 +32,6 @@ struct MaxScoreJoin { AccumulateFn accumulate, ThresholdFn above_threshold) : m_cursors(std::move(cursors)), - m_sorted_cursors(m_cursors.size()), m_cursor_idx(m_cursors.size()), m_upper_bounds(m_cursors.size()), m_init(std::move(init)), @@ -49,7 +48,6 @@ struct MaxScoreJoin { ThresholdFn above_threshold, Analyzer* analyzer) : m_cursors(std::move(cursors)), - m_sorted_cursors(m_cursors.size()), m_cursor_idx(m_cursors.size()), m_upper_bounds(m_cursors.size()), m_init(std::move(init)), @@ -67,21 +65,17 @@ struct MaxScoreJoin { m_current_value = sentinel(); m_current_payload = m_init; } - std::transform(m_cursors.begin(), - m_cursors.end(), - m_sorted_cursors.begin(), - [](auto&& cursor) { return &cursor; }); - std::sort(m_sorted_cursors.begin(), m_sorted_cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs->max_score() < rhs->max_score(); - }); std::iota(m_cursor_idx.begin(), m_cursor_idx.end(), 0); std::sort(m_cursor_idx.begin(), m_cursor_idx.end(), [this](auto&& lhs, auto&& rhs) { return m_cursors[lhs].max_score() < m_cursors[rhs].max_score(); }); + std::sort(m_cursors.begin(), m_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() < rhs.max_score(); + }); - m_upper_bounds[0] = m_sorted_cursors[0]->max_score(); - for (size_t i = 1; i < m_sorted_cursors.size(); ++i) { - m_upper_bounds[i] = m_upper_bounds[i - 1] + m_sorted_cursors[i]->max_score(); + m_upper_bounds[0] = m_cursors[0].max_score(); + for (size_t i = 1; i < m_cursors.size(); ++i) { + m_upper_bounds[i] = m_upper_bounds[i - 1] + m_cursors[i].max_score(); } m_next_docid = min_value(m_cursors); @@ -117,20 +111,19 @@ struct MaxScoreJoin { m_analyzer->document(); } - for (auto sorted_position = m_non_essential_count; - sorted_position < m_sorted_cursors.size(); + for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); sorted_position += 1) { - auto& cursor = m_sorted_cursors[sorted_position]; - if (cursor->value() == m_current_value) { + auto& cursor = m_cursors[sorted_position]; + if (cursor.value() == m_current_value) { if constexpr (not std::is_void_v) { m_analyzer->posting(); } m_current_payload = - m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); - cursor->advance(); + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + cursor.advance(); } - if (auto docid = cursor->value(); docid < m_next_docid) { + if (auto docid = cursor.value(); docid < m_next_docid) { m_next_docid = docid; } } @@ -142,14 +135,14 @@ struct MaxScoreJoin { exit = false; break; } - auto& cursor = m_sorted_cursors[sorted_position]; - cursor->advance_to_geq(m_current_value); + auto& cursor = m_cursors[sorted_position]; + cursor.advance_to_geq(m_current_value); if constexpr (not std::is_void_v) { m_analyzer->lookup(); } - if (cursor->value() == m_current_value) { + if (cursor.value() == m_current_value) { m_current_payload = - m_accumulate(m_current_payload, *cursor, m_cursor_idx[sorted_position]); + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); } } } @@ -168,7 +161,6 @@ struct MaxScoreJoin { private: CursorContainer m_cursors; - std::vector m_sorted_cursors; std::vector m_cursor_idx; std::vector m_upper_bounds; payload_type m_init; @@ -221,12 +213,8 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using value_type = decltype(index.max_scored_cursor(0, scorer).value()); - std::vector cursors; - std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(cursors), [&](auto term) { - return index.max_scored_cursor(term, scorer); - }); - - auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + auto accumulate = [](float& score, auto& cursor, auto /* term_position */) { score += cursor.payload(); return score; }; @@ -270,15 +258,14 @@ struct MaxscoreAnalyzer { topk_queue topk(query.k()); auto initial_threshold = query.threshold().value_or(-1.0); topk.set_threshold(initial_threshold); - auto joined = join_maxscore( - std::move(cursors), - 0.0F, - [&](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }, - [&](auto score) { return topk.would_enter(score); }, - this); + auto joined = join_maxscore(std::move(cursors), + 0.0F, + [&](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }, + [&](auto score) { return topk.would_enter(score); }, + this); v1::for_each(joined, [&](auto& cursor) { if (topk.insert(cursor.payload(), cursor.value())) { inserts += 1; diff --git a/include/pisa/v1/scorer/runner.hpp b/include/pisa/v1/scorer/runner.hpp index e104e3764..c042ec001 100644 --- a/include/pisa/v1/scorer/runner.hpp +++ b/include/pisa/v1/scorer/runner.hpp @@ -12,7 +12,7 @@ namespace pisa::v1 { /// that require on-the-fly scoring. template struct ScorerRunner { - explicit ScorerRunner(Index const &index, Scorers... scorers) + explicit ScorerRunner(Index const& index, Scorers... scorers) : m_index(index), m_scorers(std::move(scorers...)) { } @@ -20,28 +20,28 @@ struct ScorerRunner { template void operator()(std::string_view scorer_name, Fn fn) { - auto run = [&](auto &&scorer) -> bool { + auto run = [&](auto scorer) -> bool { if (std::hash>{}(scorer) == std::hash{}(scorer_name)) { - fn(std::forward>(scorer)); + fn(std::move(scorer)); return true; } return false; }; - bool success = - std::apply([&](Scorers... scorers) { return (run(scorers) || ...); }, m_scorers); + bool success = std::apply( + [&](Scorers... scorers) { return (run(std::move(scorers)) || ...); }, m_scorers); if (not success) { throw std::domain_error(fmt::format("Unknown scorer: {}", scorer_name)); } } private: - Index const &m_index; + Index const& m_index; std::tuple m_scorers; }; template -auto scorer_runner(Index const &index, Scorers... scorers) +auto scorer_runner(Index const& index, Scorers... scorers) { return ScorerRunner(index, std::move(scorers...)); } diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index d5ae14825..59f3ce5ec 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -12,28 +12,226 @@ namespace pisa::v1 { +template +struct UnionLookupJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr UnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_cursor_idx(std::move(cursor_idx)), + m_upper_bounds(std::move(upper_bounds)), + m_non_essential_count(non_essential_count), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + initialize(); + } + + constexpr UnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Analyzer* analyzer) + : m_cursors(std::move(cursors)), + m_cursor_idx(std::move(cursor_idx)), + m_upper_bounds(std::move(upper_bounds)), + m_non_essential_count(non_essential_count), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_analyzer(analyzer) + { + initialize(); + } + + void initialize() + { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() + || m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + if constexpr (not std::is_void_v) { + m_analyzer->document(); + } + + for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); + sorted_position += 1) { + + auto& cursor = m_cursors[sorted_position]; + if (cursor.value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_analyzer->posting(); + } + m_current_payload = + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + cursor.advance(); + } + if (auto docid = cursor.value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; + sorted_position -= 1) { + if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { + exit = false; + break; + } + auto& cursor = m_cursors[sorted_position]; + cursor.advance_to_geq(m_current_value); + if constexpr (not std::is_void_v) { + m_analyzer->lookup(); + } + if (cursor.value() == m_current_value) { + m_current_payload = + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + } + } + } + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorContainer m_cursors; + std::vector m_cursor_idx; + std::vector m_upper_bounds; + std::size_t m_non_essential_count; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + payload_type m_previous_threshold{}; + + Analyzer* m_analyzer; +}; + +template +auto join_union_lookup(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return UnionLookupJoin( + std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + std::move(init), + std::move(accumulate), + std::move(threshold)); +} + +template +auto join_union_lookup(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Analyzer* analyzer) +{ + return UnionLookupJoin( + std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + std::move(init), + std::move(accumulate), + std::move(threshold), + analyzer); +} + namespace detail { template auto unigram_union_lookup(Cursors cursors, UpperBounds upper_bounds, std::size_t non_essential_count, topk_queue topk, - Analyzer* analyzer = nullptr) + [[maybe_unused]] Analyzer* analyzer = nullptr) { - using payload_type = decltype(std::declval().payload()); - auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), 0.0F, - [&](auto acc, auto& cursor, auto /*term_idx*/) { + [&](auto& acc, auto& cursor, auto /*term_idx*/) { if constexpr (not std::is_void_v) { analyzer->posting(); } - return acc + cursor.payload(); + acc += cursor.payload(); + return acc; }); - auto lookup_cursors = gsl::make_span(cursors).first(non_essential_count); - v1::for_each(merged_essential, [&](auto& cursor) { + v1::for_each(merged_essential, [&](auto&& cursor) { if constexpr (not std::is_void_v) { analyzer->document(); } @@ -44,18 +242,22 @@ namespace detail { if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { return; } - cursors[lookup_cursor_idx].advance_to_geq(docid); + auto& lookup_cursor = cursors[lookup_cursor_idx]; + lookup_cursor.advance_to_geq(docid); if constexpr (not std::is_void_v) { analyzer->lookup(); } - if (PISA_UNLIKELY(cursors[lookup_cursor_idx].value() == docid)) { - score += cursors[lookup_cursor_idx].payload(); + if (lookup_cursor.value() == docid) { + score += lookup_cursor.payload(); } } if constexpr (not std::is_void_v) { - analyzer->insert(); + if (topk.insert(score, docid)) { + analyzer->insert(); + } + } else { + topk.insert(score, docid); } - topk.insert(score, docid); }); return topk; } @@ -66,7 +268,7 @@ auto unigram_union_lookup(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - Analyzer* analyzer = nullptr) + [[maybe_unused]] Analyzer* analyzer = nullptr) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { @@ -108,11 +310,22 @@ auto unigram_union_lookup(Query const& query, for (size_t i = 1; i < cursors.size(); ++i) { upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); } - return detail::unigram_union_lookup(std::move(cursors), - std::move(upper_bounds), - non_essential_count, - std::move(topk), - analyzer); + // TODO: + std::vector cursor_idx(cursors.size()); + std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }; + auto joined = join_union_lookup(std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + payload_type{}, + accumulate, + [&](auto score) { return topk.would_enter(score); }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; } template @@ -120,8 +333,11 @@ auto maxscore_union_lookup(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - Analyzer* analyzer = nullptr) + [[maybe_unused]] Analyzer* analyzer = nullptr) { + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { return topk; @@ -129,10 +345,12 @@ auto maxscore_union_lookup(Query const& query, auto threshold = query.get_threshold(); topk.set_threshold(threshold); - using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using payload_type = decltype(std::declval().payload()); - auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + std::vector cursor_idx(cursors.size()); + std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + std::sort(cursor_idx.begin(), cursor_idx.end(), [&](auto&& lhs, auto&& rhs) { + return cursors[lhs].max_score() < cursors[rhs].max_score(); + }); std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); @@ -146,11 +364,20 @@ auto maxscore_union_lookup(Query const& query, while (non_essential_count < cursors.size() && upper_bounds[non_essential_count] < threshold) { non_essential_count += 1; } - return detail::unigram_union_lookup(std::move(cursors), - std::move(upper_bounds), - non_essential_count, - std::move(topk), - analyzer); + + auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }; + auto joined = join_union_lookup(std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + payload_type{}, + accumulate, + [&](auto score) { return topk.would_enter(score); }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; } template @@ -281,19 +508,206 @@ struct TwoPhaseUnionLookupAnalyzer : public BaseUnionLookupAnalyzer +struct BigramUnionLookupJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr BigramUnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_cursor_idx(std::move(cursor_idx)), + m_upper_bounds(std::move(upper_bounds)), + m_non_essential_count(non_essential_count), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + initialize(); + } + + constexpr BigramUnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Analyzer* analyzer) + : m_cursors(std::move(cursors)), + m_cursor_idx(std::move(upper_bounds)), + m_upper_bounds(std::move(upper_bounds)), + m_non_essential_count(non_essential_count), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_analyzer(analyzer) + { + initialize(); + } + + void initialize() + { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() + || m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + if constexpr (not std::is_void_v) { + m_analyzer->document(); + } + + for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); + sorted_position += 1) { + + auto& cursor = m_cursors[sorted_position]; + if (cursor.value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_analyzer->posting(); + } + m_current_payload = + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + cursor.advance(); + } + if (auto docid = cursor.value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; + sorted_position -= 1) { + if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { + exit = false; + break; + } + auto& cursor = m_cursors[sorted_position]; + cursor.advance_to_geq(m_current_value); + if constexpr (not std::is_void_v) { + m_analyzer->lookup(); + } + if (cursor.value() == m_current_value) { + m_current_payload = + m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + } + } + } + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + CursorContainer m_cursors; + std::vector m_cursor_idx; + std::vector m_upper_bounds; + std::size_t m_non_essential_count; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + payload_type m_previous_threshold{}; + + Analyzer* m_analyzer; +}; + +// template auto join_union_lookup(CursorContainer cursors, +// std::vector cursor_idx, +// std::vector upper_bounds, +// std::size_t non_essential_count, +// Payload init, +// AccumulateFn accumulate, +// ThresholdFn threshold) +//{ +// return BigramUnionLookupJoin( +// std::move(cursors), +// std::move(cursor_idx), +// std::move(upper_bounds), +// non_essential_count, +// std::move(init), +// std::move(accumulate), +// std::move(threshold)); +//} +// +// template +// auto join_union_lookup(CursorContainer cursors, +// std::vector cursor_idx, +// std::vector upper_bounds, +// std::size_t non_essential_count, +// Payload init, +// AccumulateFn accumulate, +// ThresholdFn threshold, +// Analyzer* analyzer) +//{ +// return BigramUnionLookupJoin( +// std::move(cursors), +// std::move(cursor_idx), +// std::move(upper_bounds), +// non_essential_count, +// std::move(init), +// std::move(accumulate), +// std::move(threshold), +// analyzer); +//} + template auto union_lookup(Query const& query, Index const& index, @@ -332,12 +746,18 @@ auto union_lookup(Query const& query, std::back_inserter(essential_unigram_cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); + std::vector unigram_query_positions(essential_unigrams.size()); + for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); + unigram_position += 1) { + unigram_query_positions[unigram_position] = + query.sorted_position(essential_unigrams[unigram_position]); + } auto merged_unigrams = v1::union_merge( essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { if constexpr (not std::is_void_v) { analyzer->posting(); } - acc[query.sorted_position(essential_unigrams[term_idx])] = cursor.payload(); + acc[unigram_query_positions[term_idx]] = cursor.payload(); return acc; }); @@ -350,18 +770,27 @@ auto union_lookup(Query const& query, essential_bigram_cursors.push_back(cursor.take().value()); } - auto merged_bigrams = v1::union_merge( - std::move(essential_bigram_cursors), - initial_payload, - [&](auto& acc, auto& cursor, auto bigram_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); - } - auto payload = cursor.payload(); - acc[query.sorted_position(essential_bigrams[bigram_idx].first)] = std::get<0>(payload); - acc[query.sorted_position(essential_bigrams[bigram_idx].second)] = std::get<1>(payload); - return acc; - }); + std::vector> bigram_query_positions( + essential_bigrams.size()); + for (std::size_t bigram_position = 0; bigram_position < essential_bigrams.size(); + bigram_position += 1) { + bigram_query_positions[bigram_position] = + std::make_pair(query.sorted_position(essential_bigrams[bigram_position].first), + query.sorted_position(essential_bigrams[bigram_position].second)); + } + auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto bigram_idx) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + auto payload = cursor.payload(); + auto query_positions = + bigram_query_positions[bigram_idx]; + acc[query_positions.first] = std::get<0>(payload); + acc[query_positions.second] = std::get<1>(payload); + return acc; + }); auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { auto payload = cursor.payload(); @@ -422,9 +851,12 @@ auto union_lookup(Query const& query, } upper_bound -= lookup_cursor.max_score(); } - topk.insert(score, docid); if constexpr (not std::is_void_v) { - analyzer->insert(); + if (topk.insert(score, docid)) { + analyzer->insert(); + } + } else { + topk.insert(score, docid); } }); return topk; @@ -468,12 +900,18 @@ auto two_phase_union_lookup(Query const& query, std::back_inserter(essential_unigram_cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); + std::vector unigram_query_positions(essential_unigrams.size()); + for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); + unigram_position += 1) { + unigram_query_positions[unigram_position] = + query.sorted_position(essential_unigrams[unigram_position]); + } auto merged_unigrams = v1::union_merge( essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { if constexpr (not std::is_void_v) { analyzer->posting(); } - acc[query.sorted_position(essential_unigrams[term_idx])] = cursor.payload(); + acc[unigram_query_positions[term_idx]] = cursor.payload(); return acc; }); @@ -486,18 +924,27 @@ auto two_phase_union_lookup(Query const& query, essential_bigram_cursors.push_back(cursor.take().value()); } - auto merged_bigrams = v1::union_merge( - std::move(essential_bigram_cursors), - initial_payload, - [&](auto& acc, auto& cursor, auto bigram_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); - } - auto payload = cursor.payload(); - acc[query.sorted_position(essential_bigrams[bigram_idx].first)] = std::get<0>(payload); - acc[query.sorted_position(essential_bigrams[bigram_idx].second)] = std::get<1>(payload); - return acc; - }); + std::vector> bigram_query_positions( + essential_bigrams.size()); + for (std::size_t bigram_position = 0; bigram_position < essential_bigrams.size(); + bigram_position += 1) { + bigram_query_positions[bigram_position] = + std::make_pair(query.sorted_position(essential_bigrams[bigram_position].first), + query.sorted_position(essential_bigrams[bigram_position].second)); + } + auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto bigram_idx) { + if constexpr (not std::is_void_v) { + analyzer->posting(); + } + auto payload = cursor.payload(); + auto query_positions = + bigram_query_positions[bigram_idx]; + acc[query_positions.first] = std::get<0>(payload); + acc[query_positions.second] = std::get<1>(payload); + return acc; + }); auto lookup_cursors = [&]() { std::vector> @@ -519,7 +966,7 @@ auto two_phase_union_lookup(Query const& query, return acc + cursor.second.max_score(); }); - v1::for_each(merged_unigrams, [&](auto& cursor) { + auto accumulate_document = [&](auto& cursor) { if constexpr (not std::is_void_v) { analyzer->document(); } @@ -548,38 +995,10 @@ auto two_phase_union_lookup(Query const& query, if constexpr (not std::is_void_v) { analyzer->insert(); } - }); + }; - v1::for_each(merged_bigrams, [&](auto& cursor) { - if constexpr (not std::is_void_v) { - analyzer->document(); - } - auto docid = cursor.value(); - auto scores = cursor.payload(); - auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); - auto upper_bound = score + lookup_cursors_upper_bound; - for (auto& [idx, lookup_cursor] : lookup_cursors) { - if (not topk.would_enter(upper_bound)) { - return; - } - if (scores[idx] == 0) { - lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - analyzer->lookup(); - } - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - auto partial_score = lookup_cursor.payload(); - score += partial_score; - upper_bound += partial_score; - } - } - upper_bound -= lookup_cursor.max_score(); - } - topk.insert(score, docid); - if constexpr (not std::is_void_v) { - analyzer->insert(); - } - }); + v1::for_each(merged_unigrams, accumulate_document); + v1::for_each(merged_bigrams, accumulate_document); return topk; } diff --git a/script/cw09b.sh b/script/cw09b.sh index 6c76c2d5f..e83bdf90a 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -1,65 +1,133 @@ -PISA_BIN="/home/michal/work/pisa/build/bin" +PISA_BIN="/home/michal/pisa/build/bin" INTERSECT_BIN="/home/michal/intersect/target/release/intersect" -BINARY_FREQ_COLL="/mnt/michal/work/cw09b/inv" -FWD="/mnt/michal/work/cw09b/fwd" -INV="/mnt/michal/work/cw09b/inv" -BASENAME="/mnt/michal/work/v1/cw09b/cw09b" -THREADS=16 +BINARY_FREQ_COLL="/data/michal/work/cw09b/inv" +FWD="/data/michal/work/cw09b/fwd" +INV="/data/michal/work/cw09b/inv" +BASENAME="/data/michal/work/v1/cw09b/cw09b" +THREADS=4 TYPE="block_simdbp" # v0.6 ENCODING="simdbp" # v1 QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" K=1000 -OUTPUT_DIR=`pwd` +OUTPUT_DIR="/data/michal/intersect/cw09b" FILTERED_QUERIES="${OUTPUT_DIR}/filtered_queries" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +THRESHOLDS="${OUTPUT_DIR}/thresholds" -set -x set -e +set -x ## Compress an inverted index in `binary_freq_collection` format. #./bin/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} -# -## This will produce both quantized scores and max scores (both quantized and not). -#./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} -# -## This will produce both quantized scores and max scores (both quantized and not). + +# This will produce both quantized scores and max scores (both quantized and not). #./bin/score -i "${BASENAME}.yml" -j ${THREADS} +# This will produce both quantized scores and max scores (both quantized and not). +# ./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} + # Filter out queries witout existing terms. -${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ - > ${FILTERED_QUERIES} +#${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ +# > ${FILTERED_QUERIES} # Extract thresholds (TODO: estimates) -${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ - -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ - | grep -v "\[warning\]" \ - > ${OUTPUT_DIR}/thresholds -cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv +#${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ +# -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ +# | grep -v "\[warning\]" \ +# > ${THRESHOLDS} +#cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv +#cut -d: -f1 ${FILTERED_QUERIES} | paste - ${THRESHOLDS} > ${OUTPUT_DIR}/thresholds.tsv # Extract intersections -${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ - -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ - | grep -v "\[warning\]" \ - > ${OUTPUT_DIR}/intersections.tsv - -# Select unigrams -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 1 > ${OUTPUT_DIR}/selections.1 +#${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ +# -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.tsv -# Select unigrams and bigrams +## Select unigrams +#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ +# --terse --max 1 > ${OUTPUT_DIR}/selections.1 +# +## Select unigrams and bigrams +#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ +# --terse --max 2 > ${OUTPUT_DIR}/selections.2 +# +## Select unigrams and bigrams scaled +#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ +# --terse --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 +# +# Select unigrams and bigrams scaled ${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 > ${OUTPUT_DIR}/selections.2 + --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +# +#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ +# --terse --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart # Run benchmarks -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ -# --thresholds ${OUTPUT_DIR}/thresholds +# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-threshold #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ -# --thresholds ${OUTPUT_DIR}/thresholds +# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-union-lookup #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ -# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.unigram-union-lookup +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.union-lookup.1 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.1 +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.union-lookup.2 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm two-phase-union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.two-phase-union-lookup #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm two-phase-union-lookup \ - --thresholds ${OUTPUT_DIR}/thresholds --intersections ${OUTPUT_DIR}/selections.2 +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/bench.union-lookup.scaled-1.5 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/bench.union-lookup.scaled-2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore \ + --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-thresholds +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore-union-lookup \ + --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.union-lookup.1 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm two-phase-union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.two-phase-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm unigram-union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.union-lookup.1" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.union-lookup.2" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm two-phase-union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.two-phase-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 81ab1d268..842f15449 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,25 +19,25 @@ # pisa # CLI11 #) -# -#add_executable(queries queries.cpp) -#target_link_libraries(queries -# pisa -# CLI11 -#) -# + +add_executable(queries queries.cpp) +target_link_libraries(queries + pisa + CLI11 +) + #add_executable(evaluate_queries evaluate_queries.cpp) #target_link_libraries(evaluate_queries # pisa # CLI11 #) -# -#add_executable(thresholds thresholds.cpp) -#target_link_libraries(thresholds -# pisa -# CLI11 -#) -# + +add_executable(thresholds thresholds.cpp) +target_link_libraries(thresholds + pisa + CLI11 +) + #add_executable(profile_queries profile_queries.cpp) #target_link_libraries(profile_queries # pisa diff --git a/src/thresholds.cpp b/src/thresholds.cpp index 9a7c9d248..6fa69333e 100644 --- a/src/thresholds.cpp +++ b/src/thresholds.cpp @@ -9,24 +9,24 @@ #include "mappable/mapper.hpp" +#include "cursor/max_scored_cursor.hpp" #include "index_types.hpp" #include "io.hpp" #include "query/queries.hpp" #include "util/util.hpp" #include "wand_data_compressed.hpp" #include "wand_data_raw.hpp" -#include "cursor/max_scored_cursor.hpp" #include "CLI/CLI.hpp" using namespace pisa; template -void thresholds(const std::string &index_filename, - const std::optional &wand_data_filename, - const std::vector &queries, - const std::optional &thresholds_filename, - std::string const &type, +void thresholds(const std::string& index_filename, + const std::optional& wand_data_filename, + const std::vector& queries, + const std::optional& thresholds_filename, + std::string const& type, uint64_t k) { IndexType index; @@ -47,11 +47,11 @@ void thresholds(const std::string &index_filename, } wand_query wand_q(k); - for (auto const &query : queries) { + for (auto const& query : queries) { wand_q(make_max_scored_cursors(index, wdata, query), index.num_docs()); - auto results = wand_q.topk(); + auto results = wand_q.topk(); float threshold = 0.0; - if (results.size() == k) { + if (not results.empty()) { threshold = results.back().first; } std::cout << threshold << '\n'; @@ -61,7 +61,7 @@ void thresholds(const std::string &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { std::string type; std::string index_filename; @@ -82,7 +82,7 @@ int main(int argc, const char **argv) app.add_option("-q,--query", query_filename, "Queries filename"); app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); app.add_option("-k", k, "k value"); - auto *terms_opt = + auto* terms_opt = app.add_option("--terms", terms_file, "Text file with terms in separate lines"); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); CLI11_PARSE(app, argc, argv); @@ -98,24 +98,17 @@ int main(int argc, const char **argv) /**/ if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (type == BOOST_PP_STRINGIZE(T)) { \ - if (compressed) { \ - thresholds(index_filename, \ - wand_data_filename, \ - queries, \ - thresholds_filename, \ - type, \ - k); \ - } else { \ - thresholds(index_filename, \ - wand_data_filename, \ - queries, \ - thresholds_filename, \ - type, \ - k); \ - } \ +#define LOOP_BODY(R, DATA, T) \ + } \ + else if (type == BOOST_PP_STRINGIZE(T)) \ + { \ + if (compressed) { \ + thresholds( \ + index_filename, wand_data_filename, queries, thresholds_filename, type, k); \ + } else { \ + thresholds( \ + index_filename, wand_data_filename, queries, thresholds_filename, type, k); \ + } \ /**/ BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index 8ba9469e0..a2335f994 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -78,6 +78,67 @@ auto verify_compressed_index(std::string const& input, std::string_view output) return errors; } +[[nodiscard]] auto build_scored_bigram_index(IndexMetadata meta, + std::string const& index_basename, + std::vector> const& bigrams) + -> std::pair +{ + auto run = scored_index_runner(meta, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + + std::vector> pair_mapping; + auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); + auto scores_file_1 = fmt::format("{}.bigram_bm25_1", index_basename); + auto score_offsets_file_0 = fmt::format("{}.bigram_bm25_offsets_0", index_basename); + auto score_offsets_file_1 = fmt::format("{}.bigram_bm25_offsets_1", index_basename); + std::ofstream score_out_0(scores_file_0); + std::ofstream score_out_1(scores_file_1); + + run([&](auto&& index) { + ProgressStatus status(bigrams.size(), + DefaultProgress("Building scored index"), + std::chrono::milliseconds(100)); + using index_type = std::decay_t; + using score_writer_type = + typename CursorTraits::Writer; + + PostingBuilder score_builder_0(score_writer_type{}); + PostingBuilder score_builder_1(score_writer_type{}); + + score_builder_0.write_header(score_out_0); + score_builder_1.write_header(score_out_1); + + for (auto [left_term, right_term] : bigrams) { + auto intersection = intersect({index.scored_cursor(left_term, VoidScorer{}), + index.scored_cursor(right_term, VoidScorer{})}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + payload[list_idx] = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + for_each(intersection, [&](auto& cursor) { + auto payload = cursor.payload(); + score_builder_0.accumulate(std::get<0>(payload)); + score_builder_1.accumulate(std::get<1>(payload)); + }); + score_builder_0.flush_segment(score_out_0); + score_builder_1.flush_segment(score_out_1); + status += 1; + } + write_span(gsl::make_span(score_builder_0.offsets()), score_offsets_file_0); + write_span(gsl::make_span(score_builder_1.offsets()), score_offsets_file_1); + }); + return {PostingFilePaths{scores_file_0, score_offsets_file_0}, + PostingFilePaths{scores_file_1, score_offsets_file_1}}; +} + void build_bigram_index(std::string const& yml, std::vector> const& bigrams) { @@ -147,13 +208,18 @@ void build_bigram_index(std::string const& yml, write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); std::cerr << " Done.\n"; }); - std::cerr << "Writing metadata..."; meta.bigrams = BigramMetadata{ .documents = {.postings = documents_file, .offsets = document_offsets_file}, .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, + .scores = {}, .mapping = fmt::format("{}.bigram_mapping", index_basename), .count = pair_mapping.size()}; + if (not meta.scores.empty()) { + meta.bigrams->scores.push_back(build_scored_bigram_index(meta, index_basename, bigrams)); + } + + std::cerr << "Writing metadata..."; meta.write(yml); std::cerr << " Done.\nWriting bigram mapping..."; write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index c835c2f0d..717b79bd4 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -73,6 +73,11 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; .offsets = config[BIGRAM]["frequencies_0"][OFFSETS].as()}, {.postings = config[BIGRAM]["frequencies_1"][POSTINGS].as(), .offsets = config[BIGRAM]["frequencies_1"][OFFSETS].as()}}, + .scores = {{{.postings = config[BIGRAM]["scores_0"][POSTINGS].as(), + .offsets = config[BIGRAM]["scores_0"][OFFSETS].as()}, + {.postings = config[BIGRAM]["scores_1"][POSTINGS].as(), + .offsets = + config[BIGRAM]["scores_1"][OFFSETS].as()}}}, .mapping = config[BIGRAM]["mapping"].as(), .count = config[BIGRAM]["count"].as()}; } @@ -121,6 +126,12 @@ void IndexMetadata::write(std::string const& file) root[BIGRAM]["frequencies_0"][OFFSETS] = bigrams->frequencies.first.offsets; root[BIGRAM]["frequencies_1"][POSTINGS] = bigrams->frequencies.second.postings; root[BIGRAM]["frequencies_1"][OFFSETS] = bigrams->frequencies.second.offsets; + if (not bigrams->scores.empty()) { + root[BIGRAM]["scores_0"][POSTINGS] = bigrams->scores.front().first.postings; + root[BIGRAM]["scores_0"][OFFSETS] = bigrams->scores.front().first.offsets; + root[BIGRAM]["scores_1"][POSTINGS] = bigrams->scores.front().second.postings; + root[BIGRAM]["scores_1"][OFFSETS] = bigrams->scores.front().second.offsets; + } root[BIGRAM]["mapping"] = bigrams->mapping; root[BIGRAM]["count"] = bigrams->count; } diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp index aee096191..9a00a9e83 100644 --- a/src/v1/score_index.cpp +++ b/src/v1/score_index.cpp @@ -47,7 +47,7 @@ void score_index(std::string const& yml, std::size_t threads) boost::counting_iterator(end_term), [&](auto term) { for_each( - index.scoring_cursor(term, make_bm25(index)), [&](auto& cursor) { + index.scoring_cursor(term, make_bm25(index)), [&](auto&& cursor) { if (auto score = cursor.payload(); max_scores[term] < score) { max_scores[term] = score; } diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp index 6b982b173..bbe762a54 100644 --- a/test/v1/test_v1_bigram_index.cpp +++ b/test/v1/test_v1_bigram_index.cpp @@ -60,7 +60,6 @@ TEMPLATE_TEST_CASE("Bigram v intersection", auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { - auto scorer = make_bm25(index); for (auto left = 0; left < q.get_term_ids().size(); left += 1) { for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { auto left_cursor = index.cursor(q.get_term_ids()[left]); diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 264dbf09f..606983466 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -37,8 +37,14 @@ #include "v1/types.hpp" #include "v1/union_lookup.hpp" -namespace v1 = pisa::v1; using namespace pisa; +using pisa::v1::BlockedCursor; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::Index; +using pisa::v1::IndexMetadata; +using pisa::v1::ListSelection; +using pisa::v1::RawCursor; static constexpr auto RELATIVE_ERROR = 0.1F; @@ -116,17 +122,15 @@ std::unique_ptr> TEMPLATE_TEST_CASE("Query", "[v1][integration]", - (IndexFixture, - v1::RawCursor, - v1::RawCursor>), - (IndexFixture, - v1::BlockedCursor<::pisa::simdbp_block, false>, - v1::RawCursor>)) + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + BlockedCursor<::pisa::simdbp_block, false>, + RawCursor>)) { tbb::task_scheduler_init init(1); auto data = IndexData, v1::RawCursor>, - v1::Index, v1::RawCursor>>::get(); + Index, RawCursor>, + Index, RawCursor>>::get(); TestType fixture; auto input_data = GENERATE(table({{"daat_or", false}, {"maxscore", false}, @@ -139,7 +143,7 @@ TEMPLATE_TEST_CASE("Query", CAPTURE(algorithm); CAPTURE(with_threshold); auto index_basename = (fixture.tmpdir().path() / "inv").string(); - auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); ranked_or_query or_q(10); auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { if (name == "daat_or") { @@ -152,7 +156,7 @@ TEMPLATE_TEST_CASE("Query", return maxscore_union_lookup(query, index, topk_queue(10), scorer); } if (name == "unigram_union_lookup") { - query.selections(v1::ListSelection{.unigrams = query.get_term_ids(), .bigrams = {}}); + query.selections(ListSelection{.unigrams = query.get_term_ids(), .bigrams = {}}); return unigram_union_lookup(query, index, topk_queue(10), scorer); } if (name == "union_lookup") { @@ -185,7 +189,7 @@ TEMPLATE_TEST_CASE("Query", auto on_the_fly = [&]() { auto run = - v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + pisa::v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); std::vector results; run([&](auto&& index) { auto que = run_query(algorithm, query, index, make_bm25(index)); diff --git a/v1/query.cpp b/v1/query.cpp index 3a8457233..334158eeb 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "app.hpp" @@ -79,14 +80,14 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco } if (name == "union-lookup") { return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } if (query.get_term_ids().size() > 8) { - return pisa::v1::maxscore_union_lookup( + return pisa::v1::maxscore( query, index, std::move(topk), std::forward(scorer)); } - // if (query.selections() && query.selections()->bigrams.empty()) { - // return pisa::v1::unigram_union_lookup( - // query, index, std::move(topk), std::forward(scorer)); - //} return pisa::v1::union_lookup( query, index, std::move(topk), std::forward(scorer)); }); @@ -171,15 +172,18 @@ void benchmark(std::vector const& queries, int k, RetrievalAlgorithm retr times[query] = std::min(times[query], static_cast(usecs.count())); } } + for (auto time : times) { + std::cout << time << '\n'; + } std::sort(times.begin(), times.end()); double avg = std::accumulate(times.begin(), times.end(), double()) / times.size(); double q50 = times[times.size() / 2]; double q90 = times[90 * times.size() / 100]; double q95 = times[95 * times.size() / 100]; - spdlog::info("Mean: {} microsec.", avg); - spdlog::info("50% quantile: {} microsec.", q50); - spdlog::info("90% quantile: {} microsec.", q90); - spdlog::info("95% quantile: {} microsec.", q95); + spdlog::info("Mean: {} us", avg); + spdlog::info("50% quantile: {} us", q50); + spdlog::info("90% quantile: {} us", q90); + spdlog::info("95% quantile: {} us", q95); } void analyze_queries(std::vector const& queries, QueryAnalyzer analyzer) @@ -195,6 +199,9 @@ void analyze_queries(std::vector const& queries, QueryAnalyzer analyzer) int main(int argc, char** argv) { + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + std::string algorithm = "daat_or"; tl::optional threshold_file; tl::optional inter_filename; @@ -207,107 +214,110 @@ int main(int argc, char** argv) app.add_flag("--analyze", analyze, "Analyze query execution and stats"); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); - auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - if (meta.term_lexicon) { - app.terms_file = meta.term_lexicon.value(); - } - if (meta.document_lexicon) { - app.documents_file = meta.document_lexicon.value(); - } - - auto queries = [&]() { - std::vector<::pisa::Query> queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); + try { + auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); + auto stemmer = + meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + app.terms_file = meta.term_lexicon.value(); + } + if (meta.document_lexicon) { + app.documents_file = meta.document_lexicon.value(); } - std::vector v1_queries(queries.size()); - std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { - Query query(parsed.terms); - if (parsed.id) { - query.id(*parsed.id); + + auto queries = [&]() { + std::vector<::pisa::Query> queries; + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); } - query.k(app.k); - return query; - }); - return v1_queries; - }(); + std::vector v1_queries(queries.size()); + std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { + Query query(parsed.terms); + if (parsed.id) { + query.id(*parsed.id); + } + query.k(app.k); + return query; + }); + return v1_queries; + }(); - if (not app.documents_file) { - spdlog::error("Document lexicon not defined"); - std::exit(1); - } - auto source = std::make_shared(app.documents_file.value().c_str()); - auto docmap = pisa::Payload_Vector<>::from(*source); + if (not app.documents_file) { + spdlog::error("Document lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(app.documents_file.value().c_str()); + auto docmap = pisa::Payload_Vector<>::from(*source); - if (threshold_file) { - std::ifstream is(*threshold_file); - auto queries_iter = queries.begin(); - pisa::io::for_each_line(is, [&](auto&& line) { - if (queries_iter == queries.end()) { + if (threshold_file) { + std::ifstream is(*threshold_file); + auto queries_iter = queries.begin(); + pisa::io::for_each_line(is, [&](auto&& line) { + if (queries_iter == queries.end()) { + spdlog::error("Number of thresholds not equal to number of queries"); + std::exit(1); + } + queries_iter->threshold(std::stof(line)); + ++queries_iter; + }); + if (queries_iter != queries.end()) { spdlog::error("Number of thresholds not equal to number of queries"); std::exit(1); } - queries_iter->threshold(std::stof(line)); - ++queries_iter; - }); - if (queries_iter != queries.end()) { - spdlog::error("Number of thresholds not equal to number of queries"); - std::exit(1); } - } - if (inter_filename) { - auto const intersections = pisa::v1::read_intersections(*inter_filename); - if (intersections.size() != queries.size()) { - spdlog::error("Number of intersections is not equal to number of queries"); - std::exit(1); - } - /* auto unigrams = pisa::v1::filter_unigrams(intersections); */ - /* auto bigrams = pisa::v1::filter_bigrams(intersections); */ - - for (auto query_idx = 0; query_idx < queries.size(); query_idx += 1) { - queries[query_idx].add_selections(gsl::make_span(intersections[query_idx])); - // ListSelection{std::move(unigrams[query_idx]), std::move(bigrams[query_idx])}; + if (inter_filename) { + auto const intersections = pisa::v1::read_intersections(*inter_filename); + if (intersections.size() != queries.size()) { + spdlog::error("Number of intersections is not equal to number of queries"); + std::exit(1); + } + for (auto query_idx = 0; query_idx < queries.size(); query_idx += 1) { + queries[query_idx].add_selections(gsl::make_span(intersections[query_idx])); + } } - } - if (app.precomputed) { - auto run = scored_index_runner(meta, - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto&& index) { - if (app.is_benchmark) { - benchmark(queries, app.k, resolve_algorithm(algorithm, index, VoidScorer{})); - } else if (analyze) { - analyze_queries(queries, resolve_analyze(algorithm, index, VoidScorer{})); - } else { - evaluate(queries, app.k, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); - } - }); - } else { - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto&& index) { - auto with_scorer = scorer_runner(index, make_bm25(index)); - with_scorer("bm25", [&](auto scorer) { + if (app.precomputed) { + auto run = scored_index_runner(meta, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { if (app.is_benchmark) { - benchmark(queries, app.k, resolve_algorithm(algorithm, index, scorer)); + benchmark(queries, app.k, resolve_algorithm(algorithm, index, VoidScorer{})); } else if (analyze) { - analyze_queries(queries, resolve_analyze(algorithm, index, scorer)); + analyze_queries(queries, resolve_analyze(algorithm, index, VoidScorer{})); } else { - evaluate(queries, app.k, docmap, resolve_algorithm(algorithm, index, scorer)); + evaluate( + queries, app.k, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); } }); - }); + } else { + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + if (app.is_benchmark) { + benchmark(queries, app.k, resolve_algorithm(algorithm, index, scorer)); + } else if (analyze) { + analyze_queries(queries, resolve_analyze(algorithm, index, scorer)); + } else { + evaluate( + queries, app.k, docmap, resolve_algorithm(algorithm, index, scorer)); + } + }); + }); + } + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); } return 0; } From c5cedf2e4702693fdef278ac6ea7ab6e4a82ab1c Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Thu, 5 Dec 2019 16:33:19 +0000 Subject: [PATCH 30/56] Union-lookup updates --- include/pisa/v1/analyze_query.hpp | 69 --- include/pisa/v1/cursor/transform.hpp | 75 +++ include/pisa/v1/cursor_accumulator.hpp | 19 + include/pisa/v1/daat_or.hpp | 4 +- include/pisa/v1/index_metadata.hpp | 2 +- include/pisa/v1/inspect_query.hpp | 69 +++ include/pisa/v1/maxscore.hpp | 49 +- include/pisa/v1/query.hpp | 70 +-- include/pisa/v1/union_lookup.hpp | 755 ++++++++++++------------- include/pisa/v1/zip_cursor.hpp | 9 +- script/cw09b.sh | 99 ++-- src/v1/index_builder.cpp | 11 +- src/v1/index_metadata.cpp | 13 +- src/v1/query.cpp | 24 + test/v1/index_fixture.hpp | 2 +- test/v1/test_v1_queries.cpp | 50 +- test/v1/test_v1_query.cpp | 23 + v1/CMakeLists.txt | 3 + v1/filter_queries.cpp | 11 - v1/query.cpp | 69 ++- v1/threshold.cpp | 108 ++++ 21 files changed, 873 insertions(+), 661 deletions(-) delete mode 100644 include/pisa/v1/analyze_query.hpp create mode 100644 include/pisa/v1/cursor/transform.hpp create mode 100644 include/pisa/v1/inspect_query.hpp create mode 100644 test/v1/test_v1_query.cpp create mode 100644 v1/threshold.cpp diff --git a/include/pisa/v1/analyze_query.hpp b/include/pisa/v1/analyze_query.hpp deleted file mode 100644 index 77f30ee3c..000000000 --- a/include/pisa/v1/analyze_query.hpp +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -#include - -#include -#include - -namespace pisa::v1 { - -struct Query; - -struct QueryAnalyzer { - - template - explicit constexpr QueryAnalyzer(R writer) - : m_internal_analyzer(std::make_unique>(writer)) - { - } - QueryAnalyzer() = default; - QueryAnalyzer(QueryAnalyzer const& other) - : m_internal_analyzer(other.m_internal_analyzer->clone()) - { - } - QueryAnalyzer(QueryAnalyzer&& other) noexcept = default; - QueryAnalyzer& operator=(QueryAnalyzer const& other) = delete; - QueryAnalyzer& operator=(QueryAnalyzer&& other) noexcept = default; - ~QueryAnalyzer() = default; - - void operator()(Query const& query) { m_internal_analyzer->operator()(query); } - void summarize() && { std::move(*m_internal_analyzer).summarize(); } - - struct AnalyzerInterface { - AnalyzerInterface() = default; - AnalyzerInterface(AnalyzerInterface const&) = default; - AnalyzerInterface(AnalyzerInterface&&) noexcept = default; - AnalyzerInterface& operator=(AnalyzerInterface const&) = default; - AnalyzerInterface& operator=(AnalyzerInterface&&) noexcept = default; - virtual ~AnalyzerInterface() = default; - virtual void operator()(Query const& query) = 0; - virtual void summarize() && = 0; - [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; - }; - - template - struct AnalyzerImpl : AnalyzerInterface { - explicit AnalyzerImpl(R analyzer) : m_analyzer(std::move(analyzer)) {} - AnalyzerImpl() = default; - AnalyzerImpl(AnalyzerImpl const&) = default; - AnalyzerImpl(AnalyzerImpl&&) noexcept = default; - AnalyzerImpl& operator=(AnalyzerImpl const&) = default; - AnalyzerImpl& operator=(AnalyzerImpl&&) noexcept = default; - ~AnalyzerImpl() override = default; - void operator()(Query const& query) override { m_analyzer(query); } - void summarize() && override { std::move(m_analyzer).summarize(); } - [[nodiscard]] auto clone() const -> std::unique_ptr override - { - auto copy = *this; - return std::make_unique>(std::move(copy)); - } - - private: - R m_analyzer; - }; - - private: - std::unique_ptr m_internal_analyzer; -}; - -} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/transform.hpp b/include/pisa/v1/cursor/transform.hpp new file mode 100644 index 000000000..b53deae8e --- /dev/null +++ b/include/pisa/v1/cursor/transform.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include + +namespace pisa::v1 { + +template +struct TransformCursor { + using Value = + decltype(std::declval(std::declval>())); + + constexpr TransformCursor(Cursor cursor, TransformFn transform) + : m_cursor(std::move(cursor)), m_transform(std::move(transform)) + { + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value + { + return m_transform(m_cursor.value()); + } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + + private: + Cursor m_cursor; + TransformFn m_transform; +}; + +template +auto transform(Cursor cursor, TransformFn transform) +{ + return TransformCursor(std::move(cursor), std::move(transform)); +} + +template +struct TransformPayloadCursor { + constexpr TransformPayloadCursor(Cursor cursor, TransformFn transform) + : m_cursor(std::move(cursor)), m_transform(std::move(transform)) + { + } + + [[nodiscard]] constexpr auto operator*() const { return value(); } + [[nodiscard]] constexpr auto value() const noexcept { return m_cursor.value(); } + [[nodiscard]] constexpr auto payload() { return m_transform(m_cursor); } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const { return m_cursor.sentinel(); } + + private: + Cursor m_cursor; + TransformFn m_transform; +}; + +template +auto transform_payload(Cursor cursor, TransformFn transform) +{ + return TransformPayloadCursor(std::move(cursor), std::move(transform)); +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_accumulator.hpp b/include/pisa/v1/cursor_accumulator.hpp index 5eb1df5a4..760cc92c8 100644 --- a/include/pisa/v1/cursor_accumulator.hpp +++ b/include/pisa/v1/cursor_accumulator.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace pisa::v1::accumulate { @@ -13,4 +14,22 @@ struct Add { } }; +template +struct InspectAdd { + constexpr explicit InspectAdd(Inspect* inspect) : m_inspect(inspect) {} + + template + auto operator()(Score&& score, Cursor&& cursor, std::size_t /* term_idx */) + { + if constexpr (not std::is_void_v) { + m_inspect->posting(); + } + score += cursor.payload(); + return score; + } + + private: + Inspect* m_inspect; +}; + } // namespace pisa::v1::accumulate diff --git a/include/pisa/v1/daat_or.hpp b/include/pisa/v1/daat_or.hpp index f9b1f9cd7..9c394044c 100644 --- a/include/pisa/v1/daat_or.hpp +++ b/include/pisa/v1/daat_or.hpp @@ -26,8 +26,8 @@ auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& s } template -struct DaatOrAnalyzer { - DaatOrAnalyzer(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) +struct DaatOrInspector { + DaatOrInspector(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) { std::cout << fmt::format("documents\tpostings\n"); } diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index f7d0623d2..475428b15 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -161,7 +161,7 @@ template tl::optional> bigram_documents{}; tl::optional, 2>> bigram_scores{}; tl::optional const>> bigram_mapping{}; - if (metadata.bigrams) { + if (metadata.bigrams && not metadata.bigrams->scores.empty()) { bigram_document_offsets = source_span(source, metadata.bigrams->documents.offsets); bigram_score_offsets = { diff --git a/include/pisa/v1/inspect_query.hpp b/include/pisa/v1/inspect_query.hpp new file mode 100644 index 000000000..bdf7b236f --- /dev/null +++ b/include/pisa/v1/inspect_query.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include + +#include +#include + +namespace pisa::v1 { + +struct Query; + +struct QueryInspector { + + template + explicit constexpr QueryInspector(R writer) + : m_internal_analyzer(std::make_unique>(writer)) + { + } + QueryInspector() = default; + QueryInspector(QueryInspector const& other) + : m_internal_analyzer(other.m_internal_analyzer->clone()) + { + } + QueryInspector(QueryInspector&& other) noexcept = default; + QueryInspector& operator=(QueryInspector const& other) = delete; + QueryInspector& operator=(QueryInspector&& other) noexcept = default; + ~QueryInspector() = default; + + void operator()(Query const& query) { m_internal_analyzer->operator()(query); } + void summarize() && { std::move(*m_internal_analyzer).summarize(); } + + struct InspectorInterface { + InspectorInterface() = default; + InspectorInterface(InspectorInterface const&) = default; + InspectorInterface(InspectorInterface&&) noexcept = default; + InspectorInterface& operator=(InspectorInterface const&) = default; + InspectorInterface& operator=(InspectorInterface&&) noexcept = default; + virtual ~InspectorInterface() = default; + virtual void operator()(Query const& query) = 0; + virtual void summarize() && = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct InspectorImpl : InspectorInterface { + explicit InspectorImpl(R analyzer) : m_analyzer(std::move(analyzer)) {} + InspectorImpl() = default; + InspectorImpl(InspectorImpl const&) = default; + InspectorImpl(InspectorImpl&&) noexcept = default; + InspectorImpl& operator=(InspectorImpl const&) = default; + InspectorImpl& operator=(InspectorImpl&&) noexcept = default; + ~InspectorImpl() override = default; + void operator()(Query const& query) override { m_analyzer(query); } + void summarize() && override { std::move(m_analyzer).summarize(); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_analyzer; + }; + + private: + std::unique_ptr m_internal_analyzer; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index b35b17219..80531e886 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -16,7 +16,7 @@ template + typename Inspect = void> struct MaxScoreJoin { using cursor_type = typename CursorContainer::value_type; using payload_type = Payload; @@ -46,7 +46,7 @@ struct MaxScoreJoin { Payload init, AccumulateFn accumulate, ThresholdFn above_threshold, - Analyzer* analyzer) + Inspect* inspect) : m_cursors(std::move(cursors)), m_cursor_idx(m_cursors.size()), m_upper_bounds(m_cursors.size()), @@ -54,7 +54,7 @@ struct MaxScoreJoin { m_accumulate(std::move(accumulate)), m_above_threshold(std::move(above_threshold)), m_size(std::nullopt), - m_analyzer(analyzer) + m_inspect(inspect) { initialize(); } @@ -107,8 +107,8 @@ struct MaxScoreJoin { m_current_payload = m_init; m_current_value = std::exchange(m_next_docid, sentinel()); - if constexpr (not std::is_void_v) { - m_analyzer->document(); + if constexpr (not std::is_void_v) { + m_inspect->document(); } for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); @@ -116,8 +116,8 @@ struct MaxScoreJoin { auto& cursor = m_cursors[sorted_position]; if (cursor.value() == m_current_value) { - if constexpr (not std::is_void_v) { - m_analyzer->posting(); + if constexpr (not std::is_void_v) { + m_inspect->posting(); } m_current_payload = m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); @@ -137,8 +137,8 @@ struct MaxScoreJoin { } auto& cursor = m_cursors[sorted_position]; cursor.advance_to_geq(m_current_value); - if constexpr (not std::is_void_v) { - m_analyzer->lookup(); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); } if (cursor.value() == m_current_value) { m_current_payload = @@ -175,7 +175,7 @@ struct MaxScoreJoin { std::size_t m_non_essential_count = 0; payload_type m_previous_threshold{}; - Analyzer* m_analyzer; + Inspect* m_inspect; }; template @@ -192,15 +192,15 @@ template + typename Inspect> auto join_maxscore(CursorContainer cursors, Payload init, AccumulateFn accumulate, ThresholdFn threshold, - Analyzer* analyzer) + Inspect* inspect) { - return MaxScoreJoin( - std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), analyzer); + return MaxScoreJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); } template @@ -228,8 +228,8 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& } template -struct MaxscoreAnalyzer { - MaxscoreAnalyzer(Index const& index, Scorer scorer) +struct MaxscoreInspector { + MaxscoreInspector(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) { std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); @@ -258,14 +258,15 @@ struct MaxscoreAnalyzer { topk_queue topk(query.k()); auto initial_threshold = query.threshold().value_or(-1.0); topk.set_threshold(initial_threshold); - auto joined = join_maxscore(std::move(cursors), - 0.0F, - [&](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }, - [&](auto score) { return topk.would_enter(score); }, - this); + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + [&](auto& score, auto& cursor, auto /* term_position */) { + score += cursor.payload(); + return score; + }, + [&](auto score) { return topk.would_enter(score); }, + this); v1::for_each(joined, [&](auto& cursor) { if (topk.insert(cursor.payload(), cursor.value())) { inserts += 1; diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 7c0a071b5..c8e8564a3 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -10,10 +10,10 @@ #include #include "topk_queue.hpp" -#include "v1/analyze_query.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/cursor_union.hpp" +#include "v1/inspect_query.hpp" #include "v1/intersection.hpp" #include "v1/types.hpp" @@ -22,34 +22,9 @@ namespace pisa::v1 { struct ListSelection { std::vector unigrams{}; std::vector> bigrams{}; -}; - -template -std::ostream& operator<<(std::ostream& os, std::pair const& p) -{ - return os << '(' << p.first << ", " << p.second << ')'; -} - -template -std::ostream& operator<<(std::ostream& os, std::vector const& vec) -{ - auto pos = vec.begin(); - os << '['; - if (pos != vec.end()) { - os << *pos++; - } - for (; pos != vec.end(); ++pos) { - os << ' ' << *pos; - } - os << ']'; - return os; -} -inline std::ostream& operator<<(std::ostream& os, ListSelection const& selection) -{ - return os << "ListSelection { unigrams: " << selection.unigrams - << ", bigrams: " << selection.bigrams << " }"; -} + [[nodiscard]] auto overlapping() const -> bool; +}; struct TermIdSet { explicit TermIdSet(std::vector terms) : m_term_list(std::move(terms)) @@ -85,12 +60,6 @@ struct TermIdSet { std::unordered_map m_sorted_positions{}; }; -inline std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids) -{ - return os << "TermIdSet { original: " << term_ids.m_term_list - << ", unique: " << term_ids.m_term_set << " }"; -} - struct Query { Query() = default; @@ -147,6 +116,39 @@ std::ostream& operator<<(std::ostream& os, tl::optional const& value) return os; } +template +std::ostream& operator<<(std::ostream& os, std::pair const& p) +{ + return os << '(' << p.first << ", " << p.second << ')'; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector const& vec) +{ + auto pos = vec.begin(); + os << '['; + if (pos != vec.end()) { + os << *pos++; + } + for (; pos != vec.end(); ++pos) { + os << ' ' << *pos; + } + os << ']'; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, ListSelection const& selection) +{ + return os << "ListSelection { unigrams: " << selection.unigrams + << ", bigrams: " << selection.bigrams << " }"; +} + +inline std::ostream& operator<<(std::ostream& os, TermIdSet const& term_ids) +{ + return os << "TermIdSet { original: " << term_ids.m_term_list + << ", unique: " << term_ids.m_term_set << " }"; +} + inline std::ostream& operator<<(std::ostream& os, Query const& query) { return os << "Query { term_ids: " << query.m_term_ids << ", selections: " << query.m_selections diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 59f3ce5ec..fb88a6c42 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -8,6 +8,8 @@ #include #include "v1/algorithm.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" #include "v1/query.hpp" namespace pisa::v1 { @@ -16,7 +18,7 @@ template + typename Inspect = void> struct UnionLookupJoin { using cursor_type = typename CursorContainer::value_type; using payload_type = Payload; @@ -27,13 +29,13 @@ struct UnionLookupJoin { static_assert(std::is_base_of(), "cursors must be stored in a random access container"); - constexpr UnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold) + UnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) : m_cursors(std::move(cursors)), m_cursor_idx(std::move(cursor_idx)), m_upper_bounds(std::move(upper_bounds)), @@ -46,14 +48,14 @@ struct UnionLookupJoin { initialize(); } - constexpr UnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold, - Analyzer* analyzer) + UnionLookupJoin(CursorContainer cursors, + std::vector cursor_idx, + std::vector upper_bounds, + std::size_t non_essential_count, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) : m_cursors(std::move(cursors)), m_cursor_idx(std::move(cursor_idx)), m_upper_bounds(std::move(upper_bounds)), @@ -62,7 +64,7 @@ struct UnionLookupJoin { m_accumulate(std::move(accumulate)), m_above_threshold(std::move(above_threshold)), m_size(std::nullopt), - m_analyzer(analyzer) + m_inspect(inspect) { initialize(); } @@ -102,8 +104,8 @@ struct UnionLookupJoin { m_current_payload = m_init; m_current_value = std::exchange(m_next_docid, sentinel()); - if constexpr (not std::is_void_v) { - m_analyzer->document(); + if constexpr (not std::is_void_v) { + m_inspect->document(); } for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); @@ -111,8 +113,8 @@ struct UnionLookupJoin { auto& cursor = m_cursors[sorted_position]; if (cursor.value() == m_current_value) { - if constexpr (not std::is_void_v) { - m_analyzer->posting(); + if constexpr (not std::is_void_v) { + m_inspect->posting(); } m_current_payload = m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); @@ -132,8 +134,8 @@ struct UnionLookupJoin { } auto& cursor = m_cursors[sorted_position]; cursor.advance_to_geq(m_current_value); - if constexpr (not std::is_void_v) { - m_analyzer->lookup(); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); } if (cursor.value() == m_current_value) { m_current_payload = @@ -165,7 +167,7 @@ struct UnionLookupJoin { std::uint32_t m_next_docid{}; payload_type m_previous_threshold{}; - Analyzer* m_analyzer; + Inspect* m_inspect; }; template @@ -191,7 +193,7 @@ template + typename Inspect> auto join_union_lookup(CursorContainer cursors, std::vector cursor_idx, std::vector upper_bounds, @@ -199,9 +201,9 @@ auto join_union_lookup(CursorContainer cursors, Payload init, AccumulateFn accumulate, ThresholdFn threshold, - Analyzer* analyzer) + Inspect* inspect) { - return UnionLookupJoin( + return UnionLookupJoin( std::move(cursors), std::move(cursor_idx), std::move(upper_bounds), @@ -209,31 +211,31 @@ auto join_union_lookup(CursorContainer cursors, std::move(init), std::move(accumulate), std::move(threshold), - analyzer); + inspect); } namespace detail { - template + template auto unigram_union_lookup(Cursors cursors, UpperBounds upper_bounds, std::size_t non_essential_count, topk_queue topk, - [[maybe_unused]] Analyzer* analyzer = nullptr) + [[maybe_unused]] Inspect* inspect = nullptr) { auto merged_essential = v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), 0.0F, [&](auto& acc, auto& cursor, auto /*term_idx*/) { - if constexpr (not std::is_void_v) { - analyzer->posting(); + if constexpr (not std::is_void_v) { + inspect->posting(); } acc += cursor.payload(); return acc; }); v1::for_each(merged_essential, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { - analyzer->document(); + if constexpr (not std::is_void_v) { + inspect->document(); } auto docid = cursor.value(); auto score = cursor.payload(); @@ -244,16 +246,16 @@ namespace detail { } auto& lookup_cursor = cursors[lookup_cursor_idx]; lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - analyzer->lookup(); + if constexpr (not std::is_void_v) { + inspect->lookup(); } if (lookup_cursor.value() == docid) { score += lookup_cursor.payload(); } } - if constexpr (not std::is_void_v) { + if constexpr (not std::is_void_v) { if (topk.insert(score, docid)) { - analyzer->insert(); + inspect->insert(); } } else { topk.insert(score, docid); @@ -263,12 +265,12 @@ namespace detail { } } // namespace detail -template +template auto unigram_union_lookup(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - [[maybe_unused]] Analyzer* analyzer = nullptr) + [[maybe_unused]] Inspect* inspect = nullptr) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { @@ -310,30 +312,42 @@ auto unigram_union_lookup(Query const& query, for (size_t i = 1; i < cursors.size(); ++i) { upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); } - // TODO: + + // We're not interested in these being correct std::vector cursor_idx(cursors.size()); std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { score += cursor.payload(); return score; }; - auto joined = join_union_lookup(std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, - payload_type{}, - accumulate, - [&](auto score) { return topk.would_enter(score); }); - v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + auto joined = join_union_lookup( + std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + payload_type{}, + accumulate, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); return topk; } -template +template auto maxscore_union_lookup(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - [[maybe_unused]] Analyzer* analyzer = nullptr) + [[maybe_unused]] Inspect* inspect = nullptr) { using cursor_type = decltype(index.max_scored_cursor(0, scorer)); using payload_type = decltype(std::declval().payload()); @@ -369,20 +383,30 @@ auto maxscore_union_lookup(Query const& query, score += cursor.payload(); return score; }; - auto joined = join_union_lookup(std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, - payload_type{}, - accumulate, - [&](auto score) { return topk.would_enter(score); }); - v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + auto joined = join_union_lookup( + std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + payload_type{}, + accumulate, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); return topk; } template -struct BaseUnionLookupAnalyzer { - BaseUnionLookupAnalyzer(Index const& index, Scorer scorer) +struct BaseUnionLookupInspect { + BaseUnionLookupInspect(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) { std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); @@ -453,9 +477,9 @@ struct BaseUnionLookupAnalyzer { }; template -struct MaxscoreUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { - MaxscoreUnionLookupAnalyzer(Index const& index, Scorer scorer) - : BaseUnionLookupAnalyzer(index, std::move(scorer)) +struct MaxscoreUnionLookupInspect : public BaseUnionLookupInspect { + MaxscoreUnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) { } void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override @@ -465,9 +489,9 @@ struct MaxscoreUnionLookupAnalyzer : public BaseUnionLookupAnalyzer -struct UnigramUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { - UnigramUnionLookupAnalyzer(Index const& index, Scorer scorer) - : BaseUnionLookupAnalyzer(index, std::move(scorer)) +struct UnigramUnionLookupInspect : public BaseUnionLookupInspect { + UnigramUnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) { } void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override @@ -477,14 +501,16 @@ struct UnigramUnionLookupAnalyzer : public BaseUnionLookupAnalyzer -struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { - UnionLookupAnalyzer(Index const& index, Scorer scorer) - : BaseUnionLookupAnalyzer(index, std::move(scorer)) +struct UnionLookupInspect : public BaseUnionLookupInspect { + UnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) { } void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override { - if (query.get_term_ids().size() > 8) { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else if (query.get_term_ids().size() > 8) { maxscore_union_lookup(query, index, std::move(topk), scorer, this); } else { union_lookup(query, index, std::move(topk), scorer, this); @@ -493,241 +519,214 @@ struct UnionLookupAnalyzer : public BaseUnionLookupAnalyzer { }; template -struct TwoPhaseUnionLookupAnalyzer : public BaseUnionLookupAnalyzer { - TwoPhaseUnionLookupAnalyzer(Index const& index, Scorer scorer) - : BaseUnionLookupAnalyzer(index, std::move(scorer)) +struct LookupUnionInspector : public BaseUnionLookupInspect { + LookupUnionInspector(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) { } void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override { - if (query.get_term_ids().size() > 8) { - maxscore_union_lookup(query, index, std::move(topk), scorer, this); + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); } else { - two_phase_union_lookup(query, index, std::move(topk), scorer, this); + lookup_union(query, index, std::move(topk), scorer, this); } } }; -template -struct BigramUnionLookupJoin { - using cursor_type = typename CursorContainer::value_type; - using payload_type = Payload; - using value_type = std::decay_t())>; - - using iterator_category = - typename std::iterator_traits::iterator_category; - static_assert(std::is_base_of(), - "cursors must be stored in a random access container"); +// template +// struct TwoPhaseUnionLookupInspect : public BaseUnionLookupInspect { +// TwoPhaseUnionLookupInspect(Index const& index, Scorer scorer) +// : BaseUnionLookupInspect(index, std::move(scorer)) +// { +// } +// void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override +// { +// // if (query.get_term_ids().size() > 8) { +// maxscore_union_lookup(query, index, std::move(topk), scorer, this); +// //} else { +// // two_phase_union_lookup(query, index, std::move(topk), scorer, this); +// //} +// } +//}; + +/// This one assumes that terms do not repeat in essential lists and intersections. +template +auto disjoint_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + using bigram_cursor_type = std::decay_t; - constexpr BigramUnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold) - : m_cursors(std::move(cursors)), - m_cursor_idx(std::move(cursor_idx)), - m_upper_bounds(std::move(upper_bounds)), - m_non_essential_count(non_essential_count), - m_init(std::move(init)), - m_accumulate(std::move(accumulate)), - m_above_threshold(std::move(above_threshold)), - m_size(std::nullopt) - { - initialize(); + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; } - constexpr BigramUnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold, - Analyzer* analyzer) - : m_cursors(std::move(cursors)), - m_cursor_idx(std::move(upper_bounds)), - m_upper_bounds(std::move(upper_bounds)), - m_non_essential_count(non_essential_count), - m_init(std::move(init)), - m_accumulate(std::move(accumulate)), - m_above_threshold(std::move(above_threshold)), - m_size(std::nullopt), - m_analyzer(analyzer) - { - initialize(); - } + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); - void initialize() - { - if (m_cursors.empty()) { - m_current_value = sentinel(); - m_current_payload = m_init; + auto const& selections = query.get_selections(); + auto const& essential_unigrams = selections.unigrams; + auto const& essential_bigrams = selections.bigrams; + + auto const non_essential_terms = [&]() { + auto all_essential_terms = essential_unigrams; + for (auto [left, right] : essential_bigrams) { + all_essential_terms.push_back(left); + all_essential_terms.push_back(right); } - m_next_docid = min_value(m_cursors); - m_sentinel = min_sentinel(m_cursors); - advance(); - } + ranges::sort(all_essential_terms); + ranges::actions::unique(all_essential_terms); + return ranges::views::set_difference(term_ids, all_essential_terms) | ranges::to_vector; + }(); - [[nodiscard]] constexpr auto operator*() const noexcept -> value_type - { - return m_current_value; - } - [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } - [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& - { - return m_current_payload; + auto essential_unigram_cursors = + index.scored_cursors(gsl::make_span(essential_unigrams), scorer); + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, 0.0F, pisa::v1::accumulate::InspectAdd(inspect)); + + std::vector essential_bigram_cursors; + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back(cursor.take().value()); } - [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } - constexpr void advance() - { - bool exit = false; - while (not exit) { - if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() - || m_next_docid >= sentinel())) { - m_current_value = sentinel(); - m_current_payload = m_init; + auto merged_bigrams = + v1::union_merge(std::move(essential_bigram_cursors), + 0.0F, + [&](auto& acc, auto& cursor, [[maybe_unused]] auto bigram_idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + auto payload = cursor.payload(); + acc += std::get<0>(payload) + std::get<1>(payload); + return acc; + }); + auto merged = v1::variadic_union_merge( + 0.0F, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(pisa::v1::accumulate::Add{}, pisa::v1::accumulate::Add{})); + + auto lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() > rhs.max_score(); + }); + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + } + auto docid = cursor.value(); + auto score = cursor.payload(); + auto upper_bound = score + lookup_cursors_upper_bound; + for (auto& lookup_cursor : lookup_cursors) { + if (not topk.would_enter(upper_bound)) { return; } - m_current_payload = m_init; - m_current_value = std::exchange(m_next_docid, sentinel()); - - if constexpr (not std::is_void_v) { - m_analyzer->document(); + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + inspect->lookup(); } - - for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); - sorted_position += 1) { - - auto& cursor = m_cursors[sorted_position]; - if (cursor.value() == m_current_value) { - if constexpr (not std::is_void_v) { - m_analyzer->posting(); - } - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); - cursor.advance(); - } - if (auto docid = cursor.value(); docid < m_next_docid) { - m_next_docid = docid; - } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; } - - exit = true; - for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; - sorted_position -= 1) { - if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { - exit = false; - break; - } - auto& cursor = m_cursors[sorted_position]; - cursor.advance_to_geq(m_current_value); - if constexpr (not std::is_void_v) { - m_analyzer->lookup(); - } - if (cursor.value() == m_current_value) { - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); - } + upper_bound -= lookup_cursor.max_score(); + } + if constexpr (not std::is_void_v) { + if (topk.insert(score, docid)) { + inspect->insert(); } + } else { + topk.insert(score, docid); } + }); + return topk; +} + +template +struct LookupTransform { + + LookupTransform(std::vector lookup_cursors, + std::vector upper_bounds, + float lookup_cursors_upper_bound, + AboveThresholdFn above_threshold, + Inspector* inspect = nullptr) + : m_lookup_cursors(std::move(lookup_cursors)), + m_upper_bounds(std::move(upper_bounds)), + m_lookup_cursors_upper_bound(lookup_cursors_upper_bound), + m_above_threshold(std::move(above_threshold)), + m_inspect(inspect) + { } - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) - [[nodiscard]] constexpr auto empty() const noexcept -> bool + auto operator()(Cursor& cursor) { - return m_current_value >= sentinel(); + if constexpr (not std::is_void_v) { + m_inspect->document(); + m_inspect->posting(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + float score = std::get<0>(scores) + std::get<1>(scores); + auto upper_bound = score + m_lookup_cursors_upper_bound; + for (auto& lookup_cursor : m_lookup_cursors) { + if (not m_above_threshold(upper_bound)) { + return score; + } + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + upper_bound -= lookup_cursor.max_score(); + } + return score; } private: - CursorContainer m_cursors; - std::vector m_cursor_idx; - std::vector m_upper_bounds; - std::size_t m_non_essential_count; - payload_type m_init; - AccumulateFn m_accumulate; - ThresholdFn m_above_threshold; - std::optional m_size; - - value_type m_current_value{}; - value_type m_sentinel{}; - payload_type m_current_payload{}; - std::uint32_t m_next_docid{}; - payload_type m_previous_threshold{}; - - Analyzer* m_analyzer; + std::vector m_lookup_cursors; + std::vector m_upper_bounds; + float m_lookup_cursors_upper_bound; + AboveThresholdFn m_above_threshold; + Inspector* m_inspect; }; -// template auto join_union_lookup(CursorContainer cursors, -// std::vector cursor_idx, -// std::vector upper_bounds, -// std::size_t non_essential_count, -// Payload init, -// AccumulateFn accumulate, -// ThresholdFn threshold) -//{ -// return BigramUnionLookupJoin( -// std::move(cursors), -// std::move(cursor_idx), -// std::move(upper_bounds), -// non_essential_count, -// std::move(init), -// std::move(accumulate), -// std::move(threshold)); -//} -// -// template -// auto join_union_lookup(CursorContainer cursors, -// std::vector cursor_idx, -// std::vector upper_bounds, -// std::size_t non_essential_count, -// Payload init, -// AccumulateFn accumulate, -// ThresholdFn threshold, -// Analyzer* analyzer) -//{ -// return BigramUnionLookupJoin( -// std::move(cursors), -// std::move(cursor_idx), -// std::move(upper_bounds), -// non_essential_count, -// std::move(init), -// std::move(accumulate), -// std::move(threshold), -// analyzer); -//} - -template -auto union_lookup(Query const& query, +template +auto lookup_union(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - Analyzer* analyzer = nullptr) + Inspector* inspector = nullptr) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { return topk; } - if (term_ids.size() > 8) { - throw std::invalid_argument( - "Generic version of union-Lookup supported only for queries of length <= 8"); - } auto threshold = query.get_threshold(); auto const& selections = query.get_selections(); using bigram_cursor_type = std::decay_t; + using lookup_cursor_type = std::decay_t; auto& essential_unigrams = selections.unigrams; auto& essential_bigrams = selections.bigrams; @@ -736,138 +735,104 @@ auto union_lookup(Query const& query, ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; topk.set_threshold(threshold); + auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; + + auto unigram_cursor = [&]() { + auto cursors = index.max_scored_cursors(gsl::make_span(non_essential_terms), scorer); + ranges::sort(cursors, + [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); + auto non_essential_count = cursors.size(); + for (auto term : essential_unigrams) { + cursors.push_back(index.max_scored_cursor(term, scorer)); + } - std::array initial_payload{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; - - std::vector essential_unigram_cursors; - std::transform(essential_unigrams.begin(), - essential_unigrams.end(), - std::back_inserter(essential_unigram_cursors), - [&](auto term) { return index.scored_cursor(term, scorer); }); + std::vector upper_bounds(cursors.size()); + upper_bounds[0] = cursors[0].max_score(); + for (size_t i = 1; i < cursors.size(); ++i) { + upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); + } - std::vector unigram_query_positions(essential_unigrams.size()); - for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); - unigram_position += 1) { - unigram_query_positions[unigram_position] = - query.sorted_position(essential_unigrams[unigram_position]); - } - auto merged_unigrams = v1::union_merge( - essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); - } - acc[unigram_query_positions[term_idx]] = cursor.payload(); - return acc; - }); + // We're not interested in these being correct + std::vector cursor_idx(cursors.size()); + std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + + return join_union_lookup(std::move(cursors), + std::move(cursor_idx), + std::move(upper_bounds), + non_essential_count, + 0.0F, + pisa::v1::accumulate::Add{}, + is_above_threshold, + inspector); + }(); - std::vector essential_bigram_cursors; + using lookup_transform_type = LookupTransform; + using transform_payload_cursor_type = + TransformPayloadCursor; + std::vector bigram_cursors; for (auto [left, right] : essential_bigrams) { auto cursor = index.scored_bigram_cursor(left, right, scorer); if (not cursor) { throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); } - essential_bigram_cursors.push_back(cursor.take().value()); - } + std::vector essential_terms{left, right}; + auto lookup_terms = + ranges::views::set_difference(non_essential_terms, essential_terms) | ranges::to_vector; + + auto lookup_cursors = index.max_scored_cursors(lookup_terms, scorer); + ranges::sort(lookup_cursors, + [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); + + std::vector upper_bounds(lookup_cursors.size()); + if (not lookup_cursors.empty()) { + upper_bounds[0] = lookup_cursors[0].max_score(); + for (size_t i = 1; i < lookup_cursors.size(); ++i) { + upper_bounds[i] = upper_bounds[i - 1] + lookup_cursors[i].max_score(); + } + } + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); - std::vector> bigram_query_positions( - essential_bigrams.size()); - for (std::size_t bigram_position = 0; bigram_position < essential_bigrams.size(); - bigram_position += 1) { - bigram_query_positions[bigram_position] = - std::make_pair(query.sorted_position(essential_bigrams[bigram_position].first), - query.sorted_position(essential_bigrams[bigram_position].second)); + bigram_cursors.emplace_back(std::move(*cursor.take()), + lookup_transform_type(std::move(lookup_cursors), + std::move(upper_bounds), + lookup_cursors_upper_bound, + is_above_threshold, + inspector)); } - auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), - initial_payload, - [&](auto& acc, auto& cursor, auto bigram_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); - } - auto payload = cursor.payload(); - auto query_positions = - bigram_query_positions[bigram_idx]; - acc[query_positions.first] = std::get<0>(payload); - acc[query_positions.second] = std::get<1>(payload); - return acc; - }); - auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { - auto payload = cursor.payload(); - for (auto idx = 0; idx < acc.size(); idx += 1) { - if (acc[idx] == 0.0F) { - acc[idx] = payload[idx]; - } - } - return acc; + auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { + return std::max(acc, cursor.payload()); }; + auto bigram_cursor = union_merge(std::move(bigram_cursors), 0.0F, accumulate); auto merged = v1::variadic_union_merge( - initial_payload, - std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + 0.0F, + std::make_tuple(std::move(unigram_cursor), std::move(bigram_cursor)), std::make_tuple(accumulate, accumulate)); - auto lookup_cursors = [&]() { - std::vector> - lookup_cursors; - auto pos = term_ids.begin(); - for (auto non_essential_term : non_essential_terms) { - pos = std::find(pos, term_ids.end(), non_essential_term); - assert(pos != term_ids.end()); - auto idx = std::distance(term_ids.begin(), pos); - lookup_cursors.emplace_back(idx, index.max_scored_cursor(non_essential_term, scorer)); - } - return lookup_cursors; - }(); - std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs.second.max_score() > rhs.second.max_score(); - }); - auto lookup_cursors_upper_bound = std::accumulate( - lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { - return acc + cursor.second.max_score(); - }); - - v1::for_each(merged, [&](auto& cursor) { - if constexpr (not std::is_void_v) { - analyzer->document(); - } - auto docid = cursor.value(); - auto scores = cursor.payload(); - auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); - auto upper_bound = score + lookup_cursors_upper_bound; - for (auto& [idx, lookup_cursor] : lookup_cursors) { - if (not topk.would_enter(upper_bound)) { - return; - } - if (scores[idx] == 0) { - lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - analyzer->lookup(); - } - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - auto partial_score = lookup_cursor.payload(); - score += partial_score; - upper_bound += partial_score; - } - } - upper_bound -= lookup_cursor.max_score(); - } - if constexpr (not std::is_void_v) { - if (topk.insert(score, docid)) { - analyzer->insert(); + v1::for_each(merged, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspector->insert(); } } else { - topk.insert(score, docid); + topk.insert(cursor.payload(), cursor.value()); } }); return topk; } -template -auto two_phase_union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - Analyzer* analyzer = nullptr) +template +auto union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { @@ -908,8 +873,8 @@ auto two_phase_union_lookup(Query const& query, } auto merged_unigrams = v1::union_merge( essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto term_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); + if constexpr (not std::is_void_v) { + inspect->posting(); } acc[unigram_query_positions[term_idx]] = cursor.payload(); return acc; @@ -935,8 +900,8 @@ auto two_phase_union_lookup(Query const& query, auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), initial_payload, [&](auto& acc, auto& cursor, auto bigram_idx) { - if constexpr (not std::is_void_v) { - analyzer->posting(); + if constexpr (not std::is_void_v) { + inspect->posting(); } auto payload = cursor.payload(); auto query_positions = @@ -946,6 +911,20 @@ auto two_phase_union_lookup(Query const& query, return acc; }); + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + if (acc[idx] == 0.0F) { + acc[idx] = payload[idx]; + } + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); + auto lookup_cursors = [&]() { std::vector> lookup_cursors; @@ -966,9 +945,9 @@ auto two_phase_union_lookup(Query const& query, return acc + cursor.second.max_score(); }); - auto accumulate_document = [&](auto& cursor) { - if constexpr (not std::is_void_v) { - analyzer->document(); + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); } auto docid = cursor.value(); auto scores = cursor.payload(); @@ -980,8 +959,8 @@ auto two_phase_union_lookup(Query const& query, } if (scores[idx] == 0) { lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - analyzer->lookup(); + if constexpr (not std::is_void_v) { + inspect->lookup(); } if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { auto partial_score = lookup_cursor.payload(); @@ -991,14 +970,14 @@ auto two_phase_union_lookup(Query const& query, } upper_bound -= lookup_cursor.max_score(); } - topk.insert(score, docid); - if constexpr (not std::is_void_v) { - analyzer->insert(); + if constexpr (not std::is_void_v) { + if (topk.insert(score, docid)) { + inspect->insert(); + } + } else { + topk.insert(score, docid); } - }; - - v1::for_each(merged_unigrams, accumulate_document); - v1::for_each(merged_bigrams, accumulate_document); + }); return topk; } diff --git a/include/pisa/v1/zip_cursor.hpp b/include/pisa/v1/zip_cursor.hpp index 2fe4e5241..0aa4f98b4 100644 --- a/include/pisa/v1/zip_cursor.hpp +++ b/include/pisa/v1/zip_cursor.hpp @@ -19,8 +19,9 @@ struct ZipCursor { [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Value { - auto deref = [](auto... cursors) { return std::make_tuple(cursors.value()...); }; - return std::apply(deref, m_cursors); + // auto deref = [](auto... cursors) { return std::make_tuple(cursors.value()...); }; + // return std::apply(deref, m_cursors); + return std::make_tuple(std::get<0>(m_cursors).value(), std::get<1>(m_cursors).value()); } constexpr void advance() { @@ -35,8 +36,8 @@ struct ZipCursor { // TODO: Why generic doesn't work? // auto advance_all = [pos](auto... cursors) { (cursors.advance_to_position(pos), ...); }; // std::apply(advance_all, m_cursors); - std::get<0>(m_cursors).advance(); - std::get<1>(m_cursors).advance(); + std::get<0>(m_cursors).advance_to_position(pos); + std::get<1>(m_cursors).advance_to_position(pos); } [[nodiscard]] constexpr auto empty() const noexcept -> bool { diff --git a/script/cw09b.sh b/script/cw09b.sh index e83bdf90a..2040c90cf 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -24,18 +24,14 @@ set -x #./bin/score -i "${BASENAME}.yml" -j ${THREADS} # This will produce both quantized scores and max scores (both quantized and not). -# ./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} +#./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} # Filter out queries witout existing terms. -#${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" \ -# > ${FILTERED_QUERIES} +#${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" > ${FILTERED_QUERIES} # Extract thresholds (TODO: estimates) -#${PISA_BIN}/thresholds -t ${TYPE} -i ${INV}.${TYPE} \ -# -w ${INV}.wand -q ${FILTERED_QUERIES} -k ${K} --terms "${FWD}.termlex" --stemmer porter2 \ -# | grep -v "\[warning\]" \ -# > ${THRESHOLDS} -#cut -d: -f1 ${FILTERED_QUERIES} | paste - ${OUTPUT_DIR}/thresholds > ${OUTPUT_DIR}/thresholds.tsv +#${PISA_BIN}/threshold -i ${BASENAME}.yml -q ${FILTERED_QUERIES} -k ${K} \ +# | grep -v "\[warning\]" > ${THRESHOLDS} #cut -d: -f1 ${FILTERED_QUERIES} | paste - ${THRESHOLDS} > ${OUTPUT_DIR}/thresholds.tsv # Extract intersections @@ -44,45 +40,36 @@ set -x # | grep -v "\[warning\]" \ # > ${OUTPUT_DIR}/intersections.tsv -## Select unigrams +# Select unigrams #${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ # --terse --max 1 > ${OUTPUT_DIR}/selections.1 -# -## Select unigrams and bigrams #${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ # --terse --max 2 > ${OUTPUT_DIR}/selections.2 -# -## Select unigrams and bigrams scaled #${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ # --terse --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 -# -# Select unigrams and bigrams scaled -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -# +#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ +# --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 #${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ # --terse --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart # Run benchmarks -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ -# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-threshold -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ -# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-union-lookup -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.unigram-union-lookup -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.union-lookup.1 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.union-lookup.2 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm two-phase-union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.two-phase-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.union-lookup.2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm lookup-union \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.lookup-union #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ # --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/bench.union-lookup.scaled-1.5 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ # --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/bench.union-lookup.scaled-2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ # --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart @@ -94,22 +81,20 @@ ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algori --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.union-lookup.1 ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm two-phase-union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.two-phase-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm lookup-union \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart -# Evaluate +## Evaluate ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-threshold" @@ -117,17 +102,15 @@ ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxsco --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-union-lookup" ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm unigram-union-lookup \ --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.unigram-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.union-lookup.1" ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.union-lookup.2" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm two-phase-union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.two-phase-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm lookup-union \ + --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.lookup-union" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index a2335f994..69133a474 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -83,7 +83,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) std::vector> const& bigrams) -> std::pair { - auto run = scored_index_runner(meta, + auto run = scored_index_runner(std::move(meta), RawReader{}, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, @@ -92,10 +92,13 @@ auto verify_compressed_index(std::string const& input, std::string_view output) std::vector> pair_mapping; auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); auto scores_file_1 = fmt::format("{}.bigram_bm25_1", index_basename); + // auto compound_scores_file = fmt::format("{}.bigram_bm25", index_basename); auto score_offsets_file_0 = fmt::format("{}.bigram_bm25_offsets_0", index_basename); auto score_offsets_file_1 = fmt::format("{}.bigram_bm25_offsets_1", index_basename); + // auto compound_score_offsets_file = fmt::format("{}.bigram_bm25_offsets", index_basename); std::ofstream score_out_0(scores_file_0); std::ofstream score_out_1(scores_file_1); + // std::ofstream compound_score_out(compound_scores_file); run([&](auto&& index) { ProgressStatus status(bigrams.size(), @@ -107,6 +110,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) PostingBuilder score_builder_0(score_writer_type{}); PostingBuilder score_builder_1(score_writer_type{}); + PostingBuilder(score_writer_type{}); score_builder_0.write_header(score_out_0); score_builder_1.write_header(score_out_1); @@ -208,7 +212,7 @@ void build_bigram_index(std::string const& yml, write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); std::cerr << " Done.\n"; }); - meta.bigrams = BigramMetadata{ + BigramMetadata bigram_meta{ .documents = {.postings = documents_file, .offsets = document_offsets_file}, .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, @@ -216,8 +220,9 @@ void build_bigram_index(std::string const& yml, .mapping = fmt::format("{}.bigram_mapping", index_basename), .count = pair_mapping.size()}; if (not meta.scores.empty()) { - meta.bigrams->scores.push_back(build_scored_bigram_index(meta, index_basename, bigrams)); + bigram_meta.scores.push_back(build_scored_bigram_index(meta, index_basename, bigrams)); } + meta.bigrams = bigram_meta; std::cerr << "Writing metadata..."; meta.write(yml); diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 717b79bd4..dd680d969 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -65,6 +65,13 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; }(), .bigrams = [&]() -> tl::optional { if (config[BIGRAM]) { + std::vector> scores; + if (config[BIGRAM]["scores_0"]) { + scores = {{{.postings = config[BIGRAM]["scores_0"][POSTINGS].as(), + .offsets = config[BIGRAM]["scores_0"][OFFSETS].as()}, + {.postings = config[BIGRAM]["scores_1"][POSTINGS].as(), + .offsets = config[BIGRAM]["scores_1"][OFFSETS].as()}}}; + } return BigramMetadata{ .documents = {.postings = config[BIGRAM][DOCUMENTS][POSTINGS].as(), .offsets = config[BIGRAM][DOCUMENTS][OFFSETS].as()}, @@ -73,11 +80,7 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; .offsets = config[BIGRAM]["frequencies_0"][OFFSETS].as()}, {.postings = config[BIGRAM]["frequencies_1"][POSTINGS].as(), .offsets = config[BIGRAM]["frequencies_1"][OFFSETS].as()}}, - .scores = {{{.postings = config[BIGRAM]["scores_0"][POSTINGS].as(), - .offsets = config[BIGRAM]["scores_0"][OFFSETS].as()}, - {.postings = config[BIGRAM]["scores_1"][POSTINGS].as(), - .offsets = - config[BIGRAM]["scores_1"][OFFSETS].as()}}}, + .scores = std::move(scores), .mapping = config[BIGRAM]["mapping"].as(), .count = config[BIGRAM]["count"].as()}; } diff --git a/src/v1/query.cpp b/src/v1/query.cpp index 170574270..30f984b5a 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -1,5 +1,8 @@ +#include + #include #include +#include #include "v1/query.hpp" @@ -15,6 +18,25 @@ using json = nlohmann::json; return terms; } +[[nodiscard]] auto ListSelection::overlapping() const -> bool +{ + std::unordered_set terms; + for (auto term : unigrams) { + terms.insert(term); + } + for (auto [left, right] : bigrams) { + if (terms.find(left) != terms.end()) { + return true; + } + if (terms.find(right) != terms.end()) { + return true; + } + terms.insert(left); + terms.insert(right); + } + return false; +} + void Query::add_selections(gsl::span const> selections) { m_selections = ListSelection{}; @@ -34,6 +56,8 @@ void Query::add_selections(gsl::span const> selections) } } } + ranges::sort(m_selections->unigrams); + ranges::sort(m_selections->bigrams); } auto Query::resolve_term(std::size_t pos) -> TermId diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 753014c98..7269408d3 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -60,8 +60,8 @@ struct IndexFixture { index_basename); REQUIRE(errors.empty()); auto yml = fmt::format("{}.yml", index_basename); - v1::build_bigram_index(yml, collect_unique_bigrams(test_queries(), []() {})); v1::score_index(yml, 1); + v1::build_bigram_index(yml, collect_unique_bigrams(test_queries(), []() {})); } [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 606983466..8f547588f 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -132,12 +132,15 @@ TEMPLATE_TEST_CASE("Query", Index, RawCursor>, Index, RawCursor>>::get(); TestType fixture; - auto input_data = GENERATE(table({{"daat_or", false}, - {"maxscore", false}, - {"maxscore", true}, - {"maxscore_union_lookup", true}, - {"unigram_union_lookup", true}, - {"union_lookup", true}})); + auto input_data = GENERATE(table({ + {"daat_or", false}, + {"maxscore", false}, + {"maxscore", true}, + {"maxscore_union_lookup", true}, + {"unigram_union_lookup", true}, + {"union_lookup", true}, + //{"disjoint_union_lookup", true} + })); std::string algorithm = std::get<0>(input_data); bool with_threshold = std::get<1>(input_data); CAPTURE(algorithm); @@ -165,13 +168,19 @@ TEMPLATE_TEST_CASE("Query", } return union_lookup(query, index, topk_queue(10), scorer); } + if (name == "disjoint_union_lookup") { + if (query.get_term_ids().size() > 8 || query.get_selections().overlapping()) { + return maxscore_union_lookup(query, index, topk_queue(10), scorer); + } + return disjoint_union_lookup(query, index, topk_queue(10), scorer); + } std::abort(); }; int idx = 0; auto const intersections = pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); for (auto& query : test_queries()) { - if (algorithm == "union_lookup") { + if (algorithm == "union_lookup" || algorithm == "disjoint_union_lookup") { query.selections(gsl::make_span(intersections[idx])); } @@ -195,30 +204,21 @@ TEMPLATE_TEST_CASE("Query", auto que = run_query(algorithm, query, index, make_bm25(index)); que.finalize(); results = que.topk(); - results.erase(std::remove_if(results.begin(), - results.end(), - [last_score = results.back().first](auto&& entry) { - return entry.first <= last_score; - }), - results.end()); - std::sort(results.begin(), results.end(), approximate_order); + if (not results.empty()) { + results.erase(std::remove_if(results.begin(), + results.end(), + [last_score = results.back().first](auto&& entry) { + return entry.first <= last_score; + }), + results.end()); + std::sort(results.begin(), results.end(), approximate_order); + } }); return results; }(); expected.resize(on_the_fly.size()); std::sort(expected.begin(), expected.end(), approximate_order); - // if (algorithm == "union_lookup") { - // for (size_t i = 0; i < on_the_fly.size(); ++i) { - // std::cerr << fmt::format("{}, {:f} -- {}, {:f}\n", - // on_the_fly[i].second, - // on_the_fly[i].first, - // expected[i].second, - // expected[i].first); - // } - //} - // std::cerr << '\n'; - for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); diff --git a/test/v1/test_v1_query.cpp b/test/v1/test_v1_query.cpp new file mode 100644 index 000000000..a7c3c06d8 --- /dev/null +++ b/test/v1/test_v1_query.cpp @@ -0,0 +1,23 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include +#include + +#include "pisa_config.hpp" +#include "v1/query.hpp" + +using pisa::v1::ListSelection; + +TEST_CASE("List selections are overlapping", "[v1][unit]") +{ + REQUIRE(not ListSelection{.unigrams = {0, 1, 2}, .bigrams = {}}.overlapping()); + REQUIRE(not ListSelection{.unigrams = {0, 1, 2}, .bigrams = {{0, 1}, {2, 3}}}.overlapping()); + REQUIRE(ListSelection{.unigrams = {0, 1, 1, 2}, .bigrams = {}}.overlapping()); + REQUIRE(ListSelection{.unigrams = {0, 1, 2}, .bigrams = {{0, 3}}}.overlapping()); + REQUIRE(ListSelection{.unigrams = {}, .bigrams = {{0, 1}, {1, 3}}}.overlapping()); +} diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index ebac433fe..f16294273 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -15,3 +15,6 @@ target_link_libraries(bigram-index pisa CLI11) add_executable(filter-queries filter_queries.cpp) target_link_libraries(filter-queries pisa CLI11) + +add_executable(threshold threshold.cpp) +target_link_libraries(threshold pisa CLI11) diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index fedb44194..52c775026 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -12,7 +12,6 @@ #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" -#include "v1/analyze_query.hpp" #include "v1/blocked_cursor.hpp" #include "v1/daat_or.hpp" #include "v1/index_metadata.hpp" @@ -29,21 +28,11 @@ using pisa::resolve_query_parser; using pisa::TermProcessor; using pisa::v1::BlockedReader; using pisa::v1::daat_or; -using pisa::v1::DaatOrAnalyzer; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; -using pisa::v1::ListSelection; -using pisa::v1::maxscore_union_lookup; -using pisa::v1::MaxscoreAnalyzer; -using pisa::v1::MaxscoreUnionLookupAnalyzer; using pisa::v1::Query; -using pisa::v1::QueryAnalyzer; using pisa::v1::RawReader; using pisa::v1::resolve_yml; -using pisa::v1::unigram_union_lookup; -using pisa::v1::UnigramUnionLookupAnalyzer; -using pisa::v1::union_lookup; -using pisa::v1::UnionLookupAnalyzer; using pisa::v1::VoidScorer; int main(int argc, char** argv) diff --git a/v1/query.cpp b/v1/query.cpp index 334158eeb..6d8e7124d 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -13,10 +13,10 @@ #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" -#include "v1/analyze_query.hpp" #include "v1/blocked_cursor.hpp" #include "v1/daat_or.hpp" #include "v1/index_metadata.hpp" +#include "v1/inspect_query.hpp" #include "v1/intersection.hpp" #include "v1/maxscore.hpp" #include "v1/query.hpp" @@ -29,22 +29,23 @@ using pisa::resolve_query_parser; using pisa::v1::BlockedReader; using pisa::v1::daat_or; -using pisa::v1::DaatOrAnalyzer; +using pisa::v1::DaatOrInspector; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::ListSelection; +using pisa::v1::lookup_union; +using pisa::v1::LookupUnionInspector; using pisa::v1::maxscore_union_lookup; -using pisa::v1::MaxscoreAnalyzer; -using pisa::v1::MaxscoreUnionLookupAnalyzer; +using pisa::v1::MaxscoreInspector; +using pisa::v1::MaxscoreUnionLookupInspect; using pisa::v1::Query; -using pisa::v1::QueryAnalyzer; +using pisa::v1::QueryInspector; using pisa::v1::RawReader; using pisa::v1::resolve_yml; -using pisa::v1::TwoPhaseUnionLookupAnalyzer; using pisa::v1::unigram_union_lookup; -using pisa::v1::UnigramUnionLookupAnalyzer; +using pisa::v1::UnigramUnionLookupInspect; using pisa::v1::union_lookup; -using pisa::v1::UnionLookupAnalyzer; +using pisa::v1::UnionLookupInspect; using pisa::v1::VoidScorer; using RetrievalAlgorithm = std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)>; @@ -92,13 +93,13 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco query, index, std::move(topk), std::forward(scorer)); }); } - if (name == "two-phase-union-lookup") { + if (name == "lookup-union") { return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.get_term_ids().size() > 8) { - return pisa::v1::maxscore_union_lookup( + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( query, index, std::move(topk), std::forward(scorer)); } - return pisa::v1::two_phase_union_lookup( + return pisa::v1::lookup_union( query, index, std::move(topk), std::forward(scorer)); }); } @@ -107,28 +108,27 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco } template -auto resolve_analyze(std::string const& name, Index const& index, Scorer&& scorer) -> QueryAnalyzer +auto resolve_inspect(std::string const& name, Index const& index, Scorer&& scorer) -> QueryInspector { if (name == "daat_or") { - return QueryAnalyzer(DaatOrAnalyzer(index, std::forward(scorer))); + return QueryInspector(DaatOrInspector(index, std::forward(scorer))); } if (name == "maxscore") { - return QueryAnalyzer(MaxscoreAnalyzer(index, std::forward(scorer))); + return QueryInspector(MaxscoreInspector(index, std::forward(scorer))); } if (name == "maxscore-union-lookup") { - return QueryAnalyzer( - MaxscoreUnionLookupAnalyzer>(index, scorer)); + return QueryInspector( + MaxscoreUnionLookupInspect>(index, scorer)); } if (name == "unigram-union-lookup") { - return QueryAnalyzer( - UnigramUnionLookupAnalyzer>(index, scorer)); + return QueryInspector( + UnigramUnionLookupInspect>(index, scorer)); } if (name == "union-lookup") { - return QueryAnalyzer(UnionLookupAnalyzer>(index, scorer)); + return QueryInspector(UnionLookupInspect>(index, scorer)); } - if (name == "two-phase-union-lookup") { - return QueryAnalyzer( - TwoPhaseUnionLookupAnalyzer>(index, scorer)); + if (name == "lookup-union") { + return QueryInspector(LookupUnionInspector>(index, scorer)); } spdlog::error("Unknown algorithm: {}", name); std::exit(1); @@ -186,15 +186,12 @@ void benchmark(std::vector const& queries, int k, RetrievalAlgorithm retr spdlog::info("95% quantile: {} us", q95); } -void analyze_queries(std::vector const& queries, QueryAnalyzer analyzer) +void inspect_queries(std::vector const& queries, QueryInspector inspect) { - std::vector times(queries.size(), std::numeric_limits::max()); - for (auto run = 0; run < 5; run += 1) { - for (auto query = 0; query < queries.size(); query += 1) { - analyzer(queries[query]); - } + for (auto query = 0; query < queries.size(); query += 1) { + inspect(queries[query]); } - std::move(analyzer).summarize(); + std::move(inspect).summarize(); } int main(int argc, char** argv) @@ -205,13 +202,13 @@ int main(int argc, char** argv) std::string algorithm = "daat_or"; tl::optional threshold_file; tl::optional inter_filename; - bool analyze = false; + bool inspect = false; pisa::QueryApp app("Queries a v1 index."); app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.", false); app.add_option("--intersections", inter_filename, "Intersections filename"); - app.add_flag("--analyze", analyze, "Analyze query execution and stats"); + app.add_flag("--inspect", inspect, "Analyze query execution and stats"); CLI11_PARSE(app, argc, argv); try { @@ -290,8 +287,8 @@ int main(int argc, char** argv) run([&](auto&& index) { if (app.is_benchmark) { benchmark(queries, app.k, resolve_algorithm(algorithm, index, VoidScorer{})); - } else if (analyze) { - analyze_queries(queries, resolve_analyze(algorithm, index, VoidScorer{})); + } else if (inspect) { + inspect_queries(queries, resolve_inspect(algorithm, index, VoidScorer{})); } else { evaluate( queries, app.k, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); @@ -307,8 +304,8 @@ int main(int argc, char** argv) with_scorer("bm25", [&](auto scorer) { if (app.is_benchmark) { benchmark(queries, app.k, resolve_algorithm(algorithm, index, scorer)); - } else if (analyze) { - analyze_queries(queries, resolve_analyze(algorithm, index, scorer)); + } else if (inspect) { + inspect_queries(queries, resolve_inspect(algorithm, index, scorer)); } else { evaluate( queries, app.k, docmap, resolve_algorithm(algorithm, index, scorer)); diff --git a/v1/threshold.cpp b/v1/threshold.cpp new file mode 100644 index 000000000..a33fecfa8 --- /dev/null +++ b/v1/threshold.cpp @@ -0,0 +1,108 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "query/queries.hpp" +#include "timer.hpp" +#include "topk_queue.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/daat_or.hpp" +#include "v1/index_metadata.hpp" +#include "v1/query.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/scorer/runner.hpp" +#include "v1/types.hpp" + +using pisa::resolve_query_parser; +using pisa::v1::BlockedReader; +using pisa::v1::index_runner; +using pisa::v1::IndexMetadata; +using pisa::v1::Query; +using pisa::v1::RawReader; +using pisa::v1::resolve_yml; +using pisa::v1::VoidScorer; + +template +void calculate_thresholds(Index&& index, Scorer&& scorer, std::vector const& queries, int k) +{ + for (auto const& query : queries) { + auto results = + pisa::v1::daat_or(query, index, ::pisa::topk_queue(k), std::forward(scorer)); + results.finalize(); + float threshold = 0.0; + if (not results.topk().empty()) { + threshold = results.topk().back().first; + } + std::cout << threshold << '\n'; + } +} + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + pisa::QueryApp app("Calculates thresholds for a v1 index."); + CLI11_PARSE(app, argc, argv); + + try { + auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); + auto stemmer = + meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + if (meta.term_lexicon) { + app.terms_file = meta.term_lexicon.value(); + } + + auto queries = [&]() { + std::vector<::pisa::Query> queries; + auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); + if (app.query_file) { + std::ifstream is(*app.query_file); + pisa::io::for_each_line(is, parse_query); + } else { + pisa::io::for_each_line(std::cin, parse_query); + } + std::vector v1_queries(queries.size()); + std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { + Query query(parsed.terms); + if (parsed.id) { + query.id(*parsed.id); + } + query.k(app.k); + return query; + }); + return v1_queries; + }(); + + if (app.precomputed) { + auto run = scored_index_runner(meta, + RawReader{}, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { calculate_thresholds(index, VoidScorer{}, queries, app.k); }); + } else { + auto run = index_runner(meta, + RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { + auto with_scorer = scorer_runner(index, make_bm25(index)); + with_scorer("bm25", [&](auto scorer) { + calculate_thresholds(index, scorer, queries, app.k); + }); + }); + } + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); + } + return 0; +} From f44507a7b4af8874925cab535d0eddcca0ca4613 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 6 Dec 2019 13:38:53 +0000 Subject: [PATCH 31/56] Union-lookup cleanup --- include/pisa/v1/cursor_accumulator.hpp | 2 +- include/pisa/v1/maxscore.hpp | 28 +- include/pisa/v1/query.hpp | 2 - include/pisa/v1/union_lookup.hpp | 792 +++++++++---------------- script/cw09b.sh | 22 +- src/v1/query.cpp | 19 - test/v1/test_v1_maxscore_join.cpp | 70 +-- test/v1/test_v1_queries.cpp | 95 +-- test/v1/test_v1_query.cpp | 23 - v1/union_lookup.cpp | 262 -------- 10 files changed, 343 insertions(+), 972 deletions(-) delete mode 100644 test/v1/test_v1_query.cpp delete mode 100644 v1/union_lookup.cpp diff --git a/include/pisa/v1/cursor_accumulator.hpp b/include/pisa/v1/cursor_accumulator.hpp index 760cc92c8..bde0ba447 100644 --- a/include/pisa/v1/cursor_accumulator.hpp +++ b/include/pisa/v1/cursor_accumulator.hpp @@ -7,7 +7,7 @@ namespace pisa::v1::accumulate { struct Add { template - auto operator()(Score&& score, Cursor&& cursor, std::size_t /* term_idx */) + auto operator()(Score&& score, Cursor&& cursor) { score += cursor.payload(); return score; diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index 80531e886..5b69b584f 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -8,6 +8,7 @@ #include #include "v1/algorithm.hpp" +#include "v1/cursor_accumulator.hpp" #include "v1/query.hpp" namespace pisa::v1 { @@ -32,7 +33,6 @@ struct MaxScoreJoin { AccumulateFn accumulate, ThresholdFn above_threshold) : m_cursors(std::move(cursors)), - m_cursor_idx(m_cursors.size()), m_upper_bounds(m_cursors.size()), m_init(std::move(init)), m_accumulate(std::move(accumulate)), @@ -48,7 +48,6 @@ struct MaxScoreJoin { ThresholdFn above_threshold, Inspect* inspect) : m_cursors(std::move(cursors)), - m_cursor_idx(m_cursors.size()), m_upper_bounds(m_cursors.size()), m_init(std::move(init)), m_accumulate(std::move(accumulate)), @@ -65,10 +64,6 @@ struct MaxScoreJoin { m_current_value = sentinel(); m_current_payload = m_init; } - std::iota(m_cursor_idx.begin(), m_cursor_idx.end(), 0); - std::sort(m_cursor_idx.begin(), m_cursor_idx.end(), [this](auto&& lhs, auto&& rhs) { - return m_cursors[lhs].max_score() < m_cursors[rhs].max_score(); - }); std::sort(m_cursors.begin(), m_cursors.end(), [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); @@ -119,8 +114,7 @@ struct MaxScoreJoin { if constexpr (not std::is_void_v) { m_inspect->posting(); } - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + m_current_payload = m_accumulate(m_current_payload, cursor); cursor.advance(); } if (auto docid = cursor.value(); docid < m_next_docid) { @@ -141,8 +135,7 @@ struct MaxScoreJoin { m_inspect->lookup(); } if (cursor.value() == m_current_value) { - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + m_current_payload = m_accumulate(m_current_payload, cursor); } } } @@ -161,7 +154,6 @@ struct MaxScoreJoin { private: CursorContainer m_cursors; - std::vector m_cursor_idx; std::vector m_upper_bounds; payload_type m_init; AccumulateFn m_accumulate; @@ -214,15 +206,12 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& using value_type = decltype(index.max_scored_cursor(0, scorer).value()); auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); - auto accumulate = [](float& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }; if (query.threshold()) { topk.set_threshold(*query.threshold()); } - auto joined = join_maxscore( - std::move(cursors), 0.0F, accumulate, [&](auto score) { return topk.would_enter(score); }); + auto joined = join_maxscore(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + return topk.would_enter(score); + }); v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); return topk; } @@ -261,10 +250,7 @@ struct MaxscoreInspector { auto joined = join_maxscore( std::move(cursors), 0.0F, - [&](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }, + accumulate::Add{}, [&](auto score) { return topk.would_enter(score); }, this); v1::for_each(joined, [&](auto& cursor) { diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index c8e8564a3..0153b99ee 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -22,8 +22,6 @@ namespace pisa::v1 { struct ListSelection { std::vector unigrams{}; std::vector> bigrams{}; - - [[nodiscard]] auto overlapping() const -> bool; }; struct TermIdSet { diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index fb88a6c42..6580bd8b8 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -3,8 +3,6 @@ #include #include -#include -#include #include #include "v1/algorithm.hpp" @@ -14,69 +12,102 @@ namespace pisa::v1 { -template struct UnionLookupJoin { - using cursor_type = typename CursorContainer::value_type; + + using essential_cursor_type = typename EssentialCursors::value_type; + using lookup_cursor_type = typename LookupCursors::value_type; + using payload_type = Payload; - using value_type = std::decay_t())>; + using value_type = std::decay_t())>; - using iterator_category = - typename std::iterator_traits::iterator_category; - static_assert(std::is_base_of(), + using essential_iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), "cursors must be stored in a random access container"); - UnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold) - : m_cursors(std::move(cursors)), - m_cursor_idx(std::move(cursor_idx)), - m_upper_bounds(std::move(upper_bounds)), - m_non_essential_count(non_essential_count), - m_init(std::move(init)), - m_accumulate(std::move(accumulate)), - m_above_threshold(std::move(above_threshold)), - m_size(std::nullopt) - { - initialize(); - } - - UnionLookupJoin(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, + UnionLookupJoin(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, Payload init, AccumulateFn accumulate, ThresholdFn above_threshold, - Inspect* inspect) - : m_cursors(std::move(cursors)), - m_cursor_idx(std::move(cursor_idx)), - m_upper_bounds(std::move(upper_bounds)), - m_non_essential_count(non_essential_count), + Inspect* inspect = nullptr) + : m_essential_cursors(std::move(essential_cursors)), + m_lookup_cursors(std::move(lookup_cursors)), m_init(std::move(init)), m_accumulate(std::move(accumulate)), m_above_threshold(std::move(above_threshold)), - m_size(std::nullopt), m_inspect(inspect) { - initialize(); - } - - void initialize() - { - if (m_cursors.empty()) { - m_current_value = sentinel(); + if (m_essential_cursors.empty()) { + m_sentinel = std::numeric_limits::max(); + m_current_value = m_sentinel; m_current_payload = m_init; + return; } - m_next_docid = min_value(m_cursors); - m_sentinel = min_sentinel(m_cursors); + m_lookup_cumulative_upper_bound = std::accumulate( + m_lookup_cursors.begin(), m_lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + m_next_docid = min_value(m_essential_cursors); + m_sentinel = min_sentinel(m_essential_cursors); advance(); } @@ -95,8 +126,7 @@ struct UnionLookupJoin { { bool exit = false; while (not exit) { - if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() - || m_next_docid >= sentinel())) { + if (PISA_UNLIKELY(m_next_docid >= sentinel())) { m_current_value = sentinel(); m_current_payload = m_init; return; @@ -108,16 +138,12 @@ struct UnionLookupJoin { m_inspect->document(); } - for (auto sorted_position = m_non_essential_count; sorted_position < m_cursors.size(); - sorted_position += 1) { - - auto& cursor = m_cursors[sorted_position]; + for (auto&& cursor : m_essential_cursors) { if (cursor.value() == m_current_value) { if constexpr (not std::is_void_v) { m_inspect->posting(); } - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + m_current_payload = m_accumulate(m_current_payload, cursor); cursor.advance(); } if (auto docid = cursor.value(); docid < m_next_docid) { @@ -126,145 +152,81 @@ struct UnionLookupJoin { } exit = true; - for (auto sorted_position = m_non_essential_count - 1; sorted_position + 1 > 0; - sorted_position -= 1) { - if (not m_above_threshold(m_current_payload + m_upper_bounds[sorted_position])) { + auto lookup_bound = m_lookup_cumulative_upper_bound; + for (auto&& cursor : m_lookup_cursors) { + if (not m_above_threshold(m_current_payload + lookup_bound)) { exit = false; break; } - auto& cursor = m_cursors[sorted_position]; cursor.advance_to_geq(m_current_value); if constexpr (not std::is_void_v) { m_inspect->lookup(); } if (cursor.value() == m_current_value) { - m_current_payload = - m_accumulate(m_current_payload, cursor, m_cursor_idx[sorted_position]); + m_current_payload = m_accumulate(m_current_payload, cursor); } + lookup_bound -= cursor.max_score(); } } + m_position += 1; } - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_current_value >= sentinel(); } private: - CursorContainer m_cursors; - std::vector m_cursor_idx; - std::vector m_upper_bounds; - std::size_t m_non_essential_count; + EssentialCursors m_essential_cursors; + LookupCursors m_lookup_cursors; payload_type m_init; AccumulateFn m_accumulate; ThresholdFn m_above_threshold; - std::optional m_size; value_type m_current_value{}; value_type m_sentinel{}; payload_type m_current_payload{}; std::uint32_t m_next_docid{}; payload_type m_previous_threshold{}; + payload_type m_lookup_cumulative_upper_bound{}; + std::size_t m_position = 0; Inspect* m_inspect; }; -template -auto join_union_lookup(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, - Payload init, - AccumulateFn accumulate, - ThresholdFn threshold) -{ - return UnionLookupJoin( - std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, - std::move(init), - std::move(accumulate), - std::move(threshold)); -} - -template -auto join_union_lookup(CursorContainer cursors, - std::vector cursor_idx, - std::vector upper_bounds, - std::size_t non_essential_count, + typename Inspect = void> +auto join_union_lookup(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, Payload init, AccumulateFn accumulate, ThresholdFn threshold, Inspect* inspect) { - return UnionLookupJoin( - std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, - std::move(init), - std::move(accumulate), - std::move(threshold), - inspect); + return UnionLookupJoin(std::move(essential_cursors), + std::move(lookup_cursors), + std::move(init), + std::move(accumulate), + std::move(threshold), + inspect); } -namespace detail { - template - auto unigram_union_lookup(Cursors cursors, - UpperBounds upper_bounds, - std::size_t non_essential_count, - topk_queue topk, - [[maybe_unused]] Inspect* inspect = nullptr) - { - auto merged_essential = - v1::union_merge(gsl::make_span(cursors).subspan(non_essential_count), - 0.0F, - [&](auto& acc, auto& cursor, auto /*term_idx*/) { - if constexpr (not std::is_void_v) { - inspect->posting(); - } - acc += cursor.payload(); - return acc; - }); - - v1::for_each(merged_essential, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { - inspect->document(); - } - auto docid = cursor.value(); - auto score = cursor.payload(); - for (auto lookup_cursor_idx = non_essential_count - 1; lookup_cursor_idx + 1 > 0; - lookup_cursor_idx -= 1) { - if (not topk.would_enter(score + upper_bounds[lookup_cursor_idx])) { - return; - } - auto& lookup_cursor = cursors[lookup_cursor_idx]; - lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - inspect->lookup(); - } - if (lookup_cursor.value() == docid) { - score += lookup_cursor.payload(); - } - } - if constexpr (not std::is_void_v) { - if (topk.insert(score, docid)) { - inspect->insert(); - } - } else { - topk.insert(score, docid); - } - }); - return topk; - } -} // namespace detail - +/// Processes documents with the Union-Lookup method. +/// +/// This is an optimized version that works **only on single-term posting lists**. +/// It will throw an exception if bigram selections are passed to it. template auto unigram_union_lookup(Query const& query, Index const& index, @@ -272,65 +234,35 @@ auto unigram_union_lookup(Query const& query, Scorer&& scorer, [[maybe_unused]] Inspect* inspect = nullptr) { + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { return topk; } - if (not query.threshold()) { - throw std::invalid_argument("Must provide threshold to the query"); - } - if (not query.selections()) { - throw std::invalid_argument("Must provide essential list selection"); - } - if (not query.selections()->bigrams.empty()) { - throw std::invalid_argument("This algorithm only supports unigrams"); - } + auto const& selections = query.get_selections(); + ensure(not selections.bigrams.empty(), "This algorithm only supports unigrams"); - topk.set_threshold(*query.threshold()); - - using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using payload_type = decltype(std::declval().payload()); + topk.set_threshold(query.get_threshold()); auto non_essential_terms = ranges::views::set_difference(term_ids, selections.unigrams) | ranges::to_vector; - std::vector cursors; - for (auto non_essential_term : non_essential_terms) { - cursors.push_back(index.max_scored_cursor(non_essential_term, scorer)); - } - auto non_essential_count = cursors.size(); - std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs.max_score() < rhs.max_score(); - }); - for (auto essential_term : selections.unigrams) { - cursors.push_back(index.max_scored_cursor(essential_term, scorer)); - } - - std::vector upper_bounds(cursors.size()); - upper_bounds[0] = cursors[0].max_score(); - for (size_t i = 1; i < cursors.size(); ++i) { - upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); - } - - // We're not interested in these being correct - std::vector cursor_idx(cursors.size()); - std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + std::vector lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); + ranges::sort(lookup_cursors, [](auto&& l, auto&& r) { return l.max_score() > r.max_score(); }); + std::vector essential_cursors = + index.max_scored_cursors(selections.unigrams, scorer); - auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }; auto joined = join_union_lookup( - std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, + std::move(essential_cursors), + std::move(lookup_cursors), payload_type{}, - accumulate, + accumulate::Add{}, [&](auto score) { return topk.would_enter(score); }, inspect); - v1::for_each(joined, [&](auto& cursor) { + v1::for_each(joined, [&](auto&& cursor) { if constexpr (not std::is_void_v) { if (topk.insert(cursor.payload(), cursor.value())) { inspect->insert(); @@ -342,6 +274,10 @@ auto unigram_union_lookup(Query const& query, return topk; } +/// This is a special case of Union-Lookup algorithm that does not use user-defined selections, +/// but rather uses the same way of determining essential list as Maxscore does. +/// The difference is that this algorithm will never update the threshold whereas Maxscore will +/// try to improve the estimate after each accumulated document. template auto maxscore_union_lookup(Query const& query, Index const& index, @@ -360,14 +296,7 @@ auto maxscore_union_lookup(Query const& query, topk.set_threshold(threshold); auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); - std::vector cursor_idx(cursors.size()); - std::iota(cursor_idx.begin(), cursor_idx.end(), 0); - std::sort(cursor_idx.begin(), cursor_idx.end(), [&](auto&& lhs, auto&& rhs) { - return cursors[lhs].max_score() < cursors[rhs].max_score(); - }); - std::sort(cursors.begin(), cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs.max_score() < rhs.max_score(); - }); + ranges::sort(cursors, [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); std::vector upper_bounds(cursors.size()); upper_bounds[0] = cursors[0].max_score(); @@ -379,17 +308,18 @@ auto maxscore_union_lookup(Query const& query, non_essential_count += 1; } - auto accumulate = [](auto& score, auto& cursor, auto /* term_position */) { - score += cursor.payload(); - return score; - }; + auto lookup_cursors = gsl::span(&cursors[0], non_essential_count); + auto essential_cursors = + gsl::span(&cursors[non_essential_count], cursors.size() - non_essential_count); + if (not lookup_cursors.empty()) { + std::reverse(lookup_cursors.begin(), lookup_cursors.end()); + } + auto joined = join_union_lookup( - std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, + std::move(essential_cursors), + std::move(lookup_cursors), payload_type{}, - accumulate, + accumulate::Add{}, [&](auto score) { return topk.would_enter(score); }, inspect); v1::for_each(joined, [&](auto& cursor) { @@ -404,257 +334,9 @@ auto maxscore_union_lookup(Query const& query, return topk; } -template -struct BaseUnionLookupInspect { - BaseUnionLookupInspect(Index const& index, Scorer scorer) - : m_index(index), m_scorer(std::move(scorer)) - { - std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); - } - - void reset_current() - { - m_current_documents = 0; - m_current_postings = 0; - m_current_lookups = 0; - m_current_inserts = 0; - } - - virtual void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) = 0; - - void operator()(Query const& query) - { - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return; - } - using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); - using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); - - reset_current(); - run(query, m_index, m_scorer, topk_queue(query.k())); - std::cout << fmt::format("{}\t{}\t{}\t{}\n", - m_current_documents, - m_current_postings, - m_current_inserts, - m_current_lookups); - m_documents += m_current_documents; - m_postings += m_current_postings; - m_lookups += m_current_lookups; - m_inserts += m_current_inserts; - m_count += 1; - } - - void summarize() && - { - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" - "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", - static_cast(m_documents) / m_count, - static_cast(m_postings) / m_count, - static_cast(m_inserts) / m_count, - static_cast(m_lookups) / m_count); - } - - void document() { m_current_documents += 1; } - void posting() { m_current_postings += 1; } - void lookup() { m_current_lookups += 1; } - void insert() { m_current_inserts += 1; } - - private: - std::size_t m_current_documents = 0; - std::size_t m_current_postings = 0; - std::size_t m_current_lookups = 0; - std::size_t m_current_inserts = 0; - - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_lookups = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - Index const& m_index; - Scorer m_scorer; -}; - -template -struct MaxscoreUnionLookupInspect : public BaseUnionLookupInspect { - MaxscoreUnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override - { - maxscore_union_lookup(query, index, std::move(topk), scorer, this); - } -}; - -template -struct UnigramUnionLookupInspect : public BaseUnionLookupInspect { - UnigramUnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override - { - unigram_union_lookup(query, index, std::move(topk), scorer, this); - } -}; - -template -struct UnionLookupInspect : public BaseUnionLookupInspect { - UnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override - { - if (query.selections()->bigrams.empty()) { - unigram_union_lookup(query, index, std::move(topk), scorer, this); - } else if (query.get_term_ids().size() > 8) { - maxscore_union_lookup(query, index, std::move(topk), scorer, this); - } else { - union_lookup(query, index, std::move(topk), scorer, this); - } - } -}; - -template -struct LookupUnionInspector : public BaseUnionLookupInspect { - LookupUnionInspector(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override - { - if (query.selections()->bigrams.empty()) { - unigram_union_lookup(query, index, std::move(topk), scorer, this); - } else { - lookup_union(query, index, std::move(topk), scorer, this); - } - } -}; - -// template -// struct TwoPhaseUnionLookupInspect : public BaseUnionLookupInspect { -// TwoPhaseUnionLookupInspect(Index const& index, Scorer scorer) -// : BaseUnionLookupInspect(index, std::move(scorer)) -// { -// } -// void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override -// { -// // if (query.get_term_ids().size() > 8) { -// maxscore_union_lookup(query, index, std::move(topk), scorer, this); -// //} else { -// // two_phase_union_lookup(query, index, std::move(topk), scorer, this); -// //} -// } -//}; - -/// This one assumes that terms do not repeat in essential lists and intersections. -template -auto disjoint_union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - Inspect* inspect = nullptr) -{ - using bigram_cursor_type = std::decay_t; - - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return topk; - } - - auto threshold = query.get_threshold(); - topk.set_threshold(threshold); - - auto const& selections = query.get_selections(); - auto const& essential_unigrams = selections.unigrams; - auto const& essential_bigrams = selections.bigrams; - - auto const non_essential_terms = [&]() { - auto all_essential_terms = essential_unigrams; - for (auto [left, right] : essential_bigrams) { - all_essential_terms.push_back(left); - all_essential_terms.push_back(right); - } - ranges::sort(all_essential_terms); - ranges::actions::unique(all_essential_terms); - return ranges::views::set_difference(term_ids, all_essential_terms) | ranges::to_vector; - }(); - - auto essential_unigram_cursors = - index.scored_cursors(gsl::make_span(essential_unigrams), scorer); - auto merged_unigrams = v1::union_merge( - essential_unigram_cursors, 0.0F, pisa::v1::accumulate::InspectAdd(inspect)); - - std::vector essential_bigram_cursors; - for (auto [left, right] : essential_bigrams) { - auto cursor = index.scored_bigram_cursor(left, right, scorer); - if (not cursor) { - throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); - } - essential_bigram_cursors.push_back(cursor.take().value()); - } - - auto merged_bigrams = - v1::union_merge(std::move(essential_bigram_cursors), - 0.0F, - [&](auto& acc, auto& cursor, [[maybe_unused]] auto bigram_idx) { - if constexpr (not std::is_void_v) { - inspect->posting(); - } - auto payload = cursor.payload(); - acc += std::get<0>(payload) + std::get<1>(payload); - return acc; - }); - auto merged = v1::variadic_union_merge( - 0.0F, - std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), - std::make_tuple(pisa::v1::accumulate::Add{}, pisa::v1::accumulate::Add{})); - - auto lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); - std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { - return lhs.max_score() > rhs.max_score(); - }); - auto lookup_cursors_upper_bound = std::accumulate( - lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { - return acc + cursor.max_score(); - }); - - v1::for_each(merged, [&](auto& cursor) { - if constexpr (not std::is_void_v) { - inspect->document(); - } - auto docid = cursor.value(); - auto score = cursor.payload(); - auto upper_bound = score + lookup_cursors_upper_bound; - for (auto& lookup_cursor : lookup_cursors) { - if (not topk.would_enter(upper_bound)) { - return; - } - lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - inspect->lookup(); - } - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - auto partial_score = lookup_cursor.payload(); - score += partial_score; - upper_bound += partial_score; - } - upper_bound -= lookup_cursor.max_score(); - } - if constexpr (not std::is_void_v) { - if (topk.insert(score, docid)) { - inspect->insert(); - } - } else { - topk.insert(score, docid); - } - }); - return topk; -} - +/// This callable transforms a cursor by performing lookups to the current document +/// in the given lookup cursors, and then adding the scores that were found. +/// It uses the same short-circuiting rules before each lookup as `UnionLookupJoin`. template lookup_cursors, - std::vector upper_bounds, float lookup_cursors_upper_bound, AboveThresholdFn above_threshold, Inspector* inspect = nullptr) : m_lookup_cursors(std::move(lookup_cursors)), - m_upper_bounds(std::move(upper_bounds)), m_lookup_cursors_upper_bound(lookup_cursors_upper_bound), m_above_threshold(std::move(above_threshold)), m_inspect(inspect) @@ -704,12 +384,12 @@ struct LookupTransform { private: std::vector m_lookup_cursors; - std::vector m_upper_bounds; float m_lookup_cursors_upper_bound; AboveThresholdFn m_above_threshold; Inspector* m_inspect; }; +/// This algorithm... template auto lookup_union(Query const& query, Index const& index, @@ -717,51 +397,36 @@ auto lookup_union(Query const& query, Scorer&& scorer, Inspector* inspector = nullptr) { + using bigram_cursor_type = std::decay_t; + using lookup_cursor_type = std::decay_t; + auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { return topk; } auto threshold = query.get_threshold(); - auto const& selections = query.get_selections(); - - using bigram_cursor_type = std::decay_t; - using lookup_cursor_type = std::decay_t; + topk.set_threshold(threshold); + auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; + auto const& selections = query.get_selections(); auto& essential_unigrams = selections.unigrams; auto& essential_bigrams = selections.bigrams; auto non_essential_terms = ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; - topk.set_threshold(threshold); - auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; - auto unigram_cursor = [&]() { - auto cursors = index.max_scored_cursors(gsl::make_span(non_essential_terms), scorer); - ranges::sort(cursors, + auto lookup_cursors = index.max_scored_cursors(gsl::make_span(non_essential_terms), scorer); + ranges::sort(lookup_cursors, [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); - auto non_essential_count = cursors.size(); - for (auto term : essential_unigrams) { - cursors.push_back(index.max_scored_cursor(term, scorer)); - } - - std::vector upper_bounds(cursors.size()); - upper_bounds[0] = cursors[0].max_score(); - for (size_t i = 1; i < cursors.size(); ++i) { - upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); - } - - // We're not interested in these being correct - std::vector cursor_idx(cursors.size()); - std::iota(cursor_idx.begin(), cursor_idx.end(), 0); + auto essential_cursors = + index.max_scored_cursors(gsl::make_span(essential_unigrams), scorer); - return join_union_lookup(std::move(cursors), - std::move(cursor_idx), - std::move(upper_bounds), - non_essential_count, + return join_union_lookup(std::move(essential_cursors), + std::move(lookup_cursors), 0.0F, - pisa::v1::accumulate::Add{}, + accumulate::Add{}, is_above_threshold, inspector); }(); @@ -772,6 +437,7 @@ auto lookup_union(Query const& query, Inspector>; using transform_payload_cursor_type = TransformPayloadCursor; + std::vector bigram_cursors; for (auto [left, right] : essential_bigrams) { auto cursor = index.scored_bigram_cursor(left, right, scorer); @@ -786,13 +452,6 @@ auto lookup_union(Query const& query, ranges::sort(lookup_cursors, [](auto&& lhs, auto&& rhs) { return lhs.max_score() > rhs.max_score(); }); - std::vector upper_bounds(lookup_cursors.size()); - if (not lookup_cursors.empty()) { - upper_bounds[0] = lookup_cursors[0].max_score(); - for (size_t i = 1; i < lookup_cursors.size(); ++i) { - upper_bounds[i] = upper_bounds[i - 1] + lookup_cursors[i].max_score(); - } - } auto lookup_cursors_upper_bound = std::accumulate( lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { return acc + cursor.max_score(); @@ -800,7 +459,6 @@ auto lookup_union(Query const& query, bigram_cursors.emplace_back(std::move(*cursor.take()), lookup_transform_type(std::move(lookup_cursors), - std::move(upper_bounds), lookup_cursors_upper_bound, is_above_threshold, inspector)); @@ -981,4 +639,134 @@ auto union_lookup(Query const& query, return topk; } +template +struct BaseUnionLookupInspect { + BaseUnionLookupInspect(Index const& index, Scorer scorer) + : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); + } + + void reset_current() + { + m_current_documents = 0; + m_current_postings = 0; + m_current_lookups = 0; + m_current_inserts = 0; + } + + virtual void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) = 0; + + void operator()(Query const& query) + { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return; + } + using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); + using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); + + reset_current(); + run(query, m_index, m_scorer, topk_queue(query.k())); + std::cout << fmt::format("{}\t{}\t{}\t{}\n", + m_current_documents, + m_current_postings, + m_current_inserts, + m_current_lookups); + m_documents += m_current_documents; + m_postings += m_current_postings; + m_lookups += m_current_lookups; + m_inserts += m_current_inserts; + m_count += 1; + } + + void summarize() && + { + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" + "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", + static_cast(m_documents) / m_count, + static_cast(m_postings) / m_count, + static_cast(m_inserts) / m_count, + static_cast(m_lookups) / m_count); + } + + void document() { m_current_documents += 1; } + void posting() { m_current_postings += 1; } + void lookup() { m_current_lookups += 1; } + void insert() { m_current_inserts += 1; } + + private: + std::size_t m_current_documents = 0; + std::size_t m_current_postings = 0; + std::size_t m_current_lookups = 0; + std::size_t m_current_inserts = 0; + + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_lookups = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + +template +struct MaxscoreUnionLookupInspect : public BaseUnionLookupInspect { + MaxscoreUnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +template +struct UnigramUnionLookupInspect : public BaseUnionLookupInspect { + UnigramUnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +template +struct UnionLookupInspect : public BaseUnionLookupInspect { + UnionLookupInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + union_lookup(query, index, std::move(topk), scorer, this); + } + } +}; + +template +struct LookupUnionInspector : public BaseUnionLookupInspect { + LookupUnionInspector(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else { + lookup_union(query, index, std::move(topk), scorer, this); + } + } +}; + } // namespace pisa::v1 diff --git a/script/cw09b.sh b/script/cw09b.sh index 2040c90cf..52323885f 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -74,17 +74,17 @@ ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algo # --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart # Analyze -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore \ - --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-thresholds -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore-union-lookup \ - --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm lookup-union \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore \ +# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-thresholds +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore-union-lookup \ +# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-union-lookup +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm lookup-union \ +# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ # --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ diff --git a/src/v1/query.cpp b/src/v1/query.cpp index 30f984b5a..3086cbcc2 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -18,25 +18,6 @@ using json = nlohmann::json; return terms; } -[[nodiscard]] auto ListSelection::overlapping() const -> bool -{ - std::unordered_set terms; - for (auto term : unigrams) { - terms.insert(term); - } - for (auto [left, right] : bigrams) { - if (terms.find(left) != terms.end()) { - return true; - } - if (terms.find(right) != terms.end()) { - return true; - } - terms.insert(left); - terms.insert(right); - } - return false; -} - void Query::add_selections(gsl::span const> selections) { m_selections = ListSelection{}; diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp index cba225954..15e491cb5 100644 --- a/test/v1/test_v1_maxscore_join.cpp +++ b/test/v1/test_v1_maxscore_join.cpp @@ -36,14 +36,11 @@ using pisa::v1::read_sizes; using pisa::v1::TermId; using pisa::v1::accumulate::Add; -TEMPLATE_TEST_CASE("", +TEMPLATE_TEST_CASE("Max score join", "[v1][integration]", (IndexFixture, v1::RawCursor, v1::RawCursor>)) -//(IndexFixture, -// v1::BlockedCursor<::pisa::simdbp_block, false>, -// v1::RawCursor>)) { tbb::task_scheduler_init init(1); TestType fixture; @@ -52,27 +49,21 @@ TEMPLATE_TEST_CASE("", { auto index_basename = (fixture.tmpdir().path() / "inv").string(); auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); - // auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { - // if (name == "daat_or") { - // return daat_or(query, index, topk_queue(10), scorer); - // } - // if (name == "maxscore") { - // return maxscore(query, index, topk_queue(10), scorer); - // } - // std::abort(); - //}; int idx = 0; for (auto& q : test_queries()) { CAPTURE(q.get_term_ids()); CAPTURE(idx++); + auto add = [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }; auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); run([&](auto&& index) { auto union_results = collect(v1::union_merge( index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), 0.0F, - Add{})); + add)); auto maxscore_results = collect(v1::join_maxscore( index.max_scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), 0.0F, @@ -80,57 +71,6 @@ TEMPLATE_TEST_CASE("", [](auto /* score */) { return true; })); REQUIRE(union_results == maxscore_results); }); - - run([&](auto&& index) { - auto union_results = collect_with_payload(v1::union_merge( - index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), - 0.0F, - Add{})); - union_results.erase(std::remove_if(union_results.begin(), - union_results.end(), - [](auto score) { return score.second <= 5.0F; }), - union_results.end()); - auto maxscore_results = collect_with_payload(v1::join_maxscore( - index.max_scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), - 0.0F, - Add{}, - [](auto score) { return score > 5.0F; })); - REQUIRE(union_results.size() == maxscore_results.size()); - for (size_t i = 0; i < union_results.size(); ++i) { - CAPTURE(i); - REQUIRE(union_results[i].first == union_results[i].first); - REQUIRE(union_results[i].second - == Approx(union_results[i].second).epsilon(0.01)); - // REQUIRE(precomputed[i].second == expected[i].second); - // REQUIRE(precomputed[i].first == - // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - } - }); - - // // auto precomputed = [&]() { - // // auto run = - // // v1::scored_index_runner(meta, fixture.document_reader(), - // // fixture.score_reader()); - // // std::vector results; - // // run([&](auto&& index) { - // // // auto que = run_query(algorithm, v1::Query{q.terms}, index, - // // v1::VoidScorer{}); auto que = daat_or(v1::Query{q.terms}, index, - // // topk_queue(10),v1::VoidScorer{}); que.finalize(); results = que.topk(); - // // std::sort(results.begin(), results.end(), std::greater{}); - // // }); - // // return results; - // // }(); - - // REQUIRE(expected.size() == on_the_fly.size()); - // // REQUIRE(expected.size() == precomputed.size()); - // for (size_t i = 0; i < on_the_fly.size(); ++i) { - // REQUIRE(on_the_fly[i].second == expected[i].second); - // REQUIRE(on_the_fly[i].first == - // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - // // REQUIRE(precomputed[i].second == expected[i].second); - // // REQUIRE(precomputed[i].first == - // // Approx(expected[i].first).epsilon(RELATIVE_ERROR)); - // } } } } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 8f547588f..778fee7a3 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -139,7 +139,7 @@ TEMPLATE_TEST_CASE("Query", {"maxscore_union_lookup", true}, {"unigram_union_lookup", true}, {"union_lookup", true}, - //{"disjoint_union_lookup", true} + {"lookup_union", true}, })); std::string algorithm = std::get<0>(input_data); bool with_threshold = std::get<1>(input_data); @@ -168,11 +168,8 @@ TEMPLATE_TEST_CASE("Query", } return union_lookup(query, index, topk_queue(10), scorer); } - if (name == "disjoint_union_lookup") { - if (query.get_term_ids().size() > 8 || query.get_selections().overlapping()) { - return maxscore_union_lookup(query, index, topk_queue(10), scorer); - } - return disjoint_union_lookup(query, index, topk_queue(10), scorer); + if (name == "lookup_union") { + return lookup_union(query, index, topk_queue(10), scorer); } std::abort(); }; @@ -180,7 +177,7 @@ TEMPLATE_TEST_CASE("Query", auto const intersections = pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); for (auto& query : test_queries()) { - if (algorithm == "union_lookup" || algorithm == "disjoint_union_lookup") { + if (algorithm == "union_lookup" || algorithm == "lookup_union") { query.selections(gsl::make_span(intersections[idx])); } @@ -226,67 +223,33 @@ TEMPLATE_TEST_CASE("Query", idx += 1; - // auto precomputed = [&]() { - // auto run = - // v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); - // std::vector results; - // run([&](auto&& index) { - // auto que = run_query(algorithm, v1::Query{q.terms}, index, v1::VoidScorer{}); - // que.finalize(); - // results = que.topk(); - // }); - // // Remove the tail that might be different due to quantization error. - // // Note that `precomputed` will have summed quantized score, while the - // // vector we compare to will have quantized sum---that's why whe remove anything - // // that's withing 2 of the last result. - // // auto last_score = results.back().first; - // // results.erase(std::remove_if( - // // results.begin(), - // // results.end(), - // // [last_score](auto&& entry) { return entry.first <= last_score + 3; - // // }), - // // results.end()); - // // results.resize(5); - // // std::sort(results.begin(), results.end(), [](auto&& lhs, auto&& rhs) { - // // return lhs.second < rhs.second; - // //}); - // return results; - //}(); - - // constexpr float max_partial_score = 16.5724F; - // auto quantizer = [&](float score) { - // return static_cast(score * std::numeric_limits::max() - // / max_partial_score); - //}; + auto precomputed = [&]() { + auto run = + v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); + std::vector results; + run([&](auto&& index) { + auto que = run_query(algorithm, query, index, v1::VoidScorer{}); + que.finalize(); + results = que.topk(); + }); + if (not results.empty()) { + results.erase(std::remove_if(results.begin(), + results.end(), + [last_score = results.back().first](auto&& entry) { + return entry.first <= last_score; + }), + results.end()); + std::sort(results.begin(), results.end(), approximate_order); + } + return results; + }(); - // auto expected_quantized = expected; - // std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& - // rhs) { - // return lhs.first > rhs.first; - //}); - // for (auto& v : expected_quantized) { - // v.first = quantizer(v.first); - //} + constexpr float max_partial_score = 16.5724F; + auto quantizer = [&](float score) { + return static_cast(score * std::numeric_limits::max() + / max_partial_score); + }; // TODO(michal): test the quantized results - - // expected_quantized.resize(precomputed.size()); - // std::sort(expected_quantized.begin(), expected_quantized.end(), [](auto&& lhs, auto&& - // rhs) { - // return lhs.second < rhs.second; - //}); - - // for (size_t i = 0; i < precomputed.size(); ++i) { - // std::cerr << fmt::format("{}, {:f} -- {}, {:f}\n", - // precomputed[i].second, - // precomputed[i].first, - // expected_quantized[i].second, - // expected_quantized[i].first); - //} - - // for (size_t i = 0; i < precomputed.size(); ++i) { - // REQUIRE(std::abs(precomputed[i].first - expected_quantized[i].first) - // <= static_cast(q.terms.size())); - //} } } diff --git a/test/v1/test_v1_query.cpp b/test/v1/test_v1_query.cpp deleted file mode 100644 index a7c3c06d8..000000000 --- a/test/v1/test_v1_query.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch2/catch.hpp" - -#include - -#include -#include -#include -#include - -#include "pisa_config.hpp" -#include "v1/query.hpp" - -using pisa::v1::ListSelection; - -TEST_CASE("List selections are overlapping", "[v1][unit]") -{ - REQUIRE(not ListSelection{.unigrams = {0, 1, 2}, .bigrams = {}}.overlapping()); - REQUIRE(not ListSelection{.unigrams = {0, 1, 2}, .bigrams = {{0, 1}, {2, 3}}}.overlapping()); - REQUIRE(ListSelection{.unigrams = {0, 1, 1, 2}, .bigrams = {}}.overlapping()); - REQUIRE(ListSelection{.unigrams = {0, 1, 2}, .bigrams = {{0, 3}}}.overlapping()); - REQUIRE(ListSelection{.unigrams = {}, .bigrams = {{0, 1}, {1, 3}}}.overlapping()); -} diff --git a/v1/union_lookup.cpp b/v1/union_lookup.cpp deleted file mode 100644 index e8a507f7d..000000000 --- a/v1/union_lookup.cpp +++ /dev/null @@ -1,262 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include - -#include "app.hpp" -#include "io.hpp" -#include "query/queries.hpp" -#include "timer.hpp" -#include "topk_queue.hpp" -#include "v1/blocked_cursor.hpp" -#include "v1/index_metadata.hpp" -#include "v1/query.hpp" -#include "v1/raw_cursor.hpp" -#include "v1/scorer/bm25.hpp" -#include "v1/scorer/runner.hpp" -#include "v1/types.hpp" -#include "v1/union_lookup.hpp" - -using pisa::resolve_query_parser; -using pisa::v1::BlockedReader; -using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; -using pisa::v1::Query; -using pisa::v1::RawReader; -using pisa::v1::resolve_yml; -using pisa::v1::union_lookup; -using pisa::v1::VoidScorer; - -template -void evaluate(std::vector const& queries, - Index&& index, - Scorer&& scorer, - int k, - pisa::Payload_Vector<> const& docmap, - std::vector> essential_unigrams, - std::vector>> essential_bigrams) -{ - auto query_idx = 0; - for (auto const& query : queries) { - std::vector uni(query.terms.size()); - std::iota(uni.begin(), uni.end(), 0); - auto que = union_lookup(query, - index, - pisa::topk_queue(k), - scorer, - // uni, {}); - essential_unigrams[query_idx], - essential_bigrams[query_idx]); - que.finalize(); - auto rank = 0; - for (auto result : que.topk()) { - std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\t{}\n", - query.id.value_or(std::to_string(query_idx)), - "Q0", - docmap[result.second], - rank, - result.first, - "R0"); - rank += 1; - } - query_idx += 1; - } -} - -template -void benchmark(std::vector const& queries, - Index&& index, - Scorer&& scorer, - int k, - std::vector> essential_unigrams, - std::vector>> essential_bigrams) - -{ - std::vector times(queries.size(), std::numeric_limits::max()); - for (auto run = 0; run < 5; run += 1) { - for (auto query = 0; query < queries.size(); query += 1) { - std::vector uni(queries[query].terms.size()); - std::iota(uni.begin(), uni.end(), 0); - auto usecs = ::pisa::run_with_timer([&]() { - auto que = union_lookup(queries[query], - index, - pisa::topk_queue(k), - scorer, - // uni, - //{}); - essential_unigrams[query], - essential_bigrams[query]); - que.finalize(); - do_not_optimize_away(que); - }); - times[query] = std::min(times[query], static_cast(usecs.count())); - } - } - std::sort(times.begin(), times.end()); - double avg = std::accumulate(times.begin(), times.end(), double()) / times.size(); - double q50 = times[times.size() / 2]; - double q90 = times[90 * times.size() / 100]; - double q95 = times[95 * times.size() / 100]; - spdlog::info("Mean: {}", avg); - spdlog::info("50% quantile: {}", q50); - spdlog::info("90% quantile: {}", q90); - spdlog::info("95% quantile: {}", q95); -} - -int main(int argc, char** argv) -{ - std::string inter_filename; - std::string threshold_file; - - pisa::QueryApp app("Queries a v1 index."); - app.add_option("--intersections", inter_filename, "Intersections filename")->required(); - app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.")->required(); - CLI11_PARSE(app, argc, argv); - - auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); - auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - if (meta.term_lexicon) { - app.terms_file = meta.term_lexicon.value(); - } - if (meta.document_lexicon) { - app.documents_file = meta.document_lexicon.value(); - } - - auto queries = [&]() { - std::vector<::pisa::Query> queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } - std::vector v1_queries(queries.size()); - std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& query) { - return Query{.terms = query.terms, - .list_selection = {}, - .threshold = {}, - .id = - [&]() { - if (query.id) { - return tl::make_optional(*query.id); - } - return tl::optional{}; - }(), - .k = app.k}; - }); - return v1_queries; - }(); - - std::ifstream is(threshold_file); - auto queries_iter = queries.begin(); - pisa::io::for_each_line(is, [&](auto&& line) { - if (queries_iter == queries.end()) { - spdlog::error("Number of thresholds not equal to number of queries"); - std::exit(1); - } - queries_iter->threshold = tl::make_optional(std::stof(line)); - ++queries_iter; - }); - if (queries_iter != queries.end()) { - spdlog::error("Number of thresholds not equal to number of queries"); - std::exit(1); - } - - auto intersections = [&]() { - std::vector>> intersections; - std::ifstream is(inter_filename); - pisa::io::for_each_line(is, [&](auto const& query_line) { - intersections.emplace_back(); - std::istringstream iss(query_line); - std::transform( - std::istream_iterator(iss), - std::istream_iterator(), - std::back_inserter(intersections.back()), - [&](auto const& n) { - auto bits = std::bitset<64>(std::stoul(n)); - if (bits.count() > 2) { - spdlog::error("Intersections of more than 2 terms not supported yet!"); - std::exit(1); - } - return bits; - }); - }); - return intersections; - }(); - auto bitset_to_vec = [](auto bits) { - std::vector vec; - for (auto idx = 0; idx < bits.size(); idx += 1) { - if (bits.test(idx)) { - vec.push_back(idx); - } - } - return vec; - }; - auto is_n_gram = [](auto n) { return [n](auto bits) { return bits.count() == n; }; }; - std::vector> unigrams = - intersections | ranges::views::transform([&](auto&& query_intersections) { - return query_intersections | ranges::views::filter(is_n_gram(1)) - | ranges::views::transform([&](auto bits) { return bitset_to_vec(bits)[0]; }) - | ranges::to_vector; - }) - | ranges::to_vector; - std::vector>> bigrams = - intersections | ranges::views::transform([&](auto&& query_intersections) { - return query_intersections | ranges::views::filter(is_n_gram(2)) - | ranges::views::transform([&](auto bits) { - auto vec = bitset_to_vec(bits); - return std::make_pair(vec[0], vec[0]); - }) - | ranges::to_vector; - }) - | ranges::to_vector; - - if (intersections.size() != queries.size()) { - spdlog::error("Number of intersections is not equal to number of queries"); - std::exit(1); - } - - if (not app.documents_file) { - spdlog::error("Document lexicon not defined"); - std::exit(1); - } - auto source = std::make_shared(app.documents_file.value().c_str()); - auto docmap = pisa::Payload_Vector<>::from(*source); - - if (app.precomputed) { - std::abort(); - // auto run = scored_index_runner(meta, - // RawReader{}, - // RawReader{}, - // BlockedReader<::pisa::simdbp_block, true>{}, - // BlockedReader<::pisa::simdbp_block, false>{}); - // run([&](auto&& index) { - // if (app.is_benchmark) { - // benchmark(queries, index, VoidScorer{}, app.k, unigrams, bigrams); - // } else { - // evaluate(queries, index, VoidScorer{}, app.k, docmap, unigrams, bigrams); - // } - //}); - } else { - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto&& index) { - auto with_scorer = scorer_runner(index, make_bm25(index)); - with_scorer("bm25", [&](auto scorer) { - if (app.is_benchmark) { - benchmark(queries, index, scorer, app.k, unigrams, bigrams); - } else { - evaluate(queries, index, scorer, app.k, docmap, unigrams, bigrams); - } - }); - }); - } - return 0; -} From 547389cd190c3568ae00dea519fed0d73c913aaf Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 6 Dec 2019 19:50:14 +0000 Subject: [PATCH 32/56] Update porter2 --- external/Porter2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/Porter2 b/external/Porter2 index d9a7b8297..ac4f2021c 160000 --- a/external/Porter2 +++ b/external/Porter2 @@ -1 +1 @@ -Subproject commit d9a7b8297be4f026f73dc5b57584e90122702fba +Subproject commit ac4f2021cf34638595313a82acaab2d56b24535d From 9c61991b946c5a6fa931afa33d0068cddbff9dbd Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 17 Dec 2019 22:46:29 +0000 Subject: [PATCH 33/56] JSON list queries and improved CLI --- include/pisa/codec/integer_codes.hpp | 80 +++++---- .../pisa/cursor/block_max_scored_cursor.hpp | 41 ++--- include/pisa/cursor/max_scored_cursor.hpp | 33 ++-- include/pisa/query/queries.hpp | 56 +++---- include/pisa/query/term_processor.hpp | 12 +- include/pisa/v1/index_builder.hpp | 4 +- include/pisa/v1/index_metadata.hpp | 9 +- include/pisa/v1/query.hpp | 9 ++ include/pisa/v1/score_index.hpp | 4 +- src/queries.cpp | 113 +++++++------ src/thresholds.cpp | 1 + src/v1/index_builder.cpp | 10 +- src/v1/index_metadata.cpp | 33 +++- src/v1/query.cpp | 84 +++++++++- src/v1/score_index.cpp | 10 +- test/v1/index_fixture.hpp | 4 +- test/v1/test_v1_queries.cpp | 1 + test/v1/test_v1_query.cpp | 41 +++++ v1/app.hpp | 152 +++++++++++++++--- v1/bigram_index.cpp | 49 +----- v1/filter_queries.cpp | 11 +- v1/postings.cpp | 14 +- v1/query.cpp | 60 ++----- v1/score.cpp | 14 +- v1/threshold.cpp | 44 ++--- 25 files changed, 533 insertions(+), 356 deletions(-) create mode 100644 test/v1/test_v1_query.cpp diff --git a/include/pisa/codec/integer_codes.hpp b/include/pisa/codec/integer_codes.hpp index 055ddf259..4bdacb288 100644 --- a/include/pisa/codec/integer_codes.hpp +++ b/include/pisa/codec/integer_codes.hpp @@ -4,45 +4,43 @@ namespace pisa { - // note: n can be 0 - void write_gamma(bit_vector_builder& bvb, uint64_t n) - { - uint64_t nn = n + 1; - uint64_t l = broadword::msb(nn); - uint64_t hb = uint64_t(1) << l; - bvb.append_bits(hb, l + 1); - bvb.append_bits(nn ^ hb, l); - } - - void write_gamma_nonzero(bit_vector_builder& bvb, uint64_t n) - { - assert(n > 0); - write_gamma(bvb, n - 1); - } - - uint64_t read_gamma(bit_vector::enumerator& it) - { - uint64_t l = it.skip_zeros(); - return (it.take(l) | (uint64_t(1) << l)) - 1; - } - - uint64_t read_gamma_nonzero(bit_vector::enumerator& it) - { - return read_gamma(it) + 1; - } - - void write_delta(bit_vector_builder& bvb, uint64_t n) - { - uint64_t nn = n + 1; - uint64_t l = broadword::msb(nn); - uint64_t hb = uint64_t(1) << l; - write_gamma(bvb, l); - bvb.append_bits(nn ^ hb, l); - } - - uint64_t read_delta(bit_vector::enumerator& it) - { - uint64_t l = read_gamma(it); - return (it.take(l) | (uint64_t(1) << l)) - 1; - } +// note: n can be 0 +inline void write_gamma(bit_vector_builder& bvb, uint64_t n) +{ + uint64_t nn = n + 1; + uint64_t l = broadword::msb(nn); + uint64_t hb = uint64_t(1) << l; + bvb.append_bits(hb, l + 1); + bvb.append_bits(nn ^ hb, l); } + +inline void write_gamma_nonzero(bit_vector_builder& bvb, uint64_t n) +{ + assert(n > 0); + write_gamma(bvb, n - 1); +} + +inline uint64_t read_gamma(bit_vector::enumerator& it) +{ + uint64_t l = it.skip_zeros(); + return (it.take(l) | (uint64_t(1) << l)) - 1; +} + +inline uint64_t read_gamma_nonzero(bit_vector::enumerator& it) { return read_gamma(it) + 1; } + +inline void write_delta(bit_vector_builder& bvb, uint64_t n) +{ + uint64_t nn = n + 1; + uint64_t l = broadword::msb(nn); + uint64_t hb = uint64_t(1) << l; + write_gamma(bvb, l); + bvb.append_bits(nn ^ hb, l); +} + +inline uint64_t read_delta(bit_vector::enumerator& it) +{ + uint64_t l = read_gamma(it); + return (it.take(l) | (uint64_t(1) << l)) - 1; +} + +} // namespace pisa diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index 0ab76e51e..59875fdbd 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -1,40 +1,45 @@ #pragma once -#include -#include "wand_data.hpp" #include "query/queries.hpp" #include "scorer/bm25.hpp" +#include "scorer/score_function.hpp" +#include "wand_data.hpp" +#include namespace pisa { template struct block_max_scored_cursor { - using enum_type = typename Index::document_enumerator; + using enum_type = typename Index::document_enumerator; using wdata_enum = typename WandType::wand_data_enumerator; - enum_type docs_enum; + enum_type docs_enum; wdata_enum w; - float q_weight; - Scorer scorer; - float max_weight; + float q_weight; + Scorer scorer; + float max_weight; }; template -[[nodiscard]] auto make_block_max_scored_cursors(Index const &index, WandType const &wdata, - Query query) { +[[nodiscard]] auto make_block_max_scored_cursors(Index const& index, + WandType const& wdata, + Query query) +{ auto terms = query.terms; auto query_term_freqs = query_freqs(terms); - using scorer_type = bm25; - using Scorer = Score_Function; + using scorer_type = bm25; + using Scorer = Score_Function; std::vector> cursors; cursors.reserve(query_term_freqs.size()); - std::transform(query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), - [&](auto &&term) { - auto list = index[term.first]; - auto w_enum = wdata.getenum(term.first); - auto q_weight = scorer_type::query_term_weight(term.second, wdata.term_len(term.first), - index.num_docs()); + std::transform(query_term_freqs.begin(), + query_term_freqs.end(), + std::back_inserter(cursors), + [&](auto&& term) { + auto list = index[term.first]; + auto w_enum = wdata.getenum(term.first); + auto q_weight = scorer_type::query_term_weight( + term.second, wdata.term_len(term.first), index.num_docs()); auto max_weight = q_weight * wdata.max_term_weight(term.first); return block_max_scored_cursor{ std::move(list), w_enum, q_weight, {q_weight, wdata}, max_weight}; @@ -42,4 +47,4 @@ template return cursors; } -} // namespace pisa \ No newline at end of file +} // namespace pisa diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index b4c15ed31..2faf19553 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -1,9 +1,10 @@ #pragma once -#include -#include "wand_data.hpp" #include "query/queries.hpp" #include "scorer/bm25.hpp" +#include "scorer/score_function.hpp" +#include "wand_data.hpp" +#include namespace pisa { @@ -11,26 +12,28 @@ template struct max_scored_cursor { using enum_type = typename Index::document_enumerator; enum_type docs_enum; - float q_weight; - Scorer scorer; - float max_weight; + float q_weight; + Scorer scorer; + float max_weight; }; template -[[nodiscard]] auto make_max_scored_cursors(Index const &index, WandType const &wdata, - Query query) { +[[nodiscard]] auto make_max_scored_cursors(Index const& index, WandType const& wdata, Query query) +{ auto terms = query.terms; auto query_term_freqs = query_freqs(terms); - using scorer_type = bm25; - using Scorer = Score_Function; + using scorer_type = bm25; + using Scorer = Score_Function; std::vector> cursors; cursors.reserve(query_term_freqs.size()); - std::transform(query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), - [&](auto &&term) { - auto list = index[term.first]; - auto q_weight = scorer_type::query_term_weight(term.second, wdata.term_len(term.first), - index.num_docs()); + std::transform(query_term_freqs.begin(), + query_term_freqs.end(), + std::back_inserter(cursors), + [&](auto&& term) { + auto list = index[term.first]; + auto q_weight = scorer_type::query_term_weight( + term.second, wdata.term_len(term.first), index.num_docs()); auto max_weight = q_weight * wdata.max_term_weight(term.first); return max_scored_cursor{ std::move(list), q_weight, {q_weight, wdata}, max_weight}; @@ -38,4 +41,4 @@ template return cursors; } -} // namespace pisa +} // namespace pisa diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index ed580d524..2dfd43793 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -11,20 +11,14 @@ #include #include -#include "index_types.hpp" #include "query/query.hpp" -#include "scorer/score_function.hpp" #include "term_processor.hpp" #include "tokenizer.hpp" -#include "topk_queue.hpp" #include "util/util.hpp" -#include "wand_data.hpp" -#include "wand_data_compressed.hpp" -#include "wand_data_raw.hpp" namespace pisa { -[[nodiscard]] auto split_query_at_colon(std::string const &query_string) +[[nodiscard]] inline auto split_query_at_colon(std::string const& query_string) -> std::pair, std::string_view> { // query id : terms (or ids) @@ -38,8 +32,8 @@ namespace pisa { return {std::move(id), std::move(raw_query)}; } -[[nodiscard]] auto parse_query_terms(std::string const &query_string, TermProcessor term_processor) - -> Query +[[nodiscard]] inline auto parse_query_terms(std::string const& query_string, + TermProcessor term_processor) -> Query { auto [id, raw_query] = split_query_at_colon(query_string); TermTokenizer tokenizer(raw_query); @@ -51,56 +45,56 @@ namespace pisa { if (!term_processor.is_stopword(*term)) { parsed_query.push_back(std::move(*term)); } else { - //spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); + // spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); } } else { - //spdlog::warn("Term `{}` not found and will be ignored", raw_term); + // spdlog::warn("Term `{}` not found and will be ignored", raw_term); } } return {std::move(id), std::move(parsed_query), {}}; } -[[nodiscard]] auto parse_query_ids(std::string const &query_string) -> Query +[[nodiscard]] inline auto parse_query_ids(std::string const& query_string) -> Query { auto [id, raw_query] = split_query_at_colon(query_string); std::vector parsed_query; std::vector term_ids; boost::split(term_ids, raw_query, boost::is_any_of("\t, ,\v,\f,\r,\n")); - auto is_empty = [](const std::string &val) { return val.empty(); }; + auto is_empty = [](const std::string& val) { return val.empty(); }; // remove_if move matching elements to the end, preparing them for erase. term_ids.erase(std::remove_if(term_ids.begin(), term_ids.end(), is_empty), term_ids.end()); try { - auto to_int = [](const std::string &val) { return std::stoi(val); }; + auto to_int = [](const std::string& val) { return std::stoi(val); }; std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(parsed_query), to_int); - } catch (std::invalid_argument &err) { + } catch (std::invalid_argument& err) { spdlog::error("Could not parse term identifiers of query `{}`", raw_query); exit(1); } return {std::move(id), std::move(parsed_query), {}}; } -[[nodiscard]] std::function resolve_query_parser( - std::vector &queries, - std::optional const &terms_file, - std::optional const &stopwords_filename, - std::optional const &stemmer_type) +[[nodiscard]] inline std::function resolve_query_parser( + std::vector& queries, + std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) { if (terms_file) { auto term_processor = TermProcessor(terms_file, stopwords_filename, stemmer_type); return - [&queries, term_processor = std::move(term_processor)](std::string const &query_line) { + [&queries, term_processor = std::move(term_processor)](std::string const& query_line) { queries.push_back(parse_query_terms(query_line, term_processor)); }; } else { - return [&queries](std::string const &query_line) { + return [&queries](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; } } -bool read_query(term_id_vec &ret, std::istream &is = std::cin) +inline bool read_query(term_id_vec& ret, std::istream& is = std::cin) { ret.clear(); std::string line; @@ -111,7 +105,7 @@ bool read_query(term_id_vec &ret, std::istream &is = std::cin) return true; } -void remove_duplicate_terms(term_id_vec &terms) +inline void remove_duplicate_terms(term_id_vec& terms) { std::sort(terms.begin(), terms.end()); terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); @@ -120,7 +114,7 @@ void remove_duplicate_terms(term_id_vec &terms) typedef std::pair term_freq_pair; typedef std::vector term_freq_vec; -term_freq_vec query_freqs(term_id_vec terms) +inline term_freq_vec query_freqs(term_id_vec terms) { term_freq_vec query_term_freqs; std::sort(terms.begin(), terms.end()); @@ -136,15 +130,3 @@ term_freq_vec query_freqs(term_id_vec terms) } } // namespace pisa - -#include "algorithm/and_query.hpp" -#include "algorithm/block_max_maxscore_query.hpp" -#include "algorithm/block_max_ranked_and_query.hpp" -#include "algorithm/block_max_wand_query.hpp" -#include "algorithm/maxscore_query.hpp" -#include "algorithm/or_query.hpp" -#include "algorithm/range_query.hpp" -#include "algorithm/ranked_and_query.hpp" -#include "algorithm/ranked_or_query.hpp" -#include "algorithm/ranked_or_taat_query.hpp" -#include "algorithm/wand_query.hpp" diff --git a/include/pisa/query/term_processor.hpp b/include/pisa/query/term_processor.hpp index 81d995d56..5933e217d 100644 --- a/include/pisa/query/term_processor.hpp +++ b/include/pisa/query/term_processor.hpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include @@ -24,9 +24,9 @@ class TermProcessor { std::function(std::string)> _to_id; public: - TermProcessor(std::optional const &terms_file, - std::optional const &stopwords_filename, - std::optional const &stemmer_type) + TermProcessor(std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) { auto source = std::make_shared(terms_file->c_str()); auto terms = Payload_Vector<>::from(*source); @@ -49,7 +49,7 @@ class TermProcessor { } else if (*stemmer_type == "porter2") { _to_id = [=](auto str) { boost::algorithm::to_lower(str); - stem::Porter2 stemmer{}; + porter2::Stemmer stemmer{}; return to_id(std::move(stemmer.stem(str))); }; } else if (*stemmer_type == "krovetz") { @@ -66,7 +66,7 @@ class TermProcessor { // Loads stopwords. if (stopwords_filename) { std::ifstream is(*stopwords_filename); - io::for_each_line(is, [&](auto &&word) { + io::for_each_line(is, [&](auto&& word) { if (auto processed_term = _to_id(std::move(word)); processed_term.has_value()) { stopwords.insert(*processed_term); } diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 373ac6635..8d29d9fd2 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -229,7 +229,7 @@ auto collect_unique_bigrams(std::vector const& queries, std::function const& callback) -> std::vector>; -void build_bigram_index(std::string const& yml, - std::vector> const& bigrams); +auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) + -> IndexMetadata; } // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 475428b15..baf5efe7f 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -7,6 +7,7 @@ #include #include "v1/index.hpp" +#include "v1/query.hpp" #include "v1/source.hpp" #include "v1/types.hpp" @@ -15,7 +16,7 @@ namespace pisa::v1 { /// Return the passed file path if is not `nullopt`. /// Otherwise, look for an `.yml` file in the current directory. /// It will throw if no `.yml` file is found or there are multiple `.yml` files. -[[nodiscard]] auto resolve_yml(std::optional const& arg) -> std::string; +[[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string; template [[nodiscard]] auto convert_optional(Optional opt) @@ -49,6 +50,7 @@ struct BigramMetadata { }; struct IndexMetadata { + tl::optional basename{}; PostingFilePaths documents; PostingFilePaths frequencies; std::vector scores{}; @@ -61,7 +63,10 @@ struct IndexMetadata { std::map max_scores{}; std::map quantized_max_scores{}; - void write(std::string const& file); + void write(std::string const& file) const; + void update() const; + [[nodiscard]] auto query_parser() const -> std::function; + [[nodiscard]] auto get_basename() const -> std::string const&; [[nodiscard]] static auto from_file(std::string const& file) -> IndexMetadata; }; diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 0153b99ee..560150a33 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -78,18 +78,27 @@ struct Query { [[nodiscard]] auto k() const -> int; [[nodiscard]] auto selections() const -> tl::optional; [[nodiscard]] auto threshold() const -> tl::optional; + [[nodiscard]] auto raw() const -> tl::optional; /// Throwing getters [[nodiscard]] auto get_term_ids() const -> std::vector const&; [[nodiscard]] auto get_id() const -> std::string const&; [[nodiscard]] auto get_selections() const -> ListSelection const&; [[nodiscard]] auto get_threshold() const -> float; + [[nodiscard]] auto get_raw() const -> std::string const&; [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t; [[nodiscard]] auto term_at_pos(std::size_t pos) const -> TermId; + template + [[nodiscard]] auto parse(Parser&& parser) + { + parser(*this); + } + void add_selections(gsl::span const> selections); [[nodiscard]] static auto from_json(std::string_view) -> Query; + [[nodiscard]] static auto from_plain(std::string_view) -> Query; private: friend std::ostream& operator<<(std::ostream& os, Query const& query); diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp index bf1eeace6..43f131ae5 100644 --- a/include/pisa/v1/score_index.hpp +++ b/include/pisa/v1/score_index.hpp @@ -2,8 +2,10 @@ #include +#include "v1/index_metadata.hpp" + namespace pisa::v1 { -void score_index(std::string const& yml, std::size_t threads); +auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata; } // namespace pisa::v1 diff --git a/src/queries.cpp b/src/queries.cpp index c7d3d880e..a509909ae 100644 --- a/src/queries.cpp +++ b/src/queries.cpp @@ -8,23 +8,34 @@ #include #include #include -#include #include +#include #include #include "mappable/mapper.hpp" -#include "index_types.hpp" #include "accumulator/lazy_accumulator.hpp" +#include "cursor/block_max_scored_cursor.hpp" +#include "cursor/cursor.hpp" +#include "cursor/max_scored_cursor.hpp" +#include "cursor/scored_cursor.hpp" +#include "index_types.hpp" +#include "query/algorithm/and_query.hpp" +#include "query/algorithm/block_max_maxscore_query.hpp" +#include "query/algorithm/block_max_ranked_and_query.hpp" +#include "query/algorithm/block_max_wand_query.hpp" +#include "query/algorithm/maxscore_query.hpp" +#include "query/algorithm/or_query.hpp" +#include "query/algorithm/ranked_and_query.hpp" +#include "query/algorithm/ranked_or_query.hpp" +#include "query/algorithm/ranked_or_taat_query.hpp" +#include "query/algorithm/wand_query.hpp" #include "query/queries.hpp" #include "timer.hpp" +#include "util/do_not_optimize_away.hpp" #include "util/util.hpp" #include "wand_data_compressed.hpp" #include "wand_data_raw.hpp" -#include "cursor/cursor.hpp" -#include "cursor/scored_cursor.hpp" -#include "cursor/max_scored_cursor.hpp" -#include "cursor/block_max_scored_cursor.hpp" #include "CLI/CLI.hpp" @@ -33,37 +44,37 @@ using ranges::view::enumerate; template void extract_times(Fn fn, - std::vector const &queries, - std::string const &index_type, - std::string const &query_type, + std::vector const& queries, + std::string const& index_type, + std::string const& query_type, size_t runs, - std::ostream &os) + std::ostream& os) { std::vector times(runs); - for (auto &&[qid, query] : enumerate(queries)) { + for (auto&& [qid, query] : enumerate(queries)) { do_not_optimize_away(fn(query)); std::generate(times.begin(), times.end(), [&fn, &q = query]() { - return run_with_timer( - [&]() { do_not_optimize_away(fn(q)); }) + return run_with_timer([&]() { do_not_optimize_away(fn(q)); }) .count(); }); - auto mean = std::accumulate(times.begin(), times.end(), std::size_t{0}, std::plus<>()) / runs; + auto mean = + std::accumulate(times.begin(), times.end(), std::size_t{0}, std::plus<>()) / runs; os << fmt::format("{}\t{}\n", query.id.value_or(std::to_string(qid)), mean); } } template void op_perftest(Functor query_func, - std::vector const &queries, - std::string const &index_type, - std::string const &query_type, + std::vector const& queries, + std::string const& index_type, + std::string const& query_type, size_t runs) { std::vector query_times; for (size_t run = 0; run <= runs; ++run) { - for (auto const &query : queries) { + for (auto const& query : queries) { auto usecs = run_with_timer([&]() { uint64_t result = query_func(query); do_not_optimize_away(result); @@ -98,12 +109,12 @@ void op_perftest(Functor query_func, } template -void perftest(const std::string &index_filename, - const std::optional &wand_data_filename, - const std::vector &queries, - const std::optional &thresholds_filename, - std::string const &type, - std::string const &query_type, +void perftest(const std::string& index_filename, + const std::optional& wand_data_filename, + const std::vector& queries, + const std::optional& thresholds_filename, + std::string const& type, + std::string const& query_type, uint64_t k, bool extract) { @@ -114,7 +125,7 @@ void perftest(const std::string &index_filename, spdlog::info("Warming up posting lists"); std::unordered_set warmed_up; - for (auto const &q : queries) { + for (auto const& q : queries) { for (auto t : q.terms) { if (!warmed_up.count(t)) { index.warmup(t); @@ -131,7 +142,7 @@ void perftest(const std::string &index_filename, if (wand_data_filename) { std::error_code error; md.map(*wand_data_filename, error); - if(error){ + if (error) { std::cerr << "error mapping file: " << error.message() << ", exiting..." << std::endl; throw std::runtime_error("Error opening file"); } @@ -150,61 +161,64 @@ void perftest(const std::string &index_filename, spdlog::info("Performing {} queries", type); spdlog::info("K: {}", k); - for (auto &&t : query_types) { + for (auto&& t : query_types) { spdlog::info("Query type: {}", t); std::function query_fun; if (t == "and") { - query_fun = [&](Query query){ + query_fun = [&](Query query) { and_query and_q; return and_q(make_scored_cursors(index, wdata, query), index.num_docs()).size(); }; } else if (t == "and_freq") { - query_fun = [&](Query query){ + query_fun = [&](Query query) { and_query and_q; return and_q(make_scored_cursors(index, wdata, query), index.num_docs()).size(); }; } else if (t == "or") { - query_fun = [&](Query query){ + query_fun = [&](Query query) { or_query or_q; return or_q(make_cursors(index, query), index.num_docs()); }; } else if (t == "or_freq") { - query_fun = [&](Query query){ + query_fun = [&](Query query) { or_query or_q; return or_q(make_cursors(index, query), index.num_docs()); }; } else if (t == "wand" && wand_data_filename) { - query_fun = [&](Query query){ + query_fun = [&](Query query) { wand_query wand_q(k); return wand_q(make_max_scored_cursors(index, wdata, query), index.num_docs()); }; } else if (t == "block_max_wand" && wand_data_filename) { - query_fun = [&](Query query){ + query_fun = [&](Query query) { block_max_wand_query block_max_wand_q(k); - return block_max_wand_q(make_block_max_scored_cursors(index, wdata, query), index.num_docs()); + return block_max_wand_q(make_block_max_scored_cursors(index, wdata, query), + index.num_docs()); }; } else if (t == "block_max_maxscore" && wand_data_filename) { - query_fun = [&](Query query){ + query_fun = [&](Query query) { block_max_maxscore_query block_max_maxscore_q(k); - return block_max_maxscore_q(make_block_max_scored_cursors(index, wdata, query), index.num_docs()); + return block_max_maxscore_q(make_block_max_scored_cursors(index, wdata, query), + index.num_docs()); }; - } else if (t == "ranked_and" && wand_data_filename) { - query_fun = [&](Query query){ + } else if (t == "ranked_and" && wand_data_filename) { + query_fun = [&](Query query) { ranked_and_query ranked_and_q(k); return ranked_and_q(make_scored_cursors(index, wdata, query), index.num_docs()); }; } else if (t == "block_max_ranked_and" && wand_data_filename) { - query_fun = [&](Query query){ + query_fun = [&](Query query) { block_max_ranked_and_query block_max_ranked_and_q(k); - return block_max_ranked_and_q(make_block_max_scored_cursors(index, wdata, query), index.num_docs()); + return block_max_ranked_and_q(make_block_max_scored_cursors(index, wdata, query), + index.num_docs()); }; - } else if (t == "ranked_or" && wand_data_filename) { - query_fun = [&](Query query){ + } else if (t == "ranked_or" && wand_data_filename) { + query_fun = [&](Query query) { ranked_or_query ranked_or_q(k); return ranked_or_q(make_scored_cursors(index, wdata, query), index.num_docs()); }; } else if (t == "maxscore" && wand_data_filename) { - query_fun = [&](Query query){ + query_fun = [&](Query query) { maxscore_query maxscore_q(k); return maxscore_q(make_max_scored_cursors(index, wdata, query), index.num_docs()); }; @@ -212,13 +226,15 @@ void perftest(const std::string &index_filename, Simple_Accumulator accumulator(index.num_docs()); ranked_or_taat_query ranked_or_taat_q(k); query_fun = [&, ranked_or_taat_q, accumulator](Query query) mutable { - return ranked_or_taat_q(make_scored_cursors(index, wdata, query), index.num_docs(), accumulator); + return ranked_or_taat_q( + make_scored_cursors(index, wdata, query), index.num_docs(), accumulator); }; } else if (t == "ranked_or_taat_lazy" && wand_data_filename) { Lazy_Accumulator<4> accumulator(index.num_docs()); ranked_or_taat_query ranked_or_taat_q(k); query_fun = [&, ranked_or_taat_q, accumulator](Query query) mutable { - return ranked_or_taat_q(make_scored_cursors(index, wdata, query), index.num_docs(), accumulator); + return ranked_or_taat_q( + make_scored_cursors(index, wdata, query), index.num_docs(), accumulator); }; } else { spdlog::error("Unsupported query type: {}", t); @@ -235,7 +251,7 @@ void perftest(const std::string &index_filename, using wand_raw_index = wand_data; using wand_uniform_index = wand_data; -int main(int argc, const char **argv) +int main(int argc, const char** argv) { std::string type; std::string query_type; @@ -261,8 +277,9 @@ int main(int argc, const char **argv) app.add_flag("--compressed-wand", compressed, "Compressed wand input file"); app.add_option("-k", k, "k value"); app.add_option("-T,--thresholds", thresholds_filename, "k value"); - auto *terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); - app.add_option("--stopwords", stopwords_filename, "File containing stopwords to ignore")->needs(terms_opt); + auto* terms_opt = app.add_option("--terms", terms_file, "Term lexicon"); + app.add_option("--stopwords", stopwords_filename, "File containing stopwords to ignore") + ->needs(terms_opt); app.add_option("--stemmer", stemmer, "Stemmer type")->needs(terms_opt); app.add_flag("--extract", extract, "Extract individual query times"); app.add_flag("--silent", silent, "Suppress logging"); diff --git a/src/thresholds.cpp b/src/thresholds.cpp index 6fa69333e..848e0eb27 100644 --- a/src/thresholds.cpp +++ b/src/thresholds.cpp @@ -12,6 +12,7 @@ #include "cursor/max_scored_cursor.hpp" #include "index_types.hpp" #include "io.hpp" +#include "query/algorithm/wand_query.hpp" #include "query/queries.hpp" #include "util/util.hpp" #include "wand_data_compressed.hpp" diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index 69133a474..d702fc7c2 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -143,12 +143,11 @@ auto verify_compressed_index(std::string const& input, std::string_view output) PostingFilePaths{scores_file_1, score_offsets_file_1}}; } -void build_bigram_index(std::string const& yml, - std::vector> const& bigrams) +auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) + -> IndexMetadata { Expects(not bigrams.empty()); - auto index_basename = yml.substr(0, yml.size() - 4); - auto meta = IndexMetadata::from_file(yml); + auto index_basename = meta.get_basename(); auto run = index_runner(meta, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, @@ -225,10 +224,11 @@ void build_bigram_index(std::string const& yml, meta.bigrams = bigram_meta; std::cerr << "Writing metadata..."; - meta.write(yml); + meta.update(); std::cerr << " Done.\nWriting bigram mapping..."; write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); std::cerr << " Done.\n"; + return meta; } } // namespace pisa::v1 diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index dd680d969..0a8693c74 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -4,6 +4,7 @@ #include +#include "query/queries.hpp" #include "v1/index_metadata.hpp" namespace pisa::v1 { @@ -20,7 +21,7 @@ constexpr char const* BIGRAM = "bigram"; constexpr char const* MAX_SCORES = "max_scores"; constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; -[[nodiscard]] auto resolve_yml(std::optional const& arg) -> std::string +[[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string { if (arg) { return *arg; @@ -38,6 +39,7 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; .offsets = config[SCORES][OFFSETS].as()}); } return IndexMetadata{ + .basename = file.substr(0, file.size() - 4), .documents = PostingFilePaths{.postings = config[DOCUMENTS][POSTINGS].as(), .offsets = config[DOCUMENTS][OFFSETS].as()}, .frequencies = PostingFilePaths{.postings = config[FREQUENCIES][POSTINGS].as(), @@ -102,7 +104,9 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; }()}; } -void IndexMetadata::write(std::string const& file) +void IndexMetadata::update() const { write(fmt::format("{}.yml", get_basename())); } + +void IndexMetadata::write(std::string const& file) const { YAML::Node root; root[DOCUMENTS][POSTINGS] = documents.postings; @@ -152,4 +156,29 @@ void IndexMetadata::write(std::string const& file) fout << root; } +[[nodiscard]] auto IndexMetadata::query_parser() const -> std::function +{ + if (term_lexicon) { + auto term_processor = + ::pisa::TermProcessor(*term_lexicon, {}, [&]() -> std::optional { + if (stemmer) { + return *stemmer; + } + return std::nullopt; + }()); + return [term_processor = std::move(term_processor)](Query& query) { + query.term_ids(parse_query_terms(query.get_raw(), term_processor).terms); + }; + } + throw std::runtime_error("Unable to parse query: undefined term lexicon"); +} + +[[nodiscard]] auto IndexMetadata::get_basename() const -> std::string const& +{ + if (not basename) { + throw std::runtime_error("Unable to resolve index basename"); + } + return basename.value(); +} + } // namespace pisa::v1 diff --git a/src/v1/query.cpp b/src/v1/query.cpp index 3086cbcc2..0494f47e0 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -58,14 +58,71 @@ template return tl::optional{}; } +[[nodiscard]] auto Query::from_plain(std::string_view query_string) -> Query +{ + auto colon = std::find(query_string.begin(), query_string.end(), ':'); + tl::optional id; + if (colon != query_string.end()) { + id = std::string(query_string.begin(), colon); + } + auto pos = colon == query_string.end() ? query_string.begin() : std::next(colon); + return Query(std::string(&*pos, std::distance(pos, query_string.end())), std::move(id)); +} + [[nodiscard]] auto Query::from_json(std::string_view json_string) -> Query { - // auto query_json = json::parse(json_string); - // auto terms = get>(query_json, "terms"); - // auto term_ids = get>(query_json, "term_ids"); - Query query; - // query.m_raw_string = get(query_json, "query"); - return query; + try { + auto query_json = json::parse(json_string); + auto id = get(query_json, "id"); + auto raw_string = get(query_json, "query"); + auto term_ids = get>(query_json, "term_ids"); + Query query = [&]() { + if (raw_string) { + auto query = Query(*raw_string.take(), id); + if (term_ids) { + query.term_ids(*term_ids.take()); + } + return query; + } + if (term_ids) { + return Query(*term_ids.take(), id); + } + throw std::invalid_argument( + "Failed to parse query: must define either raw string or term IDs"); + }(); + if (auto threshold = get(query_json, "threshold"); threshold) { + query.threshold(*threshold); + } + if (auto k = get(query_json, "k"); k) { + query.k(*k); + } + if (auto pos = query_json.find("selections"); pos != query_json.end()) { + auto const& selections = *pos; + auto unigrams = get>(selections, "unigrams") + .value_or(std::vector{}); + auto bigrams = + get>>(selections, "bigrams") + .value_or(std::vector>{}); + std::vector> bitsets; + std::transform( + unigrams.begin(), unigrams.end(), std::back_inserter(bitsets), [](auto idx) { + std::bitset<64> bs; + bs.set(idx); + return bs; + }); + std::transform( + bigrams.begin(), bigrams.end(), std::back_inserter(bitsets), [](auto bigram) { + std::bitset<64> bs; + bs.set(bigram.first); + bs.set(bigram.second); + return bs; + }); + query.selections(gsl::span>(bitsets)); + } + return query; + } catch (json::parse_error const& err) { + throw std::runtime_error(fmt::format("Failed to parse query: {}", err.what())); + } } Query::Query(std::vector term_ids, tl::optional id) @@ -129,6 +186,13 @@ auto Query::selections() const -> tl::optional return tl::nullopt; } auto Query::threshold() const -> tl::optional { return m_threshold; } +auto Query::raw() const -> tl::optional +{ + if (m_raw_string) { + return *m_raw_string; + } + return tl::nullopt; +} /// Throwing getters auto Query::get_term_ids() const -> std::vector const& @@ -163,6 +227,14 @@ auto Query::get_threshold() const -> float return *m_threshold; } +auto Query::get_raw() const -> std::string const& +{ + if (not m_raw_string) { + throw std::runtime_error("Raw query string is not set"); + } + return *m_raw_string; +} + auto Query::sorted_position(TermId term) const -> std::size_t { if (not m_term_ids) { diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp index 9a00a9e83..b3fb44468 100644 --- a/src/v1/score_index.cpp +++ b/src/v1/score_index.cpp @@ -1,3 +1,5 @@ +#include + #include #include "codec/simdbp.hpp" @@ -20,14 +22,13 @@ using pisa::v1::write_span; namespace pisa::v1 { -void score_index(std::string const& yml, std::size_t threads) +auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata { - auto meta = IndexMetadata::from_file(yml); auto run = index_runner(meta, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - auto index_basename = yml.substr(0, yml.size() - 4); + auto const& index_basename = meta.get_basename(); auto postings_path = fmt::format("{}.bm25", index_basename); auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); auto max_scores_path = fmt::format("{}.bm25.maxf", index_basename); @@ -87,7 +88,8 @@ void score_index(std::string const& yml, std::size_t threads) meta.scores.push_back(PostingFilePaths{.postings = postings_path, .offsets = offsets_path}); meta.max_scores["bm25"] = max_scores_path; meta.quantized_max_scores["bm25"] = quantized_max_scores_path; - meta.write(yml); + meta.update(); + return meta; } } // namespace pisa::v1 diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 7269408d3..64ceb6926 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -60,8 +60,8 @@ struct IndexFixture { index_basename); REQUIRE(errors.empty()); auto yml = fmt::format("{}.yml", index_basename); - v1::score_index(yml, 1); - v1::build_bigram_index(yml, collect_unique_bigrams(test_queries(), []() {})); + auto meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); + v1::build_bigram_index(meta, collect_unique_bigrams(test_queries(), []() {})); } [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 778fee7a3..daf11c708 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -17,6 +17,7 @@ #include "index_types.hpp" #include "io.hpp" #include "pisa_config.hpp" +#include "query/algorithm/ranked_or_query.hpp" #include "query/queries.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" diff --git a/test/v1/test_v1_query.cpp b/test/v1/test_v1_query.cpp new file mode 100644 index 000000000..fcd454ea0 --- /dev/null +++ b/test/v1/test_v1_query.cpp @@ -0,0 +1,41 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include +#include +#include +#include +#include + +#include "v1/query.hpp" + +using pisa::v1::Query; +using pisa::v1::TermId; + +TEST_CASE("Parse query from JSON", "[v1][unit]") +{ + REQUIRE_THROWS(Query::from_json("{}")); + REQUIRE(Query::from_json("{\"query\": \"tell your dog I said hi\"}").get_raw() + == "tell your dog I said hi"); + REQUIRE(Query::from_json("{\"term_ids\": [0, 32, 4]}").get_term_ids() + == std::vector{0, 4, 32}); + REQUIRE(Query::from_json("{\"term_ids\": [0, 32, 4]}").k() == 1000); + auto query = Query::from_json( + R"({"id": "Q0", "query": "send dog pics", "term_ids": [0, 32, 4], "k": 15, )" + R"("threshold": 40.5, "selections": )" + R"({ "unigrams": [0, 2], "bigrams": [[0, 2], [2, 1]]}})"); + REQUIRE(query.get_id() == "Q0"); + REQUIRE(query.k() == 15); + REQUIRE(query.get_term_ids() == std::vector{0, 4, 32}); + REQUIRE(query.get_threshold() == 40.5); + REQUIRE(query.get_raw() == "send dog pics"); + REQUIRE(query.get_selections().unigrams == std::vector{0, 4}); + REQUIRE(query.get_selections().bigrams + == std::vector>{{0, 4}, {4, 32}}); + REQUIRE_THROWS(Query::from_json( + R"({"id": "Q0", "query": "send dog pics", "term_ids": [0, 32, 4], "k": 15, )" + R"("threshold": 40.5, "selections": )" + R"({ "unigrams": [0, 4], "bigrams": [[0, 4], [4, 5]]}})")); +} diff --git a/v1/app.hpp b/v1/app.hpp index 7fcddd14e..1ec096637 100644 --- a/v1/app.hpp +++ b/v1/app.hpp @@ -2,36 +2,138 @@ #include #include +#include #include +#include + +#include "io.hpp" +#include "v1/index_metadata.hpp" namespace pisa { -struct QueryApp : public CLI::App { - explicit QueryApp(std::string description) : CLI::App(std::move(description)) - { - add_option("-i,--index", - yml, - "Path of .yml file of an index " - "(if not provided, it will be looked for in the current directory)", - false); - add_option("-q,--query", query_file, "Path to file with queries", false); - add_option("-k", k, "The number of top results to return", true); - add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); - add_option("--documents", - documents_file, - "Overrides document lexicon from .yml (if defined). Required otherwise."); - add_flag("--benchmark", is_benchmark, "Run benchmark"); - add_flag("--precomputed", precomputed, "Use precomputed scores"); - } - - std::optional yml{}; - std::optional query_file{}; - std::optional terms_file{}; - std::optional documents_file{}; - int k = 1'000; - bool is_benchmark = false; - bool precomputed = false; +namespace arg { + + struct Index { + explicit Index(CLI::App* app) + { + app->add_option("-i,--index", + m_metadata_path, + "Path of .yml file of an index " + "(if not provided, it will be looked for in the current directory)", + false); + } + + [[nodiscard]] auto index_metadata() const -> v1::IndexMetadata + { + return v1::IndexMetadata::from_file(v1::resolve_yml(m_metadata_path)); + } + + private: + tl::optional m_metadata_path; + }; + + enum class QueryMode : bool { Ranked, Unranked }; + + template + struct Query { + explicit Query(CLI::App* app) + { + app->add_option("-q,--query", m_query_file, "Path to file with queries", false); + app->add_option("--qf,--query-fmt", m_query_input_format, "Input file format", true); + if constexpr (Mode == QueryMode::Ranked) { + app->add_option("-k", m_k, "The number of top results to return", true); + } + } + + [[nodiscard]] auto query_file() -> tl::optional + { + if (m_query_file) { + tl::make_optional(m_query_file.value()); + } + return tl::nullopt; + } + + [[nodiscard]] auto queries(v1::IndexMetadata const& meta) const -> std::vector + { + std::vector queries; + auto parser = meta.query_parser(); + auto parse_line = [&](auto&& line) { + auto query = [&line, this]() { + if (m_query_input_format == "jl") { + return v1::Query::from_json(line); + } + return v1::Query::from_plain(line); + }(); + query.parse(parser); + if constexpr (Mode == QueryMode::Ranked) { + query.k(m_k); + } + queries.push_back(std::move(query)); + }; + if (m_query_file) { + std::ifstream is(*m_query_file); + pisa::io::for_each_line(is, parse_line); + } else { + pisa::io::for_each_line(std::cin, parse_line); + } + return queries; + } + + private: + tl::optional m_query_file; + tl::optional m_query_input_format = "jl"; + int m_k = DefaultK; + }; + + struct Benchmark { + explicit Benchmark(CLI::App* app) + { + app->add_flag("--benchmark", m_is_benchmark, "Run benchmark"); + } + + [[nodiscard]] auto is_benchmark() const -> bool { return m_is_benchmark; } + + private: + bool m_is_benchmark = false; + }; + + struct QuantizedScores { + explicit QuantizedScores(CLI::App* app) + { + app->add_flag("--quantized", m_use_quantized, "Use quantized scores"); + } + + [[nodiscard]] auto use_quantized() const -> bool { return m_use_quantized; } + + private: + bool m_use_quantized = false; + }; + + struct Threads { + explicit Threads(CLI::App* app) + { + app->add_option("-j,--threads", m_threads, "Number of threads"); + } + + [[nodiscard]] auto threads() const -> std::size_t { return m_threads; } + + private: + std::size_t m_threads = std::thread::hardware_concurrency(); + }; + +} // namespace arg + +template +struct App : public CLI::App, public Mixin... { + explicit App(std::string const& description) : CLI::App(description), Mixin(this)... {} +}; + +struct QueryApp : public App, + arg::Benchmark, + arg::QuantizedScores> { + explicit QueryApp(std::string const& description) : App(description) {} }; } // namespace pisa diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index 168e9a3ce..583f0d8b9 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -6,6 +6,7 @@ #include #include +#include "app.hpp" #include "io.hpp" #include "query/queries.hpp" #include "timer.hpp" @@ -21,6 +22,7 @@ #include "v1/scorer/runner.hpp" #include "v1/types.hpp" +using pisa::App; using pisa::v1::build_bigram_index; using pisa::v1::collect_unique_bigrams; using pisa::v1::DefaultProgress; @@ -32,60 +34,25 @@ using pisa::v1::Query; using pisa::v1::resolve_yml; using pisa::v1::TermId; +namespace arg = pisa::arg; + int main(int argc, char** argv) { - std::optional yml{}; - std::optional query_file{}; std::optional terms_file{}; - CLI::App app{"Creates a v1 bigram index."}; - app.add_option("-i,--index", - yml, - "Path of .yml file of an index " - "(if not provided, it will be looked for in the current directory)", - false); - app.add_option("-q,--query", query_file, "Path to file with queries", false); - app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); + App> app{"Creates a v1 bigram index."}; CLI11_PARSE(app, argc, argv); - auto resolved_yml = resolve_yml(yml); - auto meta = IndexMetadata::from_file(resolved_yml); - auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - if (meta.term_lexicon) { - terms_file = meta.term_lexicon.value(); - } + auto meta = app.index_metadata(); spdlog::info("Collecting queries..."); - auto queries = [&]() { - std::vector<::pisa::Query> queries; - auto parse_query = resolve_query_parser(queries, terms_file, std::nullopt, stemmer); - if (query_file) { - std::ifstream is(*query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } - std::vector v1_queries; - v1_queries.reserve(queries.size()); - for (auto q : queries) { - if (not q.terms.empty()) { - Query query(q.terms, [&]() { - if (q.id) { - return tl::make_optional(*q.id); - } - return tl::optional{}; - }()); - v1_queries.push_back(query); - } - } - return v1_queries; - }(); + auto queries = app.queries(meta); spdlog::info("Collected {} queries", queries.size()); spdlog::info("Collecting bigrams..."); ProgressStatus status(queries.size(), DefaultProgress{}, std::chrono::milliseconds(1000)); auto bigrams = collect_unique_bigrams(queries, [&]() { status += 1; }); spdlog::info("Collected {} bigrams", bigrams.size()); - build_bigram_index(resolved_yml, bigrams); + build_bigram_index(meta, bigrams); return 0; } diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index 52c775026..3376d08da 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -40,21 +40,22 @@ int main(int argc, char** argv) pisa::QueryApp app("Filters out empty queries against a v1 index."); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); + auto meta = app.index_metadata(); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; + std::optional term_lexicon = std::nullopt; if (meta.term_lexicon) { - app.terms_file = meta.term_lexicon.value(); + term_lexicon = *meta.term_lexicon; } - auto term_processor = TermProcessor(app.terms_file, {}, stemmer); + auto term_processor = TermProcessor(term_lexicon, {}, stemmer); auto filter = [&](auto&& line) { auto query = parse_query_terms(line, term_processor); if (not query.terms.empty()) { std::cout << line << '\n'; } }; - if (app.query_file) { - std::ifstream is(*app.query_file); + if (app.query_file()) { + std::ifstream is(*app.query_file()); pisa::io::for_each_line(is, filter); } else { pisa::io::for_each_line(std::cin, filter); diff --git a/v1/postings.cpp b/v1/postings.cpp index 647bd1471..a46be9f96 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -6,6 +6,7 @@ #include #include +#include "app.hpp" #include "io.hpp" #include "query/queries.hpp" #include "topk_queue.hpp" @@ -17,6 +18,7 @@ #include "v1/scorer/runner.hpp" #include "v1/types.hpp" +using pisa::App; using pisa::Query; using pisa::resolve_query_parser; using pisa::v1::BlockedReader; @@ -25,6 +27,8 @@ using pisa::v1::IndexMetadata; using pisa::v1::RawReader; using pisa::v1::resolve_yml; +namespace arg = pisa::arg; + auto default_readers() { return std::make_tuple(RawReader{}, @@ -70,7 +74,6 @@ template int main(int argc, char** argv) { - std::optional yml{}; std::optional terms_file{}; std::optional documents_file{}; std::string query_input{}; @@ -80,12 +83,7 @@ int main(int argc, char** argv) bool print_scores = false; bool precomputed = false; - CLI::App app{"Queries a v1 index."}; - app.add_option("-i,--index", - yml, - "Path of .yml file of an index " - "(if not provided, it will be looked for in the current directory)", - false); + App app{"Queries a v1 index."}; app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); app.add_option("--documents", documents_file, @@ -98,7 +96,7 @@ int main(int argc, char** argv) app.add_option("query", query_input, "List of terms", false)->required(); CLI11_PARSE(app, argc, argv); - auto meta = IndexMetadata::from_file(resolve_yml(yml)); + auto meta = app.index_metadata(); auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; if (tid) { terms_file = std::nullopt; diff --git a/v1/query.cpp b/v1/query.cpp index 6d8e7124d..9970d031e 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -13,6 +13,7 @@ #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" +#include "util/do_not_optimize_away.hpp" #include "v1/blocked_cursor.hpp" #include "v1/daat_or.hpp" #include "v1/index_metadata.hpp" @@ -135,13 +136,12 @@ auto resolve_inspect(std::string const& name, Index const& index, Scorer&& score } void evaluate(std::vector const& queries, - int k, pisa::Payload_Vector<> const& docmap, RetrievalAlgorithm const& retrieve) { auto query_idx = 0; for (auto const& query : queries) { - auto que = retrieve(query, pisa::topk_queue(k)); + auto que = retrieve(query, pisa::topk_queue(query.k())); que.finalize(); auto rank = 0; for (auto result : que.topk()) { @@ -158,14 +158,14 @@ void evaluate(std::vector const& queries, } } -void benchmark(std::vector const& queries, int k, RetrievalAlgorithm retrieve) +void benchmark(std::vector const& queries, RetrievalAlgorithm retrieve) { std::vector times(queries.size(), std::numeric_limits::max()); for (auto run = 0; run < 5; run += 1) { for (auto query = 0; query < queries.size(); query += 1) { auto usecs = ::pisa::run_with_timer([&]() { - auto que = retrieve(queries[query], pisa::topk_queue(k)); + auto que = retrieve(queries[query], pisa::topk_queue(queries[query].k())); que.finalize(); do_not_optimize_away(que); }); @@ -212,42 +212,14 @@ int main(int argc, char** argv) CLI11_PARSE(app, argc, argv); try { - auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); - auto stemmer = - meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - if (meta.term_lexicon) { - app.terms_file = meta.term_lexicon.value(); - } - if (meta.document_lexicon) { - app.documents_file = meta.document_lexicon.value(); - } - - auto queries = [&]() { - std::vector<::pisa::Query> queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } - std::vector v1_queries(queries.size()); - std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { - Query query(parsed.terms); - if (parsed.id) { - query.id(*parsed.id); - } - query.k(app.k); - return query; - }); - return v1_queries; - }(); + auto meta = app.index_metadata(); + auto queries = app.queries(meta); - if (not app.documents_file) { + if (not meta.document_lexicon) { spdlog::error("Document lexicon not defined"); std::exit(1); } - auto source = std::make_shared(app.documents_file.value().c_str()); + auto source = std::make_shared(meta.document_lexicon.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); if (threshold_file) { @@ -278,20 +250,19 @@ int main(int argc, char** argv) } } - if (app.precomputed) { + if (app.use_quantized()) { auto run = scored_index_runner(meta, RawReader{}, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto&& index) { - if (app.is_benchmark) { - benchmark(queries, app.k, resolve_algorithm(algorithm, index, VoidScorer{})); + if (app.is_benchmark()) { + benchmark(queries, resolve_algorithm(algorithm, index, VoidScorer{})); } else if (inspect) { inspect_queries(queries, resolve_inspect(algorithm, index, VoidScorer{})); } else { - evaluate( - queries, app.k, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); + evaluate(queries, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); } }); } else { @@ -302,13 +273,12 @@ int main(int argc, char** argv) run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { - if (app.is_benchmark) { - benchmark(queries, app.k, resolve_algorithm(algorithm, index, scorer)); + if (app.is_benchmark()) { + benchmark(queries, resolve_algorithm(algorithm, index, scorer)); } else if (inspect) { inspect_queries(queries, resolve_inspect(algorithm, index, scorer)); } else { - evaluate( - queries, app.k, docmap, resolve_algorithm(algorithm, index, scorer)); + evaluate(queries, docmap, resolve_algorithm(algorithm, index, scorer)); } }); }); diff --git a/v1/score.cpp b/v1/score.cpp index 971308684..bc47cc972 100644 --- a/v1/score.cpp +++ b/v1/score.cpp @@ -4,26 +4,24 @@ #include +#include "app.hpp" #include "v1/index_metadata.hpp" #include "v1/score_index.hpp" +using pisa::App; +namespace arg = pisa::arg; + int main(int argc, char** argv) { std::optional yml{}; int bytes_per_score = 1; std::size_t threads = std::thread::hardware_concurrency(); - CLI::App app{"Scores v1 index."}; - app.add_option("-i,--index", - yml, - "Path of .yml file of an index " - "(if not provided, it will be looked for in the current directory)", - false); - app.add_option("-j,--threads", threads, "Number of threads"); + App app{"Scores v1 index."}; // TODO(michal): enable // app.add_option( // "-b,--bytes-per-score", yml, "Quantize computed scores to this many bytes", true); CLI11_PARSE(app, argc, argv); - pisa::v1::score_index(pisa::v1::resolve_yml(yml), threads); + pisa::v1::score_index(app.index_metadata(), app.threads()); return 0; } diff --git a/v1/threshold.cpp b/v1/threshold.cpp index a33fecfa8..91b702b04 100644 --- a/v1/threshold.cpp +++ b/v1/threshold.cpp @@ -32,11 +32,11 @@ using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; template -void calculate_thresholds(Index&& index, Scorer&& scorer, std::vector const& queries, int k) +void calculate_thresholds(Index&& index, Scorer&& scorer, std::vector const& queries) { for (auto const& query : queries) { - auto results = - pisa::v1::daat_or(query, index, ::pisa::topk_queue(k), std::forward(scorer)); + auto results = pisa::v1::daat_or( + query, index, ::pisa::topk_queue(query.k()), std::forward(scorer)); results.finalize(); float threshold = 0.0; if (not results.topk().empty()) { @@ -54,41 +54,16 @@ int main(int argc, char** argv) CLI11_PARSE(app, argc, argv); try { - auto meta = IndexMetadata::from_file(resolve_yml(app.yml)); - auto stemmer = - meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - if (meta.term_lexicon) { - app.terms_file = meta.term_lexicon.value(); - } - - auto queries = [&]() { - std::vector<::pisa::Query> queries; - auto parse_query = resolve_query_parser(queries, app.terms_file, std::nullopt, stemmer); - if (app.query_file) { - std::ifstream is(*app.query_file); - pisa::io::for_each_line(is, parse_query); - } else { - pisa::io::for_each_line(std::cin, parse_query); - } - std::vector v1_queries(queries.size()); - std::transform(queries.begin(), queries.end(), v1_queries.begin(), [&](auto&& parsed) { - Query query(parsed.terms); - if (parsed.id) { - query.id(*parsed.id); - } - query.k(app.k); - return query; - }); - return v1_queries; - }(); + auto meta = app.index_metadata(); + auto queries = app.queries(meta); - if (app.precomputed) { + if (app.use_quantized()) { auto run = scored_index_runner(meta, RawReader{}, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto&& index) { calculate_thresholds(index, VoidScorer{}, queries, app.k); }); + run([&](auto&& index) { calculate_thresholds(index, VoidScorer{}, queries); }); } else { auto run = index_runner(meta, RawReader{}, @@ -96,9 +71,8 @@ int main(int argc, char** argv) BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); - with_scorer("bm25", [&](auto scorer) { - calculate_thresholds(index, scorer, queries, app.k); - }); + with_scorer("bm25", + [&](auto scorer) { calculate_thresholds(index, scorer, queries); }); }); } } catch (std::exception const& error) { From 92dec029e9d4ad12c32939770367fed8e1b6cbad Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 18 Dec 2019 13:42:14 +0000 Subject: [PATCH 34/56] Fixes to filtering queries --- include/pisa/v1/query.hpp | 3 +++ src/v1/query.cpp | 18 ++++++++++++++++++ v1/app.hpp | 2 +- v1/filter_queries.cpp | 26 ++++++++------------------ 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 560150a33..8b2add33d 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -97,6 +98,8 @@ struct Query { } void add_selections(gsl::span const> selections); + + [[nodiscard]] auto to_json() const -> nlohmann::json; [[nodiscard]] static auto from_json(std::string_view) -> Query; [[nodiscard]] static auto from_plain(std::string_view) -> Query; diff --git a/src/v1/query.cpp b/src/v1/query.cpp index 0494f47e0..bc22c717a 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -125,6 +125,24 @@ template } } +[[nodiscard]] auto Query::to_json() const -> nlohmann::json +{ + json query; + if (m_id) { + query["id"] = *m_id; + } + if (m_raw_string) { + query["query"] = *m_raw_string; + } + // TODO(michal) + // tl::optional m_term_ids{}; + // tl::optional m_selections{}; + // tl::optional m_threshold{}; + // tl::optional m_raw_string; + // int m_k = 1000; + return query; +} + Query::Query(std::vector term_ids, tl::optional id) : m_term_ids(std::move(term_ids)), m_id(std::move(id)) { diff --git a/v1/app.hpp b/v1/app.hpp index 1ec096637..e2987ca3b 100644 --- a/v1/app.hpp +++ b/v1/app.hpp @@ -49,7 +49,7 @@ namespace arg { [[nodiscard]] auto query_file() -> tl::optional { if (m_query_file) { - tl::make_optional(m_query_file.value()); + return m_query_file.value(); } return tl::nullopt; } diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index 3376d08da..d2462754b 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -35,30 +35,20 @@ using pisa::v1::RawReader; using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; +namespace arg = pisa::arg; + int main(int argc, char** argv) { - pisa::QueryApp app("Filters out empty queries against a v1 index."); + pisa::App> app( + "Filters out empty queries against a v1 index."); CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); - auto stemmer = meta.stemmer ? std::make_optional(*meta.stemmer) : std::optional{}; - std::optional term_lexicon = std::nullopt; - if (meta.term_lexicon) { - term_lexicon = *meta.term_lexicon; - } - - auto term_processor = TermProcessor(term_lexicon, {}, stemmer); - auto filter = [&](auto&& line) { - auto query = parse_query_terms(line, term_processor); - if (not query.terms.empty()) { - std::cout << line << '\n'; + auto queries = app.queries(meta); + for (auto&& query : queries) { + if (query.term_ids()) { + std::cout << query.to_json() << '\n'; } - }; - if (app.query_file()) { - std::ifstream is(*app.query_file()); - pisa::io::for_each_line(is, filter); - } else { - pisa::io::for_each_line(std::cin, filter); } return 0; } From baeb423c41a9b2528bc2000a22edf5eebf6e3658 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 18 Dec 2019 13:43:11 +0000 Subject: [PATCH 35/56] Update script --- script/cw09b.sh | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/script/cw09b.sh b/script/cw09b.sh index 52323885f..753140475 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -1,16 +1,17 @@ PISA_BIN="/home/michal/pisa/build/bin" INTERSECT_BIN="/home/michal/intersect/target/release/intersect" -BINARY_FREQ_COLL="/data/michal/work/cw09b/inv" -FWD="/data/michal/work/cw09b/fwd" -INV="/data/michal/work/cw09b/inv" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +INV="/home/amallia/cw09b/CW09B" BASENAME="/data/michal/work/v1/cw09b/cw09b" THREADS=4 TYPE="block_simdbp" # v0.6 ENCODING="simdbp" # v1 -QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +QUERIES="/home/michal/topics.web.51-200.jl" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b" -FILTERED_QUERIES="${OUTPUT_DIR}/filtered_queries" +FILTERED_QUERIES="${OUTPUT_DIR}/topics.web.51-200.filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" THRESHOLDS="${OUTPUT_DIR}/thresholds" @@ -35,22 +36,23 @@ set -x #cut -d: -f1 ${FILTERED_QUERIES} | paste - ${THRESHOLDS} > ${OUTPUT_DIR}/thresholds.tsv # Extract intersections -#${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ -# -w ${INV}.wand -q ${FILTERED_QUERIES} --combinations --terms "${FWD}.termlex" --stemmer porter2 \ -# | grep -v "\[warning\]" \ -# > ${OUTPUT_DIR}/intersections.tsv +${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ + -w ${INV}.wand -q <(jq '.id + ":" + .query' ${FILTERED_QUERIES} -r) \ + --combinations --terms "${FWD}.termlex" --stemmer porter2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.tsv # Select unigrams -#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ -# --terse --max 1 > ${OUTPUT_DIR}/selections.1 -#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ -# --terse --max 2 > ${OUTPUT_DIR}/selections.2 -#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ -# --terse --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 -#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ -# --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -#${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ -# --terse --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ + --terse --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart # Run benchmarks ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore From ac51df98bfcdad541f43e3bf0a056715bd335093 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 18 Dec 2019 18:02:41 +0000 Subject: [PATCH 36/56] Test fixes after merge --- include/pisa/intersection.hpp | 3 +- include/pisa/payload_vector.hpp | 29 ++++++------- .../algorithm/block_max_maxscore_query.hpp | 2 + test/CMakeLists.txt | 6 --- test/test_bmw_queries.cpp | 33 ++++++++++---- test/test_forward_index_builder.cpp | 43 ++++++++++--------- test/test_ranked_queries.cpp | 36 ++++++++++------ test/test_wand_data.cpp | 19 ++++---- test/v1/test_v1_queries.cpp | 10 +++-- 9 files changed, 106 insertions(+), 75 deletions(-) diff --git a/include/pisa/intersection.hpp b/include/pisa/intersection.hpp index 7f69df6d3..e149b4022 100644 --- a/include/pisa/intersection.hpp +++ b/include/pisa/intersection.hpp @@ -6,6 +6,7 @@ #include "query/algorithm/and_query.hpp" #include "query/queries.hpp" +#include "scorer/scorer.hpp" namespace pisa { @@ -64,7 +65,7 @@ inline auto Intersection::compute(Index const &index, { auto filtered_query = term_mask ? intersection::filter(query, *term_mask) : query; scored_and_query retrieve{}; - auto scorer = scorer::from_name("bm25", wand); + auto scorer = ::pisa::scorer::from_name("bm25", wand); auto results = retrieve(make_scored_cursors(index, *scorer, filtered_query), index.num_docs()); auto max_element = [&](auto const &vec) -> float { auto order = [](auto const &lhs, auto const &rhs) { return lhs.second < rhs.second; }; diff --git a/include/pisa/payload_vector.hpp b/include/pisa/payload_vector.hpp index 0e8d8f593..c18342072 100644 --- a/include/pisa/payload_vector.hpp +++ b/include/pisa/payload_vector.hpp @@ -153,25 +153,24 @@ struct Payload_Vector_Buffer { std::vector const offsets; std::vector const payloads; - [[nodiscard]] static auto from_file(std::string const& filename) -> Payload_Vector_Buffer - { - boost::system::error_code ec; - auto file_size = boost::filesystem::file_size(boost::filesystem::path(filename)); - std::ifstream is(filename); + //[[nodiscard]] static auto from_file(std::string const& filename) -> Payload_Vector_Buffer + //{ + // auto file_size = boost::filesystem::file_size(boost::filesystem::path(filename)); + // std::ifstream is(filename); - size_type len; - is.read(reinterpret_cast(&len), sizeof(size_type)); + // size_type len; + // is.read(reinterpret_cast(&len), sizeof(size_type)); - auto offsets_bytes = (len + 1) * sizeof(size_type); - std::vector offsets(len + 1); - is.read(reinterpret_cast(offsets.data()), offsets_bytes); + // auto offsets_bytes = (len + 1) * sizeof(size_type); + // std::vector offsets(len + 1); + // is.read(reinterpret_cast(offsets.data()), offsets_bytes); - auto payloads_bytes = file_size - offsets_bytes - sizeof(size_type); - std::vector payloads(payloads_bytes); - is.read(reinterpret_cast(payloads.data()), payloads_bytes); + // auto payloads_bytes = file_size - offsets_bytes - sizeof(size_type); + // std::vector payloads(payloads_bytes); + // is.read(reinterpret_cast(payloads.data()), payloads_bytes); - return Payload_Vector_Buffer{std::move(offsets), std::move(payloads)}; - } + // return Payload_Vector_Buffer{std::move(offsets), std::move(payloads)}; + //} void to_file(std::string const& filename) const { diff --git a/include/pisa/query/algorithm/block_max_maxscore_query.hpp b/include/pisa/query/algorithm/block_max_maxscore_query.hpp index 95303ac1a..b471c2a57 100644 --- a/include/pisa/query/algorithm/block_max_maxscore_query.hpp +++ b/include/pisa/query/algorithm/block_max_maxscore_query.hpp @@ -1,7 +1,9 @@ #pragma once #include + #include "query/queries.hpp" +#include "topk_queue.hpp" namespace pisa { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 40bf2b4cb..6fd510dbf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,10 +19,4 @@ foreach(TEST_SRC ${TEST_SOURCES}) add_coverage(${TEST_SRC_NAME}) endforeach(TEST_SRC) -target_link_libraries(test_v1 - pisa - Catch2 - optional -) - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/test_data DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index 4623d996c..cf7b283be 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -9,7 +9,21 @@ #include "cursor/max_scored_cursor.hpp" #include "index_types.hpp" #include "pisa_config.hpp" +#include "query/algorithm/and_query.hpp" +#include "query/algorithm/block_max_maxscore_query.hpp" +#include "query/algorithm/block_max_ranked_and_query.hpp" +#include "query/algorithm/block_max_wand_query.hpp" +#include "query/algorithm/maxscore_query.hpp" +#include "query/algorithm/or_query.hpp" +#include "query/algorithm/ranked_and_query.hpp" +#include "query/algorithm/ranked_or_query.hpp" +#include "query/algorithm/ranked_or_taat_query.hpp" +#include "query/algorithm/wand_query.hpp" #include "query/queries.hpp" +#include "topk_queue.hpp" +#include "wand_data.hpp" +#include "wand_data_compressed.hpp" +#include "wand_data_range.hpp" using namespace pisa; @@ -21,8 +35,7 @@ struct IndexData { static std::unordered_map> data; - - IndexData(std::string const &scorer_name, std::unordered_set const &dropped_term_ids) + IndexData(std::string const& scorer_name, std::unordered_set const& dropped_term_ids) : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), wdata(document_sizes.begin()->begin(), @@ -34,7 +47,7 @@ struct IndexData { { typename Index::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( @@ -43,7 +56,7 @@ struct IndexData { builder.build(index); term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -56,7 +69,8 @@ struct IndexData { std::vector queries; WandTypePlain wdata; - [[nodiscard]] static auto get(std::string const &s_name, std::unordered_set const &dropped_term_ids) + [[nodiscard]] static auto get(std::string const& s_name, + std::unordered_set const& dropped_term_ids) { if (IndexData::data.find(s_name) == IndexData::data.end()) { IndexData::data[s_name] = std::make_unique>(s_name, dropped_term_ids); @@ -69,7 +83,7 @@ template std::unordered_map>> IndexData::data = {}; template -auto test(Wand &wdata, std::string const &s_name) +auto test(Wand& wdata, std::string const& s_name) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); @@ -77,8 +91,9 @@ auto test(Wand &wdata, std::string const &s_name) wand_query wand_q(10); auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { - wand_q(make_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); + for (auto const& q : data->queries) { + wand_q(make_max_scored_cursors(data->index, data->wdata, *scorer, q), + data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, wdata, *scorer, q), data->index.num_docs()); REQUIRE(wand_q.topk().size() == op_q.topk().size()); @@ -92,7 +107,7 @@ auto test(Wand &wdata, std::string const &s_name) TEST_CASE("block_max_wand", "[bmw][query][ranked][integration]", ) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); diff --git a/test/test_forward_index_builder.cpp b/test/test_forward_index_builder.cpp index 26f3ddbc8..cbd5ec1b1 100644 --- a/test/test_forward_index_builder.cpp +++ b/test/test_forward_index_builder.cpp @@ -3,13 +3,14 @@ #include #include -#include #include #include #include +#include #include #include +#include "binary_collection.hpp" #include "filesystem.hpp" #include "forward_index_builder.hpp" #include "parsing/html.hpp" @@ -58,7 +59,7 @@ TEST_CASE("Write header", "[parsing][forward_index]") } } -[[nodiscard]] std::vector load_lines(std::istream &is) +[[nodiscard]] std::vector load_lines(std::istream& is) { std::string line; std::vector vec; @@ -68,22 +69,22 @@ TEST_CASE("Write header", "[parsing][forward_index]") return vec; } -[[nodiscard]] std::vector load_lines(std::string const &filename) +[[nodiscard]] std::vector load_lines(std::string const& filename) { std::ifstream is(filename); return load_lines(is); } template -void write_lines(std::ostream &os, gsl::span &&elements) +void write_lines(std::ostream& os, gsl::span&& elements) { - for (auto const &element : elements) { + for (auto const& element : elements) { os << element << '\n'; } } template -void write_lines(std::string const &filename, gsl::span &&elements) +void write_lines(std::string const& filename, gsl::span&& elements) { std::ofstream os(filename); write_lines(os, std::forward>(elements)); @@ -91,7 +92,7 @@ void write_lines(std::string const &filename, gsl::span &&elements) TEST_CASE("Build forward index batch", "[parsing][forward_index]") { - auto identity = [](std::string const &term) -> std::string { return term; }; + auto identity = [](std::string const& term) -> std::string { return term; }; GIVEN("a few test records") { @@ -149,10 +150,10 @@ TEST_CASE("Build forward index batch", "[parsing][forward_index]") } } -void write_batch(std::string const &basename, - std::vector const &documents, - std::vector const &terms, - std::vector> const &collection) +void write_batch(std::string const& basename, + std::vector const& documents, + std::vector const& terms, + std::vector> const& collection) { std::string document_file = basename + ".documents"; std::string term_file = basename + ".terms"; @@ -160,7 +161,7 @@ void write_batch(std::string const &basename, write_lines(term_file, gsl::make_span(terms)); std::ofstream os(basename); Forward_Index_Builder::write_header(os, collection.size()); - for (auto const &seq : collection) { + for (auto const& seq : collection) { Forward_Index_Builder::write_document(os, seq.begin(), seq.end()); } } @@ -271,7 +272,7 @@ TEST_CASE("Merge forward index batches", "[parsing][forward_index]") TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") { std::vector vec; - auto map_word = [&](std::string &&word) { vec.push_back(word); }; + auto map_word = [&](std::string&& word) { vec.push_back(word); }; SECTION("empty") { parse_html_content( @@ -301,7 +302,7 @@ TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") } } -[[nodiscard]] auto load_term_map(std::string const &basename) -> std::vector +[[nodiscard]] auto load_term_map(std::string const& basename) -> std::vector { std::vector map; std::ifstream is(basename + ".terms"); @@ -315,7 +316,7 @@ TEST_CASE("Parse HTML content", "[parsing][forward_index][unit]") TEST_CASE("Build forward index", "[parsing][forward_index][integration]") { tbb::task_scheduler_init init; - auto next_record = [](std::istream &in) -> std::optional { + auto next_record = [](std::istream& in) -> std::optional { Plaintext_Record record; if (in >> record) { return Document_Record(record.trecid(), record.content(), record.url()); @@ -341,7 +342,7 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") is, output, next_record, - [](std::string &&term) -> std::string { return std::forward(term); }, + [](std::string&& term) -> std::string { return std::forward(term); }, parse_plaintext_content, batch_size, thread_count); @@ -349,8 +350,8 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") THEN("The collection mapped to terms matches input") { auto term_map = load_term_map(output); - auto term_lexicon_buffer = Payload_Vector_Buffer::from_file(output + ".termlex"); - auto term_lexicon = Payload_Vector(term_lexicon_buffer); + mio::mmap_source m((output + ".termlex").c_str()); + auto term_lexicon = Payload_Vector<>::from(m); REQUIRE(std::vector(term_lexicon.begin(), term_lexicon.end()) == term_map); binary_collection coll((output).c_str()); @@ -373,7 +374,7 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") REQUIRE(produced_body == original_body); ++seq_iter; } - auto batch_files = ls(dir, [](auto const &filename) { + auto batch_files = ls(dir, [](auto const& filename) { return filename.find("batch") != std::string::npos; }); REQUIRE(batch_files.empty()); @@ -381,8 +382,8 @@ TEST_CASE("Build forward index", "[parsing][forward_index][integration]") AND_THEN("Document lexicon contains the same titles as text file") { auto documents = io::read_string_vector(output + ".documents"); - auto doc_lexicon_buffer = Payload_Vector_Buffer::from_file(output + ".doclex"); - auto doc_lexicon = Payload_Vector(doc_lexicon_buffer); + mio::mmap_source m((output + ".doclex").c_str()); + auto doc_lexicon = Payload_Vector<>::from(m); REQUIRE(std::vector(doc_lexicon.begin(), doc_lexicon.end()) == documents); } diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index 331536e8d..e247c4ffd 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -13,6 +13,17 @@ #include "cursor/scored_cursor.hpp" #include "index_types.hpp" #include "pisa_config.hpp" +#include "query/algorithm/and_query.hpp" +#include "query/algorithm/block_max_maxscore_query.hpp" +#include "query/algorithm/block_max_ranked_and_query.hpp" +#include "query/algorithm/block_max_wand_query.hpp" +#include "query/algorithm/maxscore_query.hpp" +#include "query/algorithm/or_query.hpp" +#include "query/algorithm/range_query.hpp" +#include "query/algorithm/ranked_and_query.hpp" +#include "query/algorithm/ranked_or_query.hpp" +#include "query/algorithm/ranked_or_taat_query.hpp" +#include "query/algorithm/wand_query.hpp" #include "query/queries.hpp" using namespace pisa; @@ -22,7 +33,7 @@ struct IndexData { static std::unordered_map> data; - IndexData(std::string const &scorer_name, std::unordered_set const &dropped_term_ids) + IndexData(std::string const& scorer_name, std::unordered_set const& dropped_term_ids) : collection(PISA_SOURCE_DIR "/test/test_data/test_collection"), document_sizes(PISA_SOURCE_DIR "/test/test_data/test_collection.sizes"), wdata(document_sizes.begin()->begin(), @@ -35,7 +46,7 @@ struct IndexData { { tbb::task_scheduler_init init; typename Index::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( @@ -45,7 +56,7 @@ struct IndexData { term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); - auto push_query = [&](std::string const &query_line) { + auto push_query = [&](std::string const& query_line) { queries.push_back(parse_query_ids(query_line)); }; io::for_each_line(qfile, push_query); @@ -53,7 +64,8 @@ struct IndexData { std::string t; } - [[nodiscard]] static auto get(std::string const &s_name, std::unordered_set const &dropped_term_ids) + [[nodiscard]] static auto get(std::string const& s_name, + std::unordered_set const& dropped_term_ids) { if (IndexData::data.find(s_name) == IndexData::data.end()) { IndexData::data[s_name] = std::make_unique>(s_name, dropped_term_ids); @@ -78,7 +90,7 @@ class ranked_or_taat_query_acc : public ranked_or_taat_query { using ranked_or_taat_query::ranked_or_taat_query; template - uint64_t operator()(CursorRange &&cursors, uint64_t max_docid) + uint64_t operator()(CursorRange&& cursors, uint64_t max_docid) { Acc accumulator(max_docid); return ranked_or_taat_query::operator()(cursors, max_docid, accumulator); @@ -91,7 +103,7 @@ class range_query_128 : public range_query { using range_query::range_query; template - uint64_t operator()(CursorRange &&cursors, uint64_t max_docid) + uint64_t operator()(CursorRange&& cursors, uint64_t max_docid) { return range_query::operator()(cursors, max_docid, 128); } @@ -112,14 +124,14 @@ TEMPLATE_TEST_CASE("Ranked query test", range_query_128, range_query_128) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); TestType op_q(10); ranked_or_query or_q(10); auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { or_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); @@ -136,7 +148,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", "[query][ranked][integration]", block_max_ranked_and_query) { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); TestType op_q(10); @@ -144,7 +156,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { and_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); op_q(make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); @@ -159,7 +171,7 @@ TEMPLATE_TEST_CASE("Ranked AND query test", TEST_CASE("Top k") { - for (auto &&s_name : {"bm25", "qld"}) { + for (auto&& s_name : {"bm25", "qld"}) { std::unordered_set dropped_term_ids; auto data = IndexData::get(s_name, dropped_term_ids); ranked_or_query or_10(10); @@ -167,7 +179,7 @@ TEST_CASE("Top k") auto scorer = scorer::from_name(s_name, data->wdata); - for (auto const &q : data->queries) { + for (auto const& q : data->queries) { or_10(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); or_1(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); if (not or_10.topk().empty()) { diff --git a/test/test_wand_data.cpp b/test/test_wand_data.cpp index 734f38db5..99f2ffc45 100644 --- a/test/test_wand_data.cpp +++ b/test/test_wand_data.cpp @@ -11,7 +11,10 @@ #include "index_types.hpp" #include "pisa_config.hpp" #include "query/queries.hpp" +#include "wand_data.hpp" +#include "wand_data_compressed.hpp" #include "wand_data_range.hpp" +#include "wand_data_raw.hpp" #include "scorer/scorer.hpp" @@ -40,12 +43,12 @@ TEST_CASE("wand_data_range") SECTION("Precomputed block-max scores") { size_t term_id = 0; - for (auto const &seq : collection) { + for (auto const& seq : collection) { if (seq.docs.size() >= 1024) { auto max = wdata_range.max_term_weight(term_id); auto w = wdata_range.getenum(term_id); auto s = scorer->term_scorer(term_id); - for (auto &&[docid, freq] : ranges::views::zip(seq.docs, seq.freqs)) { + for (auto&& [docid, freq] : ranges::views::zip(seq.docs, seq.freqs)) { float score = s(docid, freq); w.next_geq(docid); CHECKED_ELSE(w.score() >= score) @@ -63,7 +66,7 @@ TEST_CASE("wand_data_range") index_type index; global_parameters params; index_type::builder builder(collection.num_docs(), params); - for (auto const &plist : collection) { + for (auto const& plist : collection) { uint64_t freqs_sum = std::accumulate(plist.freqs.begin(), plist.freqs.end(), uint64_t(0)); builder.add_posting_list( plist.docs.size(), plist.docs.begin(), plist.freqs.begin(), freqs_sum); @@ -73,15 +76,15 @@ TEST_CASE("wand_data_range") SECTION("Compute at run time") { size_t term_id = 0; - for (auto const &seq : collection) { + for (auto const& seq : collection) { auto list = index[term_id]; if (seq.docs.size() < 1024) { auto max = wdata_range.max_term_weight(term_id); - auto &w = wdata_range.get_block_wand(); + auto& w = wdata_range.get_block_wand(); auto s = scorer->term_scorer(term_id); const mapper::mappable_vector bm = w.compute_block_max_scores(list, s); WandTypeRange::enumerator we(0, bm); - for (auto &&[pos, docid, freq] : + for (auto&& [pos, docid, freq] : ranges::views::zip(ranges::views::iota(0), seq.docs, seq.freqs)) { float score = s(docid, freq); we.next_geq(docid); @@ -101,7 +104,7 @@ TEST_CASE("wand_data_range") { size_t i = 0; std::vector enums; - for (auto const &seq : collection) { + for (auto const& seq : collection) { if (seq.docs.size() >= 1024) { enums.push_back(wdata_range.getenum(i)); } @@ -114,7 +117,7 @@ TEST_CASE("wand_data_range") for (int i = 0; i < live_blocks.size(); ++i) { if (live_blocks[i] == 0) { - for (auto &&e : enums) { + for (auto&& e : enums) { e.next_geq(i * 64); REQUIRE(e.score() == 0); } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index daf11c708..841d67ac0 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -19,6 +19,7 @@ #include "pisa_config.hpp" #include "query/algorithm/ranked_or_query.hpp" #include "query/queries.hpp" +#include "scorer/bm25.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/cursor_intersection.hpp" @@ -60,7 +61,9 @@ struct IndexData { wdata(document_sizes.begin()->begin(), collection.num_docs(), collection, - BlockSize(FixedBlock())) + "bm25", + BlockSize(FixedBlock()), + {}) { typename v0_Index::builder builder(collection.num_docs(), params); @@ -186,8 +189,9 @@ TEMPLATE_TEST_CASE("Query", CAPTURE(idx); CAPTURE(intersections[idx]); - or_q(make_scored_cursors( - data->v0_index, data->wdata, ::pisa::Query{{}, query.get_term_ids(), {}}), + or_q(make_scored_cursors(data->v0_index, + ::pisa::bm25>(data->wdata), + ::pisa::Query{{}, query.get_term_ids(), {}}), data->v0_index.num_docs()); auto expected = or_q.topk(); if (with_threshold) { From 6a457f611b9fa4decbcb904816894b8d980e1c5a Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Thu, 19 Dec 2019 22:51:47 +0000 Subject: [PATCH 37/56] Intersections with JSON --- CMakeLists.txt | 2 +- include/pisa/v1/default_index_runner.hpp | 4 + include/pisa/v1/query.hpp | 1 + script/cw09b.sh | 109 +++++++++---------- src/v1/query.cpp | 47 ++++---- v1/CMakeLists.txt | 3 + v1/app.hpp | 4 +- v1/intersection.cpp | 131 +++++++++++++++++++++++ v1/query.cpp | 32 ------ v1/threshold.cpp | 37 +++++-- 10 files changed, 251 insertions(+), 119 deletions(-) create mode 100644 v1/intersection.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c3cc78cda..f0642da26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,7 +111,7 @@ add_subdirectory(src) if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() - #add_subdirectory(test) + add_subdirectory(test) add_subdirectory(test/v1) endif() diff --git a/include/pisa/v1/default_index_runner.hpp b/include/pisa/v1/default_index_runner.hpp index eaefaf027..be24db39b 100644 --- a/include/pisa/v1/default_index_runner.hpp +++ b/include/pisa/v1/default_index_runner.hpp @@ -1,6 +1,9 @@ #pragma once +#include "index_types.hpp" +#include "v1/blocked_cursor.hpp" #include "v1/index.hpp" +#include "v1/raw_cursor.hpp" namespace pisa::v1 { @@ -9,4 +12,5 @@ using DefaultIndexRunner = IndexRunner{}, RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}>; + } diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 8b2add33d..6d0665a16 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -99,6 +99,7 @@ struct Query { void add_selections(gsl::span const> selections); + [[nodiscard]] auto filtered_terms(std::bitset<64> selection) const -> std::vector; [[nodiscard]] auto to_json() const -> nlohmann::json; [[nodiscard]] static auto from_json(std::string_view) -> Query; [[nodiscard]] static auto from_plain(std::string_view) -> Query; diff --git a/script/cw09b.sh b/script/cw09b.sh index 753140475..80db3766c 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -12,6 +12,7 @@ QUERIES="/home/michal/topics.web.51-200.jl" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b" FILTERED_QUERIES="${OUTPUT_DIR}/topics.web.51-200.filtered" +FILTERED_QUERIES="${OUTPUT_DIR}/topics.web.51-200.filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" THRESHOLDS="${OUTPUT_DIR}/thresholds" @@ -31,88 +32,84 @@ set -x #${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" > ${FILTERED_QUERIES} # Extract thresholds (TODO: estimates) -#${PISA_BIN}/threshold -i ${BASENAME}.yml -q ${FILTERED_QUERIES} -k ${K} \ -# | grep -v "\[warning\]" > ${THRESHOLDS} -#cut -d: -f1 ${FILTERED_QUERIES} | paste - ${THRESHOLDS} > ${OUTPUT_DIR}/thresholds.tsv +#${PISA_BIN}/threshold -i ${BASENAME}.yml -q ${FILTERED_QUERIES} -k ${K} --in-place + +#####jq '.id + "\t" + (.threshold | tostring)' ${FILTERED_QUERIES} -r > ${OUTPUT_DIR}/thresholds.tsv # Extract intersections -${PISA_BIN}/compute_intersection -t ${TYPE} -i ${INV}.${TYPE} \ - -w ${INV}.wand -q <(jq '.id + ":" + .query' ${FILTERED_QUERIES} -r) \ - --combinations --terms "${FWD}.termlex" --stemmer porter2 \ - | grep -v "\[warning\]" \ - > ${OUTPUT_DIR}/intersections.tsv +#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl # Select unigrams -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 1 > ${OUTPUT_DIR}/selections.1 -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 > ${OUTPUT_DIR}/selections.2 -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -${INTERSECT_BIN} -t ${OUTPUT_DIR}/thresholds.tsv -m graph-greedy ${OUTPUT_DIR}/intersections.tsv \ - --terse --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart # Run benchmarks -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ - --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-threshold + > ${OUTPUT_DIR}/bench.maxscore-threshold ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ - --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/bench.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm unigram-union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/bench.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.union-lookup.2 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm lookup-union \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/bench.lookup-union + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ + > ${OUTPUT_DIR}/bench.union-lookup.2 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/bench.union-lookup.scaled-1.5 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/bench.union-lookup.scaled-1.5 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/bench.union-lookup.scaled-2 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/bench.union-lookup.scaled-2 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart +# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart # Analyze -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore \ -# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-thresholds +# > ${OUTPUT_DIR}/stats.maxscore-thresholds #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore-union-lookup \ -# --thresholds ${THRESHOLDS} > ${OUTPUT_DIR}/stats.maxscore-union-lookup +# > ${OUTPUT_DIR}/stats.maxscore-union-lookup #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup +# --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 +# --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm lookup-union \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union +# --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 +# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart +# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart ## Evaluate -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ - --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-threshold" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ - --thresholds ${THRESHOLDS} > "${OUTPUT_DIR}/eval.maxscore-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm unigram-union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.unigram-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.union-lookup.2" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm lookup-union \ - --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.lookup-union" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ +# > "${OUTPUT_DIR}/eval.maxscore-threshold" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ +# > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm unigram-union-lookup \ +# --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.unigram-union-lookup" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ +# --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.union-lookup.2" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm lookup-union \ +# --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.lookup-union" #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" +# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" +# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" +# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --thresholds ${THRESHOLDS} --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" +# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" diff --git a/src/v1/query.cpp b/src/v1/query.cpp index bc22c717a..00e601804 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -41,6 +41,19 @@ void Query::add_selections(gsl::span const> selections) ranges::sort(m_selections->bigrams); } +[[nodiscard]] auto Query::filtered_terms(std::bitset<64> selection) const -> std::vector +{ + auto&& term_ids = get_term_ids(); + std::vector terms; + std::vector weights; + for (std::size_t bitpos = 0; bitpos < term_ids.size(); ++bitpos) { + if (((1U << bitpos) & selection.to_ulong()) > 0) { + terms.push_back(term_ids.at(bitpos)); + } + } + return terms; +} + auto Query::resolve_term(std::size_t pos) -> TermId { if (not m_term_ids) { @@ -96,27 +109,12 @@ template if (auto k = get(query_json, "k"); k) { query.k(*k); } - if (auto pos = query_json.find("selections"); pos != query_json.end()) { - auto const& selections = *pos; - auto unigrams = get>(selections, "unigrams") - .value_or(std::vector{}); - auto bigrams = - get>>(selections, "bigrams") - .value_or(std::vector>{}); + if (auto selections = get>(query_json, "selections"); selections) { std::vector> bitsets; - std::transform( - unigrams.begin(), unigrams.end(), std::back_inserter(bitsets), [](auto idx) { - std::bitset<64> bs; - bs.set(idx); - return bs; - }); - std::transform( - bigrams.begin(), bigrams.end(), std::back_inserter(bitsets), [](auto bigram) { - std::bitset<64> bs; - bs.set(bigram.first); - bs.set(bigram.second); - return bs; - }); + std::transform(selections->begin(), + selections->end(), + std::back_inserter(bitsets), + [](auto selection) { return std::bitset<64>(selection); }); query.selections(gsl::span>(bitsets)); } return query; @@ -134,11 +132,14 @@ template if (m_raw_string) { query["query"] = *m_raw_string; } + if (m_term_ids) { + query["term_ids"] = m_term_ids->get(); + } + if (m_threshold) { + query["threshold"] = *m_threshold; + } // TODO(michal) - // tl::optional m_term_ids{}; // tl::optional m_selections{}; - // tl::optional m_threshold{}; - // tl::optional m_raw_string; // int m_k = 1000; return query; } diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index f16294273..f9e02ae03 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -18,3 +18,6 @@ target_link_libraries(filter-queries pisa CLI11) add_executable(threshold threshold.cpp) target_link_libraries(threshold pisa CLI11) + +add_executable(intersection intersection.cpp) +target_link_libraries(intersection pisa CLI11) diff --git a/v1/app.hpp b/v1/app.hpp index e2987ca3b..6a3d251a1 100644 --- a/v1/app.hpp +++ b/v1/app.hpp @@ -65,7 +65,9 @@ namespace arg { } return v1::Query::from_plain(line); }(); - query.parse(parser); + if (not query.term_ids()) { + query.parse(parser); + } if constexpr (Mode == QueryMode::Ranked) { query.k(m_k); } diff --git a/v1/intersection.cpp b/v1/intersection.cpp new file mode 100644 index 000000000..fa662b8af --- /dev/null +++ b/v1/intersection.cpp @@ -0,0 +1,131 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "app.hpp" +#include "intersection.hpp" +#include "query/queries.hpp" +#include "v1/blocked_cursor.hpp" +#include "v1/cursor/for_each.hpp" +#include "v1/cursor_intersection.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" + +using pisa::App; +using pisa::Intersection; +using pisa::intersection::IntersectionType; +using pisa::intersection::Mask; +using pisa::v1::BlockedReader; +using pisa::v1::intersect; +using pisa::v1::make_bm25; +using pisa::v1::RawReader; + +namespace arg = pisa::arg; +namespace v1 = pisa::v1; + +template +auto compute_intersection(Index const& index, + pisa::v1::Query const& query, + tl::optional> term_selection) -> Intersection +{ + auto const term_ids = + term_selection ? query.filtered_terms(*term_selection) : query.get_term_ids(); + auto cursors = index.scored_cursors(gsl::make_span(term_ids), make_bm25(index)); + auto intersection = + intersect(cursors, 0.0F, [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + }); + return Intersection{postings, max_score}; +} + +/// Do `func` for all intersections in a query that have a given maximum number of terms. +/// `Fn` takes `Query` and `Mask`. +template +auto for_all_subsets(v1::Query const& query, tl::optional max_term_count, Fn func) +{ + auto&& term_ids = query.get_term_ids(); + auto subset_count = 1U << term_ids.size(); + for (auto subset = 1U; subset < subset_count; ++subset) { + auto mask = std::bitset<64>(subset); + if (!max_term_count || mask.count() <= *max_term_count) { + func(query, mask); + } + } +} + +template +void compute_intersections(Index const& index, + std::vector const& queries, + IntersectionType intersection_type, + tl::optional max_term_count) +{ + for (auto const& query : queries) { + auto intersections = nlohmann::json::array(); + auto inter = [&](auto&& query, tl::optional> const& mask) { + auto intersection = compute_intersection(index, query, mask); + intersections.push_back(nlohmann::json{{"intersection", mask.value_or(0).to_ulong()}, + {"cost", intersection.length}, + {"max_score", intersection.max_score}}); + }; + if (intersection_type == IntersectionType::Combinations) { + for_all_subsets(query, max_term_count, inter); + } else { + inter(query, tl::nullopt); + } + auto output = query.to_json(); + output["intersections"] = intersections; + std::cout << output << '\n'; + } +} + +int main(int argc, const char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool combinations = false; + std::optional max_term_count; + + pisa::App> app( + "Calculates intersections for a v1 index."); + auto* combinations_flag = app.add_flag( + "--combinations", combinations, "Compute intersections for combinations of terms in query"); + app.add_option("--max-term-count,--mtc", + max_term_count, + "Max number of terms when computing combinations") + ->needs(combinations_flag); + CLI11_PARSE(app, argc, argv); + auto mtc = max_term_count ? tl::make_optional(*max_term_count) : tl::optional{}; + + IntersectionType intersection_type = + combinations ? IntersectionType::Combinations : IntersectionType::Query; + + try { + auto meta = app.index_metadata(); + auto queries = app.queries(meta); + + auto run = index_runner(meta, + // RawReader{}, + BlockedReader<::pisa::simdbp_block, true>{}, + BlockedReader<::pisa::simdbp_block, false>{}); + run([&](auto&& index) { compute_intersections(index, queries, intersection_type, mtc); }); + } catch (std::exception const& error) { + spdlog::error("{}", error.what()); + } + return 0; +} diff --git a/v1/query.cpp b/v1/query.cpp index 9970d031e..1ba9cba39 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -200,14 +200,10 @@ int main(int argc, char** argv) spdlog::set_default_logger(spdlog::stderr_color_mt("")); std::string algorithm = "daat_or"; - tl::optional threshold_file; - tl::optional inter_filename; bool inspect = false; pisa::QueryApp app("Queries a v1 index."); app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); - app.add_option("--thresholds", threshold_file, "File with (estimated) thresholds.", false); - app.add_option("--intersections", inter_filename, "Intersections filename"); app.add_flag("--inspect", inspect, "Analyze query execution and stats"); CLI11_PARSE(app, argc, argv); @@ -222,34 +218,6 @@ int main(int argc, char** argv) auto source = std::make_shared(meta.document_lexicon.value().c_str()); auto docmap = pisa::Payload_Vector<>::from(*source); - if (threshold_file) { - std::ifstream is(*threshold_file); - auto queries_iter = queries.begin(); - pisa::io::for_each_line(is, [&](auto&& line) { - if (queries_iter == queries.end()) { - spdlog::error("Number of thresholds not equal to number of queries"); - std::exit(1); - } - queries_iter->threshold(std::stof(line)); - ++queries_iter; - }); - if (queries_iter != queries.end()) { - spdlog::error("Number of thresholds not equal to number of queries"); - std::exit(1); - } - } - - if (inter_filename) { - auto const intersections = pisa::v1::read_intersections(*inter_filename); - if (intersections.size() != queries.size()) { - spdlog::error("Number of intersections is not equal to number of queries"); - std::exit(1); - } - for (auto query_idx = 0; query_idx < queries.size(); query_idx += 1) { - queries[query_idx].add_selections(gsl::make_span(intersections[query_idx])); - } - } - if (app.use_quantized()) { auto run = scored_index_runner(meta, RawReader{}, diff --git a/v1/threshold.cpp b/v1/threshold.cpp index 91b702b04..be11d1fc6 100644 --- a/v1/threshold.cpp +++ b/v1/threshold.cpp @@ -32,9 +32,12 @@ using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; template -void calculate_thresholds(Index&& index, Scorer&& scorer, std::vector const& queries) +void calculate_thresholds(Index&& index, + Scorer&& scorer, + std::vector& queries, + std::ostream& os) { - for (auto const& query : queries) { + for (auto&& query : queries) { auto results = pisa::v1::daat_or( query, index, ::pisa::topk_queue(query.k()), std::forward(scorer)); results.finalize(); @@ -42,7 +45,8 @@ void calculate_thresholds(Index&& index, Scorer&& scorer, std::vector con if (not results.topk().empty()) { threshold = results.topk().back().first; } - std::cout << threshold << '\n'; + query.threshold(threshold); + os << query.to_json() << '\n'; } } @@ -50,9 +54,17 @@ int main(int argc, char** argv) { spdlog::drop(""); spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool in_place = false; pisa::QueryApp app("Calculates thresholds for a v1 index."); + app.add_flag("--in-place", in_place, "Edit the input file"); CLI11_PARSE(app, argc, argv); + if (in_place && not app.query_file()) { + spdlog::error("Cannot edit in place when no query file passed"); + std::exit(1); + } + try { auto meta = app.index_metadata(); auto queries = app.queries(meta); @@ -63,7 +75,14 @@ int main(int argc, char** argv) RawReader{}, BlockedReader<::pisa::simdbp_block, true>{}, BlockedReader<::pisa::simdbp_block, false>{}); - run([&](auto&& index) { calculate_thresholds(index, VoidScorer{}, queries); }); + run([&](auto&& index) { + if (in_place) { + std::ofstream os(app.query_file().value()); + calculate_thresholds(index, VoidScorer{}, queries, os); + } else { + calculate_thresholds(index, VoidScorer{}, queries, std::cout); + } + }); } else { auto run = index_runner(meta, RawReader{}, @@ -71,8 +90,14 @@ int main(int argc, char** argv) BlockedReader<::pisa::simdbp_block, false>{}); run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); - with_scorer("bm25", - [&](auto scorer) { calculate_thresholds(index, scorer, queries); }); + with_scorer("bm25", [&](auto scorer) { + if (in_place) { + std::ofstream os(app.query_file().value()); + calculate_thresholds(index, scorer, queries, os); + } else { + calculate_thresholds(index, scorer, queries, std::cout); + } + }); }); } } catch (std::exception const& error) { From d428402d2141fd477466e604ae72de2700dd36eb Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 20 Dec 2019 20:15:17 +0000 Subject: [PATCH 38/56] Small fixes --- include/pisa/v1/union_lookup.hpp | 64 ++++++++++++------ script/cw09b.sh | 112 ++++++++++++------------------- v1/filter_queries.cpp | 4 ++ v1/intersection.cpp | 9 ++- 4 files changed, 97 insertions(+), 92 deletions(-) diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 6580bd8b8..aa18c1bbc 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -255,6 +255,10 @@ auto unigram_union_lookup(Query const& query, std::vector essential_cursors = index.max_scored_cursors(selections.unigrams, scorer); + if constexpr (not std::is_void_v) { + inspect->essential(essential_cursors.size()); + } + auto joined = join_union_lookup( std::move(essential_cursors), std::move(lookup_cursors), @@ -304,25 +308,28 @@ auto maxscore_union_lookup(Query const& query, upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); } std::size_t non_essential_count = 0; - while (non_essential_count < cursors.size() && upper_bounds[non_essential_count] < threshold) { + while (non_essential_count < cursors.size() && upper_bounds[non_essential_count] <= threshold) { non_essential_count += 1; } - - auto lookup_cursors = gsl::span(&cursors[0], non_essential_count); - auto essential_cursors = - gsl::span(&cursors[non_essential_count], cursors.size() - non_essential_count); - if (not lookup_cursors.empty()) { - std::reverse(lookup_cursors.begin(), lookup_cursors.end()); + if constexpr (not std::is_void_v) { + inspect->essential(cursors.size() - non_essential_count); } + std::vector essential_cursors; + std::move(std::next(cursors.begin(), non_essential_count), + cursors.end(), + std::back_inserter(essential_cursors)); + cursors.erase(std::next(cursors.begin(), non_essential_count), cursors.end()); + std::reverse(cursors.begin(), cursors.end()); + auto joined = join_union_lookup( std::move(essential_cursors), - std::move(lookup_cursors), + std::move(cursors), payload_type{}, accumulate::Add{}, [&](auto score) { return topk.would_enter(score); }, inspect); - v1::for_each(joined, [&](auto& cursor) { + v1::for_each(joined, [&](auto&& cursor) { if constexpr (not std::is_void_v) { if (topk.insert(cursor.payload(), cursor.value())) { inspect->insert(); @@ -390,12 +397,12 @@ struct LookupTransform { }; /// This algorithm... -template +template auto lookup_union(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - Inspector* inspector = nullptr) + Inspect* inspect = nullptr) { using bigram_cursor_type = std::decay_t; using lookup_cursor_type = std::decay_t; @@ -413,6 +420,10 @@ auto lookup_union(Query const& query, auto& essential_unigrams = selections.unigrams; auto& essential_bigrams = selections.bigrams; + if constexpr (not std::is_void_v) { + inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + } + auto non_essential_terms = ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; @@ -428,13 +439,13 @@ auto lookup_union(Query const& query, 0.0F, accumulate::Add{}, is_above_threshold, - inspector); + inspect); }(); using lookup_transform_type = LookupTransform; + Inspect>; using transform_payload_cursor_type = TransformPayloadCursor; @@ -461,7 +472,7 @@ auto lookup_union(Query const& query, lookup_transform_type(std::move(lookup_cursors), lookup_cursors_upper_bound, is_above_threshold, - inspector)); + inspect)); } auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { @@ -474,9 +485,9 @@ auto lookup_union(Query const& query, std::make_tuple(accumulate, accumulate)); v1::for_each(merged, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { + if constexpr (not std::is_void_v) { if (topk.insert(cursor.payload(), cursor.value())) { - inspector->insert(); + inspect->insert(); } } else { topk.insert(cursor.payload(), cursor.value()); @@ -517,6 +528,10 @@ auto union_lookup(Query const& query, std::array initial_payload{ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + if constexpr (not std::is_void_v) { + inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + } + std::vector essential_unigram_cursors; std::transform(essential_unigrams.begin(), essential_unigrams.end(), @@ -644,7 +659,7 @@ struct BaseUnionLookupInspect { BaseUnionLookupInspect(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) { - std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); + std::cout << fmt::format("documents\tpostings\tinserts\tlookups\tessential_lists\n"); } void reset_current() @@ -653,6 +668,7 @@ struct BaseUnionLookupInspect { m_current_postings = 0; m_current_lookups = 0; m_current_inserts = 0; + m_essential_lists = 0; } virtual void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) = 0; @@ -668,15 +684,17 @@ struct BaseUnionLookupInspect { reset_current(); run(query, m_index, m_scorer, topk_queue(query.k())); - std::cout << fmt::format("{}\t{}\t{}\t{}\n", + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\n", m_current_documents, m_current_postings, m_current_inserts, - m_current_lookups); + m_current_lookups, + m_current_essential_lists); m_documents += m_current_documents; m_postings += m_current_postings; m_lookups += m_current_lookups; m_inserts += m_current_inserts; + m_essential_lists += m_current_essential_lists; m_count += 1; } @@ -684,29 +702,33 @@ struct BaseUnionLookupInspect { { std::cerr << fmt::format( "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" - "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", + "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n- essential lists:\t{}\n", static_cast(m_documents) / m_count, static_cast(m_postings) / m_count, static_cast(m_inserts) / m_count, - static_cast(m_lookups) / m_count); + static_cast(m_lookups) / m_count, + static_cast(m_essential_lists) / m_count); } void document() { m_current_documents += 1; } void posting() { m_current_postings += 1; } void lookup() { m_current_lookups += 1; } void insert() { m_current_inserts += 1; } + void essential(std::size_t n) { m_current_essential_lists = n; } private: std::size_t m_current_documents = 0; std::size_t m_current_postings = 0; std::size_t m_current_lookups = 0; std::size_t m_current_inserts = 0; + std::size_t m_current_essential_lists = 0; std::size_t m_documents = 0; std::size_t m_postings = 0; std::size_t m_lookups = 0; std::size_t m_inserts = 0; std::size_t m_count = 0; + std::size_t m_essential_lists = 0; Index const& m_index; Scorer m_scorer; }; diff --git a/script/cw09b.sh b/script/cw09b.sh index 80db3766c..444baa445 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -7,37 +7,33 @@ BASENAME="/data/michal/work/v1/cw09b/cw09b" THREADS=4 TYPE="block_simdbp" # v0.6 ENCODING="simdbp" # v1 -#QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" -QUERIES="/home/michal/topics.web.51-200.jl" +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b" -FILTERED_QUERIES="${OUTPUT_DIR}/topics.web.51-200.filtered" -FILTERED_QUERIES="${OUTPUT_DIR}/topics.web.51-200.filtered" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" -THRESHOLDS="${OUTPUT_DIR}/thresholds" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" set -e set -x ## Compress an inverted index in `binary_freq_collection` format. -#./bin/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} # This will produce both quantized scores and max scores (both quantized and not). -#./bin/score -i "${BASENAME}.yml" -j ${THREADS} - -# This will produce both quantized scores and max scores (both quantized and not). -#./bin/bigram-index -i "${BASENAME}.yml" -q ${QUERIES} +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} # Filter out queries witout existing terms. -#${PISA_BIN}/filter-queries -i ${BASENAME}.yml -q ${QUERIES} | grep -v "\[warning\]" > ${FILTERED_QUERIES} - -# Extract thresholds (TODO: estimates) -#${PISA_BIN}/threshold -i ${BASENAME}.yml -q ${FILTERED_QUERIES} -k ${K} --in-place +#paste -d: ${QUERIES} ${THRESHOLDS} \ +# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ +# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} -#####jq '.id + "\t" + (.threshold | tostring)' ${FILTERED_QUERIES} -r > ${OUTPUT_DIR}/thresholds.tsv +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} # Extract intersections -#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ +#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ # | grep -v "\[warning\]" \ # > ${OUTPUT_DIR}/intersections.jl @@ -46,10 +42,6 @@ ${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTP ${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 ${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ - --max 2 --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ - --max 2 --scale-by-query-len > ${OUTPUT_DIR}/selections.2.scaled-smart # Run benchmarks ${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore @@ -60,56 +52,40 @@ ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algo ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ > ${OUTPUT_DIR}/bench.unigram-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ - > ${OUTPUT_DIR}/bench.union-lookup.2 + > ${OUTPUT_DIR}/bench.union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ > ${OUTPUT_DIR}/bench.lookup-union -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/bench.union-lookup.scaled-1.5 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/bench.union-lookup.scaled-2 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/bench.union-lookup.scaled-3 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/bench.union-lookup.scaled-smart +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 # Analyze -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --analyze --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore \ -# > ${OUTPUT_DIR}/stats.maxscore-thresholds -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm maxscore-union-lookup \ -# > ${OUTPUT_DIR}/stats.maxscore-union-lookup -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm unigram-union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.1 > ${OUTPUT_DIR}/stats.unigram-union-lookup -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.union-lookup.2 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm lookup-union \ -# --intersections ${OUTPUT_DIR}/selections.2 > ${OUTPUT_DIR}/stats.lookup-union -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > ${OUTPUT_DIR}/stats.union-lookup.scaled-1.5 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > ${OUTPUT_DIR}/stats.union-lookup.scaled-2 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > ${OUTPUT_DIR}/stats.union-lookup.scaled-3 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --analyze --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > ${OUTPUT_DIR}/stats.union-lookup.scaled-smart +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 -## Evaluate -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ -# > "${OUTPUT_DIR}/eval.maxscore-threshold" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ -# > "${OUTPUT_DIR}/eval.maxscore-union-lookup" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm unigram-union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.1 > "${OUTPUT_DIR}/eval.unigram-union-lookup" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.union-lookup.2" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm lookup-union \ -# --intersections ${OUTPUT_DIR}/selections.2 > "${OUTPUT_DIR}/eval.lookup-union" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-1.5 > "${OUTPUT_DIR}/eval.union-lookup.scale-1.5" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-2 > "${OUTPUT_DIR}/eval.union-lookup.scale-2" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-3 > "${OUTPUT_DIR}/eval.union-lookup.scale-3" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm union-lookup \ -# --intersections ${OUTPUT_DIR}/selections.2.scaled-smart > "${OUTPUT_DIR}/eval.union-lookup.scale-smart" +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index d2462754b..c6ecfcd19 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "app.hpp" @@ -39,6 +40,9 @@ namespace arg = pisa::arg; int main(int argc, char** argv) { + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + pisa::App> app( "Filters out empty queries against a v1 index."); CLI11_PARSE(app, argc, argv); diff --git a/v1/intersection.cpp b/v1/intersection.cpp index fa662b8af..d9ea60982 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -78,9 +78,12 @@ void compute_intersections(Index const& index, auto intersections = nlohmann::json::array(); auto inter = [&](auto&& query, tl::optional> const& mask) { auto intersection = compute_intersection(index, query, mask); - intersections.push_back(nlohmann::json{{"intersection", mask.value_or(0).to_ulong()}, - {"cost", intersection.length}, - {"max_score", intersection.max_score}}); + if (intersection.length > 0) { + intersections.push_back( + nlohmann::json{{"intersection", mask.value_or(0).to_ulong()}, + {"cost", intersection.length}, + {"max_score", intersection.max_score}}); + } }; if (intersection_type == IntersectionType::Combinations) { for_all_subsets(query, max_term_count, inter); From abd480ea2b15e2fb9ff4c04f529c39176f8df953 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sat, 28 Dec 2019 15:11:38 +0000 Subject: [PATCH 39/56] Translation units + WAND --- CMakeLists.txt | 30 +- external/CMakeLists.txt | 8 +- include/pisa/dec_time_prediction.hpp | 8 +- include/pisa/mixed_block.hpp | 2 +- include/pisa/topk_queue.hpp | 3 +- include/pisa/v1/blocked_cursor.hpp | 329 ++++++----- include/pisa/v1/cursor/scoring_cursor.hpp | 69 +++ include/pisa/v1/default_index_runner.hpp | 21 +- include/pisa/v1/document_payload_cursor.hpp | 10 +- include/pisa/v1/index.hpp | 592 ++++++++------------ include/pisa/v1/index_builder.hpp | 1 + include/pisa/v1/index_metadata.hpp | 251 ++++++--- include/pisa/v1/io.hpp | 1 - include/pisa/v1/query.hpp | 6 +- include/pisa/v1/raw_cursor.hpp | 14 + include/pisa/v1/score_index.hpp | 47 ++ include/pisa/v1/sequence_cursor.hpp | 326 +++++++++++ include/pisa/v1/types.hpp | 3 +- include/pisa/v1/unaligned_span.hpp | 1 + include/pisa/v1/wand.hpp | 444 +++++++++++++++ script/cw09b-est.sh | 100 ++++ script/cw09b.sh | 15 +- src/CMakeLists.txt | 10 +- src/v1/blocked_cursor.cpp | 44 ++ src/v1/index.cpp | 170 ++++++ src/v1/index_builder.cpp | 25 +- src/v1/index_metadata.cpp | 13 + src/v1/io.cpp | 6 +- src/v1/query.cpp | 12 +- src/v1/raw_cursor.cpp | 18 + src/v1/score_index.cpp | 59 +- test/v1/index_fixture.hpp | 4 + test/v1/test_v1.cpp | 36 +- test/v1/test_v1_bigram_index.cpp | 8 +- test/v1/test_v1_blocked_cursor.cpp | 48 +- test/v1/test_v1_document_payload_cursor.cpp | 10 - test/v1/test_v1_index.cpp | 19 +- test/v1/test_v1_maxscore_join.cpp | 13 +- test/v1/test_v1_queries.cpp | 43 +- test/v1/test_v1_query.cpp | 6 +- test/v1/test_v1_score_index.cpp | 72 ++- v1/CMakeLists.txt | 3 + v1/bigram_index.cpp | 19 - v1/bmscore.cpp | 25 + v1/compress.cpp | 8 +- v1/filter_queries.cpp | 33 +- v1/intersection.cpp | 11 +- v1/postings.cpp | 15 +- v1/query.cpp | 186 ++++-- v1/score.cpp | 2 +- v1/threshold.cpp | 25 +- 51 files changed, 2298 insertions(+), 926 deletions(-) create mode 100644 include/pisa/v1/sequence_cursor.hpp create mode 100644 include/pisa/v1/wand.hpp create mode 100644 script/cw09b-est.sh create mode 100644 src/v1/blocked_cursor.cpp create mode 100644 src/v1/index.cpp create mode 100644 src/v1/raw_cursor.cpp create mode 100644 v1/bmscore.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f0642da26..791574e24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,9 @@ set(CMAKE_CXX_EXTENSIONS OFF) option(PISA_ENABLE_TESTING "Enable testing of the library." ON) option(PISA_ENABLE_BENCHMARKING "Enable benchmarking of the library." ON) +option(PISA_COMPILE_TOOLS "Compile CLI tools." ON) +option(FORCE_COLORED_OUTPUT "Always produce ANSI-colored output (GNU/Clang only)." ON) +option(PISA_LIBCXX "Use libc++ standard library." OFF) configure_file( ${PISA_SOURCE_DIR}/include/pisa/pisa_config.hpp.in @@ -29,6 +32,7 @@ ExternalProject_Add(gumbo-external SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/gumbo-parser BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/gumbo-parser CONFIGURE_COMMAND ./autogen.sh && ./configure --prefix=${CMAKE_BINARY_DIR}/gumbo-parser + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/gumbo-parser/lib/libgumbo.a BUILD_COMMAND ${MAKE}) add_library(gumbo::gumbo STATIC IMPORTED) set_property(TARGET gumbo::gumbo APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES @@ -45,7 +49,8 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/external/CMake-codecov/cmake" find_package(codecov) list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") - +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-strict-aliasing") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DGSL_UNENFORCED_ON_CONTRACT_VIOLATION") if (UNIX) # For hardware popcount and other special instructions set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") @@ -59,8 +64,21 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb") # Add debug info anyway - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") + + if (${FORCE_COLORED_OUTPUT}) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always") + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics") + endif () + endif () + +endif() +if (PISA_LIBCXX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi") endif() find_package(OpenMP) @@ -106,12 +124,14 @@ target_link_libraries(pisa PUBLIC ) target_include_directories(pisa PUBLIC external) -add_subdirectory(v1) -add_subdirectory(src) +if (PISA_COMPILE_TOOLS) + add_subdirectory(v1) + add_subdirectory(src) +endif() if (PISA_ENABLE_TESTING AND BUILD_TESTING) enable_testing() - add_subdirectory(test) + #add_subdirectory(test) add_subdirectory(test/v1) endif() diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 81069c45a..464a348ad 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -106,8 +106,9 @@ set(TRECPP_BUILD_TOOL OFF CACHE BOOL "skip trecpp testing") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/trecpp) # Add trecpp +set(JSON_MultipleHeaders ON CACHE BOOL "") set(WAPOPP_ENABLE_TESTING OFF CACHE BOOL "skip wapopp testing") -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/wapopp) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/wapopp EXCLUDE_FROM_ALL) # Add fmt add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/fmt) @@ -134,5 +135,6 @@ target_compile_options(rapidcheck PRIVATE -Wno-error=all) # Add json # TODO(michal): I had to comment this out because `wapocpp` already adds this target. # How should we deal with it? -# set(JSON_BuildTests OFF CACHE BOOL "skip building JSON tests") -# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/json) +#set(JSON_MultipleHeaders ON CACHE BOOL "") +#set(JSON_BuildTests OFF CACHE BOOL "skip building JSON tests") +#add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/json) diff --git a/include/pisa/dec_time_prediction.hpp b/include/pisa/dec_time_prediction.hpp index 5e36c4689..530ebd0e8 100644 --- a/include/pisa/dec_time_prediction.hpp +++ b/include/pisa/dec_time_prediction.hpp @@ -21,7 +21,7 @@ namespace pisa { namespace time_prediction { BOOST_PP_SEQ_ENUM(PISA_FEATURE_TYPES), end }; - feature_type parse_feature_type(std::string const& name) + inline feature_type parse_feature_type(std::string const& name) { if (false) { #define LOOP_BODY(R, DATA, T) \ @@ -36,7 +36,7 @@ namespace pisa { namespace time_prediction { } - std::string feature_name(feature_type f) + inline std::string feature_name(feature_type f) { switch (f) { #define LOOP_BODY(R, DATA, T) \ @@ -106,7 +106,7 @@ namespace pisa { namespace time_prediction { float m_bias; }; - void values_statistics(std::vector values, feature_vector& f) + inline void values_statistics(std::vector values, feature_vector& f) { std::sort(values.begin(), values.end()); f[feature_type::n] = values.size(); @@ -143,7 +143,7 @@ namespace pisa { namespace time_prediction { f[feature_type::max_b] = max_b; } - bool read_block_stats(std::istream& is, uint32_t& list_id, std::vector& block_counts) + inline bool read_block_stats(std::istream& is, uint32_t& list_id, std::vector& block_counts) { thread_local std::string line; uint32_t count; diff --git a/include/pisa/mixed_block.hpp b/include/pisa/mixed_block.hpp index b54b6ba48..57f4f33a2 100644 --- a/include/pisa/mixed_block.hpp +++ b/include/pisa/mixed_block.hpp @@ -218,7 +218,7 @@ namespace pisa { typedef std::vector predictors_vec_type; - predictors_vec_type load_predictors(const char* predictors_filename) + inline predictors_vec_type load_predictors(const char* predictors_filename) { std::vector predictors(mixed_block::block_types); diff --git a/include/pisa/topk_queue.hpp b/include/pisa/topk_queue.hpp index 89a754125..00f678066 100644 --- a/include/pisa/topk_queue.hpp +++ b/include/pisa/topk_queue.hpp @@ -26,7 +26,7 @@ struct topk_queue { bool insert(float score, uint64_t docid) { - if (PISA_UNLIKELY(score <= m_threshold)) { + if (PISA_UNLIKELY(score < m_threshold)) { return false; } m_q.emplace_back(score, docid); @@ -66,6 +66,7 @@ struct topk_queue { } [[nodiscard]] uint64_t size() const noexcept { return m_k; } + [[nodiscard]] auto full() const noexcept -> bool { return m_q.size() == m_k; } void set_threshold(float threshold) noexcept { m_threshold = threshold; } diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index 0419d5f2c..db2e3e681 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -11,7 +11,6 @@ #include #include -#include "codec/block_codecs.hpp" #include "util/likely.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor_traits.hpp" @@ -21,195 +20,223 @@ namespace pisa::v1 { -/// Uncompressed example of implementation of a single value cursor. -template -struct BlockedCursor { +/// Non-template base of blocked cursors. +struct BaseBlockedCursor { using value_type = std::uint32_t; + using offset_type = std::uint32_t; + using size_type = std::uint32_t; /// Creates a cursor from the encoded bytes. - explicit constexpr BlockedCursor(gsl::span encoded_blocks, - UnalignedSpan block_endpoints, - UnalignedSpan block_last_values, - std::uint32_t length, - std::uint32_t num_blocks) + BaseBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + size_type length, + size_type num_blocks, + size_type block_length) : m_encoded_blocks(encoded_blocks), m_block_endpoints(block_endpoints), - m_block_last_values(block_last_values), + m_decoded_block(block_length), m_length(length), m_num_blocks(num_blocks), - m_current_block( - {.number = 0, - .offset = 0, - .length = std::min(length, static_cast(Codec::block_size)), - .last_value = m_block_last_values[0]}) + m_block_length(block_length), + m_current_block({.number = 0, + .offset = 0, + .length = std::min(length, static_cast(m_block_length))}) { - static_assert(DeltaEncoded, - "Cannot initialize block_last_values for not delta-encoded list"); - m_decoded_block.resize(Codec::block_size); - reset(); } - /// Creates a cursor from the encoded bytes. - explicit constexpr BlockedCursor(gsl::span encoded_blocks, - UnalignedSpan block_endpoints, - std::uint32_t length, - std::uint32_t num_blocks) - : m_encoded_blocks(encoded_blocks), - m_block_endpoints(block_endpoints), - m_length(length), - m_num_blocks(num_blocks), - m_current_block( - {.number = 0, - .offset = 0, - .length = std::min(length, static_cast(Codec::block_size)), - .last_value = 0}) + BaseBlockedCursor(BaseBlockedCursor const&) = default; + BaseBlockedCursor(BaseBlockedCursor&&) noexcept = default; + BaseBlockedCursor& operator=(BaseBlockedCursor const&) = default; + BaseBlockedCursor& operator=(BaseBlockedCursor&&) noexcept = default; + ~BaseBlockedCursor() = default; + + [[nodiscard]] auto operator*() const -> value_type; + [[nodiscard]] auto value() const noexcept -> value_type; + [[nodiscard]] auto empty() const noexcept -> bool; + [[nodiscard]] auto position() const noexcept -> std::size_t; + [[nodiscard]] auto size() const -> std::size_t; + [[nodiscard]] auto sentinel() const -> value_type; + + protected: + struct Block { + std::uint32_t number = 0; + std::uint32_t offset = 0; + std::uint32_t length = 0; + }; + + [[nodiscard]] auto block_offset(size_type block) const -> offset_type; + [[nodiscard]] auto decoded_block() -> value_type*; + [[nodiscard]] auto decoded_value(size_type n) -> value_type; + [[nodiscard]] auto encoded_block(offset_type offset) -> uint8_t const*; + [[nodiscard]] auto length() const -> size_type; + [[nodiscard]] auto num_blocks() const -> size_type; + [[nodiscard]] auto current_block() -> Block&; + void update_current_value(value_type val); + void increase_current_value(value_type val); + + private: + gsl::span m_encoded_blocks; + UnalignedSpan m_block_endpoints; + std::vector m_decoded_block{}; + size_type m_length; + size_type m_num_blocks; + size_type m_block_length; + Block m_current_block; + + value_type m_current_value{}; +}; + +template +struct GenericBlockedCursor : public BaseBlockedCursor { + using BaseBlockedCursor::offset_type; + using BaseBlockedCursor::size_type; + using BaseBlockedCursor::value_type; + + GenericBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : BaseBlockedCursor(encoded_blocks, block_endpoints, length, num_blocks, Codec::block_size), + m_block_last_values(block_last_values), + m_current_block_last_value(m_block_last_values.empty() ? value_type{} + : m_block_last_values[0]) { - static_assert(not DeltaEncoded, "Must initialize block_last_values for delta-encoded list"); - m_decoded_block.resize(Codec::block_size); reset(); } - constexpr BlockedCursor(BlockedCursor const&) = default; - constexpr BlockedCursor(BlockedCursor&&) noexcept = default; - constexpr BlockedCursor& operator=(BlockedCursor const&) = default; - constexpr BlockedCursor& operator=(BlockedCursor&&) noexcept = default; - ~BlockedCursor() = default; - void reset() { decode_and_update_block(0); } - /// Dereferences the current value. - [[nodiscard]] constexpr auto operator*() const -> value_type { return m_current_value; } - - /// Alias for `operator*()`. - [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } - /// Advances the cursor to the next position. - constexpr void advance() + void advance() { - m_current_block.offset += 1; - if (PISA_UNLIKELY(m_current_block.offset == m_current_block.length)) { - if (m_current_block.number + 1 == m_num_blocks) { - m_current_value = sentinel(); + auto& current_block = this->current_block(); + current_block.offset += 1; + if (PISA_UNLIKELY(current_block.offset == current_block.length)) { + if (current_block.number + 1 == num_blocks()) { + update_current_value(sentinel()); return; } - decode_and_update_block(m_current_block.number + 1); + decode_and_update_block(current_block.number + 1); } else { if constexpr (DeltaEncoded) { - m_current_value += m_decoded_block[m_current_block.offset] + 1U; + increase_current_value(decoded_value(current_block.offset)); } else { - m_current_value = m_decoded_block[m_current_block.offset] + 1U; + update_current_value(decoded_value(current_block.offset)); } } } /// Moves the cursor to the position `pos`. - constexpr void advance_to_position(std::uint32_t pos) + void advance_to_position(std::uint32_t pos) { Expects(pos >= position()); + auto& current_block = this->current_block(); auto block = pos / Codec::block_size; - if (PISA_UNLIKELY(block != m_current_block.number)) { + if (PISA_UNLIKELY(block != current_block.number)) { decode_and_update_block(block); } while (position() < pos) { + current_block.offset += 1; if constexpr (DeltaEncoded) { - m_current_value += m_decoded_block[++m_current_block.offset] + 1U; + increase_current_value(decoded_value(current_block.offset)); } else { - m_current_value = m_decoded_block[++m_current_block.offset] + 1U; - } - } - } - - /// Moves the cursor to the next value equal or greater than `value`. - constexpr void advance_to_geq(value_type value) - { - // static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); - // TODO(michal): This should be `static_assert` like above. But currently, - // it would not compile. What needs to be done is separating document - // and payload readers for the index runner. - assert(DeltaEncoded); - if (PISA_UNLIKELY(value > m_current_block.last_value)) { - if (value > m_block_last_values.back()) { - m_current_value = sentinel(); - return; + update_current_value(decoded_value(current_block.offset)); } - auto block = m_current_block.number + 1U; - while (m_block_last_values[block] < value) { - ++block; - } - decode_and_update_block(block); - } - - while (m_current_value < value) { - m_current_value += m_decoded_block[++m_current_block.offset] + 1U; - Ensures(m_current_block.offset < m_current_block.length); } } - ///// Returns `true` if there is no elements left. - [[nodiscard]] constexpr auto empty() const noexcept -> bool { return position() == m_length; } + protected: + [[nodiscard]] auto& block_last_values() { return m_block_last_values; } + [[nodiscard]] auto& current_block_last_value() { return m_current_block_last_value; } - /// Returns the current position. - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + void decode_and_update_block(size_type block) { - return m_current_block.number * Codec::block_size + m_current_block.offset; - } - - ///// Returns the number of elements in the list. - [[nodiscard]] constexpr auto size() const -> std::size_t { return m_length; } - - /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. - [[nodiscard]] constexpr auto sentinel() const -> value_type - { - return std::numeric_limits::max(); - } - - private: - struct Block { - std::uint32_t number = 0; - std::uint32_t offset = 0; - std::uint32_t length = 0; - value_type last_value = 0; - }; - - void decode_and_update_block(std::uint32_t block) - { - constexpr auto block_size = Codec::block_size; - auto endpoint = block > 0U ? m_block_endpoints[block - 1] : static_cast(0U); - std::uint8_t const* block_data = - std::next(reinterpret_cast(m_encoded_blocks.data()), endpoint); - m_current_block.length = + auto block_size = Codec::block_size; + auto const* block_data = encoded_block(block_offset(block)); + auto& current_block = this->current_block(); + current_block.length = ((block + 1) * block_size <= size()) ? block_size : (size() % block_size); if constexpr (DeltaEncoded) { std::uint32_t first_value = block > 0U ? m_block_last_values[block - 1] + 1U : 0U; - m_current_block.last_value = m_block_last_values[block]; + m_current_block_last_value = m_block_last_values[block]; Codec::decode(block_data, - m_decoded_block.data(), - m_current_block.last_value - first_value - (m_current_block.length - 1), - m_current_block.length); - m_decoded_block[0] += first_value; + decoded_block(), + m_current_block_last_value - first_value - (current_block.length - 1), + current_block.length); + decoded_block()[0] += first_value; } else { Codec::decode(block_data, - m_decoded_block.data(), + decoded_block(), std::numeric_limits::max(), - m_current_block.length); - m_decoded_block[0] += 1; + current_block.length); + decoded_block()[0] += 1; } - m_current_block.number = block; - m_current_block.offset = 0U; - m_current_value = m_decoded_block[0]; + current_block.number = block; + current_block.offset = 0U; + update_current_value(decoded_block()[0]); } - gsl::span m_encoded_blocks; - UnalignedSpan m_block_endpoints; + private: UnalignedSpan m_block_last_values{}; - std::vector m_decoded_block; + value_type m_current_block_last_value{}; +}; - std::uint32_t m_length; - std::uint32_t m_num_blocks; - Block m_current_block{}; - value_type m_current_value{}; +template +struct DocumentBlockedCursor : public GenericBlockedCursor { + using offset_type = typename GenericBlockedCursor::offset_type; + using size_type = typename GenericBlockedCursor::size_type; + using value_type = typename GenericBlockedCursor::value_type; + + DocumentBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : GenericBlockedCursor( + encoded_blocks, block_endpoints, block_last_values, length, num_blocks) + { + } + + /// Moves the cursor to the next value equal or greater than `value`. + void advance_to_geq(value_type value) + { + auto& current_block = this->current_block(); + if (PISA_UNLIKELY(value > this->current_block_last_value())) { + if (value > this->block_last_values().back()) { + this->update_current_value(this->sentinel()); + return; + } + auto block = current_block.number + 1U; + while (this->block_last_values()[block] < value) { + ++block; + } + this->decode_and_update_block(block); + } + + while (this->value() < value) { + this->increase_current_value(this->decoded_value(++current_block.offset)); + Ensures(current_block.offset < current_block.length); + } + } +}; + +template +struct PayloadBlockedCursor : public GenericBlockedCursor { + using offset_type = typename GenericBlockedCursor::offset_type; + using size_type = typename GenericBlockedCursor::size_type; + using value_type = typename GenericBlockedCursor::value_type; + + PayloadBlockedCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + std::uint32_t length, + std::uint32_t num_blocks) + : GenericBlockedCursor( + encoded_blocks, block_endpoints, {}, length, num_blocks) + { + } }; template @@ -223,11 +250,10 @@ constexpr auto block_encoding_type() -> std::uint32_t } template -struct BlockedReader { +struct GenericBlockedReader { using value_type = std::uint32_t; [[nodiscard]] auto read(gsl::span bytes) const - -> BlockedCursor { std::uint32_t length; auto begin = reinterpret_cast(bytes.data()); @@ -245,11 +271,10 @@ struct BlockedReader { auto encoded_blocks = bytes.subspan(length_byte_size + block_last_values.byte_size() + block_endpoints.byte_size()); if constexpr (DeltaEncoded) { - return BlockedCursor( + return DocumentBlockedCursor( encoded_blocks, block_endpoints, block_last_values, length, num_blocks); } else { - return BlockedCursor( - encoded_blocks, block_endpoints, length, num_blocks); + return PayloadBlockedCursor(encoded_blocks, block_endpoints, length, num_blocks); } } @@ -260,8 +285,13 @@ struct BlockedReader { } }; +template +using DocumentBlockedReader = GenericBlockedReader; +template +using PayloadBlockedReader = GenericBlockedReader; + template -struct BlockedWriter { +struct GenericBlockedWriter { using value_type = std::uint32_t; constexpr static auto encoding() -> std::uint32_t @@ -356,10 +386,21 @@ struct BlockedWriter { value_type m_last_value = 0U; }; -template -struct CursorTraits> { - using Writer = BlockedWriter; - using Reader = BlockedReader; +template +using DocumentBlockedWriter = GenericBlockedWriter; +template +using PayloadBlockedWriter = GenericBlockedWriter; + +template +struct CursorTraits> { + using Writer = DocumentBlockedWriter; + using Reader = DocumentBlockedReader; +}; + +template +struct CursorTraits> { + using Writer = PayloadBlockedWriter; + using Reader = PayloadBlockedReader; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index 9e992c9b4..8f7604d83 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -8,11 +8,19 @@ namespace pisa::v1 { +struct ScoringCursorTag { +}; +struct MaxScoreCursorTag { +}; +struct BlockMaxScoreCursorTag { +}; + template struct ScoringCursor { using Document = decltype(*std::declval()); using Payload = decltype((std::declval())(std::declval(), std::declval().payload())); + using Tag = ScoringCursorTag; explicit constexpr ScoringCursor(BaseCursor base_cursor, TermScorer scorer) : m_base_cursor(std::move(base_cursor)), m_scorer(std::move(scorer)) @@ -53,6 +61,7 @@ template struct MaxScoreCursor { using Document = decltype(*std::declval()); using Payload = decltype(std::declval().payload()); + using Tag = MaxScoreCursorTag; constexpr MaxScoreCursor(BaseCursor base_cursor, ScoreT max_score) : m_base_cursor(std::move(base_cursor)), m_max_score(max_score) @@ -87,4 +96,64 @@ struct MaxScoreCursor { float m_max_score; }; +template +struct BlockMaxScoreCursor { + using Document = decltype(*std::declval()); + using Payload = decltype(std::declval().payload()); + using Tag = BlockMaxScoreCursorTag; + + constexpr BlockMaxScoreCursor(BaseScoredCursor base_cursor, + BlockMaxCursor block_max_cursor, + ScoreT max_score) + : m_base_cursor(std::move(base_cursor)), + m_block_max_cursor(block_max_cursor), + m_max_score(max_score) + { + } + constexpr BlockMaxScoreCursor(BlockMaxScoreCursor const&) = default; + constexpr BlockMaxScoreCursor(BlockMaxScoreCursor&&) noexcept = default; + constexpr BlockMaxScoreCursor& operator=(BlockMaxScoreCursor const&) = default; + constexpr BlockMaxScoreCursor& operator=(BlockMaxScoreCursor&&) noexcept = default; + ~BlockMaxScoreCursor() = default; + + [[nodiscard]] constexpr auto operator*() const -> Document { return m_base_cursor.value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Document + { + return m_base_cursor.value(); + } + [[nodiscard]] constexpr auto payload() noexcept { return m_base_cursor.payload(); } + constexpr void advance() { m_base_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_base_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Document value) { m_base_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_base_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_base_cursor.position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_base_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Document { return m_base_cursor.sentinel(); } + [[nodiscard]] constexpr auto max_score() const -> float { return m_max_score; } + [[nodiscard]] constexpr auto block_max_docid() { return m_block_max_cursor.value(); } + [[nodiscard]] constexpr auto block_max_score() { return m_block_max_cursor.payload(); } + [[nodiscard]] constexpr auto block_max_score(DocId docid) + { + m_block_max_cursor.advance_to_geq(docid); + return m_block_max_cursor.payload(); + } + + private: + BaseScoredCursor m_base_cursor; + BlockMaxCursor m_block_max_cursor; + float m_max_score; +}; + +template +[[nodiscard]] auto block_max_score_cursor(BaseScoredCursor base_cursor, + BlockMaxCursor block_max_cursor, + ScoreT max_score) +{ + return BlockMaxScoreCursor( + std::move(base_cursor), std::move(block_max_cursor), max_score); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/default_index_runner.hpp b/include/pisa/v1/default_index_runner.hpp index be24db39b..6fba14cd3 100644 --- a/include/pisa/v1/default_index_runner.hpp +++ b/include/pisa/v1/default_index_runner.hpp @@ -3,14 +3,25 @@ #include "index_types.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index.hpp" +#include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" namespace pisa::v1 { -using DefaultIndexRunner = IndexRunner{}, - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}>; +[[nodiscard]] inline auto index_runner(IndexMetadata metadata) +{ + return index_runner( + std::move(metadata), + std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), + std::make_tuple(RawReader{}, PayloadBlockedReader<::pisa::simdbp_block>{})); +} +[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata) +{ + return scored_index_runner( + std::move(metadata), + std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), + std::make_tuple(RawReader{})); } + +} // namespace pisa::v1 diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index ec06e885c..d00a675fc 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -11,8 +11,7 @@ struct DocumentPayloadCursor { using Document = decltype(*std::declval()); using Payload = decltype(*std::declval()); - explicit constexpr DocumentPayloadCursor(DocumentCursor key_cursor, - PayloadCursor payload_cursor) + constexpr DocumentPayloadCursor(DocumentCursor key_cursor, PayloadCursor payload_cursor) : m_key_cursor(std::move(key_cursor)), m_payload_cursor(std::move(payload_cursor)) { } @@ -47,4 +46,11 @@ struct DocumentPayloadCursor { PayloadCursor m_payload_cursor; }; +template +[[nodiscard]] auto document_payload_cursor(DocumentCursor key_cursor, PayloadCursor payload_cursor) +{ + return DocumentPayloadCursor(std::move(key_cursor), + std::move(payload_cursor)); +} + } // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 23eadc3cc..e7c46c05d 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -2,30 +2,17 @@ #include #include -#include -#include -#include -#include - -#include -#include -#include + #include -#include #include -#include #include -#include "binary_freq_collection.hpp" -#include "payload_vector.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor/scoring_cursor.hpp" -#include "v1/cursor_intersection.hpp" #include "v1/document_payload_cursor.hpp" #include "v1/posting_builder.hpp" #include "v1/raw_cursor.hpp" -#include "v1/scorer/bm25.hpp" #include "v1/source.hpp" #include "v1/types.hpp" #include "v1/zip_cursor.hpp" @@ -35,23 +22,97 @@ namespace pisa::v1 { using OffsetSpan = gsl::span; using BinarySpan = gsl::span; -[[nodiscard]] inline auto calc_avg_length(gsl::span const& lengths) -> float -{ - auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); - return static_cast(sum) / lengths.size(); -} +[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float; +[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector; -[[nodiscard]] inline auto compare_arrays(std::array const& lhs, - std::array const& rhs) -> bool -{ - if (std::get<0>(lhs) < std::get<0>(rhs)) { - return true; - } - if (std::get<0>(lhs) > std::get<0>(rhs)) { - return false; +/// Lexicographically compares bigrams. +/// Used for looking up bigram mappings. +[[nodiscard]] auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool; + +struct PostingData { + BinarySpan postings; + OffsetSpan offsets; +}; + +struct UnigramData { + PostingData documents; + PostingData payloads; +}; + +struct BigramData { + PostingData documents; + std::array payloads; + gsl::span const> mapping; +}; + +/// Parts of the index independent of the template parameters. +struct BaseIndex { + + template + BaseIndex(PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source) + : m_documents(documents), + m_payloads(payloads), + m_bigrams(bigrams), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length.map_or_else( + [](auto&& self) { return self; }, + [&]() { return calc_avg_length(m_document_lengths); })), + m_max_scores(std::move(max_scores)), + m_block_max_scores(std::move(block_max_scores)), + m_quantized_max_scores(quantized_max_scores), + m_source(std::move(source)) + { } - return std::get<1>(lhs) < std::get<1>(rhs); -} + + [[nodiscard]] auto num_terms() const -> std::size_t; + [[nodiscard]] auto num_documents() const -> std::size_t; + [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t; + [[nodiscard]] auto avg_document_length() const -> float; + [[nodiscard]] auto normalized_document_length(DocId docid) const -> float; + [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional; + + protected: + void assert_term_in_bounds(TermId term) const; + [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_documents(TermId bigram) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const + -> std::array, 2>; + template + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const -> gsl::span; + + [[nodiscard]] auto max_score(std::size_t scorer_hash, TermId term) const -> float; + [[nodiscard]] auto block_max_scores(std::size_t scorer_hash) const -> UnigramData const&; + [[nodiscard]] auto quantized_max_score(TermId term) const -> std::uint8_t; + [[nodiscard]] auto block_max_document_reader() const -> Reader> const&; + [[nodiscard]] auto block_max_score_reader() const -> Reader> const&; + + private: + PostingData m_documents; + PostingData m_payloads; + tl::optional m_bigrams; + + gsl::span m_document_lengths; + float m_avg_document_length; + std::unordered_map> m_max_scores; + std::unordered_map m_block_max_scores; + gsl::span m_quantized_max_scores; + std::any m_source; + + Reader> m_block_max_document_reader = + Reader>(RawReader{}); + Reader> m_block_max_score_reader = + Reader>(RawReader{}); +}; /// A generic type for an inverted index. /// @@ -61,7 +122,7 @@ using BinarySpan = gsl::span; /// It can read lists of arbitrary types, such as `Frequency`, /// `Score`, or `std::pair` for a bigram scored index. template -struct Index { +struct Index : public BaseIndex { using document_cursor_type = DocumentCursor; using payload_cursor_type = PayloadCursor; @@ -70,12 +131,7 @@ struct Index { /// /// \param document_reader Reads document posting lists from bytes. /// \param payload_reader Reads payload posting lists from bytes. - /// \param document_offsets Mapping from term ID to the position in memory of its - /// document posting list. - /// \param payload_offsets Mapping from term ID to the position in memory of its - /// payload posting list. - /// \param documents Encoded bytes for document postings. - /// \param payloads Encoded bytes for payload postings. + /// TODO(michal)... /// \param source This object (optionally) owns the raw data pointed at by /// `documents` and `payloads` to ensure it is valid throughout /// the lifetime of the index. It should release any resources @@ -83,41 +139,31 @@ struct Index { template Index(DocumentReader document_reader, PayloadReader payload_reader, - gsl::span document_offsets, - gsl::span payload_offsets, - tl::optional bigram_document_offsets, - tl::optional> bigram_frequency_offsets, - gsl::span documents, - gsl::span payloads, - tl::optional bigram_documents, - tl::optional> bigram_frequencies, + PostingData documents, + PostingData payloads, + tl::optional bigrams, gsl::span document_lengths, tl::optional avg_document_length, std::unordered_map> max_scores, + std::unordered_map block_max_scores, gsl::span quantized_max_scores, - tl::optional const>> bigram_mapping, Source source) - : m_document_reader(std::move(document_reader)), - m_payload_reader(std::move(payload_reader)), - m_document_offsets(document_offsets), - m_payload_offsets(payload_offsets), - m_bigram_document_offsets(bigram_document_offsets), - m_bigram_frequency_offsets(bigram_frequency_offsets), - m_documents(documents), - m_payloads(payloads), - m_bigram_documents(bigram_documents), - m_bigram_frequencies(bigram_frequencies), - m_document_lengths(document_lengths), - m_avg_document_length(avg_document_length.map_or_else( - [](auto&& self) { return self; }, - [&]() { return calc_avg_length(m_document_lengths); })), - m_max_scores(std::move(max_scores)), - m_max_quantized_scores(quantized_max_scores), - m_bigram_mapping(bigram_mapping), - m_source(std::move(source)){} + : BaseIndex(documents, + payloads, + bigrams, + document_lengths, + avg_document_length, + std::move(max_scores), + std::move(block_max_scores), + quantized_max_scores, + source), + m_document_reader(std::move(document_reader)), + m_payload_reader(std::move(payload_reader)) + { + } - /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). - [[nodiscard]] auto cursor(TermId term) const + /// Constructs a new document-payload cursor (see document_payload_cursor.hpp). + [[nodiscard]] auto cursor(TermId term) const { return DocumentPayloadCursor(documents(term), payloads(term)); @@ -131,27 +177,6 @@ struct Index { }); return cursors; } - [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional - { - if (not m_bigram_mapping) { - throw std::logic_error("Bigrams are missing"); - } - if (right_term == left_term) { - throw std::logic_error("Requested bigram of two identical terms"); - } - auto bigram = std::array{left_term, right_term}; - if (right_term < left_term) { - std::swap(bigram[0], bigram[1]); - } - if (auto pos = std::lower_bound( - m_bigram_mapping->begin(), m_bigram_mapping->end(), bigram, compare_arrays); - pos != m_bigram_mapping->end()) { - if (*pos == bigram) { - return tl::make_optional(std::distance(m_bigram_mapping->begin(), pos)); - } - } - return tl::nullopt; - } [[nodiscard]] auto bigram_payloads_0(TermId left_term, TermId right_term) const { @@ -170,7 +195,7 @@ struct Index { [[nodiscard]] auto bigram_cursor(TermId left_term, TermId right_term) const { return bigram_id(left_term, right_term).map([this](auto bid) { - return DocumentPayloadCursor>( + return document_payload_cursor( m_document_reader.read(fetch_bigram_documents(bid)), zip(m_payload_reader.read(fetch_bigram_payloads<0>(bid)), m_payload_reader.read(fetch_bigram_payloads<1>(bid)))); @@ -222,18 +247,12 @@ struct Index { using cursor_type = std::decay_t(scorer)))>; if constexpr (std::is_convertible_v) { - if (m_max_quantized_scores.empty()) { - throw std::logic_error("Missing quantized max scores."); - } return MaxScoreCursor( - scored_cursor(term, std::forward(scorer)), m_max_quantized_scores[term]); + scored_cursor(term, std::forward(scorer)), quantized_max_score(term)); } else { - if (m_max_scores.empty()) { - throw std::logic_error("Missing max scores."); - } return MaxScoreCursor( scored_cursor(term, std::forward(scorer)), - m_max_scores.at(std::hash>{}(scorer))[term]); + max_score(std::hash>{}(scorer), term)); } } @@ -247,6 +266,49 @@ struct Index { return cursors; } + template + [[nodiscard]] auto block_max_scored_cursor(TermId term, Scorer&& scorer) const + { + auto const& document_reader = block_max_document_reader(); + auto const& score_reader = block_max_score_reader(); + // Expects(term + 1 < m_payloads.offsets.size()); + if constexpr (std::is_convertible_v) { + if (false) { // TODO(michal): Workaround for now to avoid explicitly defining return + // type. + return block_max_score_cursor( + scored_cursor(term, std::forward(scorer)), + document_payload_cursor(document_reader.read({}), score_reader.read({})), + 0.0F); + } + throw std::logic_error("Quantized block max scores uimplemented"); + } else { + auto const& data = block_max_scores(std::hash>{}(scorer)); + auto block_max_document_subspan = data.documents.postings.subspan( + data.documents.offsets[term], + data.documents.offsets[term + 1] - data.documents.offsets[term]); + auto block_max_score_subspan = data.payloads.postings.subspan( + data.payloads.offsets[term], + data.payloads.offsets[term + 1] - data.payloads.offsets[term]); + return block_max_score_cursor( + scored_cursor(term, std::forward(scorer)), + document_payload_cursor(document_reader.read(block_max_document_subspan), + score_reader.read(block_max_score_subspan)), + max_score(std::hash>{}(scorer), term)); + } + } + + template + [[nodiscard]] auto block_max_scored_cursors(gsl::span terms, + Scorer&& scorer) const + { + std::vector(0, std::forward(scorer)))> + cursors; + std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto term) { + return block_max_scored_cursor(term, scorer); + }); + return cursors; + } + /// Constructs a new document-score cursor. template [[nodiscard]] auto scoring_bigram_cursor(TermId left_term, @@ -293,109 +355,28 @@ struct Index { return m_payload_reader.read(fetch_payloads(term)); } - /// Constructs a new payload cursor. - [[nodiscard]] auto num_terms() const -> std::size_t { return m_document_offsets.size() - 1; } - - [[nodiscard]] auto num_documents() const -> std::size_t { return m_document_lengths.size(); } - [[nodiscard]] auto term_posting_count(TermId term) const -> std::uint32_t { // TODO(michal): Should be done more efficiently. return documents(term).size(); } - [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t - { - return m_document_lengths[docid]; - } - - [[nodiscard]] auto avg_document_length() const -> float { return m_avg_document_length; } - - [[nodiscard]] auto normalized_document_length(DocId docid) const -> float - { - return document_length(docid) / avg_document_length(); - } - private: - void assert_term_in_bounds(TermId term) const - { - if (term >= num_terms()) { - std::invalid_argument( - fmt::format("Requested term ID out of bounds [0-{}): {}", num_terms(), term)); - } - } - [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span - { - Expects(term + 1 < m_document_offsets.size()); - return m_documents.subspan(m_document_offsets[term], - m_document_offsets[term + 1] - m_document_offsets[term]); - } - [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span - { - Expects(term + 1 < m_payload_offsets.size()); - return m_payloads.subspan(m_payload_offsets[term], - m_payload_offsets[term + 1] - m_payload_offsets[term]); - } - [[nodiscard]] auto fetch_bigram_documents(TermId term) const -> gsl::span - { - if (not m_bigram_documents) { - throw std::logic_error("Bigrams are missing"); - } - Expects(term + 1 < m_bigram_document_offsets->size()); - return m_bigram_documents->subspan( - (*m_bigram_document_offsets)[term], - (*m_bigram_document_offsets)[term + 1] - (*m_bigram_document_offsets)[term]); - } - template - [[nodiscard]] auto fetch_bigram_payloads(TermId term) const -> gsl::span - { - if (not m_bigram_frequencies) { - throw std::logic_error("Bigrams are missing"); - } - Expects(term + 1 < std::get(*m_bigram_frequency_offsets).size()); - return std::get(*m_bigram_frequencies) - .subspan(std::get(*m_bigram_frequency_offsets)[term], - std::get(*m_bigram_frequency_offsets)[term + 1] - - std::get(*m_bigram_frequency_offsets)[term]); - } - Reader m_document_reader; Reader m_payload_reader; - - OffsetSpan m_document_offsets; - OffsetSpan m_payload_offsets; - tl::optional m_bigram_document_offsets{}; - tl::optional> m_bigram_frequency_offsets{}; - - BinarySpan m_documents; - BinarySpan m_payloads; - tl::optional m_bigram_documents{}; - tl::optional> m_bigram_frequencies{}; - - gsl::span m_document_lengths; - float m_avg_document_length; - std::unordered_map> m_max_scores; - gsl::span m_max_quantized_scores; - tl::optional const>> m_bigram_mapping; - std::any m_source; }; template auto make_index(DocumentReader document_reader, PayloadReader payload_reader, - OffsetSpan document_offsets, - OffsetSpan payload_offsets, - tl::optional bigram_document_offsets, - tl::optional> bigram_frequency_offsets, - BinarySpan documents, - BinarySpan payloads, - tl::optional bigram_documents, - tl::optional> bigram_frequencies, + PostingData documents, + PostingData payloads, + tl::optional bigrams, gsl::span document_lengths, tl::optional avg_document_length, std::unordered_map> max_scores, + std::unordered_map block_max_scores, gsl::span quantized_max_scores, - tl::optional const>> bigram_mapping, Source source) { using DocumentCursor = @@ -403,250 +384,123 @@ auto make_index(DocumentReader document_reader, using PayloadCursor = decltype(payload_reader.read(std::declval>())); return Index(std::move(document_reader), std::move(payload_reader), - document_offsets, - payload_offsets, - bigram_document_offsets, - bigram_frequency_offsets, documents, payloads, - bigram_documents, - bigram_frequencies, + bigrams, document_lengths, avg_document_length, - max_scores, + std::move(max_scores), + std::move(block_max_scores), quantized_max_scores, - bigram_mapping, std::move(source)); } -template -auto score_index(Index const& index, - std::basic_ostream& os, - Writer writer, - Scorer scorer, - Quantizer&& quantizer, - Callback&& callback) -> std::vector -{ - PostingBuilder score_builder(writer); - score_builder.write_header(os); - std::for_each(boost::counting_iterator(0), - boost::counting_iterator(index.num_terms()), - [&](auto term) { - for_each(index.scoring_cursor(term, scorer), [&](auto&& cursor) { - score_builder.accumulate(quantizer(cursor.payload())); - }); - score_builder.flush_segment(os); - callback(); - }); - return std::move(score_builder.offsets()); -} - -template -auto score_index(Index const& index, std::basic_ostream& os, Writer writer, Scorer scorer) - -> std::vector -{ - PostingBuilder score_builder(writer); - score_builder.write_header(os); - std::for_each(boost::counting_iterator(0), - boost::counting_iterator(index.num_terms()), - [&](auto term) { - for_each(index.scoring_cursor(term, scorer), - [&](auto& cursor) { score_builder.accumulate(cursor.payload()); }); - score_builder.flush_segment(os); - }); - return std::move(score_builder.offsets()); -} - -/// Initializes a memory mapped source with a given file. -inline void open_source(mio::mmap_source& source, std::string const& filename) -{ - std::error_code error; - source.map(filename, error); - if (error) { - spdlog::error("Error mapping file {}: {}", filename, error.message()); - throw std::runtime_error("Error mapping file"); - } -} - -inline auto read_sizes(std::string_view basename) -{ - binary_collection sizes(fmt::format("{}.sizes", basename).c_str()); - auto sequence = *sizes.begin(); - return std::vector(sequence.begin(), sequence.end()); -} - -[[nodiscard]] inline auto binary_collection_source(std::string const& basename) -{ - using sink_type = boost::iostreams::back_insert_device>; - using vector_stream_type = boost::iostreams::stream; - - binary_freq_collection collection(basename.c_str()); - VectorSource source{{{}, {}}, {{}, {}}, {read_sizes(basename)}}; - std::vector& docbuf = source.bytes[0]; - std::vector& freqbuf = source.bytes[1]; - - PostingBuilder document_builder(Writer(RawWriter{})); - PostingBuilder frequency_builder(Writer(RawWriter{})); - { - vector_stream_type docstream{sink_type{docbuf}}; - vector_stream_type freqstream{sink_type{freqbuf}}; - - document_builder.write_header(docstream); - frequency_builder.write_header(freqstream); - - for (auto sequence : collection) { - document_builder.write_segment(docstream, sequence.docs.begin(), sequence.docs.end()); - frequency_builder.write_segment( - freqstream, sequence.freqs.begin(), sequence.freqs.end()); - } - } - - source.offsets[0] = std::move(document_builder.offsets()); - source.offsets[1] = std::move(frequency_builder.offsets()); - - return source; -} - -template +template struct IndexRunner { template - IndexRunner(gsl::span document_offsets, - gsl::span payload_offsets, - tl::optional bigram_document_offsets, - tl::optional> bigram_frequency_offsets, - gsl::span documents, - gsl::span payloads, - tl::optional bigram_documents, - tl::optional> bigram_frequencies, + IndexRunner(PostingData documents, + PostingData payloads, + tl::optional bigrams, gsl::span document_lengths, tl::optional avg_document_length, std::unordered_map> max_scores, + std::unordered_map block_max_scores, gsl::span quantized_max_scores, - tl::optional> const> bigram_mapping, Source source, - Readers... readers) - : m_document_offsets(document_offsets), - m_payload_offsets(payload_offsets), - m_bigram_document_offsets(bigram_document_offsets), - m_bigram_frequency_offsets(bigram_frequency_offsets), - m_documents(documents), + DocumentReaders document_readers, + PayloadReaders payload_readers) + : m_documents(documents), m_payloads(payloads), - m_bigram_documents(bigram_documents), - m_bigram_frequencies(bigram_frequencies), + m_bigrams(bigrams), m_document_lengths(document_lengths), m_avg_document_length(avg_document_length), m_max_scores(std::move(max_scores)), + m_block_max_scores(std::move(block_max_scores)), m_max_quantized_scores(quantized_max_scores), - m_bigram_mapping(bigram_mapping), m_source(std::move(source)), - m_readers(readers...) - { - } - template - IndexRunner(gsl::span document_offsets, - gsl::span payload_offsets, - tl::optional bigram_document_offsets, - tl::optional> bigram_frequency_offsets, - gsl::span documents, - gsl::span payloads, - tl::optional bigram_documents, - tl::optional> bigram_frequencies, - gsl::span document_lengths, - tl::optional avg_document_length, - std::unordered_map> max_scores, - gsl::span quantized_max_scores, - tl::optional const>> bigram_mapping, - Source source, - std::tuple readers) - : m_document_offsets(document_offsets), - m_payload_offsets(payload_offsets), - m_bigram_document_offsets(bigram_document_offsets), - m_bigram_frequency_offsets(bigram_frequency_offsets), - m_documents(documents), - m_payloads(payloads), - m_bigram_documents(bigram_documents), - m_bigram_frequencies(bigram_frequencies), - m_document_lengths(document_lengths), - m_avg_document_length(avg_document_length), - m_max_scores(std::move(max_scores)), - m_max_quantized_scores(quantized_max_scores), - m_bigram_mapping(bigram_mapping), - m_source(std::move(source)), - m_readers(std::move(readers)) + m_document_readers(std::move(document_readers)), + m_payload_readers(std::move(payload_readers)) { } template auto operator()(Fn fn) { - auto dheader = PostingFormatHeader::parse(m_documents.first(8)); - auto pheader = PostingFormatHeader::parse(m_payloads.first(8)); + auto dheader = PostingFormatHeader::parse(m_documents.postings.first(8)); + auto pheader = PostingFormatHeader::parse(m_payloads.postings.first(8)); auto run = [&](auto&& dreader, auto&& preader) { if (std::decay_t::encoding() == dheader.encoding && std::decay_t::encoding() == pheader.encoding && is_type::value_type>(dheader.type) && is_type::value_type>(pheader.type)) { - fn(make_index(std::forward(dreader), - std::forward(preader), - m_document_offsets, - m_payload_offsets, - m_bigram_document_offsets, - m_bigram_frequency_offsets, - m_documents.subspan(8), - m_payloads.subspan(8), - m_bigram_documents.map([](auto&& bytes) { return bytes.subspan(8); }), - m_bigram_frequencies.map([](auto&& bytes) { - return std::array{std::get<0>(bytes).subspan(8), - std::get<1>(bytes).subspan(8)}; - }), - m_document_lengths, - m_avg_document_length, - m_max_scores, - m_max_quantized_scores, - m_bigram_mapping, - false)); + auto block_max_scores = m_block_max_scores; + for (auto& [key, data] : block_max_scores) { + data.documents.postings = data.documents.postings.subspan(8); + data.payloads.postings = data.payloads.postings.subspan(8); + } + fn(make_index( + std::forward(dreader), + std::forward(preader), + PostingData{.postings = m_documents.postings.subspan(8), + .offsets = m_documents.offsets}, + PostingData{.postings = m_payloads.postings.subspan(8), + .offsets = m_payloads.offsets}, + m_bigrams.map([](auto&& bigram_data) { + return BigramData{ + .documents = + PostingData{.postings = bigram_data.documents.postings.subspan(8), + .offsets = bigram_data.documents.offsets}, + .payloads = + std::array{ + PostingData{ + .postings = + std::get<0>(bigram_data.payloads).postings.subspan(8), + .offsets = std::get<0>(bigram_data.payloads).offsets}, + PostingData{ + .postings = + std::get<1>(bigram_data.payloads).postings.subspan(8), + .offsets = std::get<1>(bigram_data.payloads).offsets}}, + .mapping = bigram_data.mapping}; + }), + m_document_lengths, + m_avg_document_length, + m_max_scores, + block_max_scores, + m_max_quantized_scores, + false)); return true; } return false; }; auto result = std::apply( - [&](Readers... dreaders) { + [&](auto... dreaders) { auto with_document_reader = [&](auto dreader) { return std::apply( - [&](Readers... preaders) { return (run(dreader, preaders) || ...); }, - m_readers); + [&](auto... preaders) { return (run(dreader, preaders) || ...); }, + m_payload_readers); }; return (with_document_reader(dreaders) || ...); }, - m_readers); + m_document_readers); if (not result) { throw std::domain_error("Unknown posting encoding"); } } private: - gsl::span m_document_offsets; - gsl::span m_payload_offsets; - tl::optional m_bigram_document_offsets{}; - tl::optional> m_bigram_frequency_offsets{}; - - gsl::span m_documents; - gsl::span m_payloads; - tl::optional m_bigram_documents{}; - tl::optional> m_bigram_frequencies{}; + PostingData m_documents; + PostingData m_payloads; + tl::optional m_bigrams; gsl::span m_document_lengths; tl::optional m_avg_document_length; std::unordered_map> m_max_scores; + std::unordered_map m_block_max_scores; gsl::span m_max_quantized_scores; tl::optional const>> m_bigram_mapping; std::any m_source; - std::tuple m_readers; + DocumentReaders m_document_readers; + PayloadReaders m_payload_readers; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 8d29d9fd2..153040574 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -8,6 +8,7 @@ #include #include +#include "binary_freq_collection.hpp" #include "v1/index.hpp" #include "v1/index_metadata.hpp" #include "v1/progress_status.hpp" diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index baf5efe7f..acb08ce70 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "v1/index.hpp" #include "v1/query.hpp" @@ -41,6 +42,11 @@ struct PostingFilePaths { std::string offsets; }; +struct UnigramFilePaths { + PostingFilePaths documents; + PostingFilePaths payloads; +}; + struct BigramMetadata { PostingFilePaths documents; std::pair frequencies; @@ -49,7 +55,7 @@ struct BigramMetadata { std::size_t count; }; -struct IndexMetadata { +struct IndexMetadata final { tl::optional basename{}; PostingFilePaths documents; PostingFilePaths frequencies; @@ -61,6 +67,7 @@ struct IndexMetadata { tl::optional stemmer{}; tl::optional bigrams{}; std::map max_scores{}; + std::map block_max_scores{}; std::map quantized_max_scores{}; void write(std::string const& file) const; @@ -84,14 +91,10 @@ template source.file_sources.emplace_back(std::make_shared(file)).get()); }; -template -[[nodiscard]] inline auto index_runner(IndexMetadata metadata, Readers... readers) -{ - return index_runner(std::move(metadata), std::make_tuple(readers...)); -} - -template -[[nodiscard]] inline auto index_runner(IndexMetadata metadata, std::tuple readers) +template +[[nodiscard]] inline auto index_runner(IndexMetadata metadata, + DocumentReaders document_readers, + PayloadReaders payload_readers) { MMapSource source; auto documents = source_span(source, metadata.documents.postings); @@ -99,26 +102,38 @@ template auto document_offsets = source_span(source, metadata.documents.offsets); auto frequency_offsets = source_span(source, metadata.frequencies.offsets); auto document_lengths = source_span(source, metadata.document_lengths_path); - tl::optional> bigram_document_offsets{}; - tl::optional, 2>> bigram_frequency_offsets{}; - tl::optional> bigram_documents{}; - tl::optional, 2>> bigram_frequencies{}; - tl::optional const>> bigram_mapping{}; - if (metadata.bigrams) { - bigram_document_offsets = - source_span(source, metadata.bigrams->documents.offsets); - bigram_frequency_offsets = { - source_span(source, metadata.bigrams->frequencies.first.offsets), - source_span(source, metadata.bigrams->frequencies.second.offsets)}; - bigram_documents = source_span(source, metadata.bigrams->documents.postings); - bigram_frequencies = { - source_span(source, metadata.bigrams->frequencies.first.postings), - source_span(source, metadata.bigrams->frequencies.second.postings)}; - auto mapping_span = source_span(source, metadata.bigrams->mapping); - bigram_mapping = gsl::span const>( - reinterpret_cast const*>(mapping_span.data()), - mapping_span.size() / (sizeof(TermId) * 2)); - } + auto bigrams = [&]() -> tl::optional { + gsl::span bigram_document_offsets{}; + std::array, 2> bigram_frequency_offsets{}; + gsl::span bigram_documents{}; + std::array, 2> bigram_frequencies{}; + gsl::span const> bigram_mapping{}; + if (metadata.bigrams) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_frequency_offsets = { + source_span(source, metadata.bigrams->frequencies.first.offsets), + source_span(source, metadata.bigrams->frequencies.second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_frequencies = { + source_span(source, metadata.bigrams->frequencies.first.postings), + source_span(source, metadata.bigrams->frequencies.second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + return BigramData{ + .documents = {.postings = bigram_documents, .offsets = bigram_document_offsets}, + .payloads = + std::array{ + PostingData{.postings = std::get<0>(bigram_frequencies), + .offsets = std::get<0>(bigram_frequency_offsets)}, + PostingData{.postings = std::get<1>(bigram_frequencies), + .offsets = std::get<1>(bigram_frequency_offsets)}}, + .mapping = bigram_mapping}; + } + return tl::nullopt; + }(); std::unordered_map> max_scores; if (not metadata.max_scores.empty()) { for (auto [name, file] : metadata.max_scores) { @@ -127,32 +142,36 @@ template reinterpret_cast(bytes.data()), bytes.size() / (sizeof(float))); } } - return IndexRunner(document_offsets, - frequency_offsets, - bigram_document_offsets, - bigram_frequency_offsets, - documents, - frequencies, - bigram_documents, - bigram_frequencies, - document_lengths, - tl::make_optional(metadata.avg_document_length), - std::move(max_scores), - {}, - bigram_mapping, - std::move(source), - std::move(readers)); -} - -template -[[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, Readers... readers) -{ - return scored_index_runner(std::move(metadata), std::make_tuple(readers...)); -} + std::unordered_map block_max_scores; + if (not metadata.block_max_scores.empty()) { + for (auto [name, files] : metadata.block_max_scores) { + auto document_bytes = source_span(source, files.documents.postings); + auto document_offsets = source_span(source, files.documents.offsets); + auto payload_bytes = source_span(source, files.payloads.postings); + auto payload_offsets = source_span(source, files.payloads.offsets); + block_max_scores[std::hash{}(name)] = + UnigramData{.documents = {.postings = document_bytes, .offsets = document_offsets}, + .payloads = {.postings = payload_bytes, .offsets = payload_offsets}}; + } + } + return IndexRunner( + {.postings = documents, .offsets = document_offsets}, + {.postings = frequencies, .offsets = frequency_offsets}, + bigrams, + document_lengths, + tl::make_optional(metadata.avg_document_length), + std::move(max_scores), + std::move(block_max_scores), + {}, + std::move(source), + std::move(document_readers), + std::move(payload_readers)); +} // namespace pisa::v1 -template +template [[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, - std::tuple readers) + DocumentReaders document_readers, + PayloadReaders payload_readers) { MMapSource source; auto documents = source_span(source, metadata.documents.postings); @@ -161,26 +180,38 @@ template auto document_offsets = source_span(source, metadata.documents.offsets); auto score_offsets = source_span(source, metadata.scores.front().offsets); auto document_lengths = source_span(source, metadata.document_lengths_path); - tl::optional> bigram_document_offsets{}; - tl::optional, 2>> bigram_score_offsets{}; - tl::optional> bigram_documents{}; - tl::optional, 2>> bigram_scores{}; - tl::optional const>> bigram_mapping{}; - if (metadata.bigrams && not metadata.bigrams->scores.empty()) { - bigram_document_offsets = - source_span(source, metadata.bigrams->documents.offsets); - bigram_score_offsets = { - source_span(source, metadata.bigrams->scores[0].first.offsets), - source_span(source, metadata.bigrams->scores[0].second.offsets)}; - bigram_documents = source_span(source, metadata.bigrams->documents.postings); - bigram_scores = { - source_span(source, metadata.bigrams->scores[0].first.postings), - source_span(source, metadata.bigrams->scores[0].second.postings)}; - auto mapping_span = source_span(source, metadata.bigrams->mapping); - bigram_mapping = gsl::span const>( - reinterpret_cast const*>(mapping_span.data()), - mapping_span.size() / (sizeof(TermId) * 2)); - } + auto bigrams = [&]() -> tl::optional { + gsl::span bigram_document_offsets{}; + std::array, 2> bigram_score_offsets{}; + gsl::span bigram_documents{}; + std::array, 2> bigram_scores{}; + gsl::span const> bigram_mapping{}; + if (metadata.bigrams && not metadata.bigrams->scores.empty()) { + bigram_document_offsets = + source_span(source, metadata.bigrams->documents.offsets); + bigram_score_offsets = { + source_span(source, metadata.bigrams->scores[0].first.offsets), + source_span(source, metadata.bigrams->scores[0].second.offsets)}; + bigram_documents = source_span(source, metadata.bigrams->documents.postings); + bigram_scores = { + source_span(source, metadata.bigrams->scores[0].first.postings), + source_span(source, metadata.bigrams->scores[0].second.postings)}; + auto mapping_span = source_span(source, metadata.bigrams->mapping); + bigram_mapping = gsl::span const>( + reinterpret_cast const*>(mapping_span.data()), + mapping_span.size() / (sizeof(TermId) * 2)); + return BigramData{ + .documents = {.postings = bigram_documents, .offsets = bigram_document_offsets}, + .payloads = + std::array{ + PostingData{.postings = std::get<0>(bigram_scores), + .offsets = std::get<0>(bigram_score_offsets)}, + PostingData{.postings = std::get<1>(bigram_scores), + .offsets = std::get<0>(bigram_score_offsets)}}, + .mapping = bigram_mapping}; + } + return tl::nullopt; + }(); gsl::span quantized_max_scores; if (not metadata.quantized_max_scores.empty()) { // TODO(michal): support many precomputed scores @@ -188,21 +219,65 @@ template quantized_max_scores = source_span(source, file); } } - return IndexRunner(document_offsets, - score_offsets, - bigram_document_offsets, - bigram_score_offsets, - documents, - scores, - bigram_documents, - bigram_scores, - document_lengths, - tl::make_optional(metadata.avg_document_length), - {}, - quantized_max_scores, - bigram_mapping, - std::move(source), - std::move(readers)); + return IndexRunner( + {.postings = documents, .offsets = document_offsets}, + {.postings = scores, .offsets = score_offsets}, + bigrams, + document_lengths, + tl::make_optional(metadata.avg_document_length), + {}, + {}, + quantized_max_scores, + std::move(source), + std::move(document_readers), + std::move(payload_readers)); } } // namespace pisa::v1 + +namespace YAML { +template <> +struct convert<::pisa::v1::PostingFilePaths> { + static Node encode(const ::pisa::v1::PostingFilePaths& rhs) + { + Node node; + node["file"] = rhs.postings; + node["offsets"] = rhs.offsets; + return node; + } + + static bool decode(const Node& node, ::pisa::v1::PostingFilePaths& rhs) + { + if (!node.IsMap()) { + return false; + } + + rhs.postings = node["file"].as(); + rhs.offsets = node["offsets"].as(); + return true; + } +}; + +template <> +struct convert<::pisa::v1::UnigramFilePaths> { + static Node encode(const ::pisa::v1::UnigramFilePaths& rhs) + { + Node node; + node["documents"] = convert<::pisa::v1::PostingFilePaths>::encode(rhs.documents); + node["payloads"] = convert<::pisa::v1::PostingFilePaths>::encode(rhs.payloads); + return node; + } + + static bool decode(const Node& node, ::pisa::v1::UnigramFilePaths& rhs) + { + if (!node.IsMap()) { + return false; + } + + rhs.documents = node["documents"].as<::pisa::v1::PostingFilePaths>(); + rhs.payloads = node["payloads"].as<::pisa::v1::PostingFilePaths>(); + return true; + } +}; + +} // namespace YAML diff --git a/include/pisa/v1/io.hpp b/include/pisa/v1/io.hpp index 1409714d7..c413ef324 100644 --- a/include/pisa/v1/io.hpp +++ b/include/pisa/v1/io.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 6d0665a16..14fd92dcc 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -2,10 +2,10 @@ #include #include -#include +//#include #include -#include +#include #include #include #include @@ -100,7 +100,7 @@ struct Query { void add_selections(gsl::span const> selections); [[nodiscard]] auto filtered_terms(std::bitset<64> selection) const -> std::vector; - [[nodiscard]] auto to_json() const -> nlohmann::json; + [[nodiscard]] auto to_json() const -> std::unique_ptr; [[nodiscard]] static auto from_json(std::string_view) -> Query; [[nodiscard]] static auto from_plain(std::string_view) -> Query; diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index b23229ccd..d18e18b1e 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -152,4 +153,17 @@ struct CursorTraits> { constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } }; +extern template struct RawCursor; +extern template struct RawCursor; +extern template struct RawCursor; +extern template struct RawReader; +extern template struct RawReader; +extern template struct RawReader; +extern template struct RawWriter; +extern template struct RawWriter; +extern template struct RawWriter; +extern template struct CursorTraits>; +extern template struct CursorTraits>; +extern template struct CursorTraits>; + } // namespace pisa::v1 diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp index 43f131ae5..d81d412e8 100644 --- a/include/pisa/v1/score_index.hpp +++ b/include/pisa/v1/score_index.hpp @@ -2,10 +2,57 @@ #include +#include + #include "v1/index_metadata.hpp" namespace pisa::v1 { +template +auto score_index(Index const& index, + std::basic_ostream& os, + Writer writer, + Scorer scorer, + Quantizer&& quantizer, + Callback&& callback) -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), [&](auto&& cursor) { + score_builder.accumulate(quantizer(cursor.payload())); + }); + score_builder.flush_segment(os); + callback(); + }); + return std::move(score_builder.offsets()); +} + +template +auto score_index(Index const& index, std::basic_ostream& os, Writer writer, Scorer scorer) + -> std::vector +{ + PostingBuilder score_builder(writer); + score_builder.write_header(os); + std::for_each(boost::counting_iterator(0), + boost::counting_iterator(index.num_terms()), + [&](auto term) { + for_each(index.scoring_cursor(term, scorer), + [&](auto& cursor) { score_builder.accumulate(cursor.payload()); }); + score_builder.flush_segment(os); + }); + return std::move(score_builder.offsets()); +} + auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata; +auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t threads) + -> IndexMetadata; } // namespace pisa::v1 diff --git a/include/pisa/v1/sequence_cursor.hpp b/include/pisa/v1/sequence_cursor.hpp new file mode 100644 index 000000000..bb8f12d70 --- /dev/null +++ b/include/pisa/v1/sequence_cursor.hpp @@ -0,0 +1,326 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "codec/block_codecs.hpp" +#include "util/likely.hpp" +#include "v1/bit_cast.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/encoding_traits.hpp" +#include "v1/types.hpp" +#include "v1/unaligned_span.hpp" + +namespace pisa::v1 { + +/// Uncompressed example of implementation of a single value cursor. +template +struct SequenceCursor { + using value_type = std::uint32_t; + + /// Creates a cursor from the encoded bytes. + explicit constexpr SequenceCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + UnalignedSpan block_last_values, + std::uint32_t length, + std::uint32_t num_blocks) + : m_encoded_blocks(encoded_blocks), + m_block_endpoints(block_endpoints), + m_block_last_values(block_last_values), + m_length(length), + m_num_blocks(num_blocks), + m_current_block( + {.number = 0, + .offset = 0, + .length = std::min(length, static_cast(Codec::block_size)), + .last_value = m_block_last_values[0]}) + { + static_assert(DeltaEncoded, + "Cannot initialize block_last_values for not delta-encoded list"); + m_decoded_block.resize(Codec::block_size); + reset(); + } + + /// Creates a cursor from the encoded bytes. + explicit constexpr SequenceCursor(gsl::span encoded_blocks, + UnalignedSpan block_endpoints, + std::uint32_t length, + std::uint32_t num_blocks) + : m_encoded_blocks(encoded_blocks), + m_block_endpoints(block_endpoints), + m_length(length), + m_num_blocks(num_blocks), + m_current_block( + {.number = 0, + .offset = 0, + .length = std::min(length, static_cast(Codec::block_size)), + .last_value = 0}) + { + static_assert(not DeltaEncoded, "Must initialize block_last_values for delta-encoded list"); + m_decoded_block.resize(Codec::block_size); + reset(); + } + + constexpr SequenceCursor(SequenceCursor const&) = default; + constexpr SequenceCursor(SequenceCursor&&) noexcept = default; + constexpr SequenceCursor& operator=(SequenceCursor const&) = default; + constexpr SequenceCursor& operator=(SequenceCursor&&) noexcept = default; + ~SequenceCursor() = default; + + void reset() { decode_and_update_block(0); } + + /// Dereferences the current value. + [[nodiscard]] constexpr auto operator*() const -> value_type { return m_current_value; } + + /// Alias for `operator*()`. + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } + + /// Advances the cursor to the next position. + constexpr void advance() + { + m_current_block.offset += 1; + if (PISA_UNLIKELY(m_current_block.offset == m_current_block.length)) { + if (m_current_block.number + 1 == m_num_blocks) { + m_current_value = sentinel(); + return; + } + decode_and_update_block(m_current_block.number + 1); + } else { + if constexpr (DeltaEncoded) { + m_current_value += m_decoded_block[m_current_block.offset] + 1U; + } else { + m_current_value = m_decoded_block[m_current_block.offset] + 1U; + } + } + } + + /// Moves the cursor to the position `pos`. + constexpr void advance_to_position(std::uint32_t pos) + { + Expects(pos >= position()); + auto block = pos / Codec::block_size; + if (PISA_UNLIKELY(block != m_current_block.number)) { + decode_and_update_block(block); + } + while (position() < pos) { + if constexpr (DeltaEncoded) { + m_current_value += m_decoded_block[++m_current_block.offset] + 1U; + } else { + m_current_value = m_decoded_block[++m_current_block.offset] + 1U; + } + } + } + + /// Moves the cursor to the next value equal or greater than `value`. + constexpr void advance_to_geq(value_type value) + { + // static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); + // TODO(michal): This should be `static_assert` like above. But currently, + // it would not compile. What needs to be done is separating document + // and payload readers for the index runner. + assert(DeltaEncoded); + if (PISA_UNLIKELY(value > m_current_block.last_value)) { + if (value > m_block_last_values.back()) { + m_current_value = sentinel(); + return; + } + auto block = m_current_block.number + 1U; + while (m_block_last_values[block] < value) { + ++block; + } + decode_and_update_block(block); + } + + while (m_current_value < value) { + m_current_value += m_decoded_block[++m_current_block.offset] + 1U; + Ensures(m_current_block.offset < m_current_block.length); + } + } + + ///// Returns `true` if there is no elements left. + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return position() == m_length; } + + /// Returns the current position. + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_current_block.number * Codec::block_size + m_current_block.offset; + } + + ///// Returns the number of elements in the list. + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_length; } + + /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return std::numeric_limits::max(); + } + + private: + + //gsl::span m_encoded_blocks; + //UnalignedSpan m_block_endpoints; + //UnalignedSpan m_block_last_values{}; + //std::vector m_decoded_block; + + //std::uint32_t m_length; + //std::uint32_t m_num_blocks; + //Block m_current_block{}; + //value_type m_current_value{}; +}; + +// template +// constexpr auto block_encoding_type() -> std::uint32_t +//{ +// if constexpr (DeltaEncoded) { +// return EncodingId::BlockDelta; +// } else { +// return EncodingId::Block; +// } +//} + +template +struct SequenceReader { + using value_type = std::uint32_t; + + [[nodiscard]] auto read(gsl::span bytes) const -> SequenceCursor + { + // std::uint32_t length; + // auto begin = reinterpret_cast(bytes.data()); + // auto after_length_ptr = pisa::TightVariableByte::decode(begin, &length, 1); + // auto length_byte_size = std::distance(begin, after_length_ptr); + // auto num_blocks = ceil_div(length, Codec::block_size); + // UnalignedSpan block_last_values; + // if constexpr (DeltaEncoded) { + // block_last_values = UnalignedSpan( + // bytes.subspan(length_byte_size, num_blocks * sizeof(value_type))); + //} + // UnalignedSpan block_endpoints( + // bytes.subspan(length_byte_size + block_last_values.byte_size(), + // (num_blocks - 1) * sizeof(value_type))); + // auto encoded_blocks = bytes.subspan(length_byte_size + block_last_values.byte_size() + // + block_endpoints.byte_size()); + // if constexpr (DeltaEncoded) { + // return SequenceCursor( + // encoded_blocks, block_endpoints, block_last_values, length, num_blocks); + //} else { + // return SequenceCursor( + // encoded_blocks, block_endpoints, length, num_blocks); + //} + } + + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::Sequence | encoding_traits::encoding_tag::encoding(); + } +}; + +template +struct SequenceWriter { + using value_type = std::uint32_t; + + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::Sequence | encoding_traits::encoding_tag::encoding(); + } + + void push(value_type const& posting) + { + // if constexpr (DeltaEncoded) { + // if (posting < m_last_value) { + // throw std::runtime_error( + // fmt::format("Delta-encoded sequences must be monotonic, but {} < {}", + // posting, + // m_last_value)); + // } + //} + // m_postings.push_back(posting); + // m_last_value = posting; + } + + template + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t + { + // std::vector buffer; + // std::uint32_t length = m_postings.size(); + // TightVariableByte::encode_single(length, buffer); + // auto block_size = Codec::block_size; + // auto num_blocks = ceil_div(length, block_size); + // auto begin_block_maxs = buffer.size(); + // auto begin_block_endpoints = [&]() { + // if constexpr (DeltaEncoded) { + // return begin_block_maxs + 4U * num_blocks; + // } else { + // return begin_block_maxs; + // } + //}(); + // auto begin_blocks = begin_block_endpoints + 4U * (num_blocks - 1); + // buffer.resize(begin_blocks); + + // auto iter = m_postings.begin(); + // std::vector block_buffer(block_size); + // std::uint32_t last_value(-1); + // std::uint32_t block_base = 0; + // for (auto block = 0; block < num_blocks; ++block) { + // auto current_block_size = + // ((block + 1) * block_size <= length) ? block_size : (length % block_size); + + // std::for_each(block_buffer.begin(), + // std::next(block_buffer.begin(), current_block_size), + // [&](auto&& elem) { + // if constexpr (DeltaEncoded) { + // auto value = *iter++; + // elem = value - (last_value + 1); + // last_value = value; + // } else { + // elem = *iter++ - 1; + // } + // }); + + // if constexpr (DeltaEncoded) { + // std::memcpy( + // &buffer[begin_block_maxs + 4U * block], &last_value, sizeof(last_value)); + // auto size = buffer.size(); + // Codec::encode(block_buffer.data(), + // last_value - block_base - (current_block_size - 1), + // current_block_size, + // buffer); + // } else { + // Codec::encode(block_buffer.data(), std::uint32_t(-1), current_block_size, buffer); + // } + // if (block != num_blocks - 1) { + // std::size_t endpoint = buffer.size() - begin_blocks; + // std::memcpy( + // &buffer[begin_block_endpoints + 4U * block], &endpoint, sizeof(last_value)); + // } + // block_base = last_value + 1; + //} + // os.write(reinterpret_cast(buffer.data()), buffer.size()); + // return buffer.size(); + } + + void reset() + { + m_postings.clear(); + m_last_value = 0; + } + + private: + std::vector m_postings{}; + value_type m_last_value = 0U; +}; + +template +struct CursorTraits> { + using Writer = SequenceWriter; + using Reader = SequenceReader; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 92a06bc69..4b0881626 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -18,9 +18,10 @@ using Result = std::pair; using ByteOStream = std::basic_ostream; enum EncodingId { - Raw = 0xda43, + Raw = 0xDA43, BlockDelta = 0xEF00, Block = 0xFF00, + Sequence = 0xDF00, SimdBP = 0x0001, Varbyte = 0x0002 }; diff --git a/include/pisa/v1/unaligned_span.hpp b/include/pisa/v1/unaligned_span.hpp index d7f200d84..2f57b3c97 100644 --- a/include/pisa/v1/unaligned_span.hpp +++ b/include/pisa/v1/unaligned_span.hpp @@ -142,6 +142,7 @@ struct UnalignedSpan { [[nodiscard]] auto size() const -> std::size_t { return m_bytes.size() / sizeof(value_type); } [[nodiscard]] auto byte_size() const -> std::size_t { return m_bytes.size(); } [[nodiscard]] auto bytes() const -> gsl::span { return m_bytes; } + [[nodiscard]] auto empty() const -> bool { return m_bytes.empty(); } private: gsl::span m_bytes{}; diff --git a/include/pisa/v1/wand.hpp b/include/pisa/v1/wand.hpp new file mode 100644 index 000000000..603633b0d --- /dev/null +++ b/include/pisa/v1/wand.hpp @@ -0,0 +1,444 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "util/compiler_attribute.hpp" +#include "v1/algorithm.hpp" +#include "v1/cursor/scoring_cursor.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/query.hpp" + +namespace pisa::v1 { + +template +struct BlockMaxWandJoin; + +template +struct WandJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + friend BlockMaxWandJoin; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr WandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_cursors(std::move(cursors)), + m_cursor_pointers(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt) + { + initialize(); + } + + constexpr WandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) + : m_cursors(std::move(cursors)), + m_cursor_pointers(m_cursors.size()), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_size(std::nullopt), + m_inspect(inspect) + { + initialize(); + } + + void initialize() + { + if (m_cursors.empty()) { + m_current_value = sentinel(); + m_current_payload = m_init; + } + std::transform(m_cursors.begin(), + m_cursors.end(), + m_cursor_pointers.begin(), + [](auto&& cursor) { return &cursor; }); + + std::sort(m_cursor_pointers.begin(), m_cursor_pointers.end(), [](auto&& lhs, auto&& rhs) { + return lhs->value() < rhs->value(); + }); + + m_next_docid = min_value(m_cursors); + m_sentinel = min_sentinel(m_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + PISA_ALWAYSINLINE void advance() + { + while (true) { + auto pivot = find_pivot(); + if (pivot == m_cursor_pointers.end()) { + m_current_value = sentinel(); + return; + } + + auto pivot_docid = (*pivot)->value(); + if (pivot_docid == m_cursor_pointers.front()->value()) { + m_current_value = pivot_docid; + m_current_payload = m_init; + + [&]() { + auto iter = m_cursor_pointers.begin(); + for (; iter != m_cursor_pointers.end(); ++iter) { + auto* cursor = *iter; + if (cursor->value() != pivot_docid) { + break; + } + m_current_payload = m_accumulate(m_current_payload, *cursor); + cursor->advance(); + } + return iter; + }(); + + auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); }; + std::sort(m_cursor_pointers.begin(), m_cursor_pointers.end(), by_docid); + return; + } + + auto next_list = std::distance(m_cursor_pointers.begin(), pivot); + for (; m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { + } + m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + bubble_down(next_list); + } + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + PISA_ALWAYSINLINE void bubble_down(std::size_t list_idx) + { + for (size_t idx = list_idx + 1; idx < m_cursor_pointers.size(); idx += 1) { + if (m_cursor_pointers[idx]->value() < m_cursor_pointers[idx - 1]->value()) { + std::swap(m_cursor_pointers[idx], m_cursor_pointers[idx - 1]); + } else { + break; + } + } + } + + PISA_ALWAYSINLINE auto find_pivot() + { + auto upper_bound = 0.0F; + for (auto pivot = m_cursor_pointers.begin(); pivot != m_cursor_pointers.end(); ++pivot) { + auto&& cursor = **pivot; + if (cursor.empty()) { + break; + } + upper_bound += cursor.max_score(); + if (m_above_threshold(upper_bound)) { + auto pivot_docid = (*pivot)->value(); + while (std::next(pivot) != m_cursor_pointers.end()) { + if ((*std::next(pivot))->value() != pivot_docid) { + break; + } + pivot = std::next(pivot); + } + return pivot; + } + } + return m_cursor_pointers.end(); + } + + CursorContainer m_cursors; + std::vector m_cursor_pointers; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + std::optional m_size; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + payload_type m_previous_threshold{}; + + Inspect* m_inspect; +}; + +template +struct BlockMaxWandJoin { + using cursor_type = typename CursorContainer::value_type; + using payload_type = Payload; + using value_type = std::decay_t())>; + + using iterator_category = + typename std::iterator_traits::iterator_category; + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + constexpr BlockMaxWandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold) + : m_wand_join(std::move(cursors), init, std::move(accumulate), std::move(above_threshold)) + { + } + + constexpr BlockMaxWandJoin(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect) + : WandJoin( + std::move(cursors), init, std::move(accumulate), std::move(above_threshold), inspect) + { + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_wand_join.value(); + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type + { + return m_wand_join.value(); + } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_wand_join.payload(); + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t + { + return m_wand_join.sentinel(); + } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_wand_join.empty(); } + + constexpr void advance() + { + while (true) { + auto pivot = m_wand_join.find_pivot(); + if (pivot == m_wand_join.m_cursor_pointers.end()) { + m_wand_join.m_current_value = m_wand_join.sentinel(); + return; + } + + auto pivot_docid = (*pivot)->value(); + // auto block_upper_bound = std::accumulate( + // m_wand_join.m_cursor_pointers.begin(), + // std::next(pivot), + // 0.0F, + // [&](auto acc, auto* cursor) { return acc + cursor->block_max_score(pivot_docid); + // }); + // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // block_max_advance(pivot, pivot_docid); + // continue; + //} + if (pivot_docid == m_wand_join.m_cursor_pointers.front()->value()) { + m_wand_join.m_current_value = pivot_docid; + m_wand_join.m_current_payload = m_wand_join.m_init; + + [&]() { + auto iter = m_wand_join.m_cursor_pointers.begin(); + for (; iter != m_wand_join.m_cursor_pointers.end(); ++iter) { + auto* cursor = *iter; + if (cursor->value() != pivot_docid) { + break; + } + m_wand_join.m_current_payload = + m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + cursor->advance(); + } + return iter; + }(); + // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // if (cursor->value() != pivot_docid) { + // break; + // } + // m_wand_join.m_current_payload = + // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + // block_upper_bound -= cursor->block_max_score() - cursor->payload(); + // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // break; + // } + //} + + // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // if (cursor->value() != pivot_docid) { + // break; + // } + // cursor->advance(); + //} + + auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); }; + std::sort(m_wand_join.m_cursor_pointers.begin(), + m_wand_join.m_cursor_pointers.end(), + by_docid); + return; + } + + auto next_list = std::distance(m_wand_join.m_cursor_pointers.begin(), pivot); + for (; m_wand_join.m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { + } + m_wand_join.m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + m_wand_join.bubble_down(next_list); + } + } + + private: + template + void block_max_advance(Iter pivot, DocId pivot_id) + { + auto next_list = std::max_element( + m_wand_join.m_cursor_pointers.begin(), std::next(pivot), [](auto* lhs, auto* rhs) { + return lhs->max_score() < rhs->max_score(); + }); + + auto next_docid = + (*std::min_element(m_wand_join.m_cursor_pointers.begin(), + std::next(pivot), + [](auto* lhs, auto* rhs) { + return lhs->block_max_docid() < rhs->block_max_docid(); + })) + ->value(); + next_docid += 1; + + if (auto iter = std::next(pivot); iter != m_wand_join.m_cursor_pointers.end()) { + if (auto docid = (*iter)->value(); docid < next_docid) { + next_docid = docid; + } + } + + if (next_docid <= pivot_id) { + next_docid = pivot_id + 1; + } + + (*next_list)->advance_to_geq(next_docid); + m_wand_join.bubble_down(std::distance(m_wand_join.m_cursor_pointers.begin(), next_list)); + } + + WandJoin m_wand_join; +}; + +template +auto join_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return WandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto join_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect) +{ + return WandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); +} + +template +auto join_block_max_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold) +{ + return BlockMaxWandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold)); +} + +template +auto join_block_max_wand(CursorContainer cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect) +{ + return BlockMaxWandJoin( + std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); +} + +template +auto wand(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); + } + auto joined = join_wand(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + return topk.would_enter(score); + }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +template +auto bmw(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +{ + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto cursors = index.block_max_scored_cursors(gsl::make_span(term_ids), scorer); + if (query.threshold()) { + topk.set_threshold(*query.threshold()); + } + auto joined = join_block_max_wand(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + return topk.would_enter(score); + }); + v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + return topk; +} + +} // namespace pisa::v1 diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh new file mode 100644 index 000000000..8739d3ee5 --- /dev/null +++ b/script/cw09b-est.sh @@ -0,0 +1,100 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +BASENAME="/data/michal/work/v1/cw09b/cw09b" +THREADS=4 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est" +#OUTPUT_DIR="/data/michal/intersect/cw09b-est-top20" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" + +set -e +set -x + +## Compress an inverted index in `binary_freq_collection` format. +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 + +# Filter out queries witout existing terms. +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} + +# Extract intersections +${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand > ${OUTPUT_DIR}/bench.wand +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw > ${OUTPUT_DIR}/bench.bmw +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw \ + > ${OUTPUT_DIR}/bench.bmw-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/script/cw09b.sh b/script/cw09b.sh index 444baa445..630b2391c 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -44,19 +44,20 @@ ${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 # Run benchmarks -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark \ + --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ > ${OUTPUT_DIR}/bench.maxscore-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ > ${OUTPUT_DIR}/bench.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ > ${OUTPUT_DIR}/bench.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ > ${OUTPUT_DIR}/bench.union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --benchmark --algorithm lookup-union \ + --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 # Analyze diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 842f15449..c1cdcacd8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,11 +32,11 @@ target_link_libraries(queries # CLI11 #) -add_executable(thresholds thresholds.cpp) -target_link_libraries(thresholds - pisa - CLI11 -) +#add_executable(thresholds thresholds.cpp) +#target_link_libraries(thresholds +# pisa +# CLI11 +#) #add_executable(profile_queries profile_queries.cpp) #target_link_libraries(profile_queries diff --git a/src/v1/blocked_cursor.cpp b/src/v1/blocked_cursor.cpp new file mode 100644 index 000000000..3e2bbd8c6 --- /dev/null +++ b/src/v1/blocked_cursor.cpp @@ -0,0 +1,44 @@ +#include "v1/blocked_cursor.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto BaseBlockedCursor::block_offset(size_type block) const -> offset_type +{ + return block > 0U ? m_block_endpoints[block - 1] : static_cast(0U); +} +[[nodiscard]] auto BaseBlockedCursor::decoded_block() -> value_type* +{ + return m_decoded_block.data(); +} +[[nodiscard]] auto BaseBlockedCursor::encoded_block(offset_type offset) -> uint8_t const* +{ + return std::next(reinterpret_cast(m_encoded_blocks.data()), offset); +} +[[nodiscard]] auto BaseBlockedCursor::length() const -> size_type { return m_length; } +[[nodiscard]] auto BaseBlockedCursor::num_blocks() const -> size_type { return m_num_blocks; } + +[[nodiscard]] auto BaseBlockedCursor::operator*() const -> value_type { return m_current_value; } +[[nodiscard]] auto BaseBlockedCursor::value() const noexcept -> value_type { return *(*this); } +[[nodiscard]] auto BaseBlockedCursor::empty() const noexcept -> bool +{ + return value() >= sentinel(); +} +[[nodiscard]] auto BaseBlockedCursor::position() const noexcept -> std::size_t +{ + return m_current_block.number * m_block_length + m_current_block.offset; +} +[[nodiscard]] auto BaseBlockedCursor::size() const -> std::size_t { return m_length; } +[[nodiscard]] auto BaseBlockedCursor::sentinel() const -> value_type +{ + return std::numeric_limits::max(); +} +[[nodiscard]] auto BaseBlockedCursor::current_block() -> Block& { return m_current_block; } +[[nodiscard]] auto BaseBlockedCursor::decoded_value(size_type n) -> value_type +{ + return m_decoded_block[n] + 1U; +} + +void BaseBlockedCursor::update_current_value(value_type val) { m_current_value = val; } +void BaseBlockedCursor::increase_current_value(value_type val) { m_current_value += val; } + +} // namespace pisa::v1 diff --git a/src/v1/index.cpp b/src/v1/index.cpp new file mode 100644 index 000000000..fe538056d --- /dev/null +++ b/src/v1/index.cpp @@ -0,0 +1,170 @@ +#include + +#include + +#include "binary_collection.hpp" +#include "v1/index.hpp" + +namespace pisa::v1 { + +[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float +{ + auto sum = std::accumulate(lengths.begin(), lengths.end(), std::uint64_t(0), std::plus{}); + return static_cast(sum) / lengths.size(); +} + +[[nodiscard]] auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool +{ + if (std::get<0>(lhs) < std::get<0>(rhs)) { + return true; + } + if (std::get<0>(lhs) > std::get<0>(rhs)) { + return false; + } + return std::get<1>(lhs) < std::get<1>(rhs); +} + +[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector +{ + binary_collection sizes(fmt::format("{}.sizes", basename).c_str()); + auto sequence = *sizes.begin(); + return std::vector(sequence.begin(), sequence.end()); +} + +[[nodiscard]] auto BaseIndex::num_terms() const -> std::size_t +{ + return m_documents.offsets.size() - 1; +} + +[[nodiscard]] auto BaseIndex::num_documents() const -> std::size_t +{ + return m_document_lengths.size(); +} + +[[nodiscard]] auto BaseIndex::document_length(DocId docid) const -> std::uint32_t +{ + return m_document_lengths[docid]; +} + +[[nodiscard]] auto BaseIndex::avg_document_length() const -> float { return m_avg_document_length; } + +[[nodiscard]] auto BaseIndex::normalized_document_length(DocId docid) const -> float +{ + return document_length(docid) / avg_document_length(); +} + +void BaseIndex::assert_term_in_bounds(TermId term) const +{ + if (term >= num_terms()) { + throw std::invalid_argument( + fmt::format("Requested term ID out of bounds [0-{}): {}", num_terms(), term)); + } +} +[[nodiscard]] auto BaseIndex::fetch_documents(TermId term) const -> gsl::span +{ + Expects(term + 1 < m_documents.offsets.size()); + return m_documents.postings.subspan(m_documents.offsets[term], + m_documents.offsets[term + 1] - m_documents.offsets[term]); +} +[[nodiscard]] auto BaseIndex::fetch_payloads(TermId term) const -> gsl::span +{ + Expects(term + 1 < m_payloads.offsets.size()); + return m_payloads.postings.subspan(m_payloads.offsets[term], + m_payloads.offsets[term + 1] - m_payloads.offsets[term]); +} +[[nodiscard]] auto BaseIndex::fetch_bigram_documents(TermId bigram) const + -> gsl::span +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + Expects(bigram + 1 < m_bigrams->documents.offsets.size()); + return m_bigrams->documents.postings.subspan( + m_bigrams->documents.offsets[bigram], + m_bigrams->documents.offsets[bigram + 1] - m_bigrams->documents.offsets[bigram]); +} + +template +[[nodiscard]] auto BaseIndex::fetch_bigram_payloads(TermId bigram) const + -> gsl::span +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + Expects(bigram + 1 < std::get(m_bigrams->payloads).offsets.size()); + return std::get(m_bigrams->payloads) + .postings.subspan(std::get(m_bigrams->payloads).offsets[bigram], + std::get(m_bigrams->payloads).offsets[bigram + 1] + - std::get(m_bigrams->payloads).offsets[bigram]); +} + +template auto BaseIndex::fetch_bigram_payloads<0>(TermId bigram) const + -> gsl::span; +template auto BaseIndex::fetch_bigram_payloads<1>(TermId bigram) const + -> gsl::span; + +[[nodiscard]] auto BaseIndex::fetch_bigram_payloads(TermId bigram) const + -> std::array, 2> +{ + return {fetch_bigram_payloads<0>(bigram), fetch_bigram_payloads<1>(bigram)}; +} + +[[nodiscard]] auto BaseIndex::bigram_id(TermId left_term, TermId right_term) const + -> tl::optional +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + if (right_term == left_term) { + throw std::logic_error("Requested bigram of two identical terms"); + } + auto bigram = std::array{left_term, right_term}; + if (right_term < left_term) { + std::swap(bigram[0], bigram[1]); + } + if (auto pos = std::lower_bound( + m_bigrams->mapping.begin(), m_bigrams->mapping.end(), bigram, compare_arrays); + pos != m_bigrams->mapping.end()) { + if (*pos == bigram) { + return tl::make_optional(std::distance(m_bigrams->mapping.begin(), pos)); + } + } + return tl::nullopt; +} + +[[nodiscard]] auto BaseIndex::max_score(std::size_t scorer_hash, TermId term) const -> float +{ + if (m_max_scores.empty()) { + throw std::logic_error("Missing max scores."); + } + return m_max_scores.at(scorer_hash)[term]; +} + +[[nodiscard]] auto BaseIndex::block_max_scores(std::size_t scorer_hash) const -> UnigramData const& +{ + if (auto pos = m_block_max_scores.find(scorer_hash); pos != m_block_max_scores.end()) { + return pos->second; + } + throw std::logic_error("Missing block-max scores."); +} + +[[nodiscard]] auto BaseIndex::quantized_max_score(TermId term) const -> std::uint8_t +{ + if (m_quantized_max_scores.empty()) { + throw std::logic_error("Missing quantized max scores."); + } + return m_quantized_max_scores.at(term); +} + +[[nodiscard]] auto BaseIndex::block_max_document_reader() const -> Reader> const& +{ + return m_block_max_document_reader; +} + +[[nodiscard]] auto BaseIndex::block_max_score_reader() const -> Reader> const& +{ + return m_block_max_score_reader; +} + +} // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index d702fc7c2..cc7aad831 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -1,6 +1,7 @@ #include "v1/index_builder.hpp" #include "codec/simdbp.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/default_index_runner.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -12,7 +13,7 @@ auto collect_unique_bigrams(std::vector const& queries, { std::vector> bigrams; auto idx = 0; - for (auto query : queries) { + for (auto const& query : queries) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { continue; @@ -35,10 +36,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) std::vector errors; pisa::binary_freq_collection const collection(input.c_str()); auto meta = IndexMetadata::from_file(fmt::format("{}.yml", output)); - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = index_runner(meta); ProgressStatus status( collection.size(), DefaultProgress("Verifying"), std::chrono::milliseconds(100)); run([&](auto&& index) { @@ -65,7 +63,11 @@ auto verify_compressed_index(std::string const& input, std::string_view output) } if (cursor.payload() != *fit) { errors.push_back( - fmt::format("Frequency mismatch for term {} at position {}", term, pos)); + fmt::format("Frequency mismatch for term {} at position {}: {} != {}", + term, + pos, + cursor.payload(), + *fit)); } cursor.advance(); ++dit; @@ -83,11 +85,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) std::vector> const& bigrams) -> std::pair { - auto run = scored_index_runner(std::move(meta), - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = scored_index_runner(std::move(meta)); std::vector> pair_mapping; auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); @@ -148,10 +146,7 @@ auto build_bigram_index(IndexMetadata meta, std::vector{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = index_runner(meta); std::vector> pair_mapping; auto documents_file = fmt::format("{}.bigram_documents", index_basename); diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 0a8693c74..1fda06cbf 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -19,6 +19,7 @@ constexpr char const* LEXICON = "lexicon"; constexpr char const* TERMS = "terms"; constexpr char const* BIGRAM = "bigram"; constexpr char const* MAX_SCORES = "max_scores"; +constexpr char const* BLOCK_MAX_SCORES = "block_max_scores"; constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; [[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string @@ -95,6 +96,13 @@ constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; } return std::map{}; }(), + .block_max_scores = + [&]() { + if (config[BLOCK_MAX_SCORES]) { + return config[BLOCK_MAX_SCORES].as>(); + } + return std::map{}; + }(), .quantized_max_scores = [&]() { if (config[QUANTIZED_MAX_SCORES]) { @@ -147,6 +155,11 @@ void IndexMetadata::write(std::string const& file) const root[MAX_SCORES][key] = value; } } + if (not block_max_scores.empty()) { + for (auto [key, value] : block_max_scores) { + root[BLOCK_MAX_SCORES][key] = value; + } + } if (not quantized_max_scores.empty()) { for (auto [key, value] : quantized_max_scores) { root[QUANTIZED_MAX_SCORES][key] = value; diff --git a/src/v1/io.cpp b/src/v1/io.cpp index aec0b7492..82c2db350 100644 --- a/src/v1/io.cpp +++ b/src/v1/io.cpp @@ -1,8 +1,10 @@ +#include + #include "v1/io.hpp" namespace pisa::v1 { -[[nodiscard]] auto load_bytes(std::string const &data_file) -> std::vector +[[nodiscard]] auto load_bytes(std::string const& data_file) -> std::vector { std::vector data; std::basic_ifstream in(data_file.c_str(), std::ios::binary); @@ -10,7 +12,7 @@ namespace pisa::v1 { std::streamsize size = in.tellg(); in.seekg(0, std::ios::beg); data.resize(size); - if (not in.read(reinterpret_cast(data.data()), size)) { + if (not in.read(reinterpret_cast(data.data()), size)) { throw std::runtime_error("Failed reading " + data_file); } return data; diff --git a/src/v1/query.cpp b/src/v1/query.cpp index 00e601804..d7fe741ca 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -123,20 +123,20 @@ template } } -[[nodiscard]] auto Query::to_json() const -> nlohmann::json +[[nodiscard]] auto Query::to_json() const -> std::unique_ptr { - json query; + auto query = std::make_unique(); if (m_id) { - query["id"] = *m_id; + (*query)["id"] = *m_id; } if (m_raw_string) { - query["query"] = *m_raw_string; + (*query)["query"] = *m_raw_string; } if (m_term_ids) { - query["term_ids"] = m_term_ids->get(); + (*query)["term_ids"] = m_term_ids->get(); } if (m_threshold) { - query["threshold"] = *m_threshold; + (*query)["threshold"] = *m_threshold; } // TODO(michal) // tl::optional m_selections{}; diff --git a/src/v1/raw_cursor.cpp b/src/v1/raw_cursor.cpp new file mode 100644 index 000000000..0e2dfaf04 --- /dev/null +++ b/src/v1/raw_cursor.cpp @@ -0,0 +1,18 @@ +#include "v1/raw_cursor.hpp" + +namespace pisa::v1 { + +template struct RawCursor; +template struct RawCursor; +template struct RawCursor; +template struct RawReader; +template struct RawReader; +template struct RawReader; +template struct RawWriter; +template struct RawWriter; +template struct RawWriter; +template struct CursorTraits>; +template struct CursorTraits>; +template struct CursorTraits>; + +} // namespace pisa::v1 diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp index b3fb44468..6b202a16e 100644 --- a/src/v1/score_index.cpp +++ b/src/v1/score_index.cpp @@ -4,13 +4,14 @@ #include "codec/simdbp.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/default_index_runner.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" #include "v1/progress_status.hpp" #include "v1/raw_cursor.hpp" #include "v1/score_index.hpp" +#include "v1/scorer/bm25.hpp" -using pisa::v1::BlockedReader; using pisa::v1::DefaultProgress; using pisa::v1::IndexMetadata; using pisa::v1::PostingFilePaths; @@ -24,10 +25,7 @@ namespace pisa::v1 { auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata { - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = index_runner(meta); auto const& index_basename = meta.get_basename(); auto postings_path = fmt::format("{}.bm25", index_basename); auto offsets_path = fmt::format("{}.bm25_offsets", index_basename); @@ -92,4 +90,55 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata return meta; } +// TODO: Use multiple threads +auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t threads) + -> IndexMetadata +{ + auto run = index_runner(meta); + auto const& index_basename = meta.get_basename(); + auto prefix = fmt::format("{}.bm25_block_max", meta.get_basename()); + UnigramFilePaths paths{ + .documents = PostingFilePaths{.postings = fmt::format("{}_documents", prefix), + .offsets = fmt::format("{}_document_offsets", prefix)}, + .payloads = PostingFilePaths{.postings = fmt::format("{}.scores", prefix), + .offsets = fmt::format("{}.score_offsets", prefix)}, + }; + run([&](auto&& index) { + auto scorer = make_bm25(index); + ProgressStatus status(index.num_terms(), + DefaultProgress("Calculating max-blocks"), + std::chrono::milliseconds(100)); + std::ofstream document_out(paths.documents.postings); + std::ofstream score_out(paths.payloads.postings); + PostingBuilder document_builder(RawWriter{}); + PostingBuilder score_builder(RawWriter{}); + document_builder.write_header(document_out); + score_builder.write_header(score_out); + for (TermId term_id = 0; term_id < index.num_terms(); term_id += 1) { + auto cursor = index.scored_cursor(term_id, scorer); + while (not cursor.empty()) { + auto max_score = 0.0F; + auto last_docid = 0; + for (auto idx = 0; idx < block_size && not cursor.empty(); ++idx) { + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + last_docid = cursor.value(); + cursor.advance(); + } + document_builder.accumulate(last_docid); + score_builder.accumulate(max_score); + } + document_builder.flush_segment(document_out); + score_builder.flush_segment(score_out); + status += 1; + } + write_span(gsl::make_span(document_builder.offsets()), paths.documents.offsets); + write_span(gsl::make_span(score_builder.offsets()), paths.payloads.offsets); + }); + meta.block_max_scores["bm25"] = paths; + meta.update(); + return meta; +} + } // namespace pisa::v1 diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 64ceb6926..f381089eb 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -58,9 +58,13 @@ struct IndexFixture { v1::make_writer()); auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", index_basename); + for (auto&& error : errors) { + std::cerr << error << '\n'; + } REQUIRE(errors.empty()); auto yml = fmt::format("{}.yml", index_basename); auto meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); + meta = v1::bm_score_index(meta, 5, 1); v1::build_bigram_index(meta, collect_unique_bigrams(test_queries(), []() {})); } diff --git a/test/v1/test_v1.cpp b/test/v1/test_v1.cpp index 6728dd0a3..48e0cc9f9 100644 --- a/test/v1/test_v1.cpp +++ b/test/v1/test_v1.cpp @@ -9,6 +9,7 @@ #include #include +#include "binary_freq_collection.hpp" #include "io.hpp" #include "pisa_config.hpp" #include "v1/algorithm.hpp" @@ -30,6 +31,7 @@ using pisa::v1::load_bytes; using pisa::v1::next; using pisa::v1::parse_type; using pisa::v1::PostingBuilder; +using pisa::v1::PostingData; using pisa::v1::PostingFormatHeader; using pisa::v1::Primitive; using pisa::v1::RawReader; @@ -228,23 +230,20 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); - IndexRunner runner(document_offsets, - frequency_offsets, - {}, - {}, - document_span, - payload_span, - {}, + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, {}, document_sizes, tl::nullopt, {}, {}, - tl::nullopt, + {}, std::move(source), - RawReader{}, - RawReader{}); // Repeat to test that it only - // executes once + std::make_tuple(RawReader{}, + RawReader{}), // Repeat to test + // that it only + // executes once + std::make_tuple(RawReader{})); int counter = 0; runner([&](auto index) { counter += 1; @@ -271,21 +270,18 @@ TEST_CASE("Build raw document-frequency index", "[v1][unit]") reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); - IndexRunner runner(document_offsets, - frequency_offsets, - {}, - {}, - document_span, - payload_span, - {}, + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, {}, document_sizes, tl::nullopt, {}, {}, - tl::nullopt, + {}, std::move(source), - RawReader{}); // Correct encoding but not type! + std::make_tuple(RawReader{}), // Correct encoding but not + // type! + std::make_tuple()); REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); } } diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp index bbe762a54..aaa430e12 100644 --- a/test/v1/test_v1_bigram_index.cpp +++ b/test/v1/test_v1_bigram_index.cpp @@ -44,8 +44,8 @@ TEMPLATE_TEST_CASE("Bigram v intersection", (IndexFixture, v1::RawCursor, v1::RawCursor>), - (IndexFixture, - v1::BlockedCursor<::pisa::simdbp_block, false>, + (IndexFixture, + v1::PayloadBlockedCursor<::pisa::simdbp_block>, v1::RawCursor>)) { tbb::task_scheduler_init init(1); @@ -57,7 +57,9 @@ TEMPLATE_TEST_CASE("Bigram v intersection", CAPTURE(q.get_term_ids()); CAPTURE(idx++); - auto run = v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + auto run = v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); std::vector results; run([&](auto&& index) { for (auto left = 0; left < q.get_term_ids().size(); left += 1) { diff --git a/test/v1/test_v1_blocked_cursor.cpp b/test/v1/test_v1_blocked_cursor.cpp index 0dca9ca64..a44af3e36 100644 --- a/test/v1/test_v1_blocked_cursor.cpp +++ b/test/v1/test_v1_blocked_cursor.cpp @@ -19,13 +19,16 @@ #include "v1/posting_builder.hpp" #include "v1/types.hpp" -using pisa::v1::BlockedReader; -using pisa::v1::BlockedWriter; using pisa::v1::collect; using pisa::v1::DocId; +using pisa::v1::DocumentBlockedReader; +using pisa::v1::DocumentBlockedWriter; using pisa::v1::Frequency; using pisa::v1::IndexRunner; +using pisa::v1::PayloadBlockedReader; +using pisa::v1::PayloadBlockedWriter; using pisa::v1::PostingBuilder; +using pisa::v1::PostingData; using pisa::v1::RawReader; using pisa::v1::read_sizes; using pisa::v1::TermId; @@ -38,7 +41,7 @@ TEST_CASE("Build single-block blocked document file", "[v1][unit]") std::vector docids{3, 4, 5, 6, 7, 8, 9, 10, 51, 115}; std::vector docbuf; auto document_offsets = [&]() { - PostingBuilder document_builder(BlockedWriter{}); + PostingBuilder document_builder(DocumentBlockedWriter{}); vector_stream_type docstream{sink_type{docbuf}}; document_builder.write_header(docstream); document_builder.write_segment(docstream, docids.begin(), docids.end()); @@ -47,7 +50,7 @@ TEST_CASE("Build single-block blocked document file", "[v1][unit]") auto documents = gsl::span(docbuf).subspan(8); CHECK(docbuf.size() == document_offsets.back() + 8); - BlockedReader document_reader; + DocumentBlockedReader document_reader; auto term = 0; auto actual = collect(document_reader.read(documents.subspan( document_offsets[term], document_offsets[term + 1] - document_offsets[term]))); @@ -67,8 +70,8 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") std::vector docbuf; std::vector freqbuf; - PostingBuilder document_builder(BlockedWriter{}); - PostingBuilder frequency_builder(BlockedWriter{}); + PostingBuilder document_builder(DocumentBlockedWriter{}); + PostingBuilder frequency_builder(PayloadBlockedWriter{}); { vector_stream_type docstream{sink_type{docbuf}}; vector_stream_type freqstream{sink_type{freqbuf}}; @@ -94,8 +97,8 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") THEN("The values read back are euqual to the binary collection's") { CHECK(docbuf.size() == document_offsets.back() + 8); - BlockedReader document_reader; - BlockedReader frequency_reader; + DocumentBlockedReader document_reader; + PayloadBlockedReader frequency_reader; auto term = 0; std::for_each(collection.begin(), collection.end(), [&](auto&& seq) { std::vector expected_documents(seq.docs.begin(), seq.docs.end()); @@ -123,22 +126,17 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); - IndexRunner runner(document_offsets, - frequency_offsets, - {}, - {}, - document_span, - payload_span, - {}, + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, {}, document_sizes, tl::nullopt, {}, {}, - tl::nullopt, + {}, std::move(source), - BlockedReader{}, - BlockedReader{}); + std::make_tuple(DocumentBlockedReader{}), + std::make_tuple(PayloadBlockedReader{})); int counter = 0; runner([&](auto index) { counter += 1; @@ -180,21 +178,17 @@ TEST_CASE("Build blocked document-frequency index", "[v1][unit]") reinterpret_cast(source[0].data()), source[0].size()); auto payload_span = gsl::span( reinterpret_cast(source[1].data()), source[1].size()); - IndexRunner runner(document_offsets, - frequency_offsets, - {}, - {}, - document_span, - payload_span, - {}, + IndexRunner runner(PostingData{document_span, document_offsets}, + PostingData{payload_span, frequency_offsets}, {}, document_sizes, tl::nullopt, {}, {}, - tl::nullopt, + {}, std::move(source), - RawReader{}); // Correct encoding but not type! + std::make_tuple(RawReader{}), + std::make_tuple()); REQUIRE_THROWS_AS(runner([&](auto index) {}), std::domain_error); } } diff --git a/test/v1/test_v1_document_payload_cursor.cpp b/test/v1/test_v1_document_payload_cursor.cpp index dcd4d8d9a..220ccb315 100644 --- a/test/v1/test_v1_document_payload_cursor.cpp +++ b/test/v1/test_v1_document_payload_cursor.cpp @@ -19,20 +19,10 @@ #include "v1/posting_builder.hpp" #include "v1/types.hpp" -using pisa::v1::BlockedReader; -using pisa::v1::BlockedWriter; -using pisa::v1::collect; -using pisa::v1::compress_binary_collection; using pisa::v1::DocId; using pisa::v1::DocumentPayloadCursor; using pisa::v1::Frequency; -using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; -using pisa::v1::load_bytes; -using pisa::v1::PostingBuilder; using pisa::v1::RawCursor; -using pisa::v1::RawReader; -using pisa::v1::read_sizes; using pisa::v1::TermId; TEST_CASE("Document-payload cursor", "[v1][unit]") diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp index a6f4a4f26..7b349db82 100644 --- a/test/v1/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -13,19 +13,20 @@ #include "pisa_config.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" +#include "v1/default_index_runner.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" #include "v1/types.hpp" using pisa::binary_freq_collection; -using pisa::v1::BlockedReader; -using pisa::v1::BlockedWriter; using pisa::v1::compress_binary_collection; using pisa::v1::DocId; +using pisa::v1::DocumentBlockedWriter; using pisa::v1::Frequency; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::PayloadBlockedWriter; using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::TermId; @@ -47,10 +48,7 @@ TEST_CASE("Binary collection index", "[v1][unit]") REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); - auto run = index_runner(meta, - RawReader{}, - BlockedReader{}, - BlockedReader{}); + auto run = index_runner(meta); run([&](auto index) { REQUIRE(bci.num_docs() == index.num_documents()); REQUIRE(bci.size() == index.num_terms()); @@ -74,18 +72,15 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") (tmpdir.path() / "fwd").string(), (tmpdir.path() / "index").string(), 8, - make_writer(BlockedWriter<::pisa::simdbp_block, true>{}), - make_writer(BlockedWriter<::pisa::simdbp_block, false>{})); + make_writer(DocumentBlockedWriter<::pisa::simdbp_block>{}), + make_writer(PayloadBlockedWriter<::pisa::simdbp_block>{})); auto meta = IndexMetadata::from_file((tmpdir.path() / "index.yml").string()); REQUIRE(meta.documents.postings == (tmpdir.path() / "index.documents").string()); REQUIRE(meta.documents.offsets == (tmpdir.path() / "index.document_offsets").string()); REQUIRE(meta.frequencies.postings == (tmpdir.path() / "index.frequencies").string()); REQUIRE(meta.frequencies.offsets == (tmpdir.path() / "index.frequency_offsets").string()); REQUIRE(meta.document_lengths_path == (tmpdir.path() / "index.document_lengths").string()); - auto run = index_runner(meta, - RawReader{}, - BlockedReader{}, - BlockedReader{}); + auto run = index_runner(meta); run([&](auto index) { REQUIRE(bci.num_docs() == index.num_documents()); REQUIRE(bci.size() == index.num_terms()); diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp index 15e491cb5..45978b951 100644 --- a/test/v1/test_v1_maxscore_join.cpp +++ b/test/v1/test_v1_maxscore_join.cpp @@ -21,19 +21,13 @@ #include "v1/io.hpp" #include "v1/maxscore.hpp" #include "v1/posting_builder.hpp" +#include "v1/scorer/bm25.hpp" #include "v1/types.hpp" -using pisa::v1::BlockedReader; -using pisa::v1::BlockedWriter; using pisa::v1::collect; using pisa::v1::DocId; using pisa::v1::Frequency; -using pisa::v1::IndexRunner; using pisa::v1::join_maxscore; -using pisa::v1::PostingBuilder; -using pisa::v1::RawReader; -using pisa::v1::read_sizes; -using pisa::v1::TermId; using pisa::v1::accumulate::Add; TEMPLATE_TEST_CASE("Max score join", @@ -57,8 +51,9 @@ TEMPLATE_TEST_CASE("Max score join", auto add = [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { return score + cursor.payload(); }; - auto run = - v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + auto run = v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); run([&](auto&& index) { auto union_results = collect(v1::union_merge( index.scored_cursors(gsl::make_span(q.get_term_ids()), make_bm25(index)), diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 841d67ac0..adbe629a8 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -38,14 +38,16 @@ #include "v1/taat_or.hpp" #include "v1/types.hpp" #include "v1/union_lookup.hpp" +#include "v1/wand.hpp" using namespace pisa; -using pisa::v1::BlockedCursor; using pisa::v1::DocId; +using pisa::v1::DocumentBlockedCursor; using pisa::v1::Frequency; using pisa::v1::Index; using pisa::v1::IndexMetadata; using pisa::v1::ListSelection; +using pisa::v1::PayloadBlockedCursor; using pisa::v1::RawCursor; static constexpr auto RELATIVE_ERROR = 0.1F; @@ -127,8 +129,8 @@ std::unique_ptr> TEMPLATE_TEST_CASE("Query", "[v1][integration]", (IndexFixture, RawCursor, RawCursor>), - (IndexFixture, - BlockedCursor<::pisa::simdbp_block, false>, + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, RawCursor>)) { tbb::task_scheduler_init init(1); @@ -140,6 +142,10 @@ TEMPLATE_TEST_CASE("Query", {"daat_or", false}, {"maxscore", false}, {"maxscore", true}, + {"wand", false}, + {"wand", true}, + {"bmw", false}, + {"bmw", true}, {"maxscore_union_lookup", true}, {"unigram_union_lookup", true}, {"union_lookup", true}, @@ -159,6 +165,12 @@ TEMPLATE_TEST_CASE("Query", if (name == "maxscore") { return maxscore(query, index, topk_queue(10), scorer); } + if (name == "wand") { + return wand(query, index, topk_queue(10), scorer); + } + if (name == "bmw") { + return bmw(query, index, topk_queue(10), scorer); + } if (name == "maxscore_union_lookup") { return maxscore_union_lookup(query, index, topk_queue(10), scorer); } @@ -199,8 +211,9 @@ TEMPLATE_TEST_CASE("Query", } auto on_the_fly = [&]() { - auto run = - pisa::v1::index_runner(meta, fixture.document_reader(), fixture.frequency_reader()); + auto run = pisa::v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); std::vector results; run([&](auto&& index) { auto que = run_query(algorithm, query, index, make_bm25(index)); @@ -221,6 +234,17 @@ TEMPLATE_TEST_CASE("Query", expected.resize(on_the_fly.size()); std::sort(expected.begin(), expected.end(), approximate_order); + if (algorithm == "bmw") { + for (size_t i = 0; i < on_the_fly.size(); ++i) { + std::cerr << fmt::format("{} {} -- {} {}\n", + on_the_fly[i].second, + on_the_fly[i].first, + expected[i].second, + expected[i].first); + } + std::cerr << '\n'; + } + for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); REQUIRE(on_the_fly[i].first == Approx(expected[i].first).epsilon(RELATIVE_ERROR)); @@ -228,9 +252,14 @@ TEMPLATE_TEST_CASE("Query", idx += 1; + if (algorithm == "bmw") { + continue; + } + auto precomputed = [&]() { - auto run = - v1::scored_index_runner(meta, fixture.document_reader(), fixture.score_reader()); + auto run = pisa::v1::scored_index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); std::vector results; run([&](auto&& index) { auto que = run_query(algorithm, query, index, v1::VoidScorer{}); diff --git a/test/v1/test_v1_query.cpp b/test/v1/test_v1_query.cpp index fcd454ea0..6dcbb1554 100644 --- a/test/v1/test_v1_query.cpp +++ b/test/v1/test_v1_query.cpp @@ -25,7 +25,7 @@ TEST_CASE("Parse query from JSON", "[v1][unit]") auto query = Query::from_json( R"({"id": "Q0", "query": "send dog pics", "term_ids": [0, 32, 4], "k": 15, )" R"("threshold": 40.5, "selections": )" - R"({ "unigrams": [0, 2], "bigrams": [[0, 2], [2, 1]]}})"); + R"([1, 4, 5, 6]})"); REQUIRE(query.get_id() == "Q0"); REQUIRE(query.k() == 15); REQUIRE(query.get_term_ids() == std::vector{0, 4, 32}); @@ -34,8 +34,4 @@ TEST_CASE("Parse query from JSON", "[v1][unit]") REQUIRE(query.get_selections().unigrams == std::vector{0, 4}); REQUIRE(query.get_selections().bigrams == std::vector>{{0, 4}, {4, 32}}); - REQUIRE_THROWS(Query::from_json( - R"({"id": "Q0", "query": "send dog pics", "term_ids": [0, 32, 4], "k": 15, )" - R"("threshold": 40.5, "selections": )" - R"({ "unigrams": [0, 4], "bigrams": [[0, 4], [4, 5]]}})")); } diff --git a/test/v1/test_v1_score_index.cpp b/test/v1/test_v1_score_index.cpp index 0ba952fbd..d040b389f 100644 --- a/test/v1/test_v1_score_index.cpp +++ b/test/v1/test_v1_score_index.cpp @@ -17,27 +17,28 @@ #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" +#include "v1/scorer/bm25.hpp" #include "v1/types.hpp" -using pisa::v1::BlockedCursor; -using pisa::v1::BlockedReader; -using pisa::v1::BlockedWriter; -using pisa::v1::compress_binary_collection; using pisa::v1::DocId; +using pisa::v1::DocumentBlockedCursor; +using pisa::v1::DocumentBlockedReader; +using pisa::v1::DocumentBlockedWriter; using pisa::v1::Frequency; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::make_bm25; +using pisa::v1::PayloadBlockedCursor; +using pisa::v1::PayloadBlockedReader; +using pisa::v1::PayloadBlockedWriter; using pisa::v1::RawCursor; -using pisa::v1::RawReader; -using pisa::v1::RawWriter; using pisa::v1::TermId; -TEMPLATE_TEST_CASE("DAAT OR", +TEMPLATE_TEST_CASE("Score index", "[v1][integration]", (IndexFixture, RawCursor, RawCursor>), - (IndexFixture, - BlockedCursor<::pisa::simdbp_block, false>, + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, RawCursor>)) { tbb::task_scheduler_init init(1); @@ -46,8 +47,9 @@ TEMPLATE_TEST_CASE("DAAT OR", TestType fixture; THEN("Float max scores are correct") { - auto run = v1::index_runner( - fixture.meta(), fixture.document_reader(), fixture.frequency_reader()); + auto run = v1::index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); run([&](auto&& index) { for (auto term = 0; term < index.num_terms(); term += 1) { CAPTURE(term); @@ -63,8 +65,9 @@ TEMPLATE_TEST_CASE("DAAT OR", } THEN("Quantized max scores are correct") { - auto run = v1::scored_index_runner( - fixture.meta(), fixture.document_reader(), fixture.score_reader()); + auto run = v1::scored_index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); run([&](auto&& index) { for (auto term = 0; term < index.num_terms(); term += 1) { CAPTURE(term); @@ -82,3 +85,46 @@ TEMPLATE_TEST_CASE("DAAT OR", } } } + +TEMPLATE_TEST_CASE("Construct max-score lists", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>)) +{ + tbb::task_scheduler_init init(1); + GIVEN("Index fixture (built and (max) scored index)") + { + TestType fixture; + THEN("Float max scores are correct") + { + auto run = v1::index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + run([&](auto&& index) { + for (auto term = 0; term < index.num_terms(); term += 1) { + CAPTURE(term); + auto cursor = index.block_max_scored_cursor(term, make_bm25(index)); + auto term_max_score = 0.0F; + while (not cursor.empty()) { + auto max_score = 0.0F; + auto block_max_score = cursor.block_max_score(*cursor); + for (auto idx = 0; idx < 5 && not cursor.empty(); ++idx) { + REQUIRE(cursor.block_max_score(*cursor) == block_max_score); + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + cursor.advance(); + } + if (max_score > term_max_score) { + term_max_score = max_score; + } + REQUIRE(max_score == block_max_score); + } + REQUIRE(term_max_score == cursor.max_score()); + } + }); + } + } +} diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index f9e02ae03..7e5d0f829 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -10,6 +10,9 @@ target_link_libraries(postings pisa CLI11) add_executable(score score.cpp) target_link_libraries(score pisa CLI11) +add_executable(bmscore bmscore.cpp) +target_link_libraries(bmscore pisa CLI11) + add_executable(bigram-index bigram_index.cpp) target_link_libraries(bigram-index pisa CLI11) diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index 583f0d8b9..e48d952c4 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include @@ -7,32 +5,15 @@ #include #include "app.hpp" -#include "io.hpp" -#include "query/queries.hpp" -#include "timer.hpp" -#include "topk_queue.hpp" -#include "v1/blocked_cursor.hpp" -#include "v1/cursor_intersection.hpp" #include "v1/index_builder.hpp" -#include "v1/index_metadata.hpp" #include "v1/progress_status.hpp" -#include "v1/query.hpp" -#include "v1/raw_cursor.hpp" -#include "v1/scorer/bm25.hpp" -#include "v1/scorer/runner.hpp" #include "v1/types.hpp" using pisa::App; using pisa::v1::build_bigram_index; using pisa::v1::collect_unique_bigrams; using pisa::v1::DefaultProgress; -using pisa::v1::DocId; -using pisa::v1::Frequency; -using pisa::v1::IndexMetadata; using pisa::v1::ProgressStatus; -using pisa::v1::Query; -using pisa::v1::resolve_yml; -using pisa::v1::TermId; namespace arg = pisa::arg; diff --git a/v1/bmscore.cpp b/v1/bmscore.cpp new file mode 100644 index 000000000..76ddc6a00 --- /dev/null +++ b/v1/bmscore.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +#include + +#include "app.hpp" +#include "v1/score_index.hpp" + +using pisa::App; +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + std::optional yml{}; + std::size_t block_size; + std::size_t threads = std::thread::hardware_concurrency(); + + App app{"Constructs block-max score lists for v1 index."}; + app.add_option("--block-size", block_size, "The size of a block for max scores", false) + ->required(); + CLI11_PARSE(app, argc, argv); + pisa::v1::bm_score_index(app.index_metadata(), block_size, app.threads()); + return 0; +} diff --git a/v1/compress.cpp b/v1/compress.cpp index aa63efda6..dda0db0e0 100644 --- a/v1/compress.cpp +++ b/v1/compress.cpp @@ -4,7 +4,6 @@ #include #include -#include "binary_freq_collection.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" @@ -13,10 +12,11 @@ using std::literals::string_view_literals::operator""sv; -using pisa::v1::BlockedWriter; using pisa::v1::compress_binary_collection; +using pisa::v1::DocumentBlockedWriter; using pisa::v1::EncodingId; using pisa::v1::make_index_builder; +using pisa::v1::PayloadBlockedWriter; using pisa::v1::RawWriter; using pisa::v1::verify_compressed_index; @@ -63,8 +63,8 @@ int main(int argc, char** argv) tbb::task_scheduler_init init(threads); auto build = make_index_builder(RawWriter{}, - BlockedWriter<::pisa::simdbp_block, true>{}, - BlockedWriter<::pisa::simdbp_block, false>{}); + DocumentBlockedWriter<::pisa::simdbp_block>{}, + PayloadBlockedWriter<::pisa::simdbp_block>{}); build(document_encoding(encoding), frequency_encoding(encoding), [&](auto document_writer, auto payload_writer) { diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index c6ecfcd19..f844fc233 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -1,40 +1,11 @@ -#include #include -#include #include -#include -#include +#include #include #include #include "app.hpp" -#include "io.hpp" -#include "query/queries.hpp" -#include "timer.hpp" -#include "topk_queue.hpp" -#include "v1/blocked_cursor.hpp" -#include "v1/daat_or.hpp" -#include "v1/index_metadata.hpp" -#include "v1/intersection.hpp" -#include "v1/maxscore.hpp" -#include "v1/query.hpp" -#include "v1/raw_cursor.hpp" -#include "v1/scorer/bm25.hpp" -#include "v1/scorer/runner.hpp" -#include "v1/types.hpp" -#include "v1/union_lookup.hpp" - -using pisa::resolve_query_parser; -using pisa::TermProcessor; -using pisa::v1::BlockedReader; -using pisa::v1::daat_or; -using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; -using pisa::v1::Query; -using pisa::v1::RawReader; -using pisa::v1::resolve_yml; -using pisa::v1::VoidScorer; namespace arg = pisa::arg; @@ -51,7 +22,7 @@ int main(int argc, char** argv) auto queries = app.queries(meta); for (auto&& query : queries) { if (query.term_ids()) { - std::cout << query.to_json() << '\n'; + std::cout << *query.to_json() << '\n'; } } return 0; diff --git a/v1/intersection.cpp b/v1/intersection.cpp index d9ea60982..b538667d5 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include #include @@ -15,6 +15,7 @@ #include "v1/blocked_cursor.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" +#include "v1/default_index_runner.hpp" #include "v1/raw_cursor.hpp" #include "v1/scorer/bm25.hpp" @@ -22,7 +23,6 @@ using pisa::App; using pisa::Intersection; using pisa::intersection::IntersectionType; using pisa::intersection::Mask; -using pisa::v1::BlockedReader; using pisa::v1::intersect; using pisa::v1::make_bm25; using pisa::v1::RawReader; @@ -90,7 +90,7 @@ void compute_intersections(Index const& index, } else { inter(query, tl::nullopt); } - auto output = query.to_json(); + auto output = *query.to_json(); output["intersections"] = intersections; std::cout << output << '\n'; } @@ -122,10 +122,7 @@ int main(int argc, const char** argv) auto meta = app.index_metadata(); auto queries = app.queries(meta); - auto run = index_runner(meta, - // RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = index_runner(meta); run([&](auto&& index) { compute_intersections(index, queries, intersection_type, mtc); }); } catch (std::exception const& error) { spdlog::error("{}", error.what()); diff --git a/v1/postings.cpp b/v1/postings.cpp index a46be9f96..9bfa1d520 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -11,6 +11,7 @@ #include "query/queries.hpp" #include "topk_queue.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/default_index_runner.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -21,7 +22,6 @@ using pisa::App; using pisa::Query; using pisa::resolve_query_parser; -using pisa::v1::BlockedReader; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; using pisa::v1::RawReader; @@ -29,15 +29,6 @@ using pisa::v1::resolve_yml; namespace arg = pisa::arg; -auto default_readers() -{ - return std::make_tuple(RawReader{}, - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); -} - [[nodiscard]] auto load_source(std::optional const& file) -> std::shared_ptr { @@ -127,7 +118,7 @@ int main(int argc, char** argv) if (query.terms.size() == 1) { if (precomputed) { - auto run = scored_index_runner(meta, default_readers()); + auto run = scored_index_runner(meta); run([&](auto&& index) { auto print = [&](auto&& cursor) { if (did) { @@ -144,7 +135,7 @@ int main(int argc, char** argv) for_each(index.cursor(query.terms.front()), print); }); } else { - auto run = index_runner(meta, default_readers()); + auto run = index_runner(meta); run([&](auto&& index) { auto bm25 = make_bm25(index); auto scorer = bm25.term_scorer(query.terms.front()); diff --git a/v1/query.cpp b/v1/query.cpp index 1ba9cba39..565d1dc05 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -3,13 +3,10 @@ #include #include -#include -#include #include #include #include "app.hpp" -#include "io.hpp" #include "query/queries.hpp" #include "timer.hpp" #include "topk_queue.hpp" @@ -18,7 +15,6 @@ #include "v1/daat_or.hpp" #include "v1/index_metadata.hpp" #include "v1/inspect_query.hpp" -#include "v1/intersection.hpp" #include "v1/maxscore.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -26,83 +22,150 @@ #include "v1/scorer/runner.hpp" #include "v1/types.hpp" #include "v1/union_lookup.hpp" +#include "v1/wand.hpp" -using pisa::resolve_query_parser; -using pisa::v1::BlockedReader; using pisa::v1::daat_or; using pisa::v1::DaatOrInspector; +using pisa::v1::DocumentBlockedReader; using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; -using pisa::v1::ListSelection; using pisa::v1::lookup_union; using pisa::v1::LookupUnionInspector; using pisa::v1::maxscore_union_lookup; using pisa::v1::MaxscoreInspector; using pisa::v1::MaxscoreUnionLookupInspect; +using pisa::v1::PayloadBlockedReader; using pisa::v1::Query; using pisa::v1::QueryInspector; using pisa::v1::RawReader; -using pisa::v1::resolve_yml; using pisa::v1::unigram_union_lookup; using pisa::v1::UnigramUnionLookupInspect; using pisa::v1::union_lookup; using pisa::v1::UnionLookupInspect; using pisa::v1::VoidScorer; +using pisa::v1::wand; -using RetrievalAlgorithm = std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)>; +struct RetrievalAlgorithm { + template + explicit RetrievalAlgorithm(Fn fn, FallbackFn fallback, bool safe) + : m_retrieve(std::move(fn)), m_fallback(std::move(fallback)), m_safe(safe) + { + } + + [[nodiscard]] auto operator()(pisa::v1::Query const& query, ::pisa::topk_queue topk) const + -> ::pisa::topk_queue + { + topk = m_retrieve(query, topk); + if (m_safe && not topk.full()) { + topk.clear(); + topk = m_fallback(query, topk); + } + return topk; + } + + private: + std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)> m_retrieve; + std::function<::pisa::topk_queue(pisa::v1::Query, ::pisa::topk_queue)> m_fallback; + bool m_safe; +}; template -auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& scorer) +auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& scorer, bool safe) -> RetrievalAlgorithm { + auto fallback = [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + topk.clear(); + return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); + }; if (name == "daat_or") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - return pisa::v1::daat_or(query, index, std::move(topk), std::forward(scorer)); - }); + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::daat_or( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "wand") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::wand(query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } + if (name == "bmw") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::bmw(query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); } if (name == "maxscore") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.threshold()) { - topk.set_threshold(query.get_threshold()); - } - return pisa::v1::maxscore(query, index, std::move(topk), std::forward(scorer)); - }); + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.threshold()) { + topk.set_threshold(query.get_threshold()); + } + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); } if (name == "maxscore-union-lookup") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - return pisa::v1::maxscore_union_lookup( - query, index, std::move(topk), std::forward(scorer)); - }); + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + return pisa::v1::maxscore_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); } if (name == "unigram-union-lookup") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - return pisa::v1::unigram_union_lookup( - query, index, std::move(topk), std::forward(scorer)); - }); - } - if (name == "union-lookup") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.selections()->bigrams.empty()) { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { return pisa::v1::unigram_union_lookup( query, index, std::move(topk), std::forward(scorer)); - } - if (query.get_term_ids().size() > 8) { - return pisa::v1::maxscore( + }, + fallback, + safe); + } + if (name == "union-lookup") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + if (query.get_term_ids().size() > 8) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::union_lookup( query, index, std::move(topk), std::forward(scorer)); - } - return pisa::v1::union_lookup( - query, index, std::move(topk), std::forward(scorer)); - }); + }, + fallback, + safe); } if (name == "lookup-union") { - return RetrievalAlgorithm([&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { - if (query.selections()->bigrams.empty()) { - return pisa::v1::unigram_union_lookup( + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::lookup_union( query, index, std::move(topk), std::forward(scorer)); - } - return pisa::v1::lookup_union( - query, index, std::move(topk), std::forward(scorer)); - }); + }, + fallback, + safe); } spdlog::error("Unknown algorithm: {}", name); std::exit(1); @@ -201,10 +264,12 @@ int main(int argc, char** argv) std::string algorithm = "daat_or"; bool inspect = false; + bool safe = false; pisa::QueryApp app("Queries a v1 index."); - app.add_option("--algorithm", algorithm, "Query retrieval algorithm.", true); + app.add_option("--algorithm", algorithm, "Query retrieval algorithm", true); app.add_flag("--inspect", inspect, "Analyze query execution and stats"); + app.add_flag("--safe", safe, "Repeats without threshold if it was overestimated"); CLI11_PARSE(app, argc, argv); try { @@ -219,34 +284,37 @@ int main(int argc, char** argv) auto docmap = pisa::Payload_Vector<>::from(*source); if (app.use_quantized()) { - auto run = scored_index_runner(meta, - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = + scored_index_runner(meta, + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}), + std::make_tuple(RawReader{}, + PayloadBlockedReader<::pisa::simdbp_block>{})); run([&](auto&& index) { if (app.is_benchmark()) { - benchmark(queries, resolve_algorithm(algorithm, index, VoidScorer{})); + benchmark(queries, resolve_algorithm(algorithm, index, VoidScorer{}, safe)); } else if (inspect) { inspect_queries(queries, resolve_inspect(algorithm, index, VoidScorer{})); } else { - evaluate(queries, docmap, resolve_algorithm(algorithm, index, VoidScorer{})); + evaluate( + queries, docmap, resolve_algorithm(algorithm, index, VoidScorer{}, safe)); } }); } else { auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}), + std::make_tuple(PayloadBlockedReader<::pisa::simdbp_block>{})); run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { if (app.is_benchmark()) { - benchmark(queries, resolve_algorithm(algorithm, index, scorer)); + benchmark(queries, resolve_algorithm(algorithm, index, scorer, safe)); } else if (inspect) { inspect_queries(queries, resolve_inspect(algorithm, index, scorer)); } else { - evaluate(queries, docmap, resolve_algorithm(algorithm, index, scorer)); + evaluate( + queries, docmap, resolve_algorithm(algorithm, index, scorer, safe)); } }); }); diff --git a/v1/score.cpp b/v1/score.cpp index bc47cc972..4d622bcdd 100644 --- a/v1/score.cpp +++ b/v1/score.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -5,7 +6,6 @@ #include #include "app.hpp" -#include "v1/index_metadata.hpp" #include "v1/score_index.hpp" using pisa::App; diff --git a/v1/threshold.cpp b/v1/threshold.cpp index be11d1fc6..4b9ca52d0 100644 --- a/v1/threshold.cpp +++ b/v1/threshold.cpp @@ -1,20 +1,17 @@ #include #include -#include #include -#include -#include +#include #include #include #include "app.hpp" -#include "io.hpp" #include "query/queries.hpp" -#include "timer.hpp" #include "topk_queue.hpp" #include "v1/blocked_cursor.hpp" #include "v1/daat_or.hpp" +#include "v1/default_index_runner.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" @@ -22,13 +19,8 @@ #include "v1/scorer/runner.hpp" #include "v1/types.hpp" -using pisa::resolve_query_parser; -using pisa::v1::BlockedReader; using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; using pisa::v1::Query; -using pisa::v1::RawReader; -using pisa::v1::resolve_yml; using pisa::v1::VoidScorer; template @@ -46,7 +38,7 @@ void calculate_thresholds(Index&& index, threshold = results.topk().back().first; } query.threshold(threshold); - os << query.to_json() << '\n'; + os << *query.to_json() << '\n'; } } @@ -70,11 +62,7 @@ int main(int argc, char** argv) auto queries = app.queries(meta); if (app.use_quantized()) { - auto run = scored_index_runner(meta, - RawReader{}, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = scored_index_runner(meta); run([&](auto&& index) { if (in_place) { std::ofstream os(app.query_file().value()); @@ -84,10 +72,7 @@ int main(int argc, char** argv) } }); } else { - auto run = index_runner(meta, - RawReader{}, - BlockedReader<::pisa::simdbp_block, true>{}, - BlockedReader<::pisa::simdbp_block, false>{}); + auto run = index_runner(meta); run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { From e2b87385e36c30387ad75cf97ab4b8263a2caf26 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 31 Dec 2019 18:40:48 +0000 Subject: [PATCH 40/56] PEF index --- include/pisa/codec/integer_codes.hpp | 12 +- include/pisa/query/algorithm/wand_query.hpp | 2 +- .../pisa/sequence/partitioned_sequence.hpp | 2 + include/pisa/v1/base_index.hpp | 104 ++ include/pisa/v1/bit_sequence_cursor.hpp | 247 +++++ include/pisa/v1/bit_vector.hpp | 772 ++++++++++++++ include/pisa/v1/blocked_cursor.hpp | 6 + include/pisa/v1/cursor/compact_elias_fano.hpp | 67 ++ include/pisa/v1/cursor_traits.hpp | 3 + include/pisa/v1/default_index_runner.hpp | 25 +- include/pisa/v1/encoding_traits.hpp | 30 + include/pisa/v1/index.hpp | 132 +-- include/pisa/v1/index_builder.hpp | 28 +- include/pisa/v1/index_metadata.hpp | 2 +- include/pisa/v1/maxscore.hpp | 9 +- include/pisa/v1/raw_cursor.hpp | 6 + include/pisa/v1/runtime_assert.hpp | 20 + include/pisa/v1/sequence/indexed_sequence.hpp | 939 ++++++++++++++++++ .../pisa/v1/sequence/partitioned_sequence.hpp | 403 ++++++++ .../pisa/v1/sequence/positive_sequence.hpp | 311 ++++++ include/pisa/v1/sequence_cursor.hpp | 326 ------ include/pisa/v1/types.hpp | 128 ++- include/pisa/v1/union_lookup.hpp | 6 +- include/pisa/v1/wand.hpp | 254 +++-- src/v1/bit_sequence_cursor.cpp | 1 + src/v1/default_index_runner.cpp | 23 + src/v1/index.cpp | 10 - src/v1/index_builder.cpp | 6 +- test/v1/index_fixture.hpp | 33 +- test/v1/test_v1_index.cpp | 48 +- test/v1/test_v1_queries.cpp | 84 +- v1/compress.cpp | 24 +- 32 files changed, 3403 insertions(+), 660 deletions(-) create mode 100644 include/pisa/v1/base_index.hpp create mode 100644 include/pisa/v1/bit_sequence_cursor.hpp create mode 100644 include/pisa/v1/bit_vector.hpp create mode 100644 include/pisa/v1/cursor/compact_elias_fano.hpp create mode 100644 include/pisa/v1/runtime_assert.hpp create mode 100644 include/pisa/v1/sequence/indexed_sequence.hpp create mode 100644 include/pisa/v1/sequence/partitioned_sequence.hpp create mode 100644 include/pisa/v1/sequence/positive_sequence.hpp delete mode 100644 include/pisa/v1/sequence_cursor.hpp create mode 100644 src/v1/bit_sequence_cursor.cpp create mode 100644 src/v1/default_index_runner.cpp diff --git a/include/pisa/codec/integer_codes.hpp b/include/pisa/codec/integer_codes.hpp index 677e785d2..7a6ee3a29 100644 --- a/include/pisa/codec/integer_codes.hpp +++ b/include/pisa/codec/integer_codes.hpp @@ -20,14 +20,19 @@ inline void write_gamma_nonzero(bit_vector_builder& bvb, uint64_t n) write_gamma(bvb, n - 1); } -inline uint64_t read_gamma(bit_vector::enumerator& it) +template +inline uint64_t read_gamma(BitVectorEnumerator& it) { uint64_t l = it.skip_zeros(); assert(l < 64); return (it.take(l) | (uint64_t(1) << l)) - 1; } -inline uint64_t read_gamma_nonzero(bit_vector::enumerator& it) { return read_gamma(it) + 1; } +template +inline uint64_t read_gamma_nonzero(BitVectorEnumerator& it) +{ + return read_gamma(it) + 1; +} inline void write_delta(bit_vector_builder& bvb, uint64_t n) { @@ -38,7 +43,8 @@ inline void write_delta(bit_vector_builder& bvb, uint64_t n) bvb.append_bits(nn ^ hb, l); } -inline uint64_t read_delta(bit_vector::enumerator& it) +template +inline uint64_t read_delta(BitVectorEnumerator& it) { uint64_t l = read_gamma(it); return (it.take(l) | (uint64_t(1) << l)) - 1; diff --git a/include/pisa/query/algorithm/wand_query.hpp b/include/pisa/query/algorithm/wand_query.hpp index 9833e08fe..6ba348765 100644 --- a/include/pisa/query/algorithm/wand_query.hpp +++ b/include/pisa/query/algorithm/wand_query.hpp @@ -97,4 +97,4 @@ struct wand_query { topk_queue m_topk; }; -} // namespace pisa \ No newline at end of file +} // namespace pisa diff --git a/include/pisa/sequence/partitioned_sequence.hpp b/include/pisa/sequence/partitioned_sequence.hpp index 4ebc5c3e0..132b70dbd 100644 --- a/include/pisa/sequence/partitioned_sequence.hpp +++ b/include/pisa/sequence/partitioned_sequence.hpp @@ -231,6 +231,8 @@ namespace pisa { return m_partitions; } + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + friend class partitioned_sequence_test; private: diff --git a/include/pisa/v1/base_index.hpp b/include/pisa/v1/base_index.hpp new file mode 100644 index 000000000..1a520fcfb --- /dev/null +++ b/include/pisa/v1/base_index.hpp @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "v1/posting_builder.hpp" +#include "v1/source.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +using OffsetSpan = gsl::span; +using BinarySpan = gsl::span; + +[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float; +[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector; + +/// Lexicographically compares bigrams. +/// Used for looking up bigram mappings. +[[nodiscard]] auto compare_arrays(std::array const& lhs, + std::array const& rhs) -> bool; + +struct PostingData { + BinarySpan postings; + OffsetSpan offsets; +}; + +struct UnigramData { + PostingData documents; + PostingData payloads; +}; + +struct BigramData { + PostingData documents; + std::array payloads; + gsl::span const> mapping; +}; + +/// Parts of the index independent of the template parameters. +struct BaseIndex { + + template + BaseIndex(PostingData documents, + PostingData payloads, + tl::optional bigrams, + gsl::span document_lengths, + tl::optional avg_document_length, + std::unordered_map> max_scores, + std::unordered_map block_max_scores, + gsl::span quantized_max_scores, + Source source) + : m_documents(documents), + m_payloads(payloads), + m_bigrams(bigrams), + m_document_lengths(document_lengths), + m_avg_document_length(avg_document_length.map_or_else( + [](auto&& self) { return self; }, + [&]() { return calc_avg_length(m_document_lengths); })), + m_max_scores(std::move(max_scores)), + m_block_max_scores(std::move(block_max_scores)), + m_quantized_max_scores(quantized_max_scores), + m_source(std::move(source)) + { + } + + [[nodiscard]] auto num_terms() const -> std::size_t; + [[nodiscard]] auto num_documents() const -> std::size_t; + [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t; + [[nodiscard]] auto avg_document_length() const -> float; + [[nodiscard]] auto normalized_document_length(DocId docid) const -> float; + [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional; + + protected: + void assert_term_in_bounds(TermId term) const; + [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_documents(TermId bigram) const -> gsl::span; + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const + -> std::array, 2>; + template + [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const -> gsl::span; + + [[nodiscard]] auto max_score(std::size_t scorer_hash, TermId term) const -> float; + [[nodiscard]] auto block_max_scores(std::size_t scorer_hash) const -> UnigramData const&; + [[nodiscard]] auto quantized_max_score(TermId term) const -> std::uint8_t; + + private: + PostingData m_documents; + PostingData m_payloads; + tl::optional m_bigrams; + + gsl::span m_document_lengths; + float m_avg_document_length; + std::unordered_map> m_max_scores; + std::unordered_map m_block_max_scores; + gsl::span m_quantized_max_scores; + std::any m_source; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/bit_sequence_cursor.hpp b/include/pisa/v1/bit_sequence_cursor.hpp new file mode 100644 index 000000000..d2a52efac --- /dev/null +++ b/include/pisa/v1/bit_sequence_cursor.hpp @@ -0,0 +1,247 @@ +#pragma once + +#include +#include +#include + +#include + +#include "codec/block_codecs.hpp" +#include "codec/integer_codes.hpp" +#include "global_parameters.hpp" +#include "util/compiler_attribute.hpp" +#include "v1/base_index.hpp" +#include "v1/bit_cast.hpp" +#include "v1/bit_vector.hpp" +#include "v1/cursor_traits.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/types.hpp" + +namespace pisa::v1 { + +template +struct BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + + BitSequenceCursor(std::shared_ptr bits, sequence_enumerator_type sequence_enumerator) + : m_sequence_enumerator(std::move(sequence_enumerator)), m_bits(std::move(bits)) + { + reset(); + } + + void reset() + { + m_position = 0; + m_current_value = m_sequence_enumerator.move(0).second; + } + + [[nodiscard]] constexpr auto operator*() const -> value_type + { + if (PISA_UNLIKELY(empty())) { + return sentinel(); + } + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } + + constexpr void advance() + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.next(); + } + + PISA_FLATTEN_FUNC void advance_to_position(std::size_t position) + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.move(position); + } + + constexpr void advance_to_geq(value_type value) + { + std::tie(m_position, m_current_value) = m_sequence_enumerator.next_geq(value); + } + + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_position == size(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } + [[nodiscard]] constexpr auto size() const -> std::size_t + { + return m_sequence_enumerator.size(); + } + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return m_sequence_enumerator.universe(); + } + + private: + std::uint64_t m_position = 0; + std::uint64_t m_current_value{}; + sequence_enumerator_type m_sequence_enumerator; + std::shared_ptr m_bits; +}; + +template +struct DocumentBitSequenceCursor : public BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + explicit DocumentBitSequenceCursor(std::shared_ptr bits, + sequence_enumerator_type sequence_enumerator) + : BitSequenceCursor(std::move(bits), std::move(sequence_enumerator)) + { + } +}; + +template +struct PayloadBitSequenceCursor : public BitSequenceCursor { + using value_type = std::uint32_t; + using sequence_enumerator_type = typename BitSequence::enumerator; + explicit PayloadBitSequenceCursor(std::shared_ptr bits, + sequence_enumerator_type sequence_enumerator) + : BitSequenceCursor(std::move(bits), std::move(sequence_enumerator)) + { + } +}; + +template +struct BitSequenceReader { + using value_type = std::uint32_t; + + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor + { + runtime_assert(bytes.size() % sizeof(BitVector::storage_type) == 0, [&]() { + return fmt::format( + "Attempted to read no. bytes ({}) not aligned with the storage type of size {}", + bytes.size(), + sizeof(typename BitVector::storage_type)); + }); + + auto true_bit_length = bit_cast( + bytes.first(sizeof(typename BitVector::storage_type))); + bytes = bytes.subspan(sizeof(typename BitVector::storage_type)); + auto bits = std::make_shared( + gsl::span( + reinterpret_cast(bytes.data()), + bytes.size() / sizeof(BitVector::storage_type)), + true_bit_length); + BitVector::enumerator enumerator(*bits, 0); + std::uint64_t universe = read_gamma_nonzero(enumerator); + std::uint64_t n = 1; + if (universe > 1) { + n = enumerator.take(ceil_log2(universe + 1)); + } + return Cursor(bits, + typename BitSequence::enumerator( + *bits, enumerator.position(), universe + 1, n, global_parameters())); + } + + void init([[maybe_unused]] BaseIndex const& index) {} + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::BitSequence | encoding_traits::encoding_tag::encoding(); + } +}; + +template +struct DocumentBitSequenceReader + : public BitSequenceReader> { +}; +template +struct PayloadBitSequenceReader + : public BitSequenceReader> { +}; + +template +struct BitSequenceWriter { + using value_type = std::uint32_t; + BitSequenceWriter() = default; + explicit BitSequenceWriter(std::size_t num_documents) : m_num_documents(num_documents) {} + BitSequenceWriter(BitSequenceWriter const&) = default; + BitSequenceWriter(BitSequenceWriter&&) noexcept = default; + BitSequenceWriter& operator=(BitSequenceWriter const&) = default; + BitSequenceWriter& operator=(BitSequenceWriter&&) noexcept = default; + ~BitSequenceWriter() = default; + + constexpr static auto encoding() -> std::uint32_t + { + return EncodingId::BitSequence | encoding_traits::encoding_tag::encoding(); + } + + void init(pisa::binary_freq_collection const& collection) + { + m_num_documents = collection.num_docs(); + } + void push(value_type const& posting) + { + m_sum += posting; + m_postings.push_back(posting); + } + void push(value_type&& posting) + { + m_sum += posting; + m_postings.push_back(posting); + } + + template + [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t + { + runtime_assert(m_num_documents.has_value(), + "Uninitialized writer. Must call `init()` before writing."); + runtime_assert(!m_postings.empty(), "Tried to write an empty posting list"); + bit_vector_builder builder; + auto universe = [&]() { + if constexpr (DocumentWriter) { + return m_num_documents.value() - 1; + } else { + return m_sum; + } + }(); + write_gamma_nonzero(builder, universe); + if (universe > 1) { + builder.append_bits(m_postings.size(), ceil_log2(universe + 1)); + } + BitSequence::write( + builder, m_postings.begin(), universe + 1, m_postings.size(), global_parameters()); + typename BitVector::storage_type true_bit_length = builder.size(); + auto data = builder.move_bits(); + os.write(reinterpret_cast(&true_bit_length), sizeof(true_bit_length)); + auto memory = gsl::as_bytes(gsl::make_span(data.data(), data.size())); + os.write(reinterpret_cast(memory.data()), memory.size()); + auto bytes_written = sizeof(true_bit_length) + memory.size(); + runtime_assert(bytes_written % sizeof(typename BitVector::storage_type) == 0, [&]() { + return fmt::format( + "Bytes written ({}) are not aligned with the storage type of size {}", + bytes_written, + sizeof(typename BitVector::storage_type)); + }); + return bytes_written; + } + + void reset() + { + m_postings.clear(); + m_sum = 0; + } + + private: + std::vector m_postings{}; + value_type m_sum = 0; + tl::optional m_num_documents{}; +}; + +template +using DocumentBitSequenceWriter = BitSequenceWriter; + +template +using PayloadBitSequenceWriter = BitSequenceWriter; + +template +struct CursorTraits> { + using Writer = DocumentBitSequenceWriter; + using Reader = DocumentBitSequenceReader; +}; + +template +struct CursorTraits> { + using Writer = PayloadBitSequenceWriter; + using Reader = PayloadBitSequenceReader; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/bit_vector.hpp b/include/pisa/v1/bit_vector.hpp new file mode 100644 index 000000000..d7ed29922 --- /dev/null +++ b/include/pisa/v1/bit_vector.hpp @@ -0,0 +1,772 @@ +#pragma once + +#include +#include + +#include "util/broadword.hpp" +#include "util/util.hpp" + +namespace pisa::v1 { + +namespace detail { + inline size_t words_for(uint64_t n) { return ceil_div(n, 64); } +} // namespace detail + +class BitVectorBuilder { + public: + using bits_type = std::vector; + + BitVectorBuilder(uint64_t size = 0, bool init = 0) : m_size(size) + { + m_bits.resize(detail::words_for(size), uint64_t(-init)); + if (size) { + m_cur_word = &m_bits.back(); + // clear padding bits + if (init && size % 64) { + *m_cur_word >>= 64 - (size % 64); + } + } + } + BitVectorBuilder(const BitVectorBuilder&) = delete; + BitVectorBuilder& operator=(const BitVectorBuilder&) = delete; + + void reserve(uint64_t size) { m_bits.reserve(detail::words_for(size)); } + + inline void push_back(bool b) + { + uint64_t pos_in_word = m_size % 64; + if (pos_in_word == 0) { + m_bits.push_back(0); + m_cur_word = &m_bits.back(); + } + *m_cur_word |= (uint64_t)b << pos_in_word; + ++m_size; + } + + inline void set(uint64_t pos, bool b) + { + uint64_t word = pos / 64; + uint64_t pos_in_word = pos % 64; + + m_bits[word] &= ~(uint64_t(1) << pos_in_word); + m_bits[word] |= uint64_t(b) << pos_in_word; + } + + inline void set_bits(uint64_t pos, uint64_t bits, size_t len) + { + assert(pos + len <= size()); + // check there are no spurious bits + assert(len == 64 || (bits >> len) == 0); + if (!len) + return; + uint64_t mask = (len == 64) ? uint64_t(-1) : ((uint64_t(1) << len) - 1); + uint64_t word = pos / 64; + uint64_t pos_in_word = pos % 64; + + m_bits[word] &= ~(mask << pos_in_word); + m_bits[word] |= bits << pos_in_word; + + uint64_t stored = 64 - pos_in_word; + if (stored < len) { + m_bits[word + 1] &= ~(mask >> stored); + m_bits[word + 1] |= bits >> stored; + } + } + + inline void append_bits(uint64_t bits, size_t len) + { + // check there are no spurious bits + assert(len == 64 || (bits >> len) == 0); + if (!len) + return; + uint64_t pos_in_word = m_size % 64; + m_size += len; + if (pos_in_word == 0) { + m_bits.push_back(bits); + } else { + *m_cur_word |= bits << pos_in_word; + if (len > 64 - pos_in_word) { + m_bits.push_back(bits >> (64 - pos_in_word)); + } + } + m_cur_word = &m_bits.back(); + } + + inline void zero_extend(uint64_t n) + { + m_size += n; + uint64_t needed = detail::words_for(m_size) - m_bits.size(); + if (needed) { + m_bits.insert(m_bits.end(), needed, 0); + m_cur_word = &m_bits.back(); + } + } + + inline void one_extend(uint64_t n) + { + while (n >= 64) { + append_bits(uint64_t(-1), 64); + n -= 64; + } + if (n) { + append_bits(uint64_t(-1) >> (64 - n), n); + } + } + + void append(BitVectorBuilder const& rhs) + { + if (!rhs.size()) + return; + + uint64_t pos = m_bits.size(); + uint64_t shift = size() % 64; + m_size = size() + rhs.size(); + m_bits.resize(detail::words_for(m_size)); + + if (shift == 0) { // word-aligned, easy case + std::copy(rhs.m_bits.begin(), rhs.m_bits.end(), m_bits.begin() + ptrdiff_t(pos)); + } else { + uint64_t* cur_word = &m_bits.front() + pos - 1; + for (size_t i = 0; i < rhs.m_bits.size() - 1; ++i) { + uint64_t w = rhs.m_bits[i]; + *cur_word |= w << shift; + *++cur_word = w >> (64 - shift); + } + *cur_word |= rhs.m_bits.back() << shift; + if (cur_word < &m_bits.back()) { + *++cur_word = rhs.m_bits.back() >> (64 - shift); + } + } + m_cur_word = &m_bits.back(); + } + + // reverse in place + void reverse() + { + uint64_t shift = 64 - (size() % 64); + + uint64_t remainder = 0; + for (size_t i = 0; i < m_bits.size(); ++i) { + uint64_t cur_word; + if (shift != 64) { // this should be hoisted out + cur_word = remainder | (m_bits[i] << shift); + remainder = m_bits[i] >> (64 - shift); + } else { + cur_word = m_bits[i]; + } + m_bits[i] = broadword::reverse_bits(cur_word); + } + assert(remainder == 0); + std::reverse(m_bits.begin(), m_bits.end()); + } + + bits_type& move_bits() + { + assert(detail::words_for(m_size) == m_bits.size()); + return m_bits; + } + + uint64_t size() const { return m_size; } + + void swap(BitVectorBuilder& other) + { + m_bits.swap(other.m_bits); + std::swap(m_size, other.m_size); + std::swap(m_cur_word, other.m_cur_word); + } + + private: + bits_type m_bits; + uint64_t m_size; + uint64_t* m_cur_word; +}; + +class BitVector { + public: + using storage_type = std::uint64_t; + BitVector() = default; + BitVector(gsl::span bits, std::size_t size) : m_bits(bits), m_size(size) {} + BitVector(BitVector const&) = default; + BitVector(BitVector&&) noexcept = default; + BitVector& operator=(BitVector const&) = default; + BitVector& operator=(BitVector&&) noexcept = default; + ~BitVector() = default; + + // template + // void map(Visitor& visit) + //{ + // visit(m_size, "m_size")(m_bits, "m_bits"); + //} + + // void swap(BitVector &other) + //{ + // std::swap(other.m_size, m_size); + // other.m_bits.swap(m_bits); + //} + + inline size_t size() const { return m_size; } + + inline bool operator[](uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + assert(block < m_bits.size()); + uint64_t shift = pos % 64; + return (m_bits[block] >> shift) & 1; + } + + inline uint64_t get_bits(uint64_t pos, uint64_t len) const + { + assert(pos + len <= size()); + assert(len <= 64); + if (len == 0U) { + return 0; + } + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t mask = std::numeric_limits::max() + >> (std::numeric_limits::digits - len); + if (shift + len <= 64) { + return m_bits[block] >> shift & mask; + } else { + return (m_bits[block] >> shift) | (m_bits[block + 1] << (64 - shift) & mask); + } + } + + // same as get_bits(pos, 64) but it can extend further size(), padding with zeros + inline uint64_t get_word(uint64_t pos) const + { + assert(pos < size()); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = m_bits[block] >> shift; + if (shift && block + 1 < m_bits.size()) { + word |= m_bits[block + 1] << (64 - shift); + } + return word; + } + + // unsafe and fast version of get_word, it retrieves at least 56 bits + inline uint64_t get_word56(uint64_t pos) const + { + // XXX check endianness? + const char* ptr = reinterpret_cast(m_bits.data()); + return *(reinterpret_cast(ptr + pos / 8)) >> (pos % 8); + } + + inline uint64_t predecessor0(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = 64 - pos % 64 - 1; + uint64_t word = ~m_bits[block]; + word = (word << shift) >> shift; + + unsigned long ret; + while (!broadword::msb(word, ret)) { + assert(block); + word = ~m_bits[--block]; + }; + return block * 64 + ret; + } + + inline uint64_t successor0(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = (~m_bits[block] >> shift) << shift; + + unsigned long ret; + while (!broadword::lsb(word, ret)) { + ++block; + assert(block < m_bits.size()); + word = ~m_bits[block]; + }; + return block * 64 + ret; + } + + inline uint64_t predecessor1(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = 64 - pos % 64 - 1; + uint64_t word = m_bits[block]; + word = (word << shift) >> shift; + + unsigned long ret; + while (!broadword::msb(word, ret)) { + assert(block); + word = m_bits[--block]; + }; + return block * 64 + ret; + } + + inline uint64_t successor1(uint64_t pos) const + { + assert(pos < m_size); + uint64_t block = pos / 64; + uint64_t shift = pos % 64; + uint64_t word = (m_bits[block] >> shift) << shift; + + unsigned long ret; + while (!broadword::lsb(word, ret)) { + ++block; + assert(block < m_bits.size()); + word = m_bits[block]; + }; + return block * 64 + ret; + } + + std::uint64_t const* data() const { return m_bits.data(); } + + struct enumerator { + enumerator() : m_bv(0), m_pos(uint64_t(-1)) {} + + enumerator(BitVector const& bv, size_t pos) : m_bv(&bv), m_pos(pos), m_buf(0), m_avail(0) + { + // m_bv->data().prefetch(m_pos / 64); + } + + inline bool next() + { + if (!m_avail) + fill_buf(); + bool b = m_buf & 1; + m_buf >>= 1; + m_avail -= 1; + m_pos += 1; + return b; + } + + inline uint64_t take(size_t l) + { + if (m_avail < l) + fill_buf(); + uint64_t val; + if (l != 64) { + val = m_buf & ((uint64_t(1) << l) - 1); + m_buf >>= l; + } else { + val = m_buf; + } + m_avail -= l; + m_pos += l; + return val; + } + + inline uint64_t skip_zeros() + { + uint64_t zs = 0; + // XXX the loop may be optimized by aligning access + while (!m_buf) { + m_pos += m_avail; + zs += m_avail; + m_avail = 0; + fill_buf(); + } + + uint64_t l = broadword::lsb(m_buf); + m_buf >>= l; + m_buf >>= 1; + m_avail -= l + 1; + m_pos += l + 1; + return zs + l; + } + + inline uint64_t position() const { return m_pos; } + + private: + inline void fill_buf() + { + m_buf = m_bv->get_word(m_pos); + m_avail = 64; + } + + BitVector const* m_bv; + size_t m_pos; + uint64_t m_buf; + size_t m_avail; + }; + + struct unary_enumerator { + unary_enumerator() : m_data(0), m_position(0), m_buf(0) {} + + unary_enumerator(BitVector const& bv, uint64_t pos) + { + m_data = bv.data(); + m_position = pos; + m_buf = m_data[pos / 64]; + // clear low bits + m_buf &= uint64_t(-1) << (pos % 64); + } + + uint64_t position() const { return m_position; } + + uint64_t next() + { + unsigned long pos_in_word; + uint64_t buf = m_buf; + while (!broadword::lsb(buf, pos_in_word)) { + m_position += 64; + buf = m_data[m_position / 64]; + } + + m_buf = buf & (buf - 1); // clear LSB + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + return m_position; + } + + // skip to the k-th one after the current position + void skip(uint64_t k) + { + uint64_t skipped = 0; + uint64_t buf = m_buf; + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + m_position += 64; + buf = m_data[m_position / 64]; + } + assert(buf); + uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); + m_buf = buf & (uint64_t(-1) << pos_in_word); + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + } + + // return the position of the k-th one after the current position. + uint64_t skip_no_move(uint64_t k) + { + uint64_t position = m_position; + uint64_t skipped = 0; + uint64_t buf = m_buf; + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + position += 64; + buf = m_data[position / 64]; + } + assert(buf); + uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); + position = (position & ~uint64_t(63)) + pos_in_word; + return position; + } + + // skip to the k-th zero after the current position + void skip0(uint64_t k) + { + uint64_t skipped = 0; + uint64_t pos_in_word = m_position % 64; + uint64_t buf = ~m_buf & (uint64_t(-1) << pos_in_word); + uint64_t w = 0; + while (skipped + (w = broadword::popcount(buf)) <= k) { + skipped += w; + m_position += 64; + buf = ~m_data[m_position / 64]; + } + assert(buf); + pos_in_word = broadword::select_in_word(buf, k - skipped); + m_buf = ~buf & (uint64_t(-1) << pos_in_word); + m_position = (m_position & ~uint64_t(63)) + pos_in_word; + } + + private: + uint64_t const* m_data; + uint64_t m_position; + uint64_t m_buf; + }; + + protected: + gsl::span m_bits{}; + std::size_t m_size = 0; +}; + +// struct BitVector { +// using storage_type = std::uint64_t; +// +// BitVector() = default; +// BitVector(gsl::span bits, std::size_t size) : m_bits(bits), m_size(size) +// {} BitVector(BitVector const&) = default; BitVector(BitVector&&) noexcept = default; +// BitVector& operator=(BitVector const&) = default; +// BitVector& operator=(BitVector&&) noexcept = default; +// ~BitVector() = default; +// +// [[nodiscard]] inline auto size() const -> std::size_t { return m_size; } +// +// [[nodiscard]] inline auto operator[](std::size_t pos) const -> bool +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// Expects(block < m_bits.size()); +// std::uint64_t shift = pos % 64; +// return ((m_bits[block] >> shift) & 1U) != 0U; +// } +// +// [[nodiscard]] inline auto get_bits(uint64_t pos, uint64_t len) const -> std::uint64_t +// { +// Expects(pos + len <= size()); +// Expects(len <= 64); +// if (len == 0U) { +// return 0; +// } +// uint64_t block = pos / 64; +// uint64_t shift = pos % 64; +// uint64_t mask = std::numeric_limits::max() +// >> (std::numeric_limits::digits - len); +// if (shift + len <= 64) { +// return m_bits[block] >> shift & mask; +// } +// return (m_bits[block] >> shift) | (m_bits[block + 1] << (64 - shift) & mask); +// } +// +// // same as get_bits(pos, 64) but it can extend further size(), padding with zeros +// [[nodiscard]] inline auto get_word(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < size()); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = m_bits[block] >> shift; +// if (shift > 0U && block + 1U < m_bits.size()) { +// word |= m_bits[block + 1] << (64 - shift); +// } +// return word; +// } +// +// // unsafe and fast version of get_word, it retrieves at least 56 bits +// [[nodiscard]] inline auto get_word56(std::uint64_t pos) const -> std::uint64_t +// { +// // XXX check endianness? +// const char* ptr = reinterpret_cast(m_bits.data()); +// return *(reinterpret_cast(ptr + pos / 8)) >> (pos % 8); +// } +// +// [[nodiscard]] inline auto predecessor0(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = 64 - pos % 64 - 1; +// std::uint64_t word = ~m_bits[block]; +// word = (word << shift) >> shift; +// +// std::uint64_t ret; +// while (broadword::msb(word, ret) == 0U) { +// Expects(block); +// word = ~m_bits[--block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto successor0(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = (~m_bits[block] >> shift) << shift; +// +// std::uint64_t ret; +// while (broadword::lsb(word, ret) == 0U) { +// ++block; +// Expects(block < m_bits.size()); +// word = ~m_bits[block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto predecessor1(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = 64 - pos % 64 - 1; +// std::uint64_t word = m_bits[block]; +// word = (word << shift) >> shift; +// +// std::uint64_t ret; +// while (broadword::msb(word, ret) == 0) { +// Expects(block); +// word = m_bits[--block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto successor1(std::uint64_t pos) const -> std::uint64_t +// { +// Expects(pos < m_size); +// std::uint64_t block = pos / 64; +// std::uint64_t shift = pos % 64; +// std::uint64_t word = (m_bits[block] >> shift) << shift; +// +// std::uint64_t ret; +// while (broadword::lsb(word, ret) == 0U) { +// ++block; +// Expects(block < m_bits.size()); +// word = m_bits[block]; +// }; +// return block * 64 + ret; +// } +// +// [[nodiscard]] inline auto data() const -> std::uint64_t const* { return m_bits.data(); } +// +// struct enumerator { +// enumerator() : m_pos(uint64_t(-1)) {} +// +// enumerator(BitVector const& bv, size_t pos) : m_bv(&bv), m_pos(pos) +// { +// intrinsics::prefetch(std::next(m_bv->data(), m_pos / 64)); +// fill_buf(); +// } +// +// inline auto next() -> bool +// { +// if (m_avail == 0) { +// fill_buf(); +// } +// bool b = (m_buf & 1U) > 0; +// m_buf >>= 1U; +// m_avail -= 1; +// m_pos += 1; +// return b; +// } +// +// inline auto take(size_t l) -> std::uint64_t +// { +// if (m_avail < l) { +// fill_buf(); +// } +// std::uint64_t val; +// if (l != 64) { +// val = m_buf & ((std::uint64_t(1) << l) - 1); +// m_buf >>= l; +// } else { +// val = m_buf; +// } +// m_avail -= l; +// m_pos += l; +// return val; +// } +// +// inline uint64_t skip_zeros() +// { +// uint64_t zs = 0; +// // XXX the loop may be optimized by aligning access +// while (m_buf == 0) { +// m_pos += m_avail; +// zs += m_avail; +// m_avail = 0; +// fill_buf(); +// } +// +// uint64_t l = broadword::lsb(m_buf); +// m_buf >>= l; +// m_buf >>= 1U; +// m_avail -= l + 1; +// m_pos += l + 1; +// return zs + l; +// } +// +// [[nodiscard]] inline auto position() const -> std::uint64_t { return m_pos; } +// +// private: +// inline void fill_buf() +// { +// m_buf = m_bv->get_word(m_pos); +// m_avail = 64; +// } +// +// BitVector const* m_bv = nullptr; +// std::size_t m_pos; +// std::uint64_t m_buf = 0; +// std::size_t m_avail = 0; +// }; +// +// struct unary_enumerator { +// unary_enumerator() = default; +// unary_enumerator(BitVector const& bv, uint64_t pos) +// { +// m_data = bv.data(); +// m_position = pos; +// m_buf = m_data[pos / 64]; +// // clear low bits +// m_buf &= uint64_t(-1) << (pos % 64); +// } +// +// [[nodiscard]] auto position() const -> std::uint64_t { return m_position; } +// +// uint64_t next() +// { +// std::uint64_t pos_in_word; +// std::uint64_t buf = m_buf; +// while (broadword::lsb(buf, pos_in_word) == 0) { +// m_position += 64; +// buf = m_data[m_position / 64]; +// } +// +// m_buf = buf & (buf - 1); // clear LSB +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// return m_position; +// } +// +// // skip to the k-th one after the current position +// void skip(uint64_t k) +// { +// uint64_t skipped = 0; +// uint64_t buf = m_buf; +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// m_position += 64; +// buf = m_data[m_position / 64]; +// } +// Expects(buf); +// uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); +// m_buf = buf & (uint64_t(-1) << pos_in_word); +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// } +// +// // return the position of the k-th one after the current position. +// uint64_t skip_no_move(uint64_t k) +// { +// uint64_t position = m_position; +// uint64_t skipped = 0; +// uint64_t buf = m_buf; +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// position += 64; +// buf = m_data[position / 64]; +// } +// Expects(buf); +// uint64_t pos_in_word = broadword::select_in_word(buf, k - skipped); +// position = (position & ~uint64_t(63)) + pos_in_word; +// return position; +// } +// +// // skip to the k-th zero after the current position +// void skip0(uint64_t k) +// { +// uint64_t skipped = 0; +// uint64_t pos_in_word = m_position % 64; +// uint64_t buf = ~m_buf & (uint64_t(-1) << pos_in_word); +// uint64_t w = 0; +// while (skipped + (w = broadword::popcount(buf)) <= k) { +// skipped += w; +// m_position += 64; +// buf = ~m_data[m_position / 64]; +// } +// Expects(buf); +// pos_in_word = broadword::select_in_word(buf, k - skipped); +// m_buf = ~buf & (uint64_t(-1) << pos_in_word); +// m_position = (m_position & ~uint64_t(63)) + pos_in_word; +// } +// +// private: +// std::uint64_t const* m_data; +// std::uint64_t m_position; +// std::uint64_t m_buf; +// }; +// +// private: +// gsl::span m_bits{}; +// std::size_t m_size = 0; +//}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index db2e3e681..5d9d90d23 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -12,6 +12,7 @@ #include #include "util/likely.hpp" +#include "v1/base_index.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor_traits.hpp" #include "v1/encoding_traits.hpp" @@ -253,6 +254,7 @@ template struct GenericBlockedReader { using value_type = std::uint32_t; + void init(BaseIndex const& index) {} [[nodiscard]] auto read(gsl::span bytes) const { std::uint32_t length; @@ -294,12 +296,16 @@ template struct GenericBlockedWriter { using value_type = std::uint32_t; + GenericBlockedWriter() = default; + explicit GenericBlockedWriter([[maybe_unused]] std::size_t num_documents) {} + constexpr static auto encoding() -> std::uint32_t { return block_encoding_type() | encoding_traits::encoding_tag::encoding(); } + void init([[maybe_unused]] pisa::binary_freq_collection const& collection) {} void push(value_type const& posting) { if constexpr (DeltaEncoded) { diff --git a/include/pisa/v1/cursor/compact_elias_fano.hpp b/include/pisa/v1/cursor/compact_elias_fano.hpp new file mode 100644 index 000000000..22e050cc2 --- /dev/null +++ b/include/pisa/v1/cursor/compact_elias_fano.hpp @@ -0,0 +1,67 @@ +#include +#include + +#include "util/likely.hpp" + +namespace pisa::v1 { + +struct CompactEliasFanoCursor { + using value_type = std::uint32_t; + + /// Dereferences the current value. + //[[nodiscard]] auto operator*() const -> value_type + //{ + // if (PISA_UNLIKELY(empty())) { + // return sentinel(); + // } + // return bit_cast(gsl::as_bytes(m_bytes.subspan(m_current, sizeof(T)))); + //} + + ///// Alias for `operator*()`. + //[[nodiscard]] auto value() const noexcept -> value_type { return *(*this); } + + ///// Advances the cursor to the next position. + // constexpr void advance() { m_current += sizeof(T); } + + ///// Moves the cursor to the position `pos`. + // constexpr void advance_to_position(std::size_t pos) { m_current = pos * sizeof(T); } + + ///// Moves the cursor to the next value equal or greater than `value`. + // constexpr void advance_to_geq(T value) + //{ + // while (this->value() < value) { + // advance(); + // } + //} + + ///// Returns `true` if there is no elements left. + //[[nodiscard]] constexpr auto empty() const noexcept -> bool + //{ + // return m_current == m_bytes.size(); + //} + + ///// Returns the current position. + //[[nodiscard]] constexpr auto position() const noexcept -> std::size_t + //{ + // return m_current / sizeof(T); + //} + + ///// Returns the number of elements in the list. + //[[nodiscard]] constexpr auto size() const -> std::size_t { return m_bytes.size() / sizeof(T); + //} + + [[nodiscard]] constexpr auto sentinel() const -> value_type + { + return std::numeric_limits::max(); + } + + private: + // BitVector const* m_bv; + // offsets m_of; + + std::size_t m_position; + value_type m_value; + // BitVector::unary_enumerator m_high_enumerator; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor_traits.hpp b/include/pisa/v1/cursor_traits.hpp index 3e32c69f2..186962256 100644 --- a/include/pisa/v1/cursor_traits.hpp +++ b/include/pisa/v1/cursor_traits.hpp @@ -7,4 +7,7 @@ namespace pisa::v1 { template struct CursorTraits; +template +struct EncodingTraits; + } // namespace pisa::v1 diff --git a/include/pisa/v1/default_index_runner.hpp b/include/pisa/v1/default_index_runner.hpp index 6fba14cd3..cec99a48a 100644 --- a/include/pisa/v1/default_index_runner.hpp +++ b/include/pisa/v1/default_index_runner.hpp @@ -1,27 +1,36 @@ #pragma once #include "index_types.hpp" +#include "v1/bit_sequence_cursor.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index.hpp" #include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" namespace pisa::v1 { [[nodiscard]] inline auto index_runner(IndexMetadata metadata) { - return index_runner( - std::move(metadata), - std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), - std::make_tuple(RawReader{}, PayloadBlockedReader<::pisa::simdbp_block>{})); + return index_runner(std::move(metadata), + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}, + DocumentBitSequenceReader{}, + DocumentBitSequenceReader>{}), + std::make_tuple(RawReader{}, + PayloadBlockedReader<::pisa::simdbp_block>{}, + PayloadBitSequenceReader>{})); } [[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata) { - return scored_index_runner( - std::move(metadata), - std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), - std::make_tuple(RawReader{})); + return scored_index_runner(std::move(metadata), + std::make_tuple(RawReader{}, + DocumentBlockedReader<::pisa::simdbp_block>{}, + DocumentBitSequenceReader{}, + DocumentBitSequenceReader>{}), + std::make_tuple(RawReader{})); } } // namespace pisa::v1 diff --git a/include/pisa/v1/encoding_traits.hpp b/include/pisa/v1/encoding_traits.hpp index 6ec372118..983fcf2d2 100644 --- a/include/pisa/v1/encoding_traits.hpp +++ b/include/pisa/v1/encoding_traits.hpp @@ -1,6 +1,9 @@ #pragma once #include "codec/simdbp.hpp" +#include "v1/sequence/indexed_sequence.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -14,4 +17,31 @@ struct encoding_traits<::pisa::simdbp_block> { using encoding_tag = SimdBPTag; }; +struct PartitionedSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::PEF; } +}; + +template <> +struct encoding_traits> { + using encoding_tag = PartitionedSequenceTag; +}; + +struct IndexedSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return 17U; } +}; + +template <> +struct encoding_traits { + using encoding_tag = IndexedSequenceTag; +}; + +struct PositiveSequenceTag { + [[nodiscard]] static auto encoding() -> std::uint32_t { return EncodingId::PositiveSeq; } +}; + +template <> +struct encoding_traits> { + using encoding_tag = PositiveSequenceTag; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index e7c46c05d..88b5c3c82 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -7,6 +7,7 @@ #include #include +#include "v1/base_index.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor/scoring_cursor.hpp" @@ -19,101 +20,6 @@ namespace pisa::v1 { -using OffsetSpan = gsl::span; -using BinarySpan = gsl::span; - -[[nodiscard]] auto calc_avg_length(gsl::span const& lengths) -> float; -[[nodiscard]] auto read_sizes(std::string_view basename) -> std::vector; - -/// Lexicographically compares bigrams. -/// Used for looking up bigram mappings. -[[nodiscard]] auto compare_arrays(std::array const& lhs, - std::array const& rhs) -> bool; - -struct PostingData { - BinarySpan postings; - OffsetSpan offsets; -}; - -struct UnigramData { - PostingData documents; - PostingData payloads; -}; - -struct BigramData { - PostingData documents; - std::array payloads; - gsl::span const> mapping; -}; - -/// Parts of the index independent of the template parameters. -struct BaseIndex { - - template - BaseIndex(PostingData documents, - PostingData payloads, - tl::optional bigrams, - gsl::span document_lengths, - tl::optional avg_document_length, - std::unordered_map> max_scores, - std::unordered_map block_max_scores, - gsl::span quantized_max_scores, - Source source) - : m_documents(documents), - m_payloads(payloads), - m_bigrams(bigrams), - m_document_lengths(document_lengths), - m_avg_document_length(avg_document_length.map_or_else( - [](auto&& self) { return self; }, - [&]() { return calc_avg_length(m_document_lengths); })), - m_max_scores(std::move(max_scores)), - m_block_max_scores(std::move(block_max_scores)), - m_quantized_max_scores(quantized_max_scores), - m_source(std::move(source)) - { - } - - [[nodiscard]] auto num_terms() const -> std::size_t; - [[nodiscard]] auto num_documents() const -> std::size_t; - [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t; - [[nodiscard]] auto avg_document_length() const -> float; - [[nodiscard]] auto normalized_document_length(DocId docid) const -> float; - [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional; - - protected: - void assert_term_in_bounds(TermId term) const; - [[nodiscard]] auto fetch_documents(TermId term) const -> gsl::span; - [[nodiscard]] auto fetch_payloads(TermId term) const -> gsl::span; - [[nodiscard]] auto fetch_bigram_documents(TermId bigram) const -> gsl::span; - [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const - -> std::array, 2>; - template - [[nodiscard]] auto fetch_bigram_payloads(TermId bigram) const -> gsl::span; - - [[nodiscard]] auto max_score(std::size_t scorer_hash, TermId term) const -> float; - [[nodiscard]] auto block_max_scores(std::size_t scorer_hash) const -> UnigramData const&; - [[nodiscard]] auto quantized_max_score(TermId term) const -> std::uint8_t; - [[nodiscard]] auto block_max_document_reader() const -> Reader> const&; - [[nodiscard]] auto block_max_score_reader() const -> Reader> const&; - - private: - PostingData m_documents; - PostingData m_payloads; - tl::optional m_bigrams; - - gsl::span m_document_lengths; - float m_avg_document_length; - std::unordered_map> m_max_scores; - std::unordered_map m_block_max_scores; - gsl::span m_quantized_max_scores; - std::any m_source; - - Reader> m_block_max_document_reader = - Reader>(RawReader{}); - Reader> m_block_max_score_reader = - Reader>(RawReader{}); -}; - /// A generic type for an inverted index. /// /// \tparam DocumentReader Type of an object that reads document posting lists from bytes @@ -361,9 +267,24 @@ struct Index : public BaseIndex { return documents(term).size(); } + [[nodiscard]] auto block_max_document_reader() const -> Reader> const& + { + return m_block_max_document_reader; + } + + [[nodiscard]] auto block_max_score_reader() const -> Reader> const& + { + return m_block_max_score_reader; + } + private: Reader m_document_reader; Reader m_payload_reader; + + Reader> m_block_max_document_reader = + Reader>(RawReader{}); + Reader> m_block_max_score_reader = + Reader>(RawReader{}); }; template @@ -483,7 +404,26 @@ struct IndexRunner { }, m_document_readers); if (not result) { - throw std::domain_error("Unknown posting encoding"); + std::ostringstream os; + os << fmt::format( + "Unknown posting encoding. Requested document: " + "{:x} ({:b}), payload: {:x} ({:b})\n", + dheader.encoding, + static_cast(to_byte(dheader.type)), + pheader.encoding, + static_cast(to_byte(pheader.type))); + auto print_reader = [&](auto&& reader) { + os << fmt::format( + "\t{:x} ({:b})\n", + reader.encoding(), + static_cast(to_byte( + value_type::value_type>()))); + }; + os << "Available document readers: \n"; + std::apply([&](auto... readers) { (print_reader(readers), ...); }, m_document_readers); + os << "Available payload readers: \n"; + std::apply([&](auto... readers) { (print_reader(readers), ...); }, m_payload_readers); + throw std::domain_error(os.str()); } } diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 153040574..844a2aa06 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -16,9 +16,13 @@ namespace pisa::v1 { -template +template struct IndexBuilder { - explicit IndexBuilder(Writers... writers) : m_writers(std::move(writers)...) {} + explicit IndexBuilder(DocumentWriters document_writers, PayloadWriters payload_writers) + : m_document_writers(std::move(document_writers)), + m_payload_writers(std::move(payload_writers)) + { + } template void operator()(Encoding document_encoding, Encoding payload_encoding, Fn fn) @@ -32,28 +36,30 @@ struct IndexBuilder { return false; }; bool success = std::apply( - [&](Writers... dwriters) { + [&](auto... dwriters) { auto with_document_writer = [&](auto dwriter) { return std::apply( - [&](Writers... pwriters) { return (run(dwriter, pwriters) || ...); }, - m_writers); + [&](auto... pwriters) { return (run(dwriter, pwriters) || ...); }, + m_payload_writers); }; return (with_document_writer(dwriters) || ...); }, - m_writers); + m_document_writers); if (not success) { throw std::domain_error("Unknown posting encoding"); } } private: - std::tuple m_writers; + DocumentWriters m_document_writers; + PayloadWriters m_payload_writers; }; -template -auto make_index_builder(Writers... writers) +template +auto make_index_builder(DocumentWriters document_writers, PayloadWriters payload_writers) { - return IndexBuilder(std::move(writers)...); + return IndexBuilder(std::move(document_writers), + std::move(payload_writers)); } template @@ -101,6 +107,8 @@ inline void compress_binary_collection(std::string const& input, Writer frequency_writer) { pisa::binary_freq_collection const collection(input.c_str()); + document_writer.init(collection); + frequency_writer.init(collection); ProgressStatus status(collection.size(), DefaultProgress("Compressing in parallel"), std::chrono::milliseconds(100)); diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index acb08ce70..0f76d9ea6 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -166,7 +166,7 @@ template std::move(source), std::move(document_readers), std::move(payload_readers)); -} // namespace pisa::v1 +} template [[nodiscard]] inline auto scored_index_runner(IndexMetadata metadata, diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index 5b69b584f..2552b97e8 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -91,6 +91,10 @@ struct MaxScoreJoin { constexpr void advance() { + while (m_non_essential_count < m_cursors.size() + && not m_above_threshold(m_upper_bounds[m_non_essential_count])) { + m_non_essential_count += 1; + } bool exit = false; while (not exit) { if (PISA_UNLIKELY(m_non_essential_count == m_cursors.size() @@ -139,11 +143,6 @@ struct MaxScoreJoin { } } } - - while (m_non_essential_count < m_cursors.size() - && not m_above_threshold(m_upper_bounds[m_non_essential_count])) { - m_non_essential_count += 1; - } } [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index d18e18b1e..2c47d0640 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -12,6 +12,7 @@ #include #include "util/likely.hpp" +#include "v1/base_index.hpp" #include "v1/bit_cast.hpp" #include "v1/cursor_traits.hpp" #include "v1/types.hpp" @@ -116,6 +117,7 @@ struct RawReader { return RawCursor(bytes); } + void init(BaseIndex const& index) {} constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } }; @@ -124,8 +126,12 @@ struct RawWriter { static_assert(std::is_trivially_copyable::value); using value_type = T; + RawWriter() = default; + explicit RawWriter([[maybe_unused]] std::size_t num_documents) {} + constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } + void init([[maybe_unused]] pisa::binary_freq_collection const& collection) {} void push(T const& posting) { m_postings.push_back(posting); } void push(T&& posting) { m_postings.push_back(posting); } diff --git a/include/pisa/v1/runtime_assert.hpp b/include/pisa/v1/runtime_assert.hpp new file mode 100644 index 000000000..82635289a --- /dev/null +++ b/include/pisa/v1/runtime_assert.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace pisa::v1 { + +template +inline void runtime_assert(bool condition, Message&& message) +{ + if (not condition) { + if constexpr (std::is_invocable_r_v) { + throw std::runtime_error(message()); + } else { + throw std::runtime_error(std::forward(message)); + } + } +} + +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/indexed_sequence.hpp b/include/pisa/v1/sequence/indexed_sequence.hpp new file mode 100644 index 000000000..1f44fe7dc --- /dev/null +++ b/include/pisa/v1/sequence/indexed_sequence.hpp @@ -0,0 +1,939 @@ +#pragma once + +#include + +#include + +#include "bit_vector.hpp" +#include "global_parameters.hpp" +#include "util/compiler_attribute.hpp" +#include "util/likely.hpp" +#include "v1/bit_vector.hpp" + +namespace pisa::v1 { + +[[nodiscard]] constexpr auto positive(std::uint64_t n) -> std::uint64_t +{ + if (n == 0) { + throw std::logic_error("argument must be positive"); + } + return n; +} + +struct CompactEliasFano { + + struct offsets { + offsets() {} + + offsets(uint64_t base_offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : universe(universe), + n(positive(n)), + log_sampling0(params.ef_log_sampling0), + log_sampling1(params.ef_log_sampling1), + lower_bits(universe > n ? broadword::msb(universe / n) : 0), + mask((uint64_t(1) << lower_bits) - 1), + // pad with a zero on both sides as sentinels + higher_bits_length(n + (universe >> lower_bits) + 2), + pointer_size(ceil_log2(higher_bits_length)), + pointers0((higher_bits_length - n) >> log_sampling0), // XXX + pointers1(n >> log_sampling1), + pointers0_offset(base_offset), + pointers1_offset(pointers0_offset + pointers0 * pointer_size), + higher_bits_offset(pointers1_offset + pointers1 * pointer_size), + lower_bits_offset(higher_bits_offset + higher_bits_length), + end(lower_bits_offset + n * lower_bits) + { + } + + uint64_t universe; + uint64_t n; + uint64_t log_sampling0; + uint64_t log_sampling1; + + uint64_t lower_bits; + uint64_t mask; + uint64_t higher_bits_length; + uint64_t pointer_size; + uint64_t pointers0; + uint64_t pointers1; + + uint64_t pointers0_offset; + uint64_t pointers1_offset; + uint64_t higher_bits_offset; + uint64_t lower_bits_offset; + uint64_t end; + }; + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + return offsets(0, universe, n, params).end; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t base_offset = bvb.size(); + offsets of(base_offset, universe, n, params); + // initialize all the bits to 0 + bvb.zero_extend(of.end - base_offset); + + uint64_t sample1_mask = (uint64_t(1) << of.log_sampling1) - 1; + uint64_t offset; + + // utility function to set 0 pointers + auto set_ptr0s = [&](uint64_t begin, uint64_t end, uint64_t rank_end) { + uint64_t begin_zeros = begin - rank_end; + uint64_t end_zeros = end - rank_end; + + for (uint64_t ptr0 = ceil_div(begin_zeros, uint64_t(1) << of.log_sampling0); + (ptr0 << of.log_sampling0) < end_zeros; + ++ptr0) { + if (!ptr0) + continue; + offset = of.pointers0_offset + (ptr0 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.pointers1_offset); + bvb.set_bits(offset, (ptr0 << of.log_sampling0) + rank_end, of.pointer_size); + } + }; + + uint64_t last = 0; + uint64_t last_high = 0; + Iterator it = begin; + for (size_t i = 0; i < n; ++i) { + uint64_t v = *it++; + + if (i && v < last) { + throw std::runtime_error("Sequence is not sorted"); + } + + assert(v < universe); + + uint64_t high = (v >> of.lower_bits) + i + 1; + uint64_t low = v & of.mask; + + bvb.set(of.higher_bits_offset + high, 1); + + offset = of.lower_bits_offset + i * of.lower_bits; + assert(offset + of.lower_bits <= of.end); + bvb.set_bits(offset, low, of.lower_bits); + + if (i && (i & sample1_mask) == 0) { + uint64_t ptr1 = i >> of.log_sampling1; + assert(ptr1 > 0); + offset = of.pointers1_offset + (ptr1 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.higher_bits_offset); + bvb.set_bits(offset, high, of.pointer_size); + } + + // write pointers for the run of zeros in [last_high, high) + set_ptr0s(last_high + 1, high, i); + last_high = high; + last = v; + } + + // pointers to zeros after the last 1 + set_ptr0s(last_high + 1, of.higher_bits_length, n); // XXX + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_bv(&bv), + m_of(offset, universe, n, params), + m_position(size()), + m_value(m_of.universe) + { + } + + value_type move(uint64_t position) + { + assert(position <= m_of.n); + + if (position == m_position) { + return value(); + } + + uint64_t skip = position - m_position; + // optimize small forward skips + if (PISA_LIKELY(position > m_position && skip <= linear_scan_threshold)) { + m_position = position; + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + } else { + BitVector::unary_enumerator he = m_high_enumerator; + for (size_t i = 0; i < skip; ++i) { + he.next(); + } + m_value = ((he.position() - m_of.higher_bits_offset - m_position - 1) + << m_of.lower_bits) + | read_low(); + m_high_enumerator = he; + } + return value(); + } + + return slow_move(position); + } + + value_type next_geq(uint64_t lower_bound) + { + if (lower_bound == m_value) { + return value(); + } + + uint64_t high_lower_bound = lower_bound >> m_of.lower_bits; + uint64_t cur_high = m_value >> m_of.lower_bits; + uint64_t high_diff = high_lower_bound - cur_high; + + if (PISA_LIKELY(lower_bound > m_value && high_diff <= linear_scan_threshold)) { + // optimize small skips + next_reader next_value(*this, m_position + 1); + uint64_t val; + do { + m_position += 1; + if (PISA_LIKELY(m_position < size())) { + val = next_value(); + } else { + m_position = size(); + val = m_of.universe; + break; + } + } while (val < lower_bound); + + m_value = val; + return value(); + } else { + return slow_next_geq(lower_bound); + } + } + + uint64_t size() const { return m_of.n; } + + value_type next() + { + m_position += 1; + assert(m_position <= size()); + + if (PISA_LIKELY(m_position < size())) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + return value(); + } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + + uint64_t prev_high = 0; + if (PISA_LIKELY(m_position < size())) { + prev_high = m_bv->predecessor1(m_high_enumerator.position() - 1); + } else { + prev_high = m_bv->predecessor1(m_of.lower_bits_offset - 1); + } + prev_high -= m_of.higher_bits_offset; + + uint64_t prev_pos = m_position - 1; + uint64_t prev_low = + m_bv->get_word56(m_of.lower_bits_offset + prev_pos * m_of.lower_bits) & m_of.mask; + return ((prev_high - prev_pos - 1) << m_of.lower_bits) | prev_low; + } + + uint64_t position() const { return m_position; } + + inline value_type value() const { return value_type(m_position, m_value); } + + private: + value_type PISA_NOINLINE slow_move(uint64_t position) + { + if (PISA_UNLIKELY(position == size())) { + m_position = position; + m_value = m_of.universe; + return value(); + } + + uint64_t skip = position - m_position; + uint64_t to_skip; + if (position > m_position && (skip >> m_of.log_sampling1) == 0) { + to_skip = skip - 1; + } else { + uint64_t ptr = position >> m_of.log_sampling1; + uint64_t high_pos = pointer1(ptr); + uint64_t high_rank = ptr << m_of.log_sampling1; + m_high_enumerator = + BitVector::unary_enumerator(*m_bv, m_of.higher_bits_offset + high_pos); + to_skip = position - high_rank; + } + + m_high_enumerator.skip(to_skip); + m_position = position; + m_value = read_next(); + return value(); + } + + value_type PISA_NOINLINE slow_next_geq(uint64_t lower_bound) + { + if (PISA_UNLIKELY(lower_bound >= m_of.universe)) { + return move(size()); + } + + uint64_t high_lower_bound = lower_bound >> m_of.lower_bits; + uint64_t cur_high = m_value >> m_of.lower_bits; + uint64_t high_diff = high_lower_bound - cur_high; + + // XXX bounds checking! + uint64_t to_skip; + if (lower_bound > m_value && (high_diff >> m_of.log_sampling0) == 0) { + // note: at the current position in the bitvector there + // should be a 1, but since we already consumed it, it + // is 0 in the enumerator, so we need to skip it + to_skip = high_diff; + } else { + uint64_t ptr = high_lower_bound >> m_of.log_sampling0; + uint64_t high_pos = pointer0(ptr); + uint64_t high_rank0 = ptr << m_of.log_sampling0; + + m_high_enumerator = + BitVector::unary_enumerator(*m_bv, m_of.higher_bits_offset + high_pos); + to_skip = high_lower_bound - high_rank0; + } + + m_high_enumerator.skip0(to_skip); + m_position = m_high_enumerator.position() - m_of.higher_bits_offset - high_lower_bound; + + next_reader read_value(*this, m_position); + while (true) { + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + return value(); + } + auto val = read_value(); + if (val >= lower_bound) { + m_value = val; + return value(); + } + m_position++; + } + } + + static const uint64_t linear_scan_threshold = 8; + + inline uint64_t read_low() + { + return m_bv->get_word56(m_of.lower_bits_offset + m_position * m_of.lower_bits) + & m_of.mask; + } + + inline uint64_t read_next() + { + assert(m_position < size()); + uint64_t high = m_high_enumerator.next() - m_of.higher_bits_offset; + return ((high - m_position - 1) << m_of.lower_bits) | read_low(); + } + + struct next_reader { + next_reader(enumerator& e, uint64_t position) + : e(e), + high_enumerator(e.m_high_enumerator), + high_base(e.m_of.higher_bits_offset + position + 1), + lower_bits(e.m_of.lower_bits), + lower_base(e.m_of.lower_bits_offset + position * lower_bits), + mask(e.m_of.mask), + bv(*e.m_bv) + { + } + + ~next_reader() { e.m_high_enumerator = high_enumerator; } + + uint64_t operator()() + { + uint64_t high = high_enumerator.next() - high_base; + uint64_t low = bv.get_word56(lower_base) & mask; + high_base += 1; + lower_base += lower_bits; + return (high << lower_bits) | low; + } + + enumerator& e; + BitVector::unary_enumerator high_enumerator; + uint64_t high_base, lower_bits, lower_base, mask; + BitVector const& bv; + }; + + inline uint64_t pointer(uint64_t offset, uint64_t i) const + { + if (i == 0) { + return 0; + } else { + return m_bv->get_word56(offset + (i - 1) * m_of.pointer_size) + & ((uint64_t(1) << m_of.pointer_size) - 1); + } + } + + inline uint64_t pointer0(uint64_t i) const { return pointer(m_of.pointers0_offset, i); } + + inline uint64_t pointer1(uint64_t i) const { return pointer(m_of.pointers1_offset, i); } + + BitVector const* m_bv; + offsets m_of; + + uint64_t m_position; + uint64_t m_value; + BitVector::unary_enumerator m_high_enumerator; + }; +}; + +struct CompactRankedBitvector { + + struct offsets { + offsets(uint64_t base_offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : universe(universe), + n(n), + log_rank1_sampling(params.rb_log_rank1_sampling), + log_sampling1(params.rb_log_sampling1) + + , + rank1_sample_size(ceil_log2(n + 1)), + pointer_size(ceil_log2(universe)), + rank1_samples(universe >> params.rb_log_rank1_sampling), + pointers1(n >> params.rb_log_sampling1) + + , + rank1_samples_offset(base_offset), + pointers1_offset(rank1_samples_offset + rank1_samples * rank1_sample_size), + bits_offset(pointers1_offset + pointers1 * pointer_size), + end(bits_offset + universe) + { + } + + uint64_t universe; + uint64_t n; + uint64_t log_rank1_sampling; + uint64_t log_sampling1; + + uint64_t rank1_sample_size; + uint64_t pointer_size; + + uint64_t rank1_samples; + uint64_t pointers1; + + uint64_t rank1_samples_offset; + uint64_t pointers1_offset; + uint64_t bits_offset; + uint64_t end; + }; + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + return offsets(0, universe, n, params).end; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t base_offset = bvb.size(); + offsets of(base_offset, universe, n, params); + // initialize all the bits to 0 + bvb.zero_extend(of.end - base_offset); + + uint64_t offset; + + auto set_rank1_samples = [&](uint64_t begin, uint64_t end, uint64_t rank) { + for (uint64_t sample = ceil_div(begin, uint64_t(1) << of.log_rank1_sampling); + (sample << of.log_rank1_sampling) < end; + ++sample) { + if (!sample) + continue; + offset = of.rank1_samples_offset + (sample - 1) * of.rank1_sample_size; + assert(offset + of.rank1_sample_size <= of.pointers1_offset); + bvb.set_bits(offset, rank, of.rank1_sample_size); + } + }; + + uint64_t sample1_mask = (uint64_t(1) << of.log_sampling1) - 1; + uint64_t last = 0; + Iterator it = begin; + for (size_t i = 0; i < n; ++i) { + uint64_t v = *it++; + if (i && v == last) { + throw std::runtime_error("Duplicate element"); + } + if (i && v < last) { + throw std::runtime_error("Sequence is not sorted"); + } + + assert(!i || v > last); + assert(v <= universe); + + bvb.set(of.bits_offset + v, 1); + + if (i && (i & sample1_mask) == 0) { + uint64_t ptr1 = i >> of.log_sampling1; + assert(ptr1 > 0); + offset = of.pointers1_offset + (ptr1 - 1) * of.pointer_size; + assert(offset + of.pointer_size <= of.bits_offset); + bvb.set_bits(offset, v, of.pointer_size); + } + + set_rank1_samples(last + 1, v + 1, i); + last = v; + } + + set_rank1_samples(last + 1, universe, n); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_bv(&bv), + m_of(offset, universe, n, params), + m_position(size()), + m_value(m_of.universe) + { + } + + value_type move(uint64_t position) + { + assert(position <= size()); + + if (position == m_position) { + return value(); + } + + // optimize small forward skips + uint64_t skip = position - m_position; + if (PISA_LIKELY(position > m_position && skip <= linear_scan_threshold)) { + m_position = position; + if (PISA_UNLIKELY(m_position == size())) { + m_value = m_of.universe; + } else { + BitVector::unary_enumerator he = m_enumerator; + for (size_t i = 0; i < skip; ++i) { + he.next(); + } + m_value = he.position() - m_of.bits_offset; + m_enumerator = he; + } + + return value(); + } + + return slow_move(position); + } + + value_type next_geq(uint64_t lower_bound) + { + if (lower_bound == m_value) { + return value(); + } + + uint64_t diff = lower_bound - m_value; + if (PISA_LIKELY(lower_bound > m_value && diff <= linear_scan_threshold)) { + // optimize small skips + BitVector::unary_enumerator he = m_enumerator; + uint64_t val; + do { + m_position += 1; + if (PISA_LIKELY(m_position < size())) { + val = he.next() - m_of.bits_offset; + } else { + m_position = size(); + val = m_of.universe; + break; + } + } while (val < lower_bound); + + m_value = val; + m_enumerator = he; + return value(); + } else { + return slow_next_geq(lower_bound); + } + } + + value_type next() + { + m_position += 1; + assert(m_position <= size()); + + if (PISA_LIKELY(m_position < size())) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + return value(); + } + + uint64_t size() const { return m_of.n; } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + + uint64_t pos = 0; + if (PISA_LIKELY(m_position < size())) { + pos = m_bv->predecessor1(m_enumerator.position() - 1); + } else { + pos = m_bv->predecessor1(m_of.end - 1); + } + + return pos - m_of.bits_offset; + } + + private: + value_type PISA_NOINLINE slow_move(uint64_t position) + { + uint64_t skip = position - m_position; + if (PISA_UNLIKELY(position == size())) { + m_position = position; + m_value = m_of.universe; + return value(); + } + + uint64_t to_skip; + if (position > m_position && (skip >> m_of.log_sampling1) == 0) { + to_skip = skip - 1; + } else { + uint64_t ptr = position >> m_of.log_sampling1; + uint64_t ptr_pos = pointer1(ptr); + + m_enumerator = BitVector::unary_enumerator(*m_bv, m_of.bits_offset + ptr_pos); + to_skip = position - (ptr << m_of.log_sampling1); + } + + m_enumerator.skip(to_skip); + m_position = position; + m_value = read_next(); + + return value(); + } + + value_type PISA_NOINLINE slow_next_geq(uint64_t lower_bound) + { + using broadword::popcount; + + if (PISA_UNLIKELY(lower_bound >= m_of.universe)) { + return move(size()); + } + + uint64_t skip = lower_bound - m_value; + m_enumerator = BitVector::unary_enumerator(*m_bv, m_of.bits_offset + lower_bound); + + uint64_t begin; + if (lower_bound > m_value && (skip >> m_of.log_rank1_sampling) == 0) { + begin = m_of.bits_offset + m_value; + } else { + uint64_t block = lower_bound >> m_of.log_rank1_sampling; + m_position = rank1_sample(block); + + begin = m_of.bits_offset + (block << m_of.log_rank1_sampling); + } + + uint64_t end = m_of.bits_offset + lower_bound; + uint64_t begin_word = begin / 64; + uint64_t begin_shift = begin % 64; + uint64_t end_word = end / 64; + uint64_t end_shift = end % 64; + uint64_t word = (m_bv->data()[begin_word] >> begin_shift) << begin_shift; + + while (begin_word < end_word) { + m_position += popcount(word); + word = m_bv->data()[++begin_word]; + } + if (end_shift) { + m_position += popcount(word << (64 - end_shift)); + } + + if (m_position < size()) { + m_value = read_next(); + } else { + m_value = m_of.universe; + } + + return value(); + } + + static const uint64_t linear_scan_threshold = 8; + + inline value_type value() const { return value_type(m_position, m_value); } + + inline uint64_t read_next() { return m_enumerator.next() - m_of.bits_offset; } + + inline uint64_t pointer(uint64_t offset, uint64_t i, uint64_t size) const + { + if (i == 0) { + return 0; + } else { + return m_bv->get_word56(offset + (i - 1) * size) & ((uint64_t(1) << size) - 1); + } + } + + inline uint64_t pointer1(uint64_t i) const + { + return pointer(m_of.pointers1_offset, i, m_of.pointer_size); + } + + inline uint64_t rank1_sample(uint64_t i) const + { + return pointer(m_of.rank1_samples_offset, i, m_of.rank1_sample_size); + } + + BitVector const* m_bv; + offsets m_of; + + uint64_t m_position; + uint64_t m_value; + BitVector::unary_enumerator m_enumerator; + }; +}; + +struct AllOnesSequence { + + inline static uint64_t bitsize(global_parameters const& /* params */, + uint64_t universe, + uint64_t n) + { + return (universe == n) ? 0 : uint64_t(-1); + } + + template + static void write( + bit_vector_builder&, Iterator, uint64_t universe, uint64_t n, global_parameters const&) + { + assert(universe == n); + (void)universe; + (void)n; + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator( + BitVector const&, uint64_t, uint64_t universe, uint64_t n, global_parameters const&) + : m_universe(universe), m_position(size()) + { + assert(universe == n); + (void)n; + } + + value_type move(uint64_t position) + { + assert(position <= size()); + m_position = position; + return value_type(m_position, m_position); + } + + value_type next_geq(uint64_t lower_bound) + { + assert(lower_bound <= size()); + m_position = lower_bound; + return value_type(m_position, m_position); + } + + value_type next() + { + m_position += 1; + return value_type(m_position, m_position); + } + + uint64_t size() const { return m_universe; } + + uint64_t prev_value() const + { + if (m_position == 0) { + return 0; + } + return m_position - 1; + } + + private: + uint64_t m_universe; + uint64_t m_position; + }; +}; + +struct IndexedSequence { + + enum index_type { + elias_fano = 0, + ranked_bitvector = 1, + all_ones = 2, + + index_types = 3 + }; + + static const uint64_t type_bits = 1; // all_ones is implicit + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + + uint64_t ef_cost = CompactEliasFano::bitsize(params, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(params, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + } + + return best_cost; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + int best_type = all_ones; + + if (best_cost) { + uint64_t ef_cost = CompactEliasFano::bitsize(params, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + best_type = elias_fano; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(params, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + best_type = ranked_bitvector; + } + + bvb.append_bits(best_type, type_bits); + } + + switch (best_type) { + case elias_fano: + CompactEliasFano::write(bvb, begin, universe, n, params); + break; + case ranked_bitvector: + CompactRankedBitvector::write(bvb, begin, universe, n, params); + break; + case all_ones: + AllOnesSequence::write(bvb, begin, universe, n, params); + break; + default: + assert(false); + } + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_universe(universe) + { + if (AllOnesSequence::bitsize(params, universe, n) == 0) { + m_type = all_ones; + } else { + m_type = index_type(bv.get_word56(offset) & ((uint64_t(1) << type_bits) - 1)); + } + + switch (m_type) { + case elias_fano: + m_enumerator = + CompactEliasFano::enumerator(bv, offset + type_bits, universe, n, params); + break; + case ranked_bitvector: + m_enumerator = + CompactRankedBitvector::enumerator(bv, offset + type_bits, universe, n, params); + break; + case all_ones: + m_enumerator = + AllOnesSequence::enumerator(bv, offset + type_bits, universe, n, params); + break; + default: + throw std::invalid_argument("Unsupported type"); + } + } + + value_type move(uint64_t position) + { + return boost::apply_visitor([&position](auto&& e) { return e.move(position); }, + m_enumerator); + } + value_type next_geq(uint64_t lower_bound) + { + return boost::apply_visitor( + [&lower_bound](auto&& e) { return e.next_geq(lower_bound); }, m_enumerator); + } + value_type next() + { + return boost::apply_visitor([](auto&& e) { return e.next(); }, m_enumerator); + } + + uint64_t size() const + { + return boost::apply_visitor([](auto&& e) { return e.size(); }, m_enumerator); + } + + uint64_t prev_value() const + { + return boost::apply_visitor([](auto&& e) { return e.prev_value(); }, m_enumerator); + } + + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + + private: + index_type m_type; + boost::variant + m_enumerator; + std::uint32_t m_universe; + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/partitioned_sequence.hpp b/include/pisa/v1/sequence/partitioned_sequence.hpp new file mode 100644 index 000000000..6dce91c36 --- /dev/null +++ b/include/pisa/v1/sequence/partitioned_sequence.hpp @@ -0,0 +1,403 @@ +#pragma once + +#include "tbb/task_group.h" +#include + +#include "codec/integer_codes.hpp" +#include "configuration.hpp" +#include "global_parameters.hpp" +#include "optimal_partition.hpp" +#include "util/util.hpp" +#include "v1/bit_vector.hpp" +#include "v1/sequence/indexed_sequence.hpp" + +namespace pisa::v1 { + +template +struct PartitionedSequence { + + using base_sequence_type = BaseSequence; + using base_sequence_enumerator = typename base_sequence_type::enumerator; + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + { + assert(n > 0); + auto partition = compute_partition(begin, universe, n, params); + + size_t partitions = partition.size(); + assert(partitions > 0); + assert(partition.front() != 0); + assert(partition.back() == n); + write_gamma_nonzero(bvb, partitions); + + std::vector cur_partition; + std::uint64_t cur_base = 0; + if (partitions == 1) { + cur_base = *begin; + Iterator it = begin; + + for (size_t i = 0; i < n; ++i, ++it) { + cur_partition.push_back(*it - cur_base); + } + + std::uint64_t universe_bits = ceil_log2(universe); + bvb.append_bits(cur_base, universe_bits); + + // write universe only if non-singleton and not tight + if (n > 1) { + if (cur_base + cur_partition.back() + 1 == universe) { + // tight universe + write_delta(bvb, 0); + } else { + write_delta(bvb, cur_partition.back()); + } + } + + base_sequence_type::write( + bvb, cur_partition.begin(), cur_partition.back() + 1, cur_partition.size(), params); + } else { + bit_vector_builder bv_sequences; + std::vector endpoints; + std::vector upper_bounds; + + std::uint64_t cur_i = 0; + Iterator it = begin; + cur_base = *begin; + upper_bounds.push_back(cur_base); + + for (size_t p = 0; p < partition.size(); ++p) { + cur_partition.clear(); + std::uint64_t value = 0; + for (; cur_i < partition[p]; ++cur_i, ++it) { + value = *it; + cur_partition.push_back(value - cur_base); + } + + std::uint64_t upper_bound = value; + assert(cur_partition.size() > 0); + base_sequence_type::write(bv_sequences, + cur_partition.begin(), + cur_partition.back() + 1, + cur_partition.size(), // XXX skip last one? + params); + endpoints.push_back(bv_sequences.size()); + upper_bounds.push_back(upper_bound); + cur_base = upper_bound + 1; + } + + bit_vector_builder bv_sizes; + CompactEliasFano::write(bv_sizes, partition.begin(), n, partitions - 1, params); + + bit_vector_builder bv_upper_bounds; + CompactEliasFano::write( + bv_upper_bounds, upper_bounds.begin(), universe, partitions + 1, params); + + std::uint64_t endpoint_bits = ceil_log2(bv_sequences.size() + 1); + write_gamma(bvb, endpoint_bits); + + bvb.append(bv_sizes); + bvb.append(bv_upper_bounds); + + for (std::uint64_t p = 0; p < endpoints.size() - 1; ++p) { + bvb.append_bits(endpoints[p], endpoint_bits); + } + + bvb.append(bv_sequences); + } + } + + class enumerator { + public: + using value_type = std::pair; // (position, value) + + enumerator(BitVector const& bv, + std::uint64_t offset, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + : m_params(params), m_size(n), m_universe(universe), m_bv(&bv) + { + BitVector::enumerator it(bv, offset); + m_partitions = read_gamma_nonzero(it); + if (m_partitions == 1) { + m_cur_partition = 0; + m_cur_begin = 0; + m_cur_end = n; + + std::uint64_t universe_bits = ceil_log2(universe); + m_cur_base = it.take(universe_bits); + auto ub = 0; + if (n > 1) { + std::uint64_t universe_delta = read_delta(it); + ub = universe_delta > 0U ? universe_delta : (universe - m_cur_base - 1); + } + + m_partition_enum = + base_sequence_enumerator(*m_bv, it.position(), ub + 1, n, m_params); + + m_cur_upper_bound = m_cur_base + ub; + } else { + m_endpoint_bits = read_gamma(it); + + std::uint64_t cur_offset = it.position(); + m_sizes = + CompactEliasFano::enumerator(bv, cur_offset, n, m_partitions - 1, params); + cur_offset += CompactEliasFano::bitsize(params, n, m_partitions - 1); + + m_upper_bounds = CompactEliasFano::enumerator( + bv, cur_offset, universe, m_partitions + 1, params); + cur_offset += CompactEliasFano::bitsize(params, universe, m_partitions + 1); + + m_endpoints_offset = cur_offset; + std::uint64_t endpoints_size = m_endpoint_bits * (m_partitions - 1); + cur_offset += endpoints_size; + + m_sequences_offset = cur_offset; + } + + m_position = size(); + slow_move(); + } + + value_type PISA_ALWAYSINLINE move(std::uint64_t position) + { + assert(position <= size()); + m_position = position; + + if (m_position >= m_cur_begin && m_position < m_cur_end) { + std::uint64_t val = + m_cur_base + m_partition_enum.move(m_position - m_cur_begin).second; + return value_type(m_position, val); + } + + return slow_move(); + } + + // note: this is instantiated oly if BaseSequence has next_geq + template > + value_type PISA_ALWAYSINLINE next_geq(std::uint64_t lower_bound) + { + if (PISA_LIKELY(lower_bound >= m_cur_base && lower_bound <= m_cur_upper_bound)) { + auto val = m_partition_enum.next_geq(lower_bound - m_cur_base); + m_position = m_cur_begin + val.first; + return value_type(m_position, m_cur_base + val.second); + } + return slow_next_geq(lower_bound); + } + + value_type PISA_ALWAYSINLINE next() + { + ++m_position; + + if (PISA_LIKELY(m_position < m_cur_end)) { + std::uint64_t val = m_cur_base + m_partition_enum.next().second; + return value_type(m_position, val); + } + return slow_next(); + } + + [[nodiscard]] auto size() const -> std::uint64_t { return m_size; } + + [[nodiscard]] auto prev_value() const -> std::uint64_t + { + if (PISA_UNLIKELY(m_position == m_cur_begin)) { + return m_cur_partition ? m_cur_base - 1 : 0; + } + return m_cur_base + m_partition_enum.prev_value(); + } + + [[nodiscard]] auto num_partitions() const -> std::uint64_t { return m_partitions; } + [[nodiscard]] auto universe() const -> std::uint64_t { return m_universe; } + + friend class partitioned_sequence_test; + + private: + // the compiler does not seem smart enough to figure out that this + // is a very unlikely condition, and inlines the move(0) inside the + // next(), causing the code to grow. Since next is called in very + // tight loops, on microbenchmarks this causes an improvement of + // about 3ns on my i7 3Ghz + value_type PISA_NOINLINE slow_next() + { + if (PISA_UNLIKELY(m_position == m_size)) { + assert(m_cur_partition == m_partitions - 1); + auto val = m_partition_enum.next(); + assert(val.first == m_partition_enum.size()); + (void)val; + return value_type(m_position, m_universe); + } + + switch_partition(m_cur_partition + 1); + std::uint64_t val = m_cur_base + m_partition_enum.move(0).second; + return value_type(m_position, val); + } + + value_type PISA_NOINLINE slow_move() + { + if (m_position == size()) { + if (m_partitions > 1) { + switch_partition(m_partitions - 1); + } + m_partition_enum.move(m_partition_enum.size()); + return value_type(m_position, m_universe); + } + auto size_it = m_sizes.next_geq(m_position + 1); // need endpoint strictly > m_position + switch_partition(size_it.first); + std::uint64_t val = m_cur_base + m_partition_enum.move(m_position - m_cur_begin).second; + return value_type(m_position, val); + } + + value_type PISA_NOINLINE slow_next_geq(std::uint64_t lower_bound) + { + if (m_partitions == 1) { + if (lower_bound < m_cur_base) { + return move(0); + } + return move(size()); + } + + auto ub_it = m_upper_bounds.next_geq(lower_bound); + if (ub_it.first == 0) { + return move(0); + } + + if (ub_it.first == m_upper_bounds.size()) { + return move(size()); + } + + switch_partition(ub_it.first - 1); + return next_geq(lower_bound); + } + + void switch_partition(std::uint64_t partition) + { + assert(m_partitions > 1); + + std::uint64_t endpoint = + partition > 0U + ? (m_bv->get_word56(m_endpoints_offset + (partition - 1) * m_endpoint_bits) + & ((std::uint64_t(1) << m_endpoint_bits) - 1)) + : 0; + + std::uint64_t partition_begin = m_sequences_offset + endpoint; + intrinsics::prefetch(std::next(m_bv->data(), partition_begin / 64)); + + m_cur_partition = partition; + auto size_it = m_sizes.move(partition); + m_cur_end = size_it.second; + m_cur_begin = m_sizes.prev_value(); + + auto ub_it = m_upper_bounds.move(partition + 1); + m_cur_upper_bound = ub_it.second; + m_cur_base = m_upper_bounds.prev_value() + (partition > 0 ? 1 : 0); + + m_partition_enum = base_sequence_enumerator(*m_bv, + partition_begin, + m_cur_upper_bound - m_cur_base + 1, + m_cur_end - m_cur_begin, + m_params); + } + + global_parameters m_params; + std::uint64_t m_partitions; + std::uint64_t m_endpoints_offset; + std::uint64_t m_endpoint_bits; + std::uint64_t m_sequences_offset; + std::uint64_t m_size; + std::uint64_t m_universe; + + std::uint64_t m_position; + std::uint64_t m_cur_partition; + std::uint64_t m_cur_begin; + std::uint64_t m_cur_end; + std::uint64_t m_cur_base; + std::uint64_t m_cur_upper_bound; + + BitVector const* m_bv; + CompactEliasFano::enumerator m_sizes; + CompactEliasFano::enumerator m_upper_bounds; + base_sequence_enumerator m_partition_enum; + }; + + private: + template + static std::vector compute_partition(Iterator begin, + std::uint64_t universe, + std::uint64_t n, + global_parameters const& params) + { + std::vector partition; + + auto const& conf = configuration::get(); + + if (base_sequence_type::bitsize(params, universe, n) < 2 * conf.fix_cost) { + partition.push_back(n); + return partition; + } + + auto cost_fun = [&](std::uint64_t universe, std::uint64_t n) { + return base_sequence_type::bitsize(params, universe, n) + conf.fix_cost; + }; + + const size_t superblock_bound = conf.eps3 != 0 ? size_t(conf.fix_cost / conf.eps3) : n; + + std::deque> superblock_partitions; + tbb::task_group tg; + + size_t superblock_pos = 0; + auto superblock_begin = begin; + auto superblock_base = *begin; + + while (superblock_pos < n) { + size_t superblock_size = std::min(superblock_bound, n - superblock_pos); + // If the remainder is smaller than the bound (possibly + // empty), merge it to the current (now last) superblock. + if (n - (superblock_pos + superblock_size) < superblock_bound) { + superblock_size = n - superblock_pos; + } + auto superblock_last = std::next(superblock_begin, superblock_size - 1); + auto superblock_end = std::next(superblock_last); + + // If this is the last superblock, its universe is the + // list universe. + size_t superblock_universe = + superblock_pos + superblock_size == n ? universe : *superblock_last + 1; + + superblock_partitions.emplace_back(); + auto& superblock_partition = superblock_partitions.back(); + + tg.run([=, &cost_fun, &conf, &superblock_partition] { + optimal_partition opt(superblock_begin, + superblock_base, + superblock_universe, + superblock_size, + cost_fun, + conf.eps1, + conf.eps2); + + superblock_partition.reserve(opt.partition.size()); + for (auto& endpoint : opt.partition) { + superblock_partition.push_back(superblock_pos + endpoint); + } + }); + + superblock_pos += superblock_size; + superblock_begin = superblock_end; + superblock_base = superblock_universe; + } + tg.wait(); + + for (const auto& superblock_partition : superblock_partitions) { + partition.insert( + partition.end(), superblock_partition.begin(), superblock_partition.end()); + } + + return partition; + } +}; +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence/positive_sequence.hpp b/include/pisa/v1/sequence/positive_sequence.hpp new file mode 100644 index 000000000..c2cc3ddd2 --- /dev/null +++ b/include/pisa/v1/sequence/positive_sequence.hpp @@ -0,0 +1,311 @@ +#pragma once + +#include + +#include + +#include "global_parameters.hpp" +#include "util/util.hpp" +#include "v1/bit_vector.hpp" +#include "v1/sequence/indexed_sequence.hpp" + +namespace pisa::v1 { + +struct StrictEliasFano { + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + assert(universe >= n); + return CompactEliasFano::bitsize(params, universe - n + 1, n); + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + uint64_t new_universe = universe - n + 1; + typedef typename std::iterator_traits::value_type value_type; + auto new_begin = make_function_iterator( + std::make_pair(value_type(0), begin), + [](std::pair& state) { + ++state.first; + ++state.second; + }, + [](std::pair const& state) { + return *state.second - state.first; + }); + CompactEliasFano::write(bvb, new_begin, new_universe, n, params); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_ef_enum(bv, offset, universe - n + 1, n, params) + { + } + + value_type move(uint64_t position) + { + auto val = m_ef_enum.move(position); + return value_type(val.first, val.second + val.first); + } + + value_type next() + { + auto val = m_ef_enum.next(); + return value_type(val.first, val.second + val.first); + } + + uint64_t size() const { return m_ef_enum.size(); } + + uint64_t prev_value() const + { + if (m_ef_enum.position()) { + return m_ef_enum.prev_value() + m_ef_enum.position() - 1; + } else { + return 0; + } + } + + private: + CompactEliasFano::enumerator m_ef_enum; + }; +}; + +struct StrictSequence { + + enum index_type { + elias_fano = 0, + ranked_bitvector = 1, + all_ones = 2, + + index_types = 3 + }; + + static const uint64_t type_bits = 1; // all_ones is implicit + + static global_parameters strict_params(global_parameters params) + { + // we do not need to index the zeros + params.ef_log_sampling0 = 63; + params.rb_log_rank1_sampling = 63; + return params; + } + + static PISA_FLATTEN_FUNC uint64_t bitsize(global_parameters const& params, + uint64_t universe, + uint64_t n) + { + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + auto sparams = strict_params(params); + + uint64_t ef_cost = StrictEliasFano::bitsize(sparams, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(sparams, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + } + + return best_cost; + } + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + auto sparams = strict_params(params); + uint64_t best_cost = AllOnesSequence::bitsize(params, universe, n); + int best_type = all_ones; + + if (best_cost) { + uint64_t ef_cost = StrictEliasFano::bitsize(sparams, universe, n) + type_bits; + if (ef_cost < best_cost) { + best_cost = ef_cost; + best_type = elias_fano; + } + + uint64_t rb_cost = CompactRankedBitvector::bitsize(sparams, universe, n) + type_bits; + if (rb_cost < best_cost) { + best_cost = rb_cost; + best_type = ranked_bitvector; + } + + bvb.append_bits(best_type, type_bits); + } + + switch (best_type) { + case elias_fano: + StrictEliasFano::write(bvb, begin, universe, n, sparams); + break; + case ranked_bitvector: + CompactRankedBitvector::write(bvb, begin, universe, n, sparams); + break; + case all_ones: + AllOnesSequence::write(bvb, begin, universe, n, sparams); + break; + default: + assert(false); + } + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() {} + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + + auto sparams = strict_params(params); + + if (AllOnesSequence::bitsize(params, universe, n) == 0) { + m_type = all_ones; + } else { + m_type = index_type(bv.get_word56(offset) & ((uint64_t(1) << type_bits) - 1)); + } + + switch (m_type) { + case elias_fano: + m_enumerator = + StrictEliasFano::enumerator(bv, offset + type_bits, universe, n, sparams); + break; + case ranked_bitvector: + m_enumerator = CompactRankedBitvector::enumerator( + bv, offset + type_bits, universe, n, sparams); + break; + case all_ones: + m_enumerator = + AllOnesSequence::enumerator(bv, offset + type_bits, universe, n, sparams); + break; + default: + throw std::invalid_argument("Unsupported type"); + } + } + + value_type move(uint64_t position) + { + return boost::apply_visitor([&position](auto&& e) { return e.move(position); }, + m_enumerator); + } + + value_type next() + { + return boost::apply_visitor([](auto&& e) { return e.next(); }, m_enumerator); + } + + uint64_t size() const + { + return boost::apply_visitor([](auto&& e) { return e.size(); }, m_enumerator); + } + + uint64_t prev_value() const + { + return boost::apply_visitor([](auto&& e) { return e.prev_value(); }, m_enumerator); + } + + private: + index_type m_type; + boost::variant + m_enumerator; + }; +}; + +template +struct PositiveSequence { + + typedef BaseSequence base_sequence_type; + typedef typename base_sequence_type::enumerator base_sequence_enumerator; + + template + static void write(bit_vector_builder& bvb, + Iterator begin, + uint64_t universe, + uint64_t n, + global_parameters const& params) + { + assert(n > 0); + auto cumulative_begin = make_function_iterator( + std::make_pair(uint64_t(0), begin), + [](std::pair& state) { state.first += *state.second++; }, + [](std::pair const& state) { return state.first + *state.second; }); + base_sequence_type::write(bvb, cumulative_begin, universe, n, params); + } + + class enumerator { + public: + typedef std::pair value_type; // (position, value) + + enumerator() = delete; + + enumerator(BitVector const& bv, + uint64_t offset, + uint64_t universe, + uint64_t n, + global_parameters const& params) + : m_base_enum(bv, offset, universe, n, params), + m_position(m_base_enum.size()), + m_universe(universe) + { + } + + value_type next() { return move(m_position + 1); } + auto size() const { return m_base_enum.size(); } + auto universe() const { return m_universe; } + + value_type move(uint64_t position) + { + // we cache m_position and m_cur to avoid the call overhead in + // the most common cases + uint64_t prev = m_cur; + if (position != m_position + 1) { + if (PISA_UNLIKELY(position == 0)) { + // we need to special-case position 0 + m_cur = m_base_enum.move(0).second; + m_position = 0; + return value_type(m_position, m_cur); + } + prev = m_base_enum.move(position - 1).second; + } + + m_cur = m_base_enum.next().second; + m_position = position; + return value_type(position, m_cur - prev); + } + + base_sequence_enumerator const& base() const { return m_base_enum; } + + private: + base_sequence_enumerator m_base_enum; + uint64_t m_position; + uint64_t m_cur{}; + uint64_t m_universe{0}; + }; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/sequence_cursor.hpp b/include/pisa/v1/sequence_cursor.hpp deleted file mode 100644 index bb8f12d70..000000000 --- a/include/pisa/v1/sequence_cursor.hpp +++ /dev/null @@ -1,326 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "codec/block_codecs.hpp" -#include "util/likely.hpp" -#include "v1/bit_cast.hpp" -#include "v1/cursor_traits.hpp" -#include "v1/encoding_traits.hpp" -#include "v1/types.hpp" -#include "v1/unaligned_span.hpp" - -namespace pisa::v1 { - -/// Uncompressed example of implementation of a single value cursor. -template -struct SequenceCursor { - using value_type = std::uint32_t; - - /// Creates a cursor from the encoded bytes. - explicit constexpr SequenceCursor(gsl::span encoded_blocks, - UnalignedSpan block_endpoints, - UnalignedSpan block_last_values, - std::uint32_t length, - std::uint32_t num_blocks) - : m_encoded_blocks(encoded_blocks), - m_block_endpoints(block_endpoints), - m_block_last_values(block_last_values), - m_length(length), - m_num_blocks(num_blocks), - m_current_block( - {.number = 0, - .offset = 0, - .length = std::min(length, static_cast(Codec::block_size)), - .last_value = m_block_last_values[0]}) - { - static_assert(DeltaEncoded, - "Cannot initialize block_last_values for not delta-encoded list"); - m_decoded_block.resize(Codec::block_size); - reset(); - } - - /// Creates a cursor from the encoded bytes. - explicit constexpr SequenceCursor(gsl::span encoded_blocks, - UnalignedSpan block_endpoints, - std::uint32_t length, - std::uint32_t num_blocks) - : m_encoded_blocks(encoded_blocks), - m_block_endpoints(block_endpoints), - m_length(length), - m_num_blocks(num_blocks), - m_current_block( - {.number = 0, - .offset = 0, - .length = std::min(length, static_cast(Codec::block_size)), - .last_value = 0}) - { - static_assert(not DeltaEncoded, "Must initialize block_last_values for delta-encoded list"); - m_decoded_block.resize(Codec::block_size); - reset(); - } - - constexpr SequenceCursor(SequenceCursor const&) = default; - constexpr SequenceCursor(SequenceCursor&&) noexcept = default; - constexpr SequenceCursor& operator=(SequenceCursor const&) = default; - constexpr SequenceCursor& operator=(SequenceCursor&&) noexcept = default; - ~SequenceCursor() = default; - - void reset() { decode_and_update_block(0); } - - /// Dereferences the current value. - [[nodiscard]] constexpr auto operator*() const -> value_type { return m_current_value; } - - /// Alias for `operator*()`. - [[nodiscard]] constexpr auto value() const noexcept -> value_type { return *(*this); } - - /// Advances the cursor to the next position. - constexpr void advance() - { - m_current_block.offset += 1; - if (PISA_UNLIKELY(m_current_block.offset == m_current_block.length)) { - if (m_current_block.number + 1 == m_num_blocks) { - m_current_value = sentinel(); - return; - } - decode_and_update_block(m_current_block.number + 1); - } else { - if constexpr (DeltaEncoded) { - m_current_value += m_decoded_block[m_current_block.offset] + 1U; - } else { - m_current_value = m_decoded_block[m_current_block.offset] + 1U; - } - } - } - - /// Moves the cursor to the position `pos`. - constexpr void advance_to_position(std::uint32_t pos) - { - Expects(pos >= position()); - auto block = pos / Codec::block_size; - if (PISA_UNLIKELY(block != m_current_block.number)) { - decode_and_update_block(block); - } - while (position() < pos) { - if constexpr (DeltaEncoded) { - m_current_value += m_decoded_block[++m_current_block.offset] + 1U; - } else { - m_current_value = m_decoded_block[++m_current_block.offset] + 1U; - } - } - } - - /// Moves the cursor to the next value equal or greater than `value`. - constexpr void advance_to_geq(value_type value) - { - // static_assert(DeltaEncoded, "Cannot call advance_to_geq on a not delta-encoded list"); - // TODO(michal): This should be `static_assert` like above. But currently, - // it would not compile. What needs to be done is separating document - // and payload readers for the index runner. - assert(DeltaEncoded); - if (PISA_UNLIKELY(value > m_current_block.last_value)) { - if (value > m_block_last_values.back()) { - m_current_value = sentinel(); - return; - } - auto block = m_current_block.number + 1U; - while (m_block_last_values[block] < value) { - ++block; - } - decode_and_update_block(block); - } - - while (m_current_value < value) { - m_current_value += m_decoded_block[++m_current_block.offset] + 1U; - Ensures(m_current_block.offset < m_current_block.length); - } - } - - ///// Returns `true` if there is no elements left. - [[nodiscard]] constexpr auto empty() const noexcept -> bool { return position() == m_length; } - - /// Returns the current position. - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t - { - return m_current_block.number * Codec::block_size + m_current_block.offset; - } - - ///// Returns the number of elements in the list. - [[nodiscard]] constexpr auto size() const -> std::size_t { return m_length; } - - /// The sentinel value, such that `value() != nullopt` is equivalent to `*(*this) < sentinel()`. - [[nodiscard]] constexpr auto sentinel() const -> value_type - { - return std::numeric_limits::max(); - } - - private: - - //gsl::span m_encoded_blocks; - //UnalignedSpan m_block_endpoints; - //UnalignedSpan m_block_last_values{}; - //std::vector m_decoded_block; - - //std::uint32_t m_length; - //std::uint32_t m_num_blocks; - //Block m_current_block{}; - //value_type m_current_value{}; -}; - -// template -// constexpr auto block_encoding_type() -> std::uint32_t -//{ -// if constexpr (DeltaEncoded) { -// return EncodingId::BlockDelta; -// } else { -// return EncodingId::Block; -// } -//} - -template -struct SequenceReader { - using value_type = std::uint32_t; - - [[nodiscard]] auto read(gsl::span bytes) const -> SequenceCursor - { - // std::uint32_t length; - // auto begin = reinterpret_cast(bytes.data()); - // auto after_length_ptr = pisa::TightVariableByte::decode(begin, &length, 1); - // auto length_byte_size = std::distance(begin, after_length_ptr); - // auto num_blocks = ceil_div(length, Codec::block_size); - // UnalignedSpan block_last_values; - // if constexpr (DeltaEncoded) { - // block_last_values = UnalignedSpan( - // bytes.subspan(length_byte_size, num_blocks * sizeof(value_type))); - //} - // UnalignedSpan block_endpoints( - // bytes.subspan(length_byte_size + block_last_values.byte_size(), - // (num_blocks - 1) * sizeof(value_type))); - // auto encoded_blocks = bytes.subspan(length_byte_size + block_last_values.byte_size() - // + block_endpoints.byte_size()); - // if constexpr (DeltaEncoded) { - // return SequenceCursor( - // encoded_blocks, block_endpoints, block_last_values, length, num_blocks); - //} else { - // return SequenceCursor( - // encoded_blocks, block_endpoints, length, num_blocks); - //} - } - - constexpr static auto encoding() -> std::uint32_t - { - return EncodingId::Sequence | encoding_traits::encoding_tag::encoding(); - } -}; - -template -struct SequenceWriter { - using value_type = std::uint32_t; - - constexpr static auto encoding() -> std::uint32_t - { - return EncodingId::Sequence | encoding_traits::encoding_tag::encoding(); - } - - void push(value_type const& posting) - { - // if constexpr (DeltaEncoded) { - // if (posting < m_last_value) { - // throw std::runtime_error( - // fmt::format("Delta-encoded sequences must be monotonic, but {} < {}", - // posting, - // m_last_value)); - // } - //} - // m_postings.push_back(posting); - // m_last_value = posting; - } - - template - [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t - { - // std::vector buffer; - // std::uint32_t length = m_postings.size(); - // TightVariableByte::encode_single(length, buffer); - // auto block_size = Codec::block_size; - // auto num_blocks = ceil_div(length, block_size); - // auto begin_block_maxs = buffer.size(); - // auto begin_block_endpoints = [&]() { - // if constexpr (DeltaEncoded) { - // return begin_block_maxs + 4U * num_blocks; - // } else { - // return begin_block_maxs; - // } - //}(); - // auto begin_blocks = begin_block_endpoints + 4U * (num_blocks - 1); - // buffer.resize(begin_blocks); - - // auto iter = m_postings.begin(); - // std::vector block_buffer(block_size); - // std::uint32_t last_value(-1); - // std::uint32_t block_base = 0; - // for (auto block = 0; block < num_blocks; ++block) { - // auto current_block_size = - // ((block + 1) * block_size <= length) ? block_size : (length % block_size); - - // std::for_each(block_buffer.begin(), - // std::next(block_buffer.begin(), current_block_size), - // [&](auto&& elem) { - // if constexpr (DeltaEncoded) { - // auto value = *iter++; - // elem = value - (last_value + 1); - // last_value = value; - // } else { - // elem = *iter++ - 1; - // } - // }); - - // if constexpr (DeltaEncoded) { - // std::memcpy( - // &buffer[begin_block_maxs + 4U * block], &last_value, sizeof(last_value)); - // auto size = buffer.size(); - // Codec::encode(block_buffer.data(), - // last_value - block_base - (current_block_size - 1), - // current_block_size, - // buffer); - // } else { - // Codec::encode(block_buffer.data(), std::uint32_t(-1), current_block_size, buffer); - // } - // if (block != num_blocks - 1) { - // std::size_t endpoint = buffer.size() - begin_blocks; - // std::memcpy( - // &buffer[begin_block_endpoints + 4U * block], &endpoint, sizeof(last_value)); - // } - // block_base = last_value + 1; - //} - // os.write(reinterpret_cast(buffer.data()), buffer.size()); - // return buffer.size(); - } - - void reset() - { - m_postings.clear(); - m_last_value = 0; - } - - private: - std::vector m_postings{}; - value_type m_last_value = 0U; -}; - -template -struct CursorTraits> { - using Writer = SequenceWriter; - using Reader = SequenceReader; -}; - -} // namespace pisa::v1 diff --git a/include/pisa/v1/types.hpp b/include/pisa/v1/types.hpp index 4b0881626..7e81b87b5 100644 --- a/include/pisa/v1/types.hpp +++ b/include/pisa/v1/types.hpp @@ -6,6 +6,8 @@ #include +#include "binary_freq_collection.hpp" + #define Unreachable() std::abort(); namespace pisa::v1 { @@ -21,9 +23,11 @@ enum EncodingId { Raw = 0xDA43, BlockDelta = 0xEF00, Block = 0xFF00, - Sequence = 0xDF00, + BitSequence = 0xDF00, SimdBP = 0x0001, - Varbyte = 0x0002 + Varbyte = 0x0002, + PEF = 0x0003, + PositiveSeq = 0x0004 }; template @@ -34,25 +38,70 @@ struct overloaded : Ts... { template overloaded(Ts...)->overloaded; +struct BaseIndex; + template struct Reader { using Value = std::decay_t())>; - template - explicit constexpr Reader(ReaderImpl &&reader) + template + explicit constexpr Reader(R reader) : m_internal_reader(std::make_unique>(reader)) { - m_read = [reader = std::forward(reader)](gsl::span bytes) { - return reader.read(bytes); - }; } - + Reader() = default; + Reader(Reader const& other) : m_internal_reader(other.m_internal_reader->clone()) {} + Reader(Reader&& other) noexcept = default; + Reader& operator=(Reader const& other) = delete; + Reader& operator=(Reader&& other) noexcept = default; + ~Reader() = default; + + void init(BaseIndex const& index) { m_internal_reader->init(index); } [[nodiscard]] auto read(gsl::span bytes) const -> Cursor { - return m_read(bytes); + return m_internal_reader->read(bytes); } + [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_reader->encoding(); } + + struct ReaderInterface { + ReaderInterface() = default; + ReaderInterface(ReaderInterface const&) = default; + ReaderInterface(ReaderInterface&&) noexcept = default; + ReaderInterface& operator=(ReaderInterface const&) = default; + ReaderInterface& operator=(ReaderInterface&&) noexcept = default; + virtual ~ReaderInterface() = default; + virtual void init(BaseIndex const& index) = 0; + [[nodiscard]] virtual auto read(gsl::span bytes) const -> Cursor = 0; + [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct ReaderImpl : ReaderInterface { + explicit ReaderImpl(R reader) : m_reader(std::move(reader)) {} + ReaderImpl() = default; + ReaderImpl(ReaderImpl const&) = default; + ReaderImpl(ReaderImpl&&) noexcept = default; + ReaderImpl& operator=(ReaderImpl const&) = default; + ReaderImpl& operator=(ReaderImpl&&) noexcept = default; + ~ReaderImpl() = default; + void init(BaseIndex const& index) override { m_reader.init(index); } + [[nodiscard]] auto read(gsl::span bytes) const -> Cursor override + { + return m_reader.read(bytes); + } + [[nodiscard]] auto encoding() const -> std::uint32_t override { return R::encoding(); } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_reader; + }; private: - std::function)> m_read; + std::unique_ptr m_internal_reader; }; template @@ -64,30 +113,35 @@ struct Writer { { } Writer() = default; - Writer(Writer const &other) : m_internal_writer(other.m_internal_writer->clone()) {} - Writer(Writer &&other) noexcept = default; - Writer &operator=(Writer const &other) = delete; - Writer &operator=(Writer &&other) noexcept = default; + Writer(Writer const& other) : m_internal_writer(other.m_internal_writer->clone()) {} + Writer(Writer&& other) noexcept = default; + Writer& operator=(Writer const& other) = delete; + Writer& operator=(Writer&& other) noexcept = default; ~Writer() = default; - void push(T const &posting) { m_internal_writer->push(posting); } - void push(T &&posting) { m_internal_writer->push(posting); } - auto write(ByteOStream &os) const -> std::size_t { return m_internal_writer->write(os); } - auto write(std::ostream &os) const -> std::size_t { return m_internal_writer->write(os); } + void init(pisa::binary_freq_collection const& collection) + { + m_internal_writer->init(collection); + } + void push(T const& posting) { m_internal_writer->push(posting); } + void push(T&& posting) { m_internal_writer->push(posting); } + auto write(ByteOStream& os) const -> std::size_t { return m_internal_writer->write(os); } + auto write(std::ostream& os) const -> std::size_t { return m_internal_writer->write(os); } [[nodiscard]] auto encoding() const -> std::uint32_t { return m_internal_writer->encoding(); } void reset() { return m_internal_writer->reset(); } struct WriterInterface { WriterInterface() = default; - WriterInterface(WriterInterface const &) = default; - WriterInterface(WriterInterface &&) noexcept = default; - WriterInterface &operator=(WriterInterface const &) = default; - WriterInterface &operator=(WriterInterface &&) noexcept = default; + WriterInterface(WriterInterface const&) = default; + WriterInterface(WriterInterface&&) noexcept = default; + WriterInterface& operator=(WriterInterface const&) = default; + WriterInterface& operator=(WriterInterface&&) noexcept = default; virtual ~WriterInterface() = default; - virtual void push(T const &posting) = 0; - virtual void push(T &&posting) = 0; - virtual auto write(ByteOStream &os) const -> std::size_t = 0; - virtual auto write(std::ostream &os) const -> std::size_t = 0; + virtual void init(pisa::binary_freq_collection const& collection) = 0; + virtual void push(T const& posting) = 0; + virtual void push(T&& posting) = 0; + virtual auto write(ByteOStream& os) const -> std::size_t = 0; + virtual auto write(std::ostream& os) const -> std::size_t = 0; virtual void reset() = 0; [[nodiscard]] virtual auto encoding() const -> std::uint32_t = 0; [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; @@ -97,15 +151,19 @@ struct Writer { struct WriterImpl : WriterInterface { explicit WriterImpl(W writer) : m_writer(std::move(writer)) {} WriterImpl() = default; - WriterImpl(WriterImpl const &) = default; - WriterImpl(WriterImpl &&) noexcept = default; - WriterImpl &operator=(WriterImpl const &) = default; - WriterImpl &operator=(WriterImpl &&) noexcept = default; + WriterImpl(WriterImpl const&) = default; + WriterImpl(WriterImpl&&) noexcept = default; + WriterImpl& operator=(WriterImpl const&) = default; + WriterImpl& operator=(WriterImpl&&) noexcept = default; ~WriterImpl() = default; - void push(T const &posting) override { m_writer.push(posting); } - void push(T &&posting) override { m_writer.push(posting); } - auto write(ByteOStream &os) const -> std::size_t override { return m_writer.write(os); } - auto write(std::ostream &os) const -> std::size_t override { return m_writer.write(os); } + void init(pisa::binary_freq_collection const& collection) override + { + m_writer.init(collection); + } + void push(T const& posting) override { m_writer.push(posting); } + void push(T&& posting) override { m_writer.push(posting); } + auto write(ByteOStream& os) const -> std::size_t override { return m_writer.write(os); } + auto write(std::ostream& os) const -> std::size_t override { return m_writer.write(os); } void reset() override { return m_writer.reset(); } [[nodiscard]] auto encoding() const -> std::uint32_t override { return W::encoding(); } [[nodiscard]] auto clone() const -> std::unique_ptr override @@ -123,7 +181,7 @@ struct Writer { }; template -[[nodiscard]] inline auto make_writer(W &&writer) +[[nodiscard]] inline auto make_writer(W&& writer) { return Writer(std::forward(writer)); } diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index aa18c1bbc..5d0bd841c 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -97,7 +97,11 @@ struct UnionLookupJoin { m_inspect(inspect) { if (m_essential_cursors.empty()) { - m_sentinel = std::numeric_limits::max(); + if (m_lookup_cursors.empty()) { + m_sentinel = std::numeric_limits::max(); + } else { + m_sentinel = min_sentinel(m_lookup_cursors); + } m_current_value = m_sentinel; m_current_payload = m_init; return; diff --git a/include/pisa/v1/wand.hpp b/include/pisa/v1/wand.hpp index 603633b0d..533b8ff28 100644 --- a/include/pisa/v1/wand.hpp +++ b/include/pisa/v1/wand.hpp @@ -84,7 +84,6 @@ struct WandJoin { return lhs->value() < rhs->value(); }); - m_next_docid = min_value(m_cursors); m_sentinel = min_sentinel(m_cursors); advance(); } @@ -102,45 +101,67 @@ struct WandJoin { PISA_ALWAYSINLINE void advance() { - while (true) { - auto pivot = find_pivot(); - if (pivot == m_cursor_pointers.end()) { + bool exit = false; + while (not exit) { + auto upper_bound = 0.0F; + std::size_t pivot; + bool found_pivot = false; + for (pivot = 0; pivot < m_cursor_pointers.size(); ++pivot) { + if (m_cursor_pointers[pivot]->empty()) { + break; + } + upper_bound += m_cursor_pointers[pivot]->max_score(); + if (m_above_threshold(upper_bound)) { + found_pivot = true; + break; + } + } + // auto pivot = find_pivot(); + // if (PISA_UNLIKELY(not pivot)) { + // m_current_value = sentinel(); + // exit = true; + // break; + //} + if (not found_pivot) { m_current_value = sentinel(); - return; + exit = true; + break; } - auto pivot_docid = (*pivot)->value(); + // auto pivot_docid = (*pivot)->value(); + auto pivot_docid = m_cursor_pointers[pivot]->value(); if (pivot_docid == m_cursor_pointers.front()->value()) { m_current_value = pivot_docid; m_current_payload = m_init; - [&]() { - auto iter = m_cursor_pointers.begin(); - for (; iter != m_cursor_pointers.end(); ++iter) { - auto* cursor = *iter; - if (cursor->value() != pivot_docid) { - break; - } - m_current_payload = m_accumulate(m_current_payload, *cursor); - cursor->advance(); + for (auto* cursor : m_cursor_pointers) { + if (cursor->value() != pivot_docid) { + break; } - return iter; - }(); + m_current_payload = m_accumulate(m_current_payload, *cursor); + cursor->advance(); + } auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); }; std::sort(m_cursor_pointers.begin(), m_cursor_pointers.end(), by_docid); - return; - } - - auto next_list = std::distance(m_cursor_pointers.begin(), pivot); - for (; m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { + exit = true; + } else { + auto next_list = pivot; + for (; m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { + } + m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + // bubble_down(next_list); + for (size_t idx = next_list + 1; idx < m_cursor_pointers.size(); idx += 1) { + if (m_cursor_pointers[idx]->value() < m_cursor_pointers[idx - 1]->value()) { + std::swap(m_cursor_pointers[idx], m_cursor_pointers[idx - 1]); + } else { + break; + } + } } - m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); - bubble_down(next_list); } } - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t; // TODO(michal) [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_current_value >= sentinel(); @@ -158,27 +179,39 @@ struct WandJoin { } } - PISA_ALWAYSINLINE auto find_pivot() + PISA_ALWAYSINLINE auto find_pivot() -> tl::optional { auto upper_bound = 0.0F; - for (auto pivot = m_cursor_pointers.begin(); pivot != m_cursor_pointers.end(); ++pivot) { - auto&& cursor = **pivot; - if (cursor.empty()) { + std::size_t pivot; + for (pivot = 0; pivot < m_cursor_pointers.size(); ++pivot) { + if (m_cursor_pointers[pivot]->empty()) { break; } - upper_bound += cursor.max_score(); + upper_bound += m_cursor_pointers[pivot]->max_score(); if (m_above_threshold(upper_bound)) { - auto pivot_docid = (*pivot)->value(); - while (std::next(pivot) != m_cursor_pointers.end()) { - if ((*std::next(pivot))->value() != pivot_docid) { - break; - } - pivot = std::next(pivot); - } - return pivot; + return tl::make_optional(pivot); } } - return m_cursor_pointers.end(); + return tl::nullopt; + // auto upper_bound = 0.0F; + // for (auto pivot = m_cursor_pointers.begin(); pivot != m_cursor_pointers.end(); ++pivot) { + // auto&& cursor = **pivot; + // if (cursor.empty()) { + // break; + // } + // upper_bound += cursor.max_score(); + // if (m_above_threshold(upper_bound)) { + // auto pivot_docid = (*pivot)->value(); + // while (std::next(pivot) != m_cursor_pointers.end()) { + // if ((*std::next(pivot))->value() != pivot_docid) { + // break; + // } + // pivot = std::next(pivot); + // } + // return pivot; + // } + //} + // return m_cursor_pointers.end(); } CursorContainer m_cursors; @@ -191,7 +224,6 @@ struct WandJoin { value_type m_current_value{}; value_type m_sentinel{}; payload_type m_current_payload{}; - std::uint32_t m_next_docid{}; payload_type m_previous_threshold{}; Inspect* m_inspect; @@ -248,76 +280,78 @@ struct BlockMaxWandJoin { } [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_wand_join.empty(); } - constexpr void advance() - { - while (true) { - auto pivot = m_wand_join.find_pivot(); - if (pivot == m_wand_join.m_cursor_pointers.end()) { - m_wand_join.m_current_value = m_wand_join.sentinel(); - return; - } - - auto pivot_docid = (*pivot)->value(); - // auto block_upper_bound = std::accumulate( - // m_wand_join.m_cursor_pointers.begin(), - // std::next(pivot), - // 0.0F, - // [&](auto acc, auto* cursor) { return acc + cursor->block_max_score(pivot_docid); - // }); - // if (not m_wand_join.m_above_threshold(block_upper_bound)) { - // block_max_advance(pivot, pivot_docid); - // continue; - //} - if (pivot_docid == m_wand_join.m_cursor_pointers.front()->value()) { - m_wand_join.m_current_value = pivot_docid; - m_wand_join.m_current_payload = m_wand_join.m_init; - - [&]() { - auto iter = m_wand_join.m_cursor_pointers.begin(); - for (; iter != m_wand_join.m_cursor_pointers.end(); ++iter) { - auto* cursor = *iter; - if (cursor->value() != pivot_docid) { - break; - } - m_wand_join.m_current_payload = - m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); - cursor->advance(); - } - return iter; - }(); - // for (auto* cursor : m_wand_join.m_cursor_pointers) { - // if (cursor->value() != pivot_docid) { - // break; - // } - // m_wand_join.m_current_payload = - // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); - // block_upper_bound -= cursor->block_max_score() - cursor->payload(); - // if (not m_wand_join.m_above_threshold(block_upper_bound)) { - // break; - // } - //} - - // for (auto* cursor : m_wand_join.m_cursor_pointers) { - // if (cursor->value() != pivot_docid) { - // break; - // } - // cursor->advance(); - //} - - auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); }; - std::sort(m_wand_join.m_cursor_pointers.begin(), - m_wand_join.m_cursor_pointers.end(), - by_docid); - return; - } - - auto next_list = std::distance(m_wand_join.m_cursor_pointers.begin(), pivot); - for (; m_wand_join.m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) { - } - m_wand_join.m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); - m_wand_join.bubble_down(next_list); - } - } + constexpr void advance() { m_wand_join.advance(); } + // constexpr void advance() + //{ + // while (true) { + // auto pivot = m_wand_join.find_pivot(); + // if (pivot == m_wand_join.m_cursor_pointers.end()) { + // m_wand_join.m_current_value = m_wand_join.sentinel(); + // return; + // } + + // auto pivot_docid = (*pivot)->value(); + // // auto block_upper_bound = std::accumulate( + // // m_wand_join.m_cursor_pointers.begin(), + // // std::next(pivot), + // // 0.0F, + // // [&](auto acc, auto* cursor) { return acc + cursor->block_max_score(pivot_docid); + // // }); + // // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // // block_max_advance(pivot, pivot_docid); + // // continue; + // //} + // if (pivot_docid == m_wand_join.m_cursor_pointers.front()->value()) { + // m_wand_join.m_current_value = pivot_docid; + // m_wand_join.m_current_payload = m_wand_join.m_init; + + // [&]() { + // auto iter = m_wand_join.m_cursor_pointers.begin(); + // for (; iter != m_wand_join.m_cursor_pointers.end(); ++iter) { + // auto* cursor = *iter; + // if (cursor->value() != pivot_docid) { + // break; + // } + // m_wand_join.m_current_payload = + // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + // cursor->advance(); + // } + // return iter; + // }(); + // // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // // if (cursor->value() != pivot_docid) { + // // break; + // // } + // // m_wand_join.m_current_payload = + // // m_wand_join.m_accumulate(m_wand_join.m_current_payload, *cursor); + // // block_upper_bound -= cursor->block_max_score() - cursor->payload(); + // // if (not m_wand_join.m_above_threshold(block_upper_bound)) { + // // break; + // // } + // //} + + // // for (auto* cursor : m_wand_join.m_cursor_pointers) { + // // if (cursor->value() != pivot_docid) { + // // break; + // // } + // // cursor->advance(); + // //} + + // auto by_docid = [](auto&& lhs, auto&& rhs) { return lhs->value() < rhs->value(); + // }; std::sort(m_wand_join.m_cursor_pointers.begin(), + // m_wand_join.m_cursor_pointers.end(), + // by_docid); + // return; + // } + + // auto next_list = std::distance(m_wand_join.m_cursor_pointers.begin(), pivot); + // for (; m_wand_join.m_cursor_pointers[next_list]->value() == pivot_docid; --next_list) + // { + // } + // m_wand_join.m_cursor_pointers[next_list]->advance_to_geq(pivot_docid); + // m_wand_join.bubble_down(next_list); + // } + //} private: template diff --git a/src/v1/bit_sequence_cursor.cpp b/src/v1/bit_sequence_cursor.cpp new file mode 100644 index 000000000..343b466c1 --- /dev/null +++ b/src/v1/bit_sequence_cursor.cpp @@ -0,0 +1 @@ +#include "v1/bit_sequence_cursor.hpp" diff --git a/src/v1/default_index_runner.cpp b/src/v1/default_index_runner.cpp new file mode 100644 index 000000000..649434d51 --- /dev/null +++ b/src/v1/default_index_runner.cpp @@ -0,0 +1,23 @@ +#include "v1/default_index_runner.hpp" + +namespace pisa::v1 { + +//[[nodiscard]] auto index_runner(IndexMetadata metadata) +//{ +// return index_runner( +// std::move(metadata), +// std::make_tuple(RawReader{}, +// DocumentBlockedReader<::pisa::simdbp_block>{}, +// DocumentBitSequenceReader>{}), +// std::make_tuple(RawReader{}, PayloadBlockedReader<::pisa::simdbp_block>{})); +//} +// +//[[nodiscard]] auto scored_index_runner(IndexMetadata metadata) +//{ +// return scored_index_runner( +// std::move(metadata), +// std::make_tuple(RawReader{}, DocumentBlockedReader<::pisa::simdbp_block>{}), +// std::make_tuple(RawReader{})); +//} + +} // namespace pisa::v1 diff --git a/src/v1/index.cpp b/src/v1/index.cpp index fe538056d..61128fe13 100644 --- a/src/v1/index.cpp +++ b/src/v1/index.cpp @@ -157,14 +157,4 @@ template auto BaseIndex::fetch_bigram_payloads<1>(TermId bigram) const return m_quantized_max_scores.at(term); } -[[nodiscard]] auto BaseIndex::block_max_document_reader() const -> Reader> const& -{ - return m_block_max_document_reader; -} - -[[nodiscard]] auto BaseIndex::block_max_score_reader() const -> Reader> const& -{ - return m_block_max_score_reader; -} - } // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index cc7aad831..a589dde97 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -169,9 +169,9 @@ auto build_bigram_index(IndexMetadata meta, std::vector::Writer; - PostingBuilder document_builder(document_writer_type{}); - PostingBuilder frequency_builder_0(frequency_writer_type{}); - PostingBuilder frequency_builder_1(frequency_writer_type{}); + PostingBuilder document_builder(document_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_0(frequency_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_1(frequency_writer_type{index.num_documents()}); document_builder.write_header(document_out); frequency_builder_0.write_header(frequency_out_0); diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index f381089eb..1ce83f25b 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -47,25 +47,38 @@ struct IndexFixture { using FrequencyReader = typename v1::CursorTraits::Reader; using ScoreReader = typename v1::CursorTraits::Reader; - IndexFixture() : m_tmpdir(std::make_unique()) + explicit IndexFixture(bool verify = true, + bool score = true, + bool bm_score = true, + bool build_bigrams = true) + : m_tmpdir(std::make_unique()) { auto index_basename = (tmpdir().path() / "inv").string(); v1::compress_binary_collection(PISA_SOURCE_DIR "/test/test_data/test_collection", PISA_SOURCE_DIR "/test/test_data/test_collection.fwd", index_basename, - 2, + 1, v1::make_writer(), v1::make_writer()); - auto errors = v1::verify_compressed_index(PISA_SOURCE_DIR "/test/test_data/test_collection", - index_basename); - for (auto&& error : errors) { - std::cerr << error << '\n'; + if (verify) { + auto errors = v1::verify_compressed_index( + PISA_SOURCE_DIR "/test/test_data/test_collection", index_basename); + for (auto&& error : errors) { + std::cerr << error << '\n'; + } + REQUIRE(errors.empty()); } - REQUIRE(errors.empty()); auto yml = fmt::format("{}.yml", index_basename); - auto meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); - meta = v1::bm_score_index(meta, 5, 1); - v1::build_bigram_index(meta, collect_unique_bigrams(test_queries(), []() {})); + auto meta = v1::IndexMetadata::from_file(yml); + if (score) { + meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); + } + if (bm_score) { + meta = v1::bm_score_index(meta, 5, 1); + } + if (build_bigrams) { + v1::build_bigram_index(meta, collect_unique_bigrams(test_queries(), []() {})); + } } [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp index 7b349db82..5a53ab484 100644 --- a/test/v1/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -10,24 +10,35 @@ #include "../temporary_directory.hpp" #include "binary_collection.hpp" #include "codec/simdbp.hpp" +#include "index_fixture.hpp" #include "pisa_config.hpp" +#include "v1/bit_sequence_cursor.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/default_index_runner.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" #include "v1/types.hpp" using pisa::binary_freq_collection; using pisa::v1::compress_binary_collection; using pisa::v1::DocId; +using pisa::v1::DocumentBitSequenceCursor; +using pisa::v1::DocumentBlockedCursor; using pisa::v1::DocumentBlockedWriter; using pisa::v1::Frequency; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceCursor; +using pisa::v1::PayloadBlockedCursor; using pisa::v1::PayloadBlockedWriter; -using pisa::v1::RawReader; +using pisa::v1::PositiveSequence; +using pisa::v1::RawCursor; using pisa::v1::RawWriter; using pisa::v1::TermId; @@ -94,3 +105,38 @@ TEST_CASE("Binary collection index -- SIMDBP", "[v1][unit]") } }); } + +TEMPLATE_TEST_CASE("Index", + "[v1][integration]", + (IndexFixture, RawCursor, RawCursor>), + (IndexFixture, + PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>), + (IndexFixture, + PayloadBitSequenceCursor>, + RawCursor>), + (IndexFixture>, + PayloadBitSequenceCursor>, + RawCursor>)) +{ + tbb::task_scheduler_init init{1}; + TestType fixture(false, false, false, false); + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + auto run = pisa::v1::index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader())); + auto bci = binary_freq_collection(PISA_SOURCE_DIR "/test/test_data/test_collection"); + run([&](auto&& index) { + REQUIRE(bci.num_docs() == index.num_documents()); + REQUIRE(bci.size() == index.num_terms()); + auto bci_iter = bci.begin(); + for (auto term = 0; term < 1'000; term += 1) { + REQUIRE(std::vector(bci_iter->docs.begin(), bci_iter->docs.end()) + == collect(index.documents(term))); + REQUIRE(std::vector(bci_iter->freqs.begin(), bci_iter->freqs.end()) + == collect(index.payloads(term))); + ++bci_iter; + } + }); +} diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index adbe629a8..f2a47f190 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -20,6 +20,7 @@ #include "query/algorithm/ranked_or_query.hpp" #include "query/queries.hpp" #include "scorer/bm25.hpp" +#include "v1/bit_sequence_cursor.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" #include "v1/cursor_intersection.hpp" @@ -35,19 +36,24 @@ #include "v1/query.hpp" #include "v1/score_index.hpp" #include "v1/scorer/bm25.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" #include "v1/taat_or.hpp" #include "v1/types.hpp" #include "v1/union_lookup.hpp" #include "v1/wand.hpp" -using namespace pisa; using pisa::v1::DocId; +using pisa::v1::DocumentBitSequenceCursor; using pisa::v1::DocumentBlockedCursor; using pisa::v1::Frequency; using pisa::v1::Index; using pisa::v1::IndexMetadata; using pisa::v1::ListSelection; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceCursor; using pisa::v1::PayloadBlockedCursor; +using pisa::v1::PositiveSequence; using pisa::v1::RawCursor; static constexpr auto RELATIVE_ERROR = 0.1F; @@ -64,7 +70,7 @@ struct IndexData { collection.num_docs(), collection, "bm25", - BlockSize(FixedBlock()), + ::pisa::BlockSize(::pisa::FixedBlock()), {}) { @@ -77,12 +83,12 @@ struct IndexData { } builder.build(v0_index); - term_id_vec q; + ::pisa::term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); + queries.push_back(::pisa::parse_query_ids(query_line)); }; - io::for_each_line(qfile, push_query); + ::pisa::io::for_each_line(qfile, push_query); std::string t; std::ifstream tin(PISA_SOURCE_DIR "/test/test_data/top5_thresholds"); @@ -99,13 +105,13 @@ struct IndexData { return IndexData::data.get(); } - global_parameters params; - binary_freq_collection collection; - binary_collection document_sizes; + ::pisa::global_parameters params; + ::pisa::binary_freq_collection collection; + ::pisa::binary_collection document_sizes; v0_Index v0_index; - std::vector queries; + std::vector<::pisa::Query> queries; std::vector thresholds; - wand_data wdata; + ::pisa::wand_data<::pisa::wand_data_raw> wdata; }; /// Inefficient, do not use in production code. @@ -131,10 +137,13 @@ TEMPLATE_TEST_CASE("Query", (IndexFixture, RawCursor, RawCursor>), (IndexFixture, PayloadBlockedCursor<::pisa::simdbp_block>, + RawCursor>), + (IndexFixture>, + PayloadBitSequenceCursor>, RawCursor>)) { tbb::task_scheduler_init init(1); - auto data = IndexData, RawCursor>, Index, RawCursor>>::get(); TestType fixture; @@ -157,35 +166,35 @@ TEMPLATE_TEST_CASE("Query", CAPTURE(with_threshold); auto index_basename = (fixture.tmpdir().path() / "inv").string(); auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); - ranked_or_query or_q(10); + ::pisa::ranked_or_query or_q(10); auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { if (name == "daat_or") { - return daat_or(query, index, topk_queue(10), scorer); + return daat_or(query, index, ::pisa::topk_queue(10), scorer); } if (name == "maxscore") { - return maxscore(query, index, topk_queue(10), scorer); + return maxscore(query, index, ::pisa::topk_queue(10), scorer); } if (name == "wand") { - return wand(query, index, topk_queue(10), scorer); + return wand(query, index, ::pisa::topk_queue(10), scorer); } if (name == "bmw") { - return bmw(query, index, topk_queue(10), scorer); + return bmw(query, index, ::pisa::topk_queue(10), scorer); } if (name == "maxscore_union_lookup") { - return maxscore_union_lookup(query, index, topk_queue(10), scorer); + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); } if (name == "unigram_union_lookup") { query.selections(ListSelection{.unigrams = query.get_term_ids(), .bigrams = {}}); - return unigram_union_lookup(query, index, topk_queue(10), scorer); + return unigram_union_lookup(query, index, ::pisa::topk_queue(10), scorer); } if (name == "union_lookup") { if (query.get_term_ids().size() > 8) { - return maxscore_union_lookup(query, index, topk_queue(10), scorer); + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); } - return union_lookup(query, index, topk_queue(10), scorer); + return union_lookup(query, index, ::pisa::topk_queue(10), scorer); } if (name == "lookup_union") { - return lookup_union(query, index, topk_queue(10), scorer); + return lookup_union(query, index, ::pisa::topk_queue(10), scorer); } std::abort(); }; @@ -201,10 +210,11 @@ TEMPLATE_TEST_CASE("Query", CAPTURE(idx); CAPTURE(intersections[idx]); - or_q(make_scored_cursors(data->v0_index, - ::pisa::bm25>(data->wdata), - ::pisa::Query{{}, query.get_term_ids(), {}}), - data->v0_index.num_docs()); + or_q( + make_scored_cursors(data->v0_index, + ::pisa::bm25<::pisa::wand_data<::pisa::wand_data_raw>>(data->wdata), + ::pisa::Query{{}, query.get_term_ids(), {}}), + data->v0_index.num_docs()); auto expected = or_q.topk(); if (with_threshold) { query.threshold(expected.back().first - 1.0F); @@ -214,7 +224,7 @@ TEMPLATE_TEST_CASE("Query", auto run = pisa::v1::index_runner(meta, std::make_tuple(fixture.document_reader()), std::make_tuple(fixture.frequency_reader())); - std::vector results; + std::vector results; run([&](auto&& index) { auto que = run_query(algorithm, query, index, make_bm25(index)); que.finalize(); @@ -234,16 +244,16 @@ TEMPLATE_TEST_CASE("Query", expected.resize(on_the_fly.size()); std::sort(expected.begin(), expected.end(), approximate_order); - if (algorithm == "bmw") { - for (size_t i = 0; i < on_the_fly.size(); ++i) { - std::cerr << fmt::format("{} {} -- {} {}\n", - on_the_fly[i].second, - on_the_fly[i].first, - expected[i].second, - expected[i].first); - } - std::cerr << '\n'; - } + // if (algorithm == "bmw") { + // for (size_t i = 0; i < on_the_fly.size(); ++i) { + // std::cerr << fmt::format("{} {} -- {} {}\n", + // on_the_fly[i].second, + // on_the_fly[i].first, + // expected[i].second, + // expected[i].first); + // } + // std::cerr << '\n'; + //} for (size_t i = 0; i < on_the_fly.size(); ++i) { REQUIRE(on_the_fly[i].second == expected[i].second); @@ -260,7 +270,7 @@ TEMPLATE_TEST_CASE("Query", auto run = pisa::v1::scored_index_runner(meta, std::make_tuple(fixture.document_reader()), std::make_tuple(fixture.score_reader())); - std::vector results; + std::vector results; run([&](auto&& index) { auto que = run_query(algorithm, query, index, v1::VoidScorer{}); que.finalize(); diff --git a/v1/compress.cpp b/v1/compress.cpp index dda0db0e0..1d2521a1b 100644 --- a/v1/compress.cpp +++ b/v1/compress.cpp @@ -4,19 +4,27 @@ #include #include +#include "sequence/partitioned_sequence.hpp" +#include "v1/bit_sequence_cursor.hpp" #include "v1/blocked_cursor.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" +#include "v1/sequence/partitioned_sequence.hpp" +#include "v1/sequence/positive_sequence.hpp" #include "v1/types.hpp" using std::literals::string_view_literals::operator""sv; using pisa::v1::compress_binary_collection; +using pisa::v1::DocumentBitSequenceWriter; using pisa::v1::DocumentBlockedWriter; using pisa::v1::EncodingId; using pisa::v1::make_index_builder; +using pisa::v1::PartitionedSequence; +using pisa::v1::PayloadBitSequenceWriter; using pisa::v1::PayloadBlockedWriter; +using pisa::v1::PositiveSequence; using pisa::v1::RawWriter; using pisa::v1::verify_compressed_index; @@ -28,6 +36,9 @@ auto document_encoding(std::string_view name) -> std::uint32_t if (name == "simdbp"sv) { return EncodingId::BlockDelta | EncodingId::SimdBP; } + if (name == "pef"sv) { + return EncodingId::BitSequence | EncodingId::PEF; + } spdlog::error("Unknown encoding: {}", name); std::exit(1); } @@ -40,6 +51,9 @@ auto frequency_encoding(std::string_view name) -> std::uint32_t if (name == "simdbp"sv) { return EncodingId::Block | EncodingId::SimdBP; } + if (name == "pef"sv) { + return EncodingId::BitSequence | EncodingId::PositiveSeq; + } spdlog::error("Unknown encoding: {}", name); std::exit(1); } @@ -62,9 +76,13 @@ int main(int argc, char** argv) CLI11_PARSE(app, argc, argv); tbb::task_scheduler_init init(threads); - auto build = make_index_builder(RawWriter{}, - DocumentBlockedWriter<::pisa::simdbp_block>{}, - PayloadBlockedWriter<::pisa::simdbp_block>{}); + auto build = + make_index_builder(std::make_tuple(RawWriter{}, + DocumentBlockedWriter<::pisa::simdbp_block>{}, + DocumentBitSequenceWriter>{}), + std::make_tuple(RawWriter{}, + PayloadBlockedWriter<::pisa::simdbp_block>{}, + PayloadBitSequenceWriter>{})); build(document_encoding(encoding), frequency_encoding(encoding), [&](auto document_writer, auto payload_writer) { From 14b779007566efe162db5bae0cacce35b35f07a3 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 6 Jan 2020 14:23:39 +0000 Subject: [PATCH 41/56] Minor fixes --- CMakeLists.txt | 9 +-- include/pisa/v1/cursor/accumulate.hpp | 12 ++++ script/cw09b-est-pef.sh | 99 +++++++++++++++++++++++++++ script/cw09b-est.sh | 91 ++---------------------- script/cw09b-exact.sh | 92 +++++++++++++++++++++++++ script/cw09b.sh | 80 +++++++++++++--------- v1/filter_queries.cpp | 2 +- v1/intersection.cpp | 55 +++++++++++++-- v1/query.cpp | 8 +-- 9 files changed, 312 insertions(+), 136 deletions(-) create mode 100644 include/pisa/v1/cursor/accumulate.hpp create mode 100644 script/cw09b-est-pef.sh create mode 100644 script/cw09b-exact.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 791574e24..d2c35e9e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,10 +50,10 @@ find_package(codecov) list(APPEND LCOV_REMOVE_PATTERNS "'${PROJECT_SOURCE_DIR}/external/*'") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-strict-aliasing") -set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DGSL_UNENFORCED_ON_CONTRACT_VIOLATION") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DGSL_UNENFORCED_ON_CONTRACT_VIOLATION -flto") if (UNIX) # For hardware popcount and other special instructions - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wno-odr -Wfatal-errors") # Extensive warnings set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") @@ -62,9 +62,10 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") endif () - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb") # Add debug info anyway + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb -gdwarf") # Add debug info anyway - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") if (${FORCE_COLORED_OUTPUT}) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/include/pisa/v1/cursor/accumulate.hpp b/include/pisa/v1/cursor/accumulate.hpp new file mode 100644 index 000000000..94b1d6b56 --- /dev/null +++ b/include/pisa/v1/cursor/accumulate.hpp @@ -0,0 +1,12 @@ +#include "v1/cursor/for_each.hpp" + +namespace pisa::v1 { + +template +[[nodiscard]] constexpr inline auto accumulate(Cursor cursor, Payload init, AccumulateFn accumulate) +{ + for_each(cursor, [&](auto&& cursor) { init = accumulate(init, cursor); }); + return init; +} + +} // namespace pisa::v1 diff --git a/script/cw09b-est-pef.sh b/script/cw09b-est-pef.sh new file mode 100644 index 000000000..7bd029a2b --- /dev/null +++ b/script/cw09b-est-pef.sh @@ -0,0 +1,99 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +ENCODING="pef" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-pef" +THREADS=4 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est-pef" +#OUTPUT_DIR="/data/michal/intersect/cw09b-est-top20" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" + +set -e +set -x + +## Compress an inverted index in `binary_freq_collection` format. +${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 + +# Filter out queries witout existing terms. +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} + +# Extract intersections +${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand > ${OUTPUT_DIR}/bench.wand +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw > ${OUTPUT_DIR}/bench.bmw +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw \ + > ${OUTPUT_DIR}/bench.bmw-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh index 8739d3ee5..e6252643c 100644 --- a/script/cw09b-est.sh +++ b/script/cw09b-est.sh @@ -2,10 +2,9 @@ PISA_BIN="/home/michal/pisa/build/bin" INTERSECT_BIN="/home/michal/intersect/target/release/intersect" BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" FWD="/home/amallia/cw09b/CW09B.fwd" -BASENAME="/data/michal/work/v1/cw09b/cw09b" -THREADS=4 -TYPE="block_simdbp" # v0.6 ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" +THREADS=4 QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" #QUERIES="/home/michal/topics.web.51-200.jl" K=1000 @@ -14,87 +13,5 @@ OUTPUT_DIR="/data/michal/intersect/cw09b-est" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" -#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" - -set -e -set -x - -## Compress an inverted index in `binary_freq_collection` format. -#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} - -# This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} - -# This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 - -# Filter out queries witout existing terms. -paste -d: ${QUERIES} ${THRESHOLDS} \ - | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ - | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} - -# This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} - -# Extract intersections -${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ - | grep -v "\[warning\]" \ - > ${OUTPUT_DIR}/intersections.jl - -# Select unigrams -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ - --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 - -# Run benchmarks -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand > ${OUTPUT_DIR}/bench.wand -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw > ${OUTPUT_DIR}/bench.bmw -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw \ - > ${OUTPUT_DIR}/bench.bmw-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore \ - > ${OUTPUT_DIR}/bench.maxscore-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup \ - > ${OUTPUT_DIR}/bench.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup \ - > ${OUTPUT_DIR}/bench.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup \ - > ${OUTPUT_DIR}/bench.union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union \ - > ${OUTPUT_DIR}/bench.lookup-union -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --benchmark --algorithm lookup-union \ - > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 - -# Analyze -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ - > ${OUTPUT_DIR}/stats.maxscore-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ - > ${OUTPUT_DIR}/stats.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ - > ${OUTPUT_DIR}/stats.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ - > ${OUTPUT_DIR}/stats.union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ - > ${OUTPUT_DIR}/stats.lookup-union -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --inspect --algorithm lookup-union \ - > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 - -# Evaluate -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ - > "${OUTPUT_DIR}/eval.maxscore-threshold" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ - > "${OUTPUT_DIR}/eval.maxscore-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ - > "${OUTPUT_DIR}/eval.unigram-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ - > "${OUTPUT_DIR}/eval.union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ - > "${OUTPUT_DIR}/eval.lookup-union" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ - > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" diff --git a/script/cw09b-exact.sh b/script/cw09b-exact.sh new file mode 100644 index 000000000..630b2391c --- /dev/null +++ b/script/cw09b-exact.sh @@ -0,0 +1,92 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +INV="/home/amallia/cw09b/CW09B" +BASENAME="/data/michal/work/v1/cw09b/cw09b" +THREADS=4 +TYPE="block_simdbp" # v0.6 +ENCODING="simdbp" # v1 +QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +#QUERIES="/home/michal/topics.web.51-200.jl" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" + +set -e +set -x + +## Compress an inverted index in `binary_freq_collection` format. +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} + +# Filter out queries witout existing terms. +#paste -d: ${QUERIES} ${THRESHOLDS} \ +# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ +# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} + +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} + +# Extract intersections +#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl + +# Select unigrams +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ + --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 + +# Run benchmarks +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark \ + --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 + +# Analyze +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 + +# Evaluate +${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ + > "${OUTPUT_DIR}/eval.maxscore-threshold" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ + > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ + > "${OUTPUT_DIR}/eval.unigram-union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ + > "${OUTPUT_DIR}/eval.union-lookup" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union" +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ + > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/script/cw09b.sh b/script/cw09b.sh index 630b2391c..f59e2fc3b 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -1,21 +1,25 @@ -PISA_BIN="/home/michal/pisa/build/bin" -INTERSECT_BIN="/home/michal/intersect/target/release/intersect" -BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" -FWD="/home/amallia/cw09b/CW09B.fwd" -INV="/home/amallia/cw09b/CW09B" -BASENAME="/data/michal/work/v1/cw09b/cw09b" -THREADS=4 -TYPE="block_simdbp" # v0.6 -ENCODING="simdbp" # v1 -QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" -#QUERIES="/home/michal/topics.web.51-200.jl" -K=1000 -OUTPUT_DIR="/data/michal/intersect/cw09b" -FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" -#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" - +# Fail if any variable is unset +set -u set -e + +source $1 + +echo "Running experiment with the following environment:" +echo "" +echo " PISA_BIN = ${PISA_BIN}" +echo " INTERSECT_BIN = ${INTERSECT_BIN}" +echo " BINARY_FREQ_COLL = ${BINARY_FREQ_COLL}" +echo " FWD = ${FWD}" +echo " BASENAME = ${BASENAME}" +echo " THREADS = ${THREADS}" +echo " ENCODING = ${ENCODING}" +echo " OUTPUT_DIR = ${OUTPUT_DIR}" +echo " QUERIES = ${QUERIES}" +echo " FILTERED_QUERIES = ${FILTERED_QUERIES}" +echo " K = ${K}" +echo " THRESHOLDS = ${THRESHOLDS}" +echo "" + set -x ## Compress an inverted index in `binary_freq_collection` format. @@ -24,6 +28,9 @@ set -x # This will produce both quantized scores and max scores (both quantized and not). #${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} +# This will produce both quantized scores and max scores (both quantized and not). +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 + # Filter out queries witout existing terms. #paste -d: ${QUERIES} ${THRESHOLDS} \ # | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ @@ -33,32 +40,40 @@ set -x #${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} # Extract intersections -#${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --mtc 2 \ -# | grep -v "\[warning\]" \ -# > ${OUTPUT_DIR}/intersections.jl +${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl # Select unigrams -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 1 > ${OUTPUT_DIR}/selections.1 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl --max 2 > ${OUTPUT_DIR}/selections.2 -${INTERSECT_BIN} -m graph-greedy ${OUTPUT_DIR}/intersections.jl \ - --max 2 --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ + --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ + --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 # Run benchmarks -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark \ - --algorithm maxscore > ${OUTPUT_DIR}/bench.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ - > ${OUTPUT_DIR}/bench.maxscore-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ - > ${OUTPUT_DIR}/bench.maxscore-union-lookup +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +# > ${OUTPUT_DIR}/bench.bmw-threshold +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ +# > ${OUTPUT_DIR}/bench.maxscore-threshold +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ +# > ${OUTPUT_DIR}/bench.maxscore-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ > ${OUTPUT_DIR}/bench.unigram-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ > ${OUTPUT_DIR}/bench.union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --benchmark --algorithm lookup-union --safe \ +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ + --benchmark --algorithm lookup-union \ > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 # Analyze ${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore @@ -90,3 +105,4 @@ ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm > "${OUTPUT_DIR}/eval.lookup-union" ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" + diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index f844fc233..ce17f6e07 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -21,7 +21,7 @@ int main(int argc, char** argv) auto meta = app.index_metadata(); auto queries = app.queries(meta); for (auto&& query : queries) { - if (query.term_ids()) { + if (not query.get_term_ids().empty()) { std::cout << *query.to_json() << '\n'; } } diff --git a/v1/intersection.cpp b/v1/intersection.cpp index b538667d5..b5b825b04 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -13,6 +13,7 @@ #include "intersection.hpp" #include "query/queries.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/cursor/accumulate.hpp" #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/default_index_runner.hpp" @@ -72,7 +73,8 @@ template void compute_intersections(Index const& index, std::vector const& queries, IntersectionType intersection_type, - tl::optional max_term_count) + tl::optional max_term_count, + bool existing) { for (auto const& query : queries) { auto intersections = nlohmann::json::array(); @@ -86,7 +88,40 @@ void compute_intersections(Index const& index, } }; if (intersection_type == IntersectionType::Combinations) { - for_all_subsets(query, max_term_count, inter); + if (existing) { + std::uint64_t left_mask = 1; + auto term_ids = query.get_term_ids(); + for (auto left = 0; left < term_ids.size(); left += 1) { + auto cursor = index.max_scored_cursor(term_ids[left], make_bm25(index)); + intersections.push_back(nlohmann::json{{"intersection", left_mask}, + {"cost", cursor.size()}, + {"max_score", cursor.max_score()}}); + std::uint64_t right_mask = left_mask << 1; + for (auto right = left + 1; right < term_ids.size(); right += 1) { + index + .scored_bigram_cursor(term_ids[left], term_ids[right], make_bm25(index)) + .map([&](auto&& cursor) { + // TODO(michal): Do not traverse once max scores for bigrams are + // implemented + auto cost = cursor.size(); + auto max_score = + accumulate(cursor, 0.0F, [](float acc, auto&& cursor) { + auto score = std::get<0>(cursor.payload()) + + std::get<1>(cursor.payload()); + return std::max(acc, score); + }); + intersections.push_back( + nlohmann::json{{"intersection", left_mask | right_mask}, + {"cost", cost}, + {"max_score", max_score}}); + }); + right_mask <<= 1; + } + left_mask <<= 1; + } + } else { + for_all_subsets(query, max_term_count, inter); + } } else { inter(query, tl::nullopt); } @@ -102,16 +137,20 @@ int main(int argc, const char** argv) spdlog::set_default_logger(spdlog::stderr_color_mt("")); bool combinations = false; + bool existing = false; std::optional max_term_count; pisa::App> app( "Calculates intersections for a v1 index."); auto* combinations_flag = app.add_flag( "--combinations", combinations, "Compute intersections for combinations of terms in query"); - app.add_option("--max-term-count,--mtc", - max_term_count, - "Max number of terms when computing combinations") - ->needs(combinations_flag); + auto* mtc_flag = app.add_option("--max-term-count,--mtc", + max_term_count, + "Max number of terms when computing combinations"); + mtc_flag->needs(combinations_flag); + app.add_flag("--existing", existing, "Use only existing bigrams") + ->needs(combinations_flag) + ->excludes(mtc_flag); CLI11_PARSE(app, argc, argv); auto mtc = max_term_count ? tl::make_optional(*max_term_count) : tl::optional{}; @@ -123,7 +162,9 @@ int main(int argc, const char** argv) auto queries = app.queries(meta); auto run = index_runner(meta); - run([&](auto&& index) { compute_intersections(index, queries, intersection_type, mtc); }); + run([&](auto&& index) { + compute_intersections(index, queries, intersection_type, mtc, existing); + }); } catch (std::exception const& error) { spdlog::error("{}", error.what()); } diff --git a/v1/query.cpp b/v1/query.cpp index 565d1dc05..9ff210a2e 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -13,6 +13,7 @@ #include "util/do_not_optimize_away.hpp" #include "v1/blocked_cursor.hpp" #include "v1/daat_or.hpp" +#include "v1/default_index_runner.hpp" #include "v1/index_metadata.hpp" #include "v1/inspect_query.hpp" #include "v1/maxscore.hpp" @@ -144,7 +145,7 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco return pisa::v1::unigram_union_lookup( query, index, std::move(topk), std::forward(scorer)); } - if (query.get_term_ids().size() > 8) { + if (query.get_term_ids().size() >= 8) { return pisa::v1::maxscore( query, index, std::move(topk), std::forward(scorer)); } @@ -301,10 +302,7 @@ int main(int argc, char** argv) } }); } else { - auto run = index_runner(meta, - std::make_tuple(RawReader{}, - DocumentBlockedReader<::pisa::simdbp_block>{}), - std::make_tuple(PayloadBlockedReader<::pisa::simdbp_block>{})); + auto run = index_runner(meta); run([&](auto&& index) { auto with_scorer = scorer_runner(index, make_bm25(index)); with_scorer("bm25", [&](auto scorer) { From 131b3058f7aa2ed4372108174014846a5e8c14a5 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 6 Jan 2020 19:45:37 +0000 Subject: [PATCH 42/56] Selecting best bigrams --- CMakeLists.txt | 2 +- include/pisa/v1/index_builder.hpp | 5 ++ include/pisa/v1/query.hpp | 4 ++ script/cw09b.sh | 2 +- src/v1/index_builder.cpp | 95 +++++++++++++++++++++++++++++-- src/v1/query.cpp | 18 ++++++ 6 files changed, 119 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d2c35e9e1..0f5d343bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-strict-aliasing") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DGSL_UNENFORCED_ON_CONTRACT_VIOLATION -flto") if (UNIX) # For hardware popcount and other special instructions - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wno-odr -Wfatal-errors") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wno-odr") # Extensive warnings set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 844a2aa06..09dea43b6 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -238,6 +238,11 @@ auto collect_unique_bigrams(std::vector const& queries, std::function const& callback) -> std::vector>; +[[nodiscard]] auto select_best_bigrams(IndexMetadata const& meta, + std::vector const& queries, + std::size_t num_bigrams_to_select) + -> std::vector>; + auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) -> IndexMetadata; diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 14fd92dcc..0370b66e8 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -72,6 +72,7 @@ struct Query { auto selections(gsl::span const> selections) -> Query&; auto selections(ListSelection selections) -> Query&; auto threshold(float threshold) -> Query&; + auto probability(float probability) -> Query&; /// Non-throwing getters [[nodiscard]] auto term_ids() const -> tl::optional const&>; @@ -79,6 +80,7 @@ struct Query { [[nodiscard]] auto k() const -> int; [[nodiscard]] auto selections() const -> tl::optional; [[nodiscard]] auto threshold() const -> tl::optional; + [[nodiscard]] auto probability() const -> tl::optional; [[nodiscard]] auto raw() const -> tl::optional; /// Throwing getters @@ -86,6 +88,7 @@ struct Query { [[nodiscard]] auto get_id() const -> std::string const&; [[nodiscard]] auto get_selections() const -> ListSelection const&; [[nodiscard]] auto get_threshold() const -> float; + [[nodiscard]] auto get_probability() const -> float; [[nodiscard]] auto get_raw() const -> std::string const&; [[nodiscard]] auto sorted_position(TermId term) const -> std::size_t; @@ -113,6 +116,7 @@ struct Query { tl::optional m_threshold{}; tl::optional m_id{}; tl::optional m_raw_string; + tl::optional m_probability; int m_k = 1000; }; diff --git a/script/cw09b.sh b/script/cw09b.sh index f59e2fc3b..04cfe1b60 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -37,7 +37,7 @@ set -x # | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} # Extract intersections ${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index a589dde97..bc6c068dc 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -1,9 +1,16 @@ -#include "v1/index_builder.hpp" +#include +#include + +#include + #include "codec/simdbp.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/cursor/accumulate.hpp" +#include "v1/cursor_intersection.hpp" #include "v1/default_index_runner.hpp" +#include "v1/index_builder.hpp" #include "v1/query.hpp" -#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" namespace pisa::v1 { @@ -90,13 +97,10 @@ auto verify_compressed_index(std::string const& input, std::string_view output) std::vector> pair_mapping; auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); auto scores_file_1 = fmt::format("{}.bigram_bm25_1", index_basename); - // auto compound_scores_file = fmt::format("{}.bigram_bm25", index_basename); auto score_offsets_file_0 = fmt::format("{}.bigram_bm25_offsets_0", index_basename); auto score_offsets_file_1 = fmt::format("{}.bigram_bm25_offsets_1", index_basename); - // auto compound_score_offsets_file = fmt::format("{}.bigram_bm25_offsets", index_basename); std::ofstream score_out_0(scores_file_0); std::ofstream score_out_1(scores_file_1); - // std::ofstream compound_score_out(compound_scores_file); run([&](auto&& index) { ProgressStatus status(bigrams.size(), @@ -141,6 +145,87 @@ auto verify_compressed_index(std::string const& input, std::string_view output) PostingFilePaths{scores_file_1, score_offsets_file_1}}; } +template > +struct HeapPriorityQueue { + using value_type = T; + + explicit HeapPriorityQueue(std::size_t capacity, Order order = Order()) + : m_capacity(capacity), m_order(std::move(order)) + { + m_elements.reserve(m_capacity + 1); + } + HeapPriorityQueue(HeapPriorityQueue const&) = default; + HeapPriorityQueue(HeapPriorityQueue&&) noexcept = default; + HeapPriorityQueue& operator=(HeapPriorityQueue const&) = default; + HeapPriorityQueue& operator=(HeapPriorityQueue&&) noexcept = default; + ~HeapPriorityQueue() = default; + + void push(value_type value) + { + m_elements.push_back(value); + if (PISA_UNLIKELY(m_elements.size() <= m_capacity)) { + std::push_heap(m_elements.begin(), m_elements.end(), m_order); + } else { + std::pop_heap(m_elements.begin(), m_elements.end(), m_order); + m_elements.pop_back(); + } + } + + [[nodiscard]] auto size() const noexcept { return m_elements.size(); } + [[nodiscard]] auto capacity() const noexcept { return m_capacity; } + + [[nodiscard]] auto take() && -> std::vector + { + std::sort(m_elements.begin(), m_elements.end(), m_order); + return std::move(m_elements); + } + + private: + std::size_t m_capacity; + std::vector m_elements; + Order m_order; +}; + +[[nodiscard]] auto select_best_bigrams(IndexMetadata const& meta, + std::vector const& queries, + std::size_t num_bigrams_to_select) + -> std::vector> +{ + using Bigram = std::pair; + auto order = [](auto&& lhs, auto&& rhs) { return lhs.second > rhs.second; }; + auto top_bigrams = HeapPriorityQueue(num_bigrams_to_select, order); + + auto run = index_runner(meta); + run([&](auto&& index) { + auto bigram_gain = [&](Query const& bigram) -> float { + auto&& term_ids = bigram.get_term_ids(); + runtime_assert(term_ids.size() == 2, "Queries must be of exactly two unique terms"); + auto cursors = index.scored_cursors(term_ids, make_bm25(index)); + auto union_length = cursors[0].size() + cursors[1].size(); + auto intersection_length = + accumulate(intersect(std::move(cursors), + false, + []([[maybe_unused]] auto count, + [[maybe_unused]] auto&& cursor, + [[maybe_unused]] auto idx) { return true; }), + std::size_t{0}, + [](auto count, [[maybe_unused]] auto&& cursor) { return count + 1; }); + return static_cast(bigram.get_probability()) * static_cast(union_length) + / static_cast(intersection_length); + }; + for (auto&& query : queries) { + top_bigrams.push(std::make_pair(&query, bigram_gain(query))); + } + }); + auto top = std::move(top_bigrams).take(); + return ranges::views::transform(top, + [](auto&& elem) { + auto&& term_ids = elem.first->get_term_ids(); + return std::make_pair(term_ids[0], term_ids[1]); + }) + | ranges::to_vector; +} + auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) -> IndexMetadata { diff --git a/src/v1/query.cpp b/src/v1/query.cpp index d7fe741ca..c2feeb65a 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -138,6 +138,9 @@ template if (m_threshold) { (*query)["threshold"] = *m_threshold; } + if (m_probability) { + (*query)["probability"] = *m_probability; + } // TODO(michal) // tl::optional m_selections{}; // int m_k = 1000; @@ -190,6 +193,12 @@ auto Query::threshold(float threshold) -> Query& return *this; } +auto Query::probability(float probability) -> Query& +{ + m_probability = probability; + return *this; +} + auto Query::term_ids() const -> tl::optional const&> { return m_term_ids.map( @@ -205,6 +214,7 @@ auto Query::selections() const -> tl::optional return tl::nullopt; } auto Query::threshold() const -> tl::optional { return m_threshold; } +auto Query::probability() const -> tl::optional { return m_probability; } auto Query::raw() const -> tl::optional { if (m_raw_string) { @@ -246,6 +256,14 @@ auto Query::get_threshold() const -> float return *m_threshold; } +auto Query::get_probability() const -> float +{ + if (not m_probability) { + throw std::runtime_error("Probability is not set"); + } + return *m_probability; +} + auto Query::get_raw() const -> std::string const& { if (not m_raw_string) { From 36e8da9452e1294f63f2a58c08117cdc32b1c78f Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 6 Jan 2020 19:48:03 +0000 Subject: [PATCH 43/56] Add cereal library submodule --- .gitmodules | 3 +++ external/cereal | 1 + 2 files changed, 4 insertions(+) create mode 160000 external/cereal diff --git a/.gitmodules b/.gitmodules index 97ddee4b6..51a1937f0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -82,3 +82,6 @@ [submodule "external/json"] path = external/json url = https://github.com/nlohmann/json.git +[submodule "external/cereal"] + path = external/cereal + url = https://github.com/USCiLab/cereal.git diff --git a/external/cereal b/external/cereal new file mode 160000 index 000000000..a5a309531 --- /dev/null +++ b/external/cereal @@ -0,0 +1 @@ +Subproject commit a5a30953125e70b115a28dd76b64adf3c97cc883 From 826a772c7774a9c81452e5773659c64e25c4a0e5 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 7 Jan 2020 19:19:39 +0000 Subject: [PATCH 44/56] Fixes to selecting pairs for indexing --- include/pisa/v1/query.hpp | 17 ++++++++++++ script/cw09b-est.sh | 4 +-- script/cw09b.sh | 6 ++--- src/v1/index_builder.cpp | 13 +++++---- src/v1/query.cpp | 40 ++++++++++++++++++++++++++++ test/v1/test_v1_index.cpp | 56 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 126 insertions(+), 10 deletions(-) diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 0370b66e8..27f5bb2e2 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -65,6 +65,14 @@ struct Query { explicit Query(std::string query, tl::optional id = tl::nullopt); explicit Query(std::vector term_ids, tl::optional id = tl::nullopt); + template + static auto from_ids(Ids... ids) -> Query + { + std::vector id_vec; + (id_vec.push_back(ids), ...); + return Query(std::move(id_vec)); + } + /// Setters for optional values (or ones with default value). auto term_ids(std::vector term_ids) -> Query&; auto id(std::string) -> Query&; @@ -74,6 +82,15 @@ struct Query { auto threshold(float threshold) -> Query&; auto probability(float probability) -> Query&; + /// Consuming setters. + auto with_term_ids(std::vector term_ids) && -> Query; + auto with_id(std::string) && -> Query; + auto with_k(int k) && -> Query; + auto with_selections(gsl::span const> selections) && -> Query; + auto with_selections(ListSelection selections) && -> Query; + auto with_threshold(float threshold) && -> Query; + auto with_probability(float probability) && -> Query; + /// Non-throwing getters [[nodiscard]] auto term_ids() const -> tl::optional const&>; [[nodiscard]] auto id() const -> tl::optional const&; diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh index e6252643c..e4a8a9d1d 100644 --- a/script/cw09b-est.sh +++ b/script/cw09b-est.sh @@ -5,7 +5,7 @@ FWD="/home/amallia/cw09b/CW09B.fwd" ENCODING="simdbp" # v1 BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" THREADS=4 -QUERIES="/home/michal/biscorer/data/queries/05.efficiency_topics.no_dups.1k" +QUERIES="/home/michal/05.clean.shuf.test" #QUERIES="/home/michal/topics.web.51-200.jl" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-est" @@ -14,4 +14,4 @@ FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.efficiency_topics.no_dups.1k" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" diff --git a/script/cw09b.sh b/script/cw09b.sh index 04cfe1b60..c8ab75398 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -32,9 +32,9 @@ set -x #${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 # Filter out queries witout existing terms. -#paste -d: ${QUERIES} ${THRESHOLDS} \ -# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ -# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). ${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index bc6c068dc..fb7b498e6 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -122,7 +122,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) index.scored_cursor(right_term, VoidScorer{})}, std::array{0, 0}, [](auto& payload, auto& cursor, auto list_idx) { - payload[list_idx] = cursor.payload(); + gsl::at(payload, list_idx) = cursor.payload(); return payload; }); if (intersection.empty()) { @@ -163,9 +163,8 @@ struct HeapPriorityQueue { void push(value_type value) { m_elements.push_back(value); - if (PISA_UNLIKELY(m_elements.size() <= m_capacity)) { - std::push_heap(m_elements.begin(), m_elements.end(), m_order); - } else { + std::push_heap(m_elements.begin(), m_elements.end(), m_order); + if (PISA_LIKELY(m_elements.size() > m_capacity)) { std::pop_heap(m_elements.begin(), m_elements.end(), m_order); m_elements.pop_back(); } @@ -210,10 +209,14 @@ struct HeapPriorityQueue { [[maybe_unused]] auto idx) { return true; }), std::size_t{0}, [](auto count, [[maybe_unused]] auto&& cursor) { return count + 1; }); + if (intersection_length == 0) { + return 0.0; + } return static_cast(bigram.get_probability()) * static_cast(union_length) / static_cast(intersection_length); }; for (auto&& query : queries) { + auto&& term_ids = query.get_term_ids(); top_bigrams.push(std::make_pair(&query, bigram_gain(query))); } }); @@ -266,7 +269,7 @@ auto build_bigram_index(IndexMetadata meta, std::vector{0, 0}, [](auto& payload, auto& cursor, auto list_idx) { - payload[list_idx] = cursor.payload(); + gsl::at(payload, list_idx) = cursor.payload(); return payload; }); if (intersection.empty()) { diff --git a/src/v1/query.cpp b/src/v1/query.cpp index c2feeb65a..93c91cba9 100644 --- a/src/v1/query.cpp +++ b/src/v1/query.cpp @@ -106,6 +106,9 @@ template if (auto threshold = get(query_json, "threshold"); threshold) { query.threshold(*threshold); } + if (auto probability = get(query_json, "probability"); probability) { + query.probability(*probability); + } if (auto k = get(query_json, "k"); k) { query.k(*k); } @@ -199,6 +202,43 @@ auto Query::probability(float probability) -> Query& return *this; } +auto Query::with_term_ids(std::vector term_ids) && -> Query +{ + this->term_ids(std::move(term_ids)); + return std::move(*this); +} +auto Query::with_id(std::string id) && -> Query +{ + m_id = std::move(id); + return std::move(*this); +} +auto Query::with_k(int k) && -> Query +{ + m_k = k; + return std::move(*this); +} +auto Query::with_selections(gsl::span const> selections) && -> Query +{ + this->selections(selections); + return std::move(*this); +} +auto Query::with_selections(ListSelection selections) && -> Query +{ + this->selections(std::move(selections)); + return std::move(*this); +} +auto Query::with_threshold(float threshold) && -> Query +{ + m_threshold = threshold; + return std::move(*this); +} +auto Query::with_probability(float probability) && -> Query +{ + m_probability = probability; + return std::move(*this); +} + +/// Getters auto Query::term_ids() const -> tl::optional const&> { return m_term_ids.map( diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp index 5a53ab484..36727887e 100644 --- a/test/v1/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -15,11 +15,13 @@ #include "v1/bit_sequence_cursor.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" +#include "v1/cursor/for_each.hpp" #include "v1/default_index_runner.hpp" #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" #include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" #include "v1/sequence/partitioned_sequence.hpp" #include "v1/sequence/positive_sequence.hpp" #include "v1/types.hpp" @@ -30,14 +32,17 @@ using pisa::v1::DocId; using pisa::v1::DocumentBitSequenceCursor; using pisa::v1::DocumentBlockedCursor; using pisa::v1::DocumentBlockedWriter; +using pisa::v1::for_each; using pisa::v1::Frequency; using pisa::v1::index_runner; using pisa::v1::IndexMetadata; +using pisa::v1::make_bm25; using pisa::v1::PartitionedSequence; using pisa::v1::PayloadBitSequenceCursor; using pisa::v1::PayloadBlockedCursor; using pisa::v1::PayloadBlockedWriter; using pisa::v1::PositiveSequence; +using pisa::v1::Query; using pisa::v1::RawCursor; using pisa::v1::RawWriter; using pisa::v1::TermId; @@ -140,3 +145,54 @@ TEMPLATE_TEST_CASE("Index", } }); } + +TEST_CASE("Select best bigrams", "[v1][integration]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture( + false, false, false, false); + std::string index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + { + std::vector queries; + auto best_bigrams = select_best_bigrams(meta, queries, 10); + REQUIRE(best_bigrams.empty()); + } + { + std::vector queries = {Query::from_ids(0, 1).with_probability(0.1)}; + auto best_bigrams = select_best_bigrams(meta, queries, 10); + REQUIRE(best_bigrams == std::vector>{{0, 1}}); + } + { + std::vector queries = { + Query::from_ids(0, 1).with_probability(0.2), // u: 3758, i: 808 u/i: 4.650990099009901 + Query::from_ids(1, 2).with_probability(0.2), // u: 3961, i: 734 u/i: 5.3964577656675745 + Query::from_ids(2, 3).with_probability(0.2), // u: 2298, i: 61 u/i: 37.67213114754098 + Query::from_ids(3, 4).with_probability(0.2), // u: 90, i: 3 u/i: 30.0 + Query::from_ids(4, 5).with_probability(0.2), // u: 21, i: 1 u/i: 21.0 + Query::from_ids(5, 6).with_probability(0.2), // u: 8, i: 3 u/i: 2.6666666666666665 + Query::from_ids(6, 7).with_probability(0.2), // u: 4, i: 0 u/i: --- + Query::from_ids(7, 8).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(8, 9).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(9, 10).with_probability(0.2) // u: 2, i: 1 u/i: 2 + }; + auto best_bigrams = select_best_bigrams(meta, queries, 3); + REQUIRE(best_bigrams == std::vector>{{2, 3}, {3, 4}, {4, 5}}); + } + { + std::vector queries = { + Query::from_ids(0, 1).with_probability(0.2), // u: 3758, i: 808 u/i: 4.650990099009901 + Query::from_ids(1, 2).with_probability(0.2), // u: 3961, i: 734 u/i: 5.3964577656675745 + Query::from_ids(2, 3).with_probability(0.2), // u: 2298, i: 61 u/i: 37.67213114754098 + Query::from_ids(3, 4).with_probability(0.4), // u: 90, i: 3 u/i: 30.0 + Query::from_ids(4, 5).with_probability(0.01), // u: 21, i: 1 u/i: 21.0 + Query::from_ids(5, 6).with_probability(0.2), // u: 8, i: 3 u/i: 2.6666666666666665 + Query::from_ids(6, 7).with_probability(0.2), // u: 4, i: 0 u/i: --- + Query::from_ids(7, 8).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(8, 9).with_probability(0.2), // u: 2, i: 1 u/i: 2 + Query::from_ids(9, 10).with_probability(0.2) // u: 2, i: 1 u/i: 2 + }; + auto best_bigrams = select_best_bigrams(meta, queries, 3); + REQUIRE(best_bigrams == std::vector>{{3, 4}, {2, 3}, {1, 2}}); + } +} From 2dc73a220205c3e1b902e10ee63c744e3e925070 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 8 Jan 2020 17:22:04 +0000 Subject: [PATCH 45/56] Support posting stats --- v1/postings.cpp | 114 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 32 deletions(-) diff --git a/v1/postings.cpp b/v1/postings.cpp index 9bfa1d520..f01138c78 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -11,6 +11,7 @@ #include "query/queries.hpp" #include "topk_queue.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" #include "v1/default_index_runner.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" @@ -22,10 +23,8 @@ using pisa::App; using pisa::Query; using pisa::resolve_query_parser; +using pisa::v1::collect_payloads; using pisa::v1::index_runner; -using pisa::v1::IndexMetadata; -using pisa::v1::RawReader; -using pisa::v1::resolve_yml; namespace arg = pisa::arg; @@ -63,6 +62,67 @@ template return val; } +template +void calc_stats(Cursor&& cursor) +{ + using payload_type = std::decay_t; + auto length = cursor.size(); + auto payloads = collect_payloads(cursor); + std::sort(payloads.begin(), payloads.end(), std::greater<>{}); + auto kth = [&](auto k) { + if (k < payloads.size()) { + return payloads[k]; + } + return payload_type{}; + }; + std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\n", length, kth(0), kth(9), kth(99), kth(999)); +} + +template +auto print_postings(Cursor&& cursor, + Scorer&& scorer, + std::optional> const& docmap, + bool did, + bool print_frequencies, + bool print_scores) +{ + auto print = [&](auto&& cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if (print_frequencies) { + std::cout << " " << cursor.payload(); + } + if (print_scores) { + std::cout << " " << scorer(cursor.value(), cursor.payload()); + } + std::cout << '\n'; + }; + for_each(cursor, print); +}; + +template +auto print_precomputed_postings(Cursor&& cursor, + std::optional> const& docmap, + bool did) +{ + auto print = [&](auto&& cursor) { + if (did) { + std::cout << *cursor; + } else { + std::cout << docmap.value()[*cursor]; + } + if constexpr (std::is_same_v) { + std::cout << " " << static_cast(cursor.payload()) << '\n'; + } else { + std::cout << " " << cursor.payload() << '\n'; + } + }; + for_each(std::forward(cursor), print); +}; + int main(int argc, char** argv) { std::optional terms_file{}; @@ -73,6 +133,7 @@ int main(int argc, char** argv) bool print_frequencies = false; bool print_scores = false; bool precomputed = false; + bool stats = false; App app{"Queries a v1 index."}; app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); @@ -84,6 +145,7 @@ int main(int argc, char** argv) app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); auto* scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); app.add_flag("--precomputed", precomputed, "Use BM25 precomputed scores")->needs(scores_option); + app.add_flag("--stats", stats, "Print stats instead of listing postings"); app.add_option("query", query_input, "List of terms", false)->required(); CLI11_PARSE(app, argc, argv); @@ -120,40 +182,28 @@ int main(int argc, char** argv) if (precomputed) { auto run = scored_index_runner(meta); run([&](auto&& index) { - auto print = [&](auto&& cursor) { - if (did) { - std::cout << *cursor; - } else { - std::cout << docmap.value()[*cursor]; - } - if constexpr (std::is_same_v) { - std::cout << " " << static_cast(cursor.payload()) << '\n'; - } else { - std::cout << " " << cursor.payload() << '\n'; - } - }; - for_each(index.cursor(query.terms.front()), print); + auto cursor = index.cursor(query.terms.front()); + if (stats) { + calc_stats(cursor); + } else { + print_precomputed_postings(cursor, docmap, did); + } }); } else { auto run = index_runner(meta); run([&](auto&& index) { auto bm25 = make_bm25(index); - auto scorer = bm25.term_scorer(query.terms.front()); - auto print = [&](auto&& cursor) { - if (did) { - std::cout << *cursor; - } else { - std::cout << docmap.value()[*cursor]; - } - if (print_frequencies) { - std::cout << " " << cursor.payload(); - } - if (print_scores) { - std::cout << " " << scorer(cursor.value(), cursor.payload()); - } - std::cout << '\n'; - }; - for_each(index.cursor(query.terms.front()), print); + if (stats) { + calc_stats(index.scored_cursor(query.terms.front(), bm25)); + } else { + auto scorer = bm25.term_scorer(query.terms.front()); + print_postings(index.cursor(query.terms.front()), + scorer, + docmap, + did, + print_frequencies, + print_scores); + } }); } } else { From 087d51bd19b0bb3d603cb0c035f245e8cddc146b Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 10 Jan 2020 22:22:14 +0000 Subject: [PATCH 46/56] Selecting term-pairs and refactoring --- .clang-tidy | 3 +- include/pisa/v1/bit_sequence_cursor.hpp | 21 ++--- include/pisa/v1/cursor/accumulate.hpp | 2 + include/pisa/v1/cursor_accumulator.hpp | 2 +- include/pisa/v1/index_builder.hpp | 26 ++++++ include/pisa/v1/maxscore.hpp | 4 +- include/pisa/v1/runtime_assert.hpp | 42 ++++++++-- include/pisa/v1/union_lookup.hpp | 6 +- include/pisa/v1/wand.hpp | 4 +- script/cw09b-est-val.sh | 12 +++ script/cw09b-est.sh | 2 +- script/cw09b.sh | 68 ++++++++-------- src/CMakeLists.txt | 12 +-- src/v1/index_builder.cpp | 3 +- test/v1/test_v1_maxscore_join.cpp | 2 +- v1/CMakeLists.txt | 3 + v1/app.hpp | 49 +++++++++++- v1/filter_queries.cpp | 3 +- v1/intersection.cpp | 12 +-- v1/postings.cpp | 102 ++++++++++++++++++++++-- v1/select_pairs.cpp | 41 ++++++++++ 21 files changed, 331 insertions(+), 88 deletions(-) create mode 100644 script/cw09b-est-val.sh create mode 100644 v1/select_pairs.cpp diff --git a/.clang-tidy b/.clang-tidy index df6d7beae..834809eb1 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -13,4 +13,5 @@ Checks: | -cppcoreguidelines-avoid-magic-numbers, -cppcoreguidelines-pro-bounds-array-to-pointer-decay, -modernize-use-trailing-return-type, - -misc-non-private-member-variables-in-classes + -misc-non-private-member-variables-in-classes, + -readability-magic-numbers diff --git a/include/pisa/v1/bit_sequence_cursor.hpp b/include/pisa/v1/bit_sequence_cursor.hpp index d2a52efac..26f3f43a6 100644 --- a/include/pisa/v1/bit_sequence_cursor.hpp +++ b/include/pisa/v1/bit_sequence_cursor.hpp @@ -106,7 +106,7 @@ struct BitSequenceReader { [[nodiscard]] auto read(gsl::span bytes) const -> Cursor { - runtime_assert(bytes.size() % sizeof(BitVector::storage_type) == 0, [&]() { + runtime_assert(bytes.size() % sizeof(BitVector::storage_type) == 0).or_throw([&]() { return fmt::format( "Attempted to read no. bytes ({}) not aligned with the storage type of size {}", bytes.size(), @@ -182,9 +182,9 @@ struct BitSequenceWriter { template [[nodiscard]] auto write(std::basic_ostream& os) const -> std::size_t { - runtime_assert(m_num_documents.has_value(), - "Uninitialized writer. Must call `init()` before writing."); - runtime_assert(!m_postings.empty(), "Tried to write an empty posting list"); + runtime_assert(m_num_documents.has_value()) + .or_throw("Uninitialized writer. Must call `init()` before writing."); + runtime_assert(!m_postings.empty()).or_throw("Tried to write an empty posting list"); bit_vector_builder builder; auto universe = [&]() { if constexpr (DocumentWriter) { @@ -205,12 +205,13 @@ struct BitSequenceWriter { auto memory = gsl::as_bytes(gsl::make_span(data.data(), data.size())); os.write(reinterpret_cast(memory.data()), memory.size()); auto bytes_written = sizeof(true_bit_length) + memory.size(); - runtime_assert(bytes_written % sizeof(typename BitVector::storage_type) == 0, [&]() { - return fmt::format( - "Bytes written ({}) are not aligned with the storage type of size {}", - bytes_written, - sizeof(typename BitVector::storage_type)); - }); + runtime_assert(bytes_written % sizeof(typename BitVector::storage_type) == 0) + .or_throw([&]() { + return fmt::format( + "Bytes written ({}) are not aligned with the storage type of size {}", + bytes_written, + sizeof(typename BitVector::storage_type)); + }); return bytes_written; } diff --git a/include/pisa/v1/cursor/accumulate.hpp b/include/pisa/v1/cursor/accumulate.hpp index 94b1d6b56..9a5f58680 100644 --- a/include/pisa/v1/cursor/accumulate.hpp +++ b/include/pisa/v1/cursor/accumulate.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "v1/cursor/for_each.hpp" namespace pisa::v1 { diff --git a/include/pisa/v1/cursor_accumulator.hpp b/include/pisa/v1/cursor_accumulator.hpp index bde0ba447..f12b76374 100644 --- a/include/pisa/v1/cursor_accumulator.hpp +++ b/include/pisa/v1/cursor_accumulator.hpp @@ -3,7 +3,7 @@ #include #include -namespace pisa::v1::accumulate { +namespace pisa::v1::accumulators { struct Add { template diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index 09dea43b6..b53069eb0 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -9,10 +9,14 @@ #include #include "binary_freq_collection.hpp" +#include "v1/cursor/accumulate.hpp" +#include "v1/cursor_intersection.hpp" #include "v1/index.hpp" #include "v1/index_metadata.hpp" #include "v1/progress_status.hpp" #include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/scorer/bm25.hpp" namespace pisa::v1 { @@ -231,6 +235,28 @@ inline void compress_binary_collection(std::string const& input, .write(fmt::format("{}.yml", output)); } +template +auto bigram_gain(Index&& index, Query const& bigram) -> float +{ + auto&& term_ids = bigram.get_term_ids(); + runtime_assert(term_ids.size() == 2).or_throw("Queries must be of exactly two unique terms"); + auto cursors = index.scored_cursors(term_ids, make_bm25(index)); + auto union_length = cursors[0].size() + cursors[1].size(); + auto intersection_length = + accumulate(intersect(std::move(cursors), + false, + []([[maybe_unused]] auto count, + [[maybe_unused]] auto&& cursor, + [[maybe_unused]] auto idx) { return true; }), + std::size_t{0}, + [](auto count, [[maybe_unused]] auto&& cursor) { return count + 1; }); + if (intersection_length == 0) { + return 0.0; + } + return static_cast(bigram.get_probability()) * static_cast(union_length) + / static_cast(intersection_length); +} + auto verify_compressed_index(std::string const& input, std::string_view output) -> std::vector; diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index 2552b97e8..dc1fecf77 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -208,7 +208,7 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& if (query.threshold()) { topk.set_threshold(*query.threshold()); } - auto joined = join_maxscore(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + auto joined = join_maxscore(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }); v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); @@ -249,7 +249,7 @@ struct MaxscoreInspector { auto joined = join_maxscore( std::move(cursors), 0.0F, - accumulate::Add{}, + accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }, this); v1::for_each(joined, [&](auto& cursor) { diff --git a/include/pisa/v1/runtime_assert.hpp b/include/pisa/v1/runtime_assert.hpp index 82635289a..736342334 100644 --- a/include/pisa/v1/runtime_assert.hpp +++ b/include/pisa/v1/runtime_assert.hpp @@ -1,20 +1,46 @@ #pragma once +#include + #include #include namespace pisa::v1 { -template -inline void runtime_assert(bool condition, Message&& message) -{ - if (not condition) { - if constexpr (std::is_invocable_r_v) { - throw std::runtime_error(message()); - } else { - throw std::runtime_error(std::forward(message)); +struct RuntimeAssert { + explicit RuntimeAssert(bool condition) : passes_(condition) {} + + template + void or_exit(Message&& message) + { + if (not passes_) { + if constexpr (std::is_invocable_r_v) { + spdlog::error(message()); + } else { + spdlog::error("{}", message); + } } } + + template + void or_throw(Message&& message) + { + if (not passes_) { + if constexpr (std::is_invocable_r_v) { + throw std::runtime_error(message()); + } else { + throw std::runtime_error(std::forward(message)); + } + } + } + + private: + bool passes_; +}; + +[[nodiscard]] inline auto runtime_assert(bool condition) -> RuntimeAssert +{ + return RuntimeAssert(condition); } } // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 5d0bd841c..6cdbf414e 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -267,7 +267,7 @@ auto unigram_union_lookup(Query const& query, std::move(essential_cursors), std::move(lookup_cursors), payload_type{}, - accumulate::Add{}, + accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }, inspect); v1::for_each(joined, [&](auto&& cursor) { @@ -330,7 +330,7 @@ auto maxscore_union_lookup(Query const& query, std::move(essential_cursors), std::move(cursors), payload_type{}, - accumulate::Add{}, + accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }, inspect); v1::for_each(joined, [&](auto&& cursor) { @@ -441,7 +441,7 @@ auto lookup_union(Query const& query, return join_union_lookup(std::move(essential_cursors), std::move(lookup_cursors), 0.0F, - accumulate::Add{}, + accumulators::Add{}, is_above_threshold, inspect); }(); diff --git a/include/pisa/v1/wand.hpp b/include/pisa/v1/wand.hpp index 533b8ff28..bbee82b56 100644 --- a/include/pisa/v1/wand.hpp +++ b/include/pisa/v1/wand.hpp @@ -450,7 +450,7 @@ auto wand(Query const& query, Index const& index, topk_queue topk, Scorer&& scor if (query.threshold()) { topk.set_threshold(*query.threshold()); } - auto joined = join_wand(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + auto joined = join_wand(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }); v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); @@ -468,7 +468,7 @@ auto bmw(Query const& query, Index const& index, topk_queue topk, Scorer&& score if (query.threshold()) { topk.set_threshold(*query.threshold()); } - auto joined = join_block_max_wand(std::move(cursors), 0.0F, accumulate::Add{}, [&](auto score) { + auto joined = join_block_max_wand(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { return topk.would_enter(score); }); v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); diff --git a/script/cw09b-est-val.sh b/script/cw09b-est-val.sh new file mode 100644 index 000000000..4b7868291 --- /dev/null +++ b/script/cw09b-est-val.sh @@ -0,0 +1,12 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv" +FWD="/home/amallia/cw09b/CW09B.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.val" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-est-val" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.val" diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh index e4a8a9d1d..8cfaa4886 100644 --- a/script/cw09b-est.sh +++ b/script/cw09b-est.sh @@ -8,7 +8,7 @@ THREADS=4 QUERIES="/home/michal/05.clean.shuf.test" #QUERIES="/home/michal/topics.web.51-200.jl" K=1000 -OUTPUT_DIR="/data/michal/intersect/cw09b-est" +OUTPUT_DIR="/data/michal/intersect/cw09b-est-lm" #OUTPUT_DIR="/data/michal/intersect/cw09b-est-top20" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" #THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" diff --git a/script/cw09b.sh b/script/cw09b.sh index c8ab75398..35eb77d6c 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -21,6 +21,7 @@ echo " THRESHOLDS = ${THRESHOLDS}" echo "" set -x +mkdir -p ${OUTPUT_DIR} ## Compress an inverted index in `binary_freq_collection` format. #${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} @@ -58,51 +59,50 @@ ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ #${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ # > ${OUTPUT_DIR}/bench.bmw-threshold -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ -# > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ + > ${OUTPUT_DIR}/bench.maxscore-threshold #${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ # > ${OUTPUT_DIR}/bench.maxscore-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ > ${OUTPUT_DIR}/bench.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ - > ${OUTPUT_DIR}/bench.union-lookup +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ +# > ${OUTPUT_DIR}/bench.union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ - --benchmark --algorithm lookup-union \ - > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ - --benchmark --algorithm lookup-union \ - > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ +# --benchmark --algorithm lookup-union \ +# > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ +# --benchmark --algorithm lookup-union \ +# > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 # Analyze -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ - > ${OUTPUT_DIR}/stats.maxscore-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ - > ${OUTPUT_DIR}/stats.maxscore-union-lookup +# ${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ +# > ${OUTPUT_DIR}/stats.maxscore-threshold +# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ +# > ${OUTPUT_DIR}/stats.maxscore-union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ > ${OUTPUT_DIR}/stats.unigram-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ - > ${OUTPUT_DIR}/stats.union-lookup +# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ +# > ${OUTPUT_DIR}/stats.union-lookup ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ > ${OUTPUT_DIR}/stats.lookup-union -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --inspect --algorithm lookup-union \ - > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ +# --inspect --algorithm lookup-union \ +# > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 # Evaluate -${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ - > "${OUTPUT_DIR}/eval.maxscore-threshold" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ - > "${OUTPUT_DIR}/eval.maxscore-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ - > "${OUTPUT_DIR}/eval.unigram-union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ - > "${OUTPUT_DIR}/eval.union-lookup" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ - > "${OUTPUT_DIR}/eval.lookup-union" -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ - > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" - +#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ +# > "${OUTPUT_DIR}/eval.maxscore-threshold" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ +# > "${OUTPUT_DIR}/eval.maxscore-union-lookup" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ +# > "${OUTPUT_DIR}/eval.unigram-union-lookup" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ +# > "${OUTPUT_DIR}/eval.union-lookup" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ +# > "${OUTPUT_DIR}/eval.lookup-union" +#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ +# > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c1cdcacd8..88ca50602 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,12 +95,12 @@ target_link_libraries(compute_intersection CLI11 ) -#add_executable(lexicon lexicon.cpp) -#target_link_libraries(lexicon -# pisa -# CLI11 -#) -# +add_executable(lexicon lexicon.cpp) +target_link_libraries(lexicon + pisa + CLI11 +) + #add_executable(extract_topics extract_topics.cpp) #target_link_libraries(extract_topics # pisa diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index fb7b498e6..21c342061 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -198,7 +198,8 @@ struct HeapPriorityQueue { run([&](auto&& index) { auto bigram_gain = [&](Query const& bigram) -> float { auto&& term_ids = bigram.get_term_ids(); - runtime_assert(term_ids.size() == 2, "Queries must be of exactly two unique terms"); + runtime_assert(term_ids.size() == 2) + .or_throw("Queries must be of exactly two unique terms"); auto cursors = index.scored_cursors(term_ids, make_bm25(index)); auto union_length = cursors[0].size() + cursors[1].size(); auto intersection_length = diff --git a/test/v1/test_v1_maxscore_join.cpp b/test/v1/test_v1_maxscore_join.cpp index 45978b951..086e8ee1d 100644 --- a/test/v1/test_v1_maxscore_join.cpp +++ b/test/v1/test_v1_maxscore_join.cpp @@ -28,7 +28,7 @@ using pisa::v1::collect; using pisa::v1::DocId; using pisa::v1::Frequency; using pisa::v1::join_maxscore; -using pisa::v1::accumulate::Add; +using pisa::v1::accumulators::Add; TEMPLATE_TEST_CASE("Max score join", "[v1][integration]", diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index 7e5d0f829..8d62026f8 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -24,3 +24,6 @@ target_link_libraries(threshold pisa CLI11) add_executable(intersection intersection.cpp) target_link_libraries(intersection pisa CLI11) + +add_executable(select-pairs select_pairs.cpp) +target_link_libraries(select-pairs pisa CLI11) diff --git a/v1/app.hpp b/v1/app.hpp index 6a3d251a1..958ea2b36 100644 --- a/v1/app.hpp +++ b/v1/app.hpp @@ -5,10 +5,13 @@ #include #include +#include +#include #include #include "io.hpp" #include "v1/index_metadata.hpp" +#include "v1/runtime_assert.hpp" namespace pisa { @@ -44,6 +47,9 @@ namespace arg { if constexpr (Mode == QueryMode::Ranked) { app->add_option("-k", m_k, "The number of top results to return", true); } + app->add_flag("--force-parse", + m_force_parse, + "Force parsing of query string even ifterm IDs already available"); } [[nodiscard]] auto query_file() -> tl::optional @@ -65,7 +71,7 @@ namespace arg { } return v1::Query::from_plain(line); }(); - if (not query.term_ids()) { + if (not query.term_ids() || m_force_parse) { query.parse(parser); } if constexpr (Mode == QueryMode::Ranked) { @@ -82,10 +88,49 @@ namespace arg { return queries; } + [[nodiscard]] auto query_range(v1::IndexMetadata const& meta) + { + auto lines = [&] { + if (m_query_file) { + m_query_file_handle = std::make_unique(*m_query_file); + return ranges::getlines(*m_query_file_handle); + } + return ranges::getlines(std::cin); + }(); + return ranges::views::transform(lines, + [force_parse = m_force_parse, + k = m_k, + parser = meta.query_parser(), + qfmt = m_query_input_format](auto&& line) { + auto query = [&]() { + if (qfmt == "jl") { + return v1::Query::from_json(line); + } + if (qfmt == "plain") { + return v1::Query::from_plain(line); + } + spdlog::error("Unknown query format: {}", + qfmt); + std::exit(1); + }(); + if (not query.term_ids() || force_parse) { + query.parse(parser); + } + // Not constexpr to silence unused k value warning. + // Performance is not a concern. + if (Mode == QueryMode::Ranked) { + query.k(k); + } + return query; + }); + } + private: + std::unique_ptr m_query_file_handle = nullptr; tl::optional m_query_file; - tl::optional m_query_input_format = "jl"; + std::string m_query_input_format = "jl"; int m_k = DefaultK; + bool m_force_parse{false}; }; struct Benchmark { diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index ce17f6e07..627209208 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -19,8 +19,7 @@ int main(int argc, char** argv) CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); - auto queries = app.queries(meta); - for (auto&& query : queries) { + for (auto&& query : app.query_range(meta)) { if (not query.get_term_ids().empty()) { std::cout << *query.to_json() << '\n'; } diff --git a/v1/intersection.cpp b/v1/intersection.cpp index b5b825b04..7f9ce0287 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -69,9 +69,9 @@ auto for_all_subsets(v1::Query const& query, tl::optional max_term_ } } -template +template void compute_intersections(Index const& index, - std::vector const& queries, + QRng queries, IntersectionType intersection_type, tl::optional max_term_count, bool existing) @@ -96,7 +96,7 @@ void compute_intersections(Index const& index, intersections.push_back(nlohmann::json{{"intersection", left_mask}, {"cost", cursor.size()}, {"max_score", cursor.max_score()}}); - std::uint64_t right_mask = left_mask << 1; + std::uint64_t right_mask = left_mask << 1U; for (auto right = left + 1; right < term_ids.size(); right += 1) { index .scored_bigram_cursor(term_ids[left], term_ids[right], make_bm25(index)) @@ -115,9 +115,9 @@ void compute_intersections(Index const& index, {"cost", cost}, {"max_score", max_score}}); }); - right_mask <<= 1; + right_mask <<= 1U; } - left_mask <<= 1; + left_mask <<= 1U; } } else { for_all_subsets(query, max_term_count, inter); @@ -159,7 +159,7 @@ int main(int argc, const char** argv) try { auto meta = app.index_metadata(); - auto queries = app.queries(meta); + auto queries = app.query_range(meta); auto run = index_runner(meta); run([&](auto&& index) { diff --git a/v1/postings.cpp b/v1/postings.cpp index f01138c78..221d65162 100644 --- a/v1/postings.cpp +++ b/v1/postings.cpp @@ -12,10 +12,13 @@ #include "topk_queue.hpp" #include "v1/blocked_cursor.hpp" #include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/cursor_intersection.hpp" #include "v1/default_index_runner.hpp" #include "v1/index_metadata.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" +#include "v1/runtime_assert.hpp" #include "v1/scorer/bm25.hpp" #include "v1/scorer/runner.hpp" #include "v1/types.hpp" @@ -25,6 +28,7 @@ using pisa::Query; using pisa::resolve_query_parser; using pisa::v1::collect_payloads; using pisa::v1::index_runner; +using pisa::v1::runtime_assert; namespace arg = pisa::arg; @@ -62,12 +66,26 @@ template return val; } +void print_header(std::vector const& percentiles, std::vector const& cutoffs) +{ + std::cout << "length"; + for (auto cutoff : cutoffs) { + std::cout << '\t' << "top-" << cutoff + 1; + } + for (auto percentile : percentiles) { + std::cout << '\t' << "perc-" << percentile; + } + std::cout << '\n'; +} + template -void calc_stats(Cursor&& cursor) +void calc_stats(Cursor&& cursor, + std::vector const& percentiles, + std::vector const& cutoffs) { using payload_type = std::decay_t; - auto length = cursor.size(); auto payloads = collect_payloads(cursor); + auto length = payloads.size(); std::sort(payloads.begin(), payloads.end(), std::greater<>{}); auto kth = [&](auto k) { if (k < payloads.size()) { @@ -75,7 +93,15 @@ void calc_stats(Cursor&& cursor) } return payload_type{}; }; - std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\n", length, kth(0), kth(9), kth(99), kth(999)); + std::cout << length; + for (auto cutoff : cutoffs) { + std::cout << '\t' << kth(cutoff); + } + for (auto percentile : percentiles) { + std::cout << '\t' + << payloads[std::min(percentile * payloads.size() / 100, payloads.size() - 1)]; + } + std::cout << '\n'; } template @@ -125,6 +151,9 @@ auto print_precomputed_postings(Cursor&& cursor, int main(int argc, char** argv) { + std::vector percentiles = {0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}; + std::vector cutoffs = { + 0, 9, 99, 999, 9'999, 99'999, 999'999, 9'999'999, 99'999'999}; std::optional terms_file{}; std::optional documents_file{}; std::string query_input{}; @@ -134,6 +163,7 @@ int main(int argc, char** argv) bool print_scores = false; bool precomputed = false; bool stats = false; + bool header = false; App app{"Queries a v1 index."}; app.add_option("--terms", terms_file, "Overrides document lexicon from .yml (if defined)."); @@ -145,7 +175,13 @@ int main(int argc, char** argv) app.add_flag("-f,--frequencies", print_frequencies, "Print frequencies"); auto* scores_option = app.add_flag("-s,--scores", print_scores, "Print BM25 scores"); app.add_flag("--precomputed", precomputed, "Use BM25 precomputed scores")->needs(scores_option); - app.add_flag("--stats", stats, "Print stats instead of listing postings"); + + auto* stats_option = app.add_flag("--stats", stats, "Print stats instead of listing postings"); + app.add_option("--percentiles", percentiles, "Percentiles for stats", true) + ->needs(stats_option); + app.add_option("--cutoffs", cutoffs, "Cut-offs for stats", true)->needs(stats_option); + app.add_flag("--header", header, "Print stats header")->needs(stats_option); + app.add_option("query", query_input, "List of terms", false)->required(); CLI11_PARSE(app, argc, argv); @@ -178,13 +214,24 @@ int main(int argc, char** argv) return queries[0]; }(); + if (header) { + print_header(percentiles, cutoffs); + } + + if (stats) { + for (auto percentile : percentiles) { + pisa::v1::runtime_assert(percentile >= 0 && percentile <= 100) + .or_exit("Percentiles must be in [0, 100]"); + } + } + if (query.terms.size() == 1) { if (precomputed) { auto run = scored_index_runner(meta); run([&](auto&& index) { auto cursor = index.cursor(query.terms.front()); if (stats) { - calc_stats(cursor); + calc_stats(cursor, percentiles, cutoffs); } else { print_precomputed_postings(cursor, docmap, did); } @@ -194,7 +241,8 @@ int main(int argc, char** argv) run([&](auto&& index) { auto bm25 = make_bm25(index); if (stats) { - calc_stats(index.scored_cursor(query.terms.front(), bm25)); + calc_stats( + index.scored_cursor(query.terms.front(), bm25), percentiles, cutoffs); } else { auto scorer = bm25.term_scorer(query.terms.front()); print_postings(index.cursor(query.terms.front()), @@ -207,8 +255,46 @@ int main(int argc, char** argv) }); } } else { - std::cerr << "Multiple terms unimplemented"; - std::exit(1); + if (precomputed) { + auto run = scored_index_runner(meta); + run([&](auto&& index) { + auto cursor = + ::pisa::v1::intersect(index.cursors(gsl::make_span(query.terms)), + 0.0, + [](auto acc, auto&& cursor, [[maybe_unused]] auto idx) { + return acc + cursor.payload(); + }); + if (stats) { + calc_stats(cursor, percentiles, cutoffs); + } else { + print_precomputed_postings(cursor, docmap, did); + } + }); + } else { + auto run = index_runner(meta); + run([&](auto&& index) { + auto bm25 = make_bm25(index); + if (stats) { + auto cursor = ::pisa::v1::intersect( + index.scored_cursors(gsl::make_span(query.terms), bm25), + 0.0, + [](auto acc, auto&& cursor, [[maybe_unused]] auto idx) { + return acc + cursor.payload(); + }); + calc_stats(cursor, percentiles, cutoffs); + } else { + runtime_assert(query.terms.size() == 1) + .or_exit("Printing scoring intersections not supported yet."); + auto scorer = bm25.term_scorer(query.terms.front()); + print_postings(index.cursor(query.terms.front()), + scorer, + docmap, + did, + print_frequencies, + print_scores); + } + }); + } } return 0; diff --git a/v1/select_pairs.cpp b/v1/select_pairs.cpp new file mode 100644 index 000000000..9db65b0f9 --- /dev/null +++ b/v1/select_pairs.cpp @@ -0,0 +1,41 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/progress_status.hpp" +#include "v1/types.hpp" + +using pisa::App; +using pisa::v1::bigram_gain; +using pisa::v1::index_runner; + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + std::optional terms_file{}; + //std::size_t num_pairs_to_select; + + App> app{"Creates a v1 bigram index."}; + // app.add_option("--count", num_pairs_to_select, "Number of pairs to select")->required(); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + + auto run = index_runner(meta); + run([&](auto&& index) { + for (auto&& query : app.query_range(meta)) { + auto term_ids = query.get_term_ids(); + std::cout + << fmt::format("{}\t{}\t{}\n", term_ids[0], term_ids[1], bigram_gain(index, query)); + } + }); + return 0; +} From 83d41206b3f14c78be01371147a7d5f31967e592 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 10 Jan 2020 22:23:19 +0000 Subject: [PATCH 47/56] Update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 077c235de..1f0bdf6f1 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ docs/_build/ node_modules .clangd/ +compile_commands.json From 50078563c75deb734fd6a8c378c8c79069b80cce Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 15 Jan 2020 17:28:56 +0000 Subject: [PATCH 48/56] Multi-threaded pair index building --- CMakeLists.txt | 3 +- include/pisa/v1/base_index.hpp | 2 + include/pisa/v1/bit_sequence_cursor.hpp | 2 + include/pisa/v1/blocked_cursor.hpp | 2 + include/pisa/v1/index_builder.hpp | 24 +- include/pisa/v1/index_metadata.hpp | 2 + include/pisa/v1/io.hpp | 29 +- include/pisa/v1/posting_builder.hpp | 15 +- include/pisa/v1/progress_status.hpp | 108 ++++-- include/pisa/v1/raw_cursor.hpp | 1 + include/pisa/v1/score_index.hpp | 13 +- script/cw09b-est.sh | 8 +- script/cw09b.sh | 62 ++-- src/CMakeLists.txt | 20 +- src/v1/index.cpp | 13 + src/v1/index_builder.cpp | 418 ++++++++++++++++++++---- src/v1/index_metadata.cpp | 17 + src/v1/progress_status.cpp | 65 ++-- src/v1/score_index.cpp | 57 ++-- test/v1/index_fixture.hpp | 6 +- test/v1/test_v1_bigram_index.cpp | 74 ++++- test/v1/test_v1_index.cpp | 48 +++ test/v1/test_v1_queries.cpp | 11 +- v1/CMakeLists.txt | 6 + v1/bigram_index.cpp | 17 +- v1/bmscore.cpp | 20 +- v1/count_postings.cpp | 51 +++ v1/filter_queries.cpp | 5 +- v1/intersection.cpp | 4 + v1/stats.cpp | 29 ++ 30 files changed, 906 insertions(+), 226 deletions(-) create mode 100644 v1/count_postings.cpp create mode 100644 v1/stats.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f5d343bb..b72fa3fb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,7 +56,8 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wno-odr") # Extensive warnings - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-missing-braces") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-missing-braces") if (USE_SANITIZERS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") diff --git a/include/pisa/v1/base_index.hpp b/include/pisa/v1/base_index.hpp index 1a520fcfb..c41bace75 100644 --- a/include/pisa/v1/base_index.hpp +++ b/include/pisa/v1/base_index.hpp @@ -69,10 +69,12 @@ struct BaseIndex { [[nodiscard]] auto num_terms() const -> std::size_t; [[nodiscard]] auto num_documents() const -> std::size_t; + [[nodiscard]] auto num_pairs() const -> std::size_t; [[nodiscard]] auto document_length(DocId docid) const -> std::uint32_t; [[nodiscard]] auto avg_document_length() const -> float; [[nodiscard]] auto normalized_document_length(DocId docid) const -> float; [[nodiscard]] auto bigram_id(TermId left_term, TermId right_term) const -> tl::optional; + [[nodiscard]] auto pairs() const -> tl::optional const>>; protected: void assert_term_in_bounds(TermId term) const; diff --git a/include/pisa/v1/bit_sequence_cursor.hpp b/include/pisa/v1/bit_sequence_cursor.hpp index 26f3f43a6..d9895a3e5 100644 --- a/include/pisa/v1/bit_sequence_cursor.hpp +++ b/include/pisa/v1/bit_sequence_cursor.hpp @@ -235,12 +235,14 @@ using PayloadBitSequenceWriter = BitSequenceWriter; template struct CursorTraits> { + using Value = std::uint32_t; using Writer = DocumentBitSequenceWriter; using Reader = DocumentBitSequenceReader; }; template struct CursorTraits> { + using Value = std::uint32_t; using Writer = PayloadBitSequenceWriter; using Reader = PayloadBitSequenceReader; }; diff --git a/include/pisa/v1/blocked_cursor.hpp b/include/pisa/v1/blocked_cursor.hpp index 5d9d90d23..37b60d75c 100644 --- a/include/pisa/v1/blocked_cursor.hpp +++ b/include/pisa/v1/blocked_cursor.hpp @@ -399,12 +399,14 @@ using PayloadBlockedWriter = GenericBlockedWriter; template struct CursorTraits> { + using Value = std::uint32_t; using Writer = DocumentBlockedWriter; using Reader = DocumentBlockedReader; }; template struct CursorTraits> { + using Value = std::uint32_t; using Writer = PayloadBlockedWriter; using Reader = PayloadBlockedReader; }; diff --git a/include/pisa/v1/index_builder.hpp b/include/pisa/v1/index_builder.hpp index b53069eb0..8746b6306 100644 --- a/include/pisa/v1/index_builder.hpp +++ b/include/pisa/v1/index_builder.hpp @@ -96,13 +96,19 @@ auto compress_batch(CollectionIterator first, } template -void write_span(gsl::span offsets, std::string const& file) +void write_span(gsl::span offsets, std::ofstream& os) { - std::ofstream os(file); auto bytes = gsl::as_bytes(offsets); os.write(reinterpret_cast(bytes.data()), bytes.size()); } +template +void write_span(gsl::span offsets, std::string const& file) +{ + std::ofstream os(file); + write_span(offsets, os); +} + inline void compress_binary_collection(std::string const& input, std::string_view fwd, std::string_view output, @@ -114,7 +120,7 @@ inline void compress_binary_collection(std::string const& input, document_writer.init(collection); frequency_writer.init(collection); ProgressStatus status(collection.size(), - DefaultProgress("Compressing in parallel"), + DefaultProgressCallback("Compressing in parallel"), std::chrono::milliseconds(100)); tbb::task_group group; auto const num_terms = collection.size(); @@ -189,7 +195,7 @@ inline void compress_binary_collection(std::string const& input, { ProgressStatus merge_status( - threads, DefaultProgress("Merging files"), std::chrono::milliseconds(500)); + threads, DefaultProgressCallback("Merging files"), std::chrono::milliseconds(500)); for_each_batch([&](auto thread_idx) { std::transform( std::next(document_offsets[thread_idx].begin()), @@ -269,7 +275,13 @@ auto collect_unique_bigrams(std::vector const& queries, std::size_t num_bigrams_to_select) -> std::vector>; -auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) - -> IndexMetadata; +auto build_bigram_index(IndexMetadata meta, + std::vector> const& bigrams, + tl::optional const& clone_path) -> IndexMetadata; + +auto build_pair_index(IndexMetadata meta, + std::vector> const& pairs, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata; } // namespace pisa::v1 diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index 0f76d9ea6..cd281f371 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -14,6 +14,8 @@ namespace pisa::v1 { +[[nodiscard]] auto append_extension(std::string file_path) -> std::string; + /// Return the passed file path if is not `nullopt`. /// Otherwise, look for an `.yml` file in the current directory. /// It will throw if no `.yml` file is found or there are multiple `.yml` files. diff --git a/include/pisa/v1/io.hpp b/include/pisa/v1/io.hpp index c413ef324..06e792083 100644 --- a/include/pisa/v1/io.hpp +++ b/include/pisa/v1/io.hpp @@ -1,11 +1,38 @@ #pragma once #include +#include #include #include +#include + +#include + namespace pisa::v1 { -[[nodiscard]] auto load_bytes(std::string const &data_file) -> std::vector; +[[nodiscard]] auto load_bytes(std::string const& data_file) -> std::vector; + +template +[[nodiscard]] auto load_vector(std::string const& data_file) -> std::vector +{ + std::vector data; + std::basic_ifstream in(data_file.c_str(), std::ios::binary); + in.seekg(0, std::ios::end); + std::streamsize size = in.tellg(); + in.seekg(0, std::ios::beg); + + runtime_assert(size % sizeof(T) == 0).or_exit([&] { + return fmt::format("Tried loading a vector of elements of size {} but size of file is {}", + sizeof(T), + size); + }); + data.resize(size / sizeof(T)); + + runtime_assert(in.read(reinterpret_cast(data.data()), size).good()).or_exit([&] { + return fmt::format("Failed reading ", data_file); + }); + return data; +} } // namespace pisa::v1 diff --git a/include/pisa/v1/posting_builder.hpp b/include/pisa/v1/posting_builder.hpp index 699d486dd..a381aa568 100644 --- a/include/pisa/v1/posting_builder.hpp +++ b/include/pisa/v1/posting_builder.hpp @@ -5,6 +5,7 @@ #include +#include "v1/cursor_traits.hpp" #include "v1/posting_format_header.hpp" #include "v1/types.hpp" @@ -22,28 +23,28 @@ struct PostingBuilder { } template - void write_header(std::basic_ostream &os) const + void write_header(std::basic_ostream& os) const { std::array header{}; PostingFormatHeader{.version = FormatVersion::current(), .type = value_type(), .encoding = m_writer.encoding()} .write(gsl::make_span(header)); - os.write(reinterpret_cast(header.data()), header.size()); + os.write(reinterpret_cast(header.data()), header.size()); } template - auto write_segment(std::basic_ostream &os, ValueIterator first, ValueIterator last) - -> std::basic_ostream & + auto write_segment(std::basic_ostream& os, ValueIterator first, ValueIterator last) + -> std::basic_ostream& { - std::for_each(first, last, [&](auto &&value) { m_writer.push(value); }); + std::for_each(first, last, [&](auto&& value) { m_writer.push(value); }); return flush_segment(os); } void accumulate(Value value) { m_writer.push(value); } template - auto flush_segment(std::basic_ostream &os) -> std::basic_ostream & + auto flush_segment(std::basic_ostream& os) -> std::basic_ostream& { m_offsets.push_back(m_offsets.back() + m_writer.write(os)); m_writer.reset(); @@ -55,7 +56,7 @@ struct PostingBuilder { return gsl::make_span(m_offsets); } - [[nodiscard]] auto offsets() -> std::vector && { return std::move(m_offsets); } + [[nodiscard]] auto offsets() -> std::vector&& { return std::move(m_offsets); } private: Writer m_writer; diff --git a/include/pisa/v1/progress_status.hpp b/include/pisa/v1/progress_status.hpp index 6e662b0c8..84e8ca477 100644 --- a/include/pisa/v1/progress_status.hpp +++ b/include/pisa/v1/progress_status.hpp @@ -7,61 +7,101 @@ #include #include -namespace pisa::v1 { - -std::ostream &format_interval(std::ostream &out, std::chrono::seconds time); - -using CallbackFunction = std::function)>; +#include "runtime_assert.hpp" +#include "type_safe.hpp" -struct DefaultProgress { - DefaultProgress() = default; - explicit DefaultProgress(std::string caption); - DefaultProgress(DefaultProgress const &) = default; - DefaultProgress(DefaultProgress &&) noexcept = default; - DefaultProgress &operator=(DefaultProgress const &) = default; - DefaultProgress &operator=(DefaultProgress &&) noexcept = default; - ~DefaultProgress() = default; - - void operator()(std::size_t count, - std::size_t goal, - std::chrono::time_point start); +namespace pisa::v1 { - private: - std::size_t m_previous = 0; - std::string m_caption; +/// Represents progress of a certain operation. +struct Progress { + /// A number between 0 and `target` to indicate the current progress. + std::size_t count; + /// The target value `count` reaches at completion. + std::size_t target; }; +/// An alias of the type of callback function used in a progress status. +using CallbackFunction = + std::function)>; + +/// This thread-safe object is responsible for keeping the current progress of an operation. +/// At a defined interval, it invokes a callback function with the current progress, and +/// the starting time of the operation (see `CallbackFunction` type alias). +/// +/// In order to ensure that terminal updates are not iterfered with, there should be no +/// writing to stdin or stderr outside of the callback function between `ProgressStatus` +/// construction and either its destruction (e.g., by going out of scope) or closing it +/// explicitly by calling `close()` member function. struct ProgressStatus { + + /// Constructs a new progress status. + /// + /// \tparam Callback A callable that conforms to `CallbackFunction` signature. + /// \tparam Duration A duration type that will be called to pause the status thread. + /// + /// \param target Usually a number of processed elements that will be incremented + /// throughout an operation. + /// \param callback A function that prints out the current status. See + /// `DefaultProgressCallback`. + /// \param interval A time interval between printing progress. template - explicit ProgressStatus(std::size_t count, Callback &&callback, Duration interval) - : m_goal(count), m_callback(std::forward(callback)) + explicit ProgressStatus(std::size_t target, Callback&& callback, Duration interval) + : m_target(target), m_callback(std::forward(callback)) { m_loop = std::thread([this, interval]() { - this->m_callback(this->m_count.load(), this->m_goal, this->m_start); - while (this->m_count.load() < this->m_goal) { + this->m_callback(Progress{this->m_count.load(), this->m_target}, this->m_start); + while (this->m_count.load() < this->m_target) { std::this_thread::sleep_for(interval); - this->m_callback(this->m_count.load(), this->m_goal, this->m_start); + this->m_callback(Progress{this->m_count.load(), this->m_target}, this->m_start); } }); } - ProgressStatus(ProgressStatus const &) = delete; - ProgressStatus(ProgressStatus &&) = delete; - ProgressStatus &operator=(ProgressStatus const &) = delete; - ProgressStatus &operator=(ProgressStatus &&) = delete; + ProgressStatus(ProgressStatus const&) = delete; + ProgressStatus(ProgressStatus&&) = delete; + ProgressStatus& operator=(ProgressStatus const&) = delete; + ProgressStatus& operator=(ProgressStatus&&) = delete; ~ProgressStatus(); + + /// Increments the counter by `inc`. void operator+=(std::size_t inc) { m_count += inc; } + /// Increments the counter by 1. void operator++() { m_count += 1; } + /// Increments the counter by 1. void operator++(int) { m_count += 1; } + /// Sets the progress to 100% and terminates the progress thread. + /// This function should be only called if it is known that the operation is finished. + /// An example is a situation when it is one wants to print a message after finishing a task + /// but before the progress status goes out of the current scope. + /// However, otherwise it is better to just let it go out of scope, which will clean up + /// automatically to enable further writing to the standard output. + void close(); private: - std::size_t const m_goal; - std::function)> - m_callback; + std::size_t const m_target; + CallbackFunction m_callback; std::atomic_size_t m_count = 0; std::chrono::time_point m_start = std::chrono::steady_clock::now(); std::thread m_loop; + bool m_open = true; +}; + +/// This is the default callback that prints status in the following format: +/// `Building bigram index: 1% [1m 29s] [<1h 46m 49s]` or +/// `% [] [<]`. +struct DefaultProgressCallback { + DefaultProgressCallback() = default; + explicit DefaultProgressCallback(std::string caption); + DefaultProgressCallback(DefaultProgressCallback const&) = default; + DefaultProgressCallback(DefaultProgressCallback&&) noexcept = default; + DefaultProgressCallback& operator=(DefaultProgressCallback const&) = default; + DefaultProgressCallback& operator=(DefaultProgressCallback&&) noexcept = default; + ~DefaultProgressCallback() = default; + void operator()(Progress progress, std::chrono::time_point start); + + private: + std::size_t m_previous = 0; + std::string m_caption; + std::size_t m_prev_msg_len = 0; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/raw_cursor.hpp b/include/pisa/v1/raw_cursor.hpp index 2c47d0640..0d7fefa39 100644 --- a/include/pisa/v1/raw_cursor.hpp +++ b/include/pisa/v1/raw_cursor.hpp @@ -154,6 +154,7 @@ struct RawWriter { template struct CursorTraits> { + using Value = T; using Writer = RawWriter; using Reader = RawReader; constexpr static auto encoding() -> std::uint32_t { return EncodingId::Raw + sizeof(T); } diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp index d81d412e8..ea1ac580c 100644 --- a/include/pisa/v1/score_index.hpp +++ b/include/pisa/v1/score_index.hpp @@ -8,6 +8,16 @@ namespace pisa::v1 { +struct FixedBlock { + std::size_t size; +}; + +struct VariableBlock { + float lambda; +}; + +using BlockType = std::variant; + template & os, Writer write } auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata; -auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t threads) - -> IndexMetadata; +auto bm_score_index(IndexMetadata meta, BlockType block_type, std::size_t threads) -> IndexMetadata; } // namespace pisa::v1 diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh index 8cfaa4886..bae1ea935 100644 --- a/script/cw09b-est.sh +++ b/script/cw09b-est.sh @@ -6,12 +6,10 @@ ENCODING="simdbp" # v1 BASENAME="/data/michal/work/v1/cw09b/cw09b-${ENCODING}" THREADS=4 QUERIES="/home/michal/05.clean.shuf.test" -#QUERIES="/home/michal/topics.web.51-200.jl" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-est-lm" -#OUTPUT_DIR="/data/michal/intersect/cw09b-est-top20" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" -#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.web.51-200" -#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.top1000.bm25.05.efficiency_topics.no_dups.1k" -#THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top15.bm25.05.efficiency_topics.no_dups.1k" +PAIRS="/home/michal/real.aol.top6m.jl" +PAIR_INDEX_BASENAME="${BASENAME}-pair" THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=100 diff --git a/script/cw09b.sh b/script/cw09b.sh index 35eb77d6c..07253225b 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -18,6 +18,7 @@ echo " QUERIES = ${QUERIES}" echo " FILTERED_QUERIES = ${FILTERED_QUERIES}" echo " K = ${K}" echo " THRESHOLDS = ${THRESHOLDS}" +echo " QUERY_LIMIT = ${QUERY_LIMIT}" echo "" set -x @@ -33,15 +34,16 @@ mkdir -p ${OUTPUT_DIR} #${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 # Filter out queries witout existing terms. -paste -d: ${QUERIES} ${THRESHOLDS} \ - | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ - | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} +#paste -d: ${QUERIES} ${THRESHOLDS} \ +# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ +# | head -${QUERY_LIMIT} \ +# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). -${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 # Extract intersections -${PISA_BIN}/intersection -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ +${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ | grep -v "\[warning\]" \ > ${OUTPUT_DIR}/intersections.jl @@ -54,55 +56,55 @@ ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 # Run benchmarks -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ # > ${OUTPUT_DIR}/bench.bmw-threshold -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ > ${OUTPUT_DIR}/bench.maxscore-threshold -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ # > ${OUTPUT_DIR}/bench.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ > ${OUTPUT_DIR}/bench.unigram-union-lookup -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ # > ${OUTPUT_DIR}/bench.union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ # --benchmark --algorithm lookup-union \ # > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ # --benchmark --algorithm lookup-union \ # > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 # Analyze -# ${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ +# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ # > ${OUTPUT_DIR}/stats.maxscore-threshold -# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ +# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ # > ${OUTPUT_DIR}/stats.maxscore-union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ > ${OUTPUT_DIR}/stats.unigram-union-lookup -# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ +# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ # > ${OUTPUT_DIR}/stats.union-lookup -${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ > ${OUTPUT_DIR}/stats.lookup-union -# ${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ +# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ # --inspect --algorithm lookup-union \ # > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 # Evaluate -#${PISA_BIN}/query -i "${BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore \ # > "${OUTPUT_DIR}/eval.maxscore-threshold" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --algorithm maxscore-union-lookup \ # > "${OUTPUT_DIR}/eval.maxscore-union-lookup" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --algorithm unigram-union-lookup \ # > "${OUTPUT_DIR}/eval.unigram-union-lookup" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm union-lookup \ # > "${OUTPUT_DIR}/eval.union-lookup" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --algorithm lookup-union \ # > "${OUTPUT_DIR}/eval.lookup-union" -#${PISA_BIN}/query -i "${BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ # > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 88ca50602..4ffd34d80 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,11 +20,11 @@ # CLI11 #) -add_executable(queries queries.cpp) -target_link_libraries(queries - pisa - CLI11 -) +#add_executable(queries queries.cpp) +#target_link_libraries(queries +# pisa +# CLI11 +#) #add_executable(evaluate_queries evaluate_queries.cpp) #target_link_libraries(evaluate_queries @@ -89,11 +89,11 @@ target_link_libraries(queries # CLI11 #) -add_executable(compute_intersection compute_intersection.cpp) -target_link_libraries(compute_intersection - pisa - CLI11 -) +#add_executable(compute_intersection compute_intersection.cpp) +#target_link_libraries(compute_intersection +# pisa +# CLI11 +#) add_executable(lexicon lexicon.cpp) target_link_libraries(lexicon diff --git a/src/v1/index.cpp b/src/v1/index.cpp index 61128fe13..7b02777e8 100644 --- a/src/v1/index.cpp +++ b/src/v1/index.cpp @@ -42,6 +42,14 @@ namespace pisa::v1 { return m_document_lengths.size(); } +[[nodiscard]] auto BaseIndex::num_pairs() const -> std::size_t +{ + if (not m_bigrams) { + throw std::logic_error("Bigrams are missing"); + } + return m_bigrams->mapping.size(); +} + [[nodiscard]] auto BaseIndex::document_length(DocId docid) const -> std::uint32_t { return m_document_lengths[docid]; @@ -157,4 +165,9 @@ template auto BaseIndex::fetch_bigram_payloads<1>(TermId bigram) const return m_quantized_max_scores.at(term); } +[[nodiscard]] auto BaseIndex::pairs() const -> tl::optional const>> +{ + return m_bigrams.map([](auto&& bigrams) { return bigrams.mapping; }); +} + } // namespace pisa::v1 diff --git a/src/v1/index_builder.cpp b/src/v1/index_builder.cpp index 21c342061..16954b80b 100644 --- a/src/v1/index_builder.cpp +++ b/src/v1/index_builder.cpp @@ -1,7 +1,12 @@ +#include #include #include +#include +#include +#include #include +#include #include "codec/simdbp.hpp" #include "v1/blocked_cursor.hpp" @@ -9,11 +14,130 @@ #include "v1/cursor_intersection.hpp" #include "v1/default_index_runner.hpp" #include "v1/index_builder.hpp" +#include "v1/index_metadata.hpp" +#include "v1/io.hpp" #include "v1/query.hpp" #include "v1/scorer/bm25.hpp" namespace pisa::v1 { +[[nodiscard]] auto make_temp() -> std::string +{ + return (boost::filesystem::temp_directory_path() / boost::filesystem::unique_path()).string(); +} + +template +void write_document_header(Index&& index, std::ostream& os) +{ + using index_type = std::decay_t; + using writer_type = typename CursorTraits::Writer; + PostingBuilder(writer_type{index.num_documents()}).write_header(os); +} + +template +void write_payload_header(Index&& index, std::ostream& os) +{ + using index_type = std::decay_t; + using writer_type = typename CursorTraits::Writer; + using value_type = typename CursorTraits::Value; + PostingBuilder(writer_type{index.num_documents()}).write_header(os); +} + +auto merge_into(PostingFilePaths const& batch, + std::ofstream& posting_sink, + std::ofstream& offsets_sink, + std::size_t shift) -> std::size_t +{ + std::ifstream batch_stream(batch.postings); + batch_stream.seekg(0, std::ios::end); + std::streamsize size = batch_stream.tellg(); + batch_stream.seekg(0, std::ios::beg); + posting_sink << batch_stream.rdbuf(); + auto offsets = load_vector(batch.offsets); + std::transform(offsets.begin(), offsets.end(), offsets.begin(), [shift](auto offset) { + return shift + offset; + }); + + // Because each batch has one superfluous offset indicating the end of data. + // Effectively, the last offset of batch `i` overlaps with the first offsets of batch `i + 1`. + std::size_t start = shift > 0 ? 8U : 0U; + + offsets_sink.write(reinterpret_cast(offsets.data()) + start, + offsets.size() * sizeof(std::size_t) - start); + return shift + size; +} + +template +void merge_postings(std::string const& message, + std::ofstream& posting_sink, + std::ofstream& offset_sink, + Rng&& batches) +{ + ProgressStatus status( + batches.size(), DefaultProgressCallback(message), std::chrono::milliseconds(500)); + std::size_t shift = 0; + for (auto&& batch : batches) { + shift = merge_into(batch, posting_sink, offset_sink, shift); + boost::filesystem::remove(batch.postings); + boost::filesystem::remove(batch.offsets); + status += 1; + }; +} + +/// Represents open-ended interval [begin, end). +template +struct Interval { + Interval(std::size_t begin, std::size_t end, gsl::span const& elements) + : m_begin(begin), m_end(end), m_elements(elements) + { + } + [[nodiscard]] auto span() const -> gsl::span + { + return m_elements.subspan(begin(), end() - begin()); + } + [[nodiscard]] auto begin() const -> std::size_t { return m_begin; } + [[nodiscard]] auto end() const -> std::size_t { return m_end; } + + private: + std::size_t m_begin; + std::size_t m_end; + gsl::span const& m_elements; +}; + +template +struct BatchConcurrentBuilder { + explicit BatchConcurrentBuilder(std::size_t num_batches, gsl::span elements) + : m_num_batches(num_batches), m_elements(elements) + { + runtime_assert(elements.size() >= num_batches).or_exit([&] { + return fmt::format( + "The number of elements ({}) must be at least the number of batches ({})", + elements.size(), + num_batches); + }); + } + + template + void execute_batch_jobs(BatchFn batch_job) + { + auto batch_size = m_elements.size() / m_num_batches; + tbb::task_group group; + for (auto batch = 0; batch < m_num_batches; batch += 1) { + group.run([batch_size, batch_job, batch, this] { + auto first_idx = batch * batch_size; + auto last_idx = batch < this->m_num_batches - 1 ? (batch + 1) * batch_size + : this->m_elements.size(); + batch_job(batch, Interval(first_idx, last_idx, m_elements)); + }); + } + group.wait(); + } + + private: + std::size_t m_num_batches; + gsl::span m_elements; +}; + auto collect_unique_bigrams(std::vector const& queries, std::function const& callback) -> std::vector> @@ -45,7 +169,7 @@ auto verify_compressed_index(std::string const& input, std::string_view output) auto meta = IndexMetadata::from_file(fmt::format("{}.yml", output)); auto run = index_runner(meta); ProgressStatus status( - collection.size(), DefaultProgress("Verifying"), std::chrono::milliseconds(100)); + collection.size(), DefaultProgressCallback("Verifying"), std::chrono::milliseconds(500)); run([&](auto&& index) { auto sequence_iter = collection.begin(); for (auto term = 0; term < index.num_terms(); term += 1, ++sequence_iter) { @@ -87,6 +211,97 @@ auto verify_compressed_index(std::string const& input, std::string_view output) return errors; } +[[nodiscard]] auto build_scored_pair_index(IndexMetadata meta, + std::string const& index_basename, + std::vector> const& pairs, + std::size_t threads) + -> std::pair +{ + auto run = scored_index_runner(std::move(meta)); + + PostingFilePaths scores_0 = {.postings = fmt::format("{}.bigram_bm25_0", index_basename), + .offsets = + fmt::format("{}.bigram_bm25_offsets_0", index_basename)}; + PostingFilePaths scores_1 = {.postings = fmt::format("{}.bigram_bm25_1", index_basename), + .offsets = + fmt::format("{}.bigram_bm25_offsets_1", index_basename)}; + std::ofstream score_out_0(scores_0.postings); + std::ofstream score_out_1(scores_1.postings); + + ProgressStatus status(pairs.size(), + DefaultProgressCallback("Building scored pair index"), + std::chrono::milliseconds(500)); + + std::vector> batch_files(threads); + + BatchConcurrentBuilder batch_builder(threads, + gsl::span const>(pairs)); + run([&](auto&& index) { + using index_type = std::decay_t; + using score_writer_type = + typename CursorTraits::Writer; + + batch_builder.execute_batch_jobs( + [&status, &batch_files, &index](auto batch_idx, auto interval) { + auto scores_file_0 = make_temp(); + auto scores_file_1 = make_temp(); + auto score_offsets_file_0 = make_temp(); + auto score_offsets_file_1 = make_temp(); + std::ofstream score_out_0(scores_file_0); + std::ofstream score_out_1(scores_file_1); + + PostingBuilder score_builder_0(score_writer_type{}); + PostingBuilder score_builder_1(score_writer_type{}); + + for (auto [left_term, right_term] : interval.span()) { + auto intersection = intersect({index.scored_cursor(left_term, VoidScorer{}), + index.scored_cursor(right_term, VoidScorer{})}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + gsl::at(payload, list_idx) = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + for_each(intersection, [&](auto& cursor) { + auto payload = cursor.payload(); + score_builder_0.accumulate(std::get<0>(payload)); + score_builder_1.accumulate(std::get<1>(payload)); + }); + score_builder_0.flush_segment(score_out_0); + score_builder_1.flush_segment(score_out_1); + status += 1; + } + write_span(gsl::make_span(score_builder_0.offsets()), score_offsets_file_0); + write_span(gsl::make_span(score_builder_1.offsets()), score_offsets_file_1); + batch_files[batch_idx] = {PostingFilePaths{scores_file_0, score_offsets_file_0}, + PostingFilePaths{scores_file_1, score_offsets_file_1}}; + }); + + write_payload_header(index, score_out_0); + write_payload_header(index, score_out_1); + }); + + std::ofstream score_offsets_out_0(scores_0.offsets); + std::ofstream score_offsets_out_1(scores_1.offsets); + + merge_postings( + "Merging scores<0>", + score_out_0, + score_offsets_out_0, + ranges::views::transform(batch_files, [](auto&& files) { return std::get<0>(files); })); + + merge_postings( + "Merging scores<1>", + score_out_1, + score_offsets_out_1, + ranges::views::transform(batch_files, [](auto&& files) { return std::get<1>(files); })); + + return {scores_0, scores_1}; +} + [[nodiscard]] auto build_scored_bigram_index(IndexMetadata meta, std::string const& index_basename, std::vector> const& bigrams) @@ -94,7 +309,6 @@ auto verify_compressed_index(std::string const& input, std::string_view output) { auto run = scored_index_runner(std::move(meta)); - std::vector> pair_mapping; auto scores_file_0 = fmt::format("{}.bigram_bm25_0", index_basename); auto scores_file_1 = fmt::format("{}.bigram_bm25_1", index_basename); auto score_offsets_file_0 = fmt::format("{}.bigram_bm25_offsets_0", index_basename); @@ -104,8 +318,8 @@ auto verify_compressed_index(std::string const& input, std::string_view output) run([&](auto&& index) { ProgressStatus status(bigrams.size(), - DefaultProgress("Building scored index"), - std::chrono::milliseconds(100)); + DefaultProgressCallback("Building scored index"), + std::chrono::milliseconds(500)); using index_type = std::decay_t; using score_writer_type = typename CursorTraits::Writer; @@ -230,87 +444,159 @@ struct HeapPriorityQueue { | ranges::to_vector; } -auto build_bigram_index(IndexMetadata meta, std::vector> const& bigrams) - -> IndexMetadata +template +auto build_pair_batch(Index&& index, + gsl::span const> pairs, + ProgressStatus& status) + -> std::pair>> { - Expects(not bigrams.empty()); - auto index_basename = meta.get_basename(); - auto run = index_runner(meta); + using index_type = std::decay_t; + using document_writer_type = + typename CursorTraits::Writer; + using frequency_writer_type = + typename CursorTraits::Writer; std::vector> pair_mapping; - auto documents_file = fmt::format("{}.bigram_documents", index_basename); - auto frequencies_file_0 = fmt::format("{}.bigram_frequencies_0", index_basename); - auto frequencies_file_1 = fmt::format("{}.bigram_frequencies_1", index_basename); - auto document_offsets_file = fmt::format("{}.bigram_document_offsets", index_basename); - auto frequency_offsets_file_0 = fmt::format("{}.bigram_frequency_offsets_0", index_basename); - auto frequency_offsets_file_1 = fmt::format("{}.bigram_frequency_offsets_1", index_basename); - std::ofstream document_out(documents_file); - std::ofstream frequency_out_0(frequencies_file_0); - std::ofstream frequency_out_1(frequencies_file_1); + BigramMetadata batch_meta{.documents = {.postings = make_temp(), .offsets = make_temp()}, + .frequencies = {{.postings = make_temp(), .offsets = make_temp()}, + {.postings = make_temp(), .offsets = make_temp()}}}; + std::ofstream document_out(batch_meta.documents.postings); + std::ofstream frequency_out_0(batch_meta.frequencies.first.postings); + std::ofstream frequency_out_1(batch_meta.frequencies.second.postings); + + PostingBuilder document_builder(document_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_0(frequency_writer_type{index.num_documents()}); + PostingBuilder frequency_builder_1(frequency_writer_type{index.num_documents()}); + + for (auto [left_term, right_term] : pairs) { + auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, + std::array{0, 0}, + [](auto& payload, auto& cursor, auto list_idx) { + gsl::at(payload, list_idx) = cursor.payload(); + return payload; + }); + if (intersection.empty()) { + status += 1; + continue; + } + pair_mapping.push_back({left_term, right_term}); + for_each(intersection, [&](auto& cursor) { + document_builder.accumulate(*cursor); + auto payload = cursor.payload(); + frequency_builder_0.accumulate(std::get<0>(payload)); + frequency_builder_1.accumulate(std::get<1>(payload)); + }); + document_builder.flush_segment(document_out); + frequency_builder_0.flush_segment(frequency_out_0); + frequency_builder_1.flush_segment(frequency_out_1); + status += 1; + } + write_span(gsl::make_span(document_builder.offsets()), batch_meta.documents.offsets); + write_span(gsl::make_span(frequency_builder_0.offsets()), batch_meta.frequencies.first.offsets); + write_span(gsl::make_span(frequency_builder_1.offsets()), + batch_meta.frequencies.second.offsets); + return {std::move(batch_meta), std::move(pair_mapping)}; +} + +auto build_pair_index(IndexMetadata meta, + std::vector> const& pairs, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata +{ + runtime_assert(not pairs.empty()).or_throw("Pair index must contain pairs but none passed"); + std::string index_basename = clone_path.value_or(std::string(meta.get_basename())); + auto run = index_runner(meta); + + std::vector>> pair_mapping(threads); + std::vector batch_meta(threads); + + PostingFilePaths documents{.postings = fmt::format("{}.bigram_documents", index_basename), + .offsets = + fmt::format("{}.bigram_document_offsets", index_basename)}; + PostingFilePaths frequencies_0{ + .postings = fmt::format("{}.bigram_frequencies_0", index_basename), + .offsets = fmt::format("{}.bigram_frequency_offsets_0", index_basename)}; + PostingFilePaths frequencies_1{ + .postings = fmt::format("{}.bigram_frequencies_1", index_basename), + .offsets = fmt::format("{}.bigram_frequency_offsets_1", index_basename)}; + + std::ofstream document_out(documents.postings); + std::ofstream frequency_out_0(frequencies_0.postings); + std::ofstream frequency_out_1(frequencies_1.postings); + + std::ofstream document_offsets_out(documents.offsets); + std::ofstream frequency_offsets_out_0(frequencies_0.offsets); + std::ofstream frequency_offsets_out_1(frequencies_1.offsets); + + ProgressStatus status(pairs.size(), + DefaultProgressCallback("Building bigram index"), + std::chrono::milliseconds(500)); + + BatchConcurrentBuilder batch_builder(threads, + gsl::span const>(pairs)); run([&](auto&& index) { - ProgressStatus status(bigrams.size(), - DefaultProgress("Building bigram index"), - std::chrono::milliseconds(100)); using index_type = std::decay_t; using document_writer_type = typename CursorTraits::Writer; using frequency_writer_type = typename CursorTraits::Writer; - PostingBuilder document_builder(document_writer_type{index.num_documents()}); - PostingBuilder frequency_builder_0(frequency_writer_type{index.num_documents()}); - PostingBuilder frequency_builder_1(frequency_writer_type{index.num_documents()}); - - document_builder.write_header(document_out); - frequency_builder_0.write_header(frequency_out_0); - frequency_builder_1.write_header(frequency_out_1); + batch_builder.execute_batch_jobs([&](auto batch_idx, auto interval) { + auto res = build_pair_batch(index, interval.span(), status); + batch_meta[batch_idx] = res.first; + pair_mapping[batch_idx] = std::move(res.second); + }); - for (auto [left_term, right_term] : bigrams) { - auto intersection = intersect({index.cursor(left_term), index.cursor(right_term)}, - std::array{0, 0}, - [](auto& payload, auto& cursor, auto list_idx) { - gsl::at(payload, list_idx) = cursor.payload(); - return payload; - }); - if (intersection.empty()) { - status += 1; - continue; - } - pair_mapping.push_back({left_term, right_term}); - for_each(intersection, [&](auto& cursor) { - document_builder.accumulate(*cursor); - auto payload = cursor.payload(); - frequency_builder_0.accumulate(std::get<0>(payload)); - frequency_builder_1.accumulate(std::get<1>(payload)); - }); - document_builder.flush_segment(document_out); - frequency_builder_0.flush_segment(frequency_out_0); - frequency_builder_1.flush_segment(frequency_out_1); - status += 1; - } - std::cerr << "Writing offsets..."; - write_span(gsl::make_span(document_builder.offsets()), document_offsets_file); - write_span(gsl::make_span(frequency_builder_0.offsets()), frequency_offsets_file_0); - write_span(gsl::make_span(frequency_builder_1.offsets()), frequency_offsets_file_1); - std::cerr << " Done.\n"; + write_document_header(index, document_out); + write_payload_header(index, frequency_out_0); + write_payload_header(index, frequency_out_1); }); - BigramMetadata bigram_meta{ - .documents = {.postings = documents_file, .offsets = document_offsets_file}, - .frequencies = {{.postings = frequencies_file_0, .offsets = frequency_offsets_file_0}, - {.postings = frequencies_file_1, .offsets = frequency_offsets_file_1}}, - .scores = {}, - .mapping = fmt::format("{}.bigram_mapping", index_basename), - .count = pair_mapping.size()}; + + merge_postings( + "Merging documents", + document_out, + document_offsets_out, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.documents; })); + + merge_postings( + "Merging frequencies<0>", + frequency_out_0, + frequency_offsets_out_0, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.frequencies.first; })); + + merge_postings( + "Merging frequencies<1>", + frequency_out_1, + frequency_offsets_out_1, + ranges::views::transform(batch_meta, [](auto&& meta) { return meta.frequencies.second; })); + + auto count = std::accumulate(pair_mapping.begin(), + pair_mapping.end(), + std::size_t{0}, + [](auto acc, auto&& m) { return acc + m.size(); }); + BigramMetadata bigram_meta{.documents = documents, + .frequencies = {frequencies_0, frequencies_1}, + .scores = {}, + .mapping = fmt::format("{}.bigram_mapping", index_basename), + .count = count}; + if (not meta.scores.empty()) { - bigram_meta.scores.push_back(build_scored_bigram_index(meta, index_basename, bigrams)); + bigram_meta.scores.push_back(build_scored_pair_index(meta, index_basename, pairs, threads)); } meta.bigrams = bigram_meta; std::cerr << "Writing metadata..."; - meta.update(); + if (clone_path) { + meta.write(append_extension(clone_path.value())); + } else { + meta.update(); + } std::cerr << " Done.\nWriting bigram mapping..."; - write_span(gsl::make_span(pair_mapping), meta.bigrams->mapping); + std::ofstream os(meta.bigrams->mapping); + for (auto mapping_batch : pair_mapping) { + write_span(gsl::make_span(mapping_batch), os); + } std::cerr << " Done.\n"; return meta; } diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 1fda06cbf..08b6863a8 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -22,6 +22,23 @@ constexpr char const* MAX_SCORES = "max_scores"; constexpr char const* BLOCK_MAX_SCORES = "block_max_scores"; constexpr char const* QUANTIZED_MAX_SCORES = "quantized_max_scores"; +[[nodiscard]] auto has_extension(std::string_view file_path, std::string_view extension) -> bool +{ + if (file_path.size() <= 4) { + return false; + } + return std::string_view(file_path.data(), file_path.size() - 4) == extension; +} + +[[nodiscard]] auto append_extension(std::string file_path) -> std::string +{ + using namespace std::literals; + if (has_extension(file_path, ".yml"sv)) { + return file_path; + } + return fmt::format("{}{}", file_path, ".yml"sv); +} + [[nodiscard]] auto resolve_yml(tl::optional const& arg) -> std::string { if (arg) { diff --git a/src/v1/progress_status.cpp b/src/v1/progress_status.cpp index 5bef98191..fa03c296e 100644 --- a/src/v1/progress_status.cpp +++ b/src/v1/progress_status.cpp @@ -1,55 +1,76 @@ #include "v1/progress_status.hpp" +#include +#include + +#include + namespace pisa::v1 { -std::ostream &format_interval(std::ostream &out, std::chrono::seconds time) +using std::chrono::hours; +using std::chrono::minutes; +using std::chrono::seconds; + +[[nodiscard]] auto format_interval(std::chrono::seconds time) -> std::string { - using std::chrono::hours; - using std::chrono::minutes; - using std::chrono::seconds; hours h = std::chrono::duration_cast(time); minutes m = std::chrono::duration_cast(time - h); seconds s = std::chrono::duration_cast(time - h - m); + std::ostringstream os; if (h.count() > 0) { - out << h.count() << "h "; + os << h.count() << "h "; } if (m.count() > 0) { - out << m.count() << "m "; + os << m.count() << "m "; } - out << s.count() << "s"; - return out; + os << s.count() << "s"; + return os.str(); +} + +[[nodiscard]] auto estimate_remaining(seconds elapsed, Progress const& progress) -> seconds +{ + return seconds(elapsed.count() * (progress.target - progress.count) / progress.count); } -DefaultProgress::DefaultProgress(std::string caption) : m_caption(std::move(caption)) +DefaultProgressCallback::DefaultProgressCallback(std::string caption) + : m_caption(std::move(caption)) { if (not m_caption.empty()) { m_caption.append(": "); } } -void DefaultProgress::operator()(std::size_t count, - std::size_t goal, - std::chrono::time_point start) +void DefaultProgressCallback::operator()(Progress current_progress, + std::chrono::time_point start) { - size_t progress = (100 * count) / goal; - // if (progress == m_previous) { - // return; - //} + std::size_t progress = (100 * current_progress.count) / current_progress.target; m_previous = progress; std::chrono::seconds elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); - std::cerr << '\r' << m_caption << progress << "% ["; - format_interval(std::cerr, elapsed); - std::cerr << "]"; + auto msg = fmt::format( + "{}{}\% [{}]{}", m_caption, progress, format_interval(elapsed), [&]() -> std::string { + if (current_progress.count > 0 && progress < 100) { + return fmt::format(" [<{}]", + format_interval(estimate_remaining(elapsed, current_progress))); + } + return ""; + }()); + std::cerr << '\r' << std::left << std::setfill(' ') << std::setw(m_prev_msg_len) << msg; if (progress == 100) { std::cerr << '\n'; } + m_prev_msg_len = msg.size(); } -ProgressStatus::~ProgressStatus() +void ProgressStatus::close() { - m_count = m_goal; - m_loop.join(); + if (m_open) { + m_count = m_target; + m_loop.join(); + m_open = false; + } } +ProgressStatus::~ProgressStatus() { close(); } + } // namespace pisa::v1 diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp index 6b202a16e..fd5ac039f 100644 --- a/src/v1/score_index.cpp +++ b/src/v1/score_index.cpp @@ -3,7 +3,9 @@ #include #include "codec/simdbp.hpp" +#include "score_opt_partition.hpp" #include "v1/blocked_cursor.hpp" +#include "v1/cursor/collect.hpp" #include "v1/default_index_runner.hpp" #include "v1/index_builder.hpp" #include "v1/index_metadata.hpp" @@ -12,11 +14,10 @@ #include "v1/score_index.hpp" #include "v1/scorer/bm25.hpp" -using pisa::v1::DefaultProgress; +using pisa::v1::DefaultProgressCallback; using pisa::v1::IndexMetadata; using pisa::v1::PostingFilePaths; using pisa::v1::ProgressStatus; -using pisa::v1::RawReader; using pisa::v1::RawWriter; using pisa::v1::TermId; using pisa::v1::write_span; @@ -33,10 +34,10 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata auto quantized_max_scores_path = fmt::format("{}.bm25.maxq", index_basename); run([&](auto&& index) { ProgressStatus calc_max_status(index.num_terms(), - DefaultProgress("Calculating max partial score"), + DefaultProgressCallback("Calculating max partial score"), std::chrono::milliseconds(100)); std::vector max_scores(index.num_terms(), 0.0F); - tbb::task_group group; + tbb::task_group group; // TODO(michal): Unused? auto batch_size = index.num_terms() / threads; for (auto thread_id = 0; thread_id < threads; thread_id += 1) { auto first_term = thread_id * batch_size; @@ -55,8 +56,9 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata }); } group.wait(); + calc_max_status.close(); auto max_score = *std::max_element(max_scores.begin(), max_scores.end()); - std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.", + std::cerr << fmt::format("Max partial score is: {}. It will be scaled to {}.\n", max_score, std::numeric_limits::max()); @@ -71,7 +73,7 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata quantizer); ProgressStatus status( - index.num_terms(), DefaultProgress("Scoring"), std::chrono::milliseconds(100)); + index.num_terms(), DefaultProgressCallback("Scoring"), std::chrono::milliseconds(100)); std::ofstream score_file_stream(postings_path); auto offsets = score_index(index, score_file_stream, @@ -91,8 +93,7 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata } // TODO: Use multiple threads -auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t threads) - -> IndexMetadata +auto bm_score_index(IndexMetadata meta, BlockType block_type, std::size_t threads) -> IndexMetadata { auto run = index_runner(meta); auto const& index_basename = meta.get_basename(); @@ -106,7 +107,7 @@ auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t thre run([&](auto&& index) { auto scorer = make_bm25(index); ProgressStatus status(index.num_terms(), - DefaultProgress("Calculating max-blocks"), + DefaultProgressCallback("Calculating max-blocks"), std::chrono::milliseconds(100)); std::ofstream document_out(paths.documents.postings); std::ofstream score_out(paths.payloads.postings); @@ -116,21 +117,35 @@ auto bm_score_index(IndexMetadata meta, std::size_t block_size, std::size_t thre score_builder.write_header(score_out); for (TermId term_id = 0; term_id < index.num_terms(); term_id += 1) { auto cursor = index.scored_cursor(term_id, scorer); - while (not cursor.empty()) { - auto max_score = 0.0F; - auto last_docid = 0; - for (auto idx = 0; idx < block_size && not cursor.empty(); ++idx) { - if (auto score = cursor.payload(); score > max_score) { - max_score = score; + if (std::holds_alternative(block_type)) { + auto block_size = std::get(block_type).size; + while (not cursor.empty()) { + auto max_score = 0.0F; + auto last_docid = 0; + for (auto idx = 0; idx < block_size && not cursor.empty(); ++idx) { + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + last_docid = cursor.value(); + cursor.advance(); } - last_docid = cursor.value(); - cursor.advance(); + document_builder.accumulate(last_docid); + score_builder.accumulate(max_score); } - document_builder.accumulate(last_docid); - score_builder.accumulate(max_score); + document_builder.flush_segment(document_out); + score_builder.flush_segment(score_out); + } else { + auto lambda = std::get(block_type).lambda; + auto eps1 = configuration::get().eps1_wand; + auto eps2 = configuration::get().eps2_wand; + auto vec = collect_with_payload(cursor); + auto partition = + score_opt_partition(vec.begin(), 0, vec.size(), eps1, eps2, lambda); + document_builder.write_segment( + document_out, partition.docids.begin(), partition.docids.end()); + score_builder.write_segment( + score_out, partition.max_values.begin(), partition.max_values.end()); } - document_builder.flush_segment(document_out); - score_builder.flush_segment(score_out); status += 1; } write_span(gsl::make_span(document_builder.offsets()), paths.documents.offsets); diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 1ce83f25b..7430492a1 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -9,6 +9,7 @@ #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/intersection.hpp" +#include "v1/io.hpp" #include "v1/query.hpp" #include "v1/score_index.hpp" @@ -74,10 +75,11 @@ struct IndexFixture { meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); } if (bm_score) { - meta = v1::bm_score_index(meta, 5, 1); + meta = v1::bm_score_index(meta, pisa::v1::FixedBlock{5}, 1); } if (build_bigrams) { - v1::build_bigram_index(meta, collect_unique_bigrams(test_queries(), []() {})); + v1::build_pair_index( + meta, collect_unique_bigrams(test_queries(), []() {}), tl::nullopt, 4); } } diff --git a/test/v1/test_v1_bigram_index.cpp b/test/v1/test_v1_bigram_index.cpp index aaa430e12..712c0276d 100644 --- a/test/v1/test_v1_bigram_index.cpp +++ b/test/v1/test_v1_bigram_index.cpp @@ -64,17 +64,20 @@ TEMPLATE_TEST_CASE("Bigram v intersection", run([&](auto&& index) { for (auto left = 0; left < q.get_term_ids().size(); left += 1) { for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { + CAPTURE(q.get_term_ids()[left]); + CAPTURE(q.get_term_ids()[right]); auto left_cursor = index.cursor(q.get_term_ids()[left]); auto right_cursor = index.cursor(q.get_term_ids()[right]); auto intersection = v1::intersect({left_cursor, right_cursor}, std::array{0, 0}, [](auto& acc, auto&& cursor, auto idx) { - acc[idx] = cursor.payload(); + gsl::at(acc, idx) = cursor.payload(); return acc; }); if (not intersection.empty()) { auto bigram_cursor = - *index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]); + index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]) + .value(); std::vector bigram_documents; std::vector bigram_frequencies_0; std::vector bigram_frequencies_1; @@ -102,3 +105,70 @@ TEMPLATE_TEST_CASE("Bigram v intersection", }); } } + +TEMPLATE_TEST_CASE("Scored pair index v. intersection", + "[v1][integration]", + (IndexFixture, + v1::RawCursor, + v1::RawCursor>), + (IndexFixture, + v1::PayloadBlockedCursor<::pisa::simdbp_block>, + v1::RawCursor>)) +{ + tbb::task_scheduler_init init(1); + TestType fixture; + auto index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = v1::IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + int idx = 0; + for (auto& q : test_queries()) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx++); + + auto run = v1::scored_index_runner(meta, + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.score_reader())); + std::vector results; + run([&](auto&& index) { + for (auto left = 0; left < q.get_term_ids().size(); left += 1) { + for (auto right = left + 1; right < q.get_term_ids().size(); right += 1) { + CAPTURE(q.get_term_ids()[left]); + CAPTURE(q.get_term_ids()[right]); + auto left_cursor = index.cursor(q.get_term_ids()[left]); + auto right_cursor = index.cursor(q.get_term_ids()[right]); + auto intersection = v1::intersect({left_cursor, right_cursor}, + std::array{0.0F, 0.0F}, + [](auto& acc, auto&& cursor, auto idx) { + gsl::at(acc, idx) = cursor.payload(); + return acc; + }); + if (not intersection.empty()) { + auto bigram_cursor = + index.bigram_cursor(q.get_term_ids()[left], q.get_term_ids()[right]) + .value(); + std::vector bigram_documents; + std::vector bigram_scores_0; + std::vector bigram_scores_1; + v1::for_each(bigram_cursor, [&](auto&& cursor) { + bigram_documents.push_back(*cursor); + auto payload = cursor.payload(); + bigram_scores_0.push_back(std::get<0>(payload)); + bigram_scores_1.push_back(std::get<1>(payload)); + }); + std::vector intersection_documents; + std::vector intersection_scores_0; + std::vector intersection_scores_1; + v1::for_each(intersection, [&](auto&& cursor) { + intersection_documents.push_back(*cursor); + auto payload = cursor.payload(); + intersection_scores_0.push_back(std::get<0>(payload)); + intersection_scores_1.push_back(std::get<1>(payload)); + }); + CHECK(bigram_documents == intersection_documents); + CHECK(bigram_scores_0 == intersection_scores_0); + REQUIRE(bigram_scores_1 == intersection_scores_1); + } + } + } + }); + } +} diff --git a/test/v1/test_v1_index.cpp b/test/v1/test_v1_index.cpp index 36727887e..4d87f2875 100644 --- a/test/v1/test_v1_index.cpp +++ b/test/v1/test_v1_index.cpp @@ -27,6 +27,8 @@ #include "v1/types.hpp" using pisa::binary_freq_collection; +using pisa::v1::BigramMetadata; +using pisa::v1::build_pair_index; using pisa::v1::compress_binary_collection; using pisa::v1::DocId; using pisa::v1::DocumentBitSequenceCursor; @@ -42,6 +44,7 @@ using pisa::v1::PayloadBitSequenceCursor; using pisa::v1::PayloadBlockedCursor; using pisa::v1::PayloadBlockedWriter; using pisa::v1::PositiveSequence; +using pisa::v1::PostingFilePaths; using pisa::v1::Query; using pisa::v1::RawCursor; using pisa::v1::RawWriter; @@ -196,3 +199,48 @@ TEST_CASE("Select best bigrams", "[v1][integration]") REQUIRE(best_bigrams == std::vector>{{3, 4}, {2, 3}, {1, 2}}); } } + +TEST_CASE("Build pair index", "[v1][integration]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture( + true, true, true, false); + std::string index_basename = (fixture.tmpdir().path() / "inv").string(); + auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); + SECTION("In place") + { + build_pair_index(meta, {{0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}}, tl::nullopt, 4); + auto run = index_runner(IndexMetadata::from_file(fmt::format("{}.yml", index_basename))); + run([&](auto&& index) { + REQUIRE(index.bigram_cursor(0, 1).has_value()); + REQUIRE(index.bigram_cursor(1, 0).has_value()); + REQUIRE(not index.bigram_cursor(1, 2).has_value()); + REQUIRE(not index.bigram_cursor(2, 1).has_value()); + }); + } + SECTION("Cloned") + { + auto cloned_basename = (fixture.tmpdir().path() / "cloned").string(); + build_pair_index( + meta, {{0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}}, tl::make_optional(cloned_basename), 4); + SECTION("Original index has no pairs") + { + auto run = + index_runner(IndexMetadata::from_file(fmt::format("{}.yml", index_basename))); + run([&](auto&& index) { + REQUIRE_THROWS_AS(index.bigram_cursor(0, 1), std::logic_error); + }); + } + SECTION("New index has pairs") + { + auto run = + index_runner(IndexMetadata::from_file(fmt::format("{}.yml", cloned_basename))); + run([&](auto&& index) { + REQUIRE(index.bigram_cursor(0, 1).has_value()); + REQUIRE(index.bigram_cursor(1, 0).has_value()); + REQUIRE(not index.bigram_cursor(1, 2).has_value()); + REQUIRE(not index.bigram_cursor(2, 1).has_value()); + }); + } + } +} diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index f2a47f190..260f9e863 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -288,12 +288,11 @@ TEMPLATE_TEST_CASE("Query", return results; }(); - constexpr float max_partial_score = 16.5724F; - auto quantizer = [&](float score) { - return static_cast(score * std::numeric_limits::max() - / max_partial_score); - }; - // TODO(michal): test the quantized results + // constexpr float max_partial_score = 16.5724F; + // auto quantizer = [&](float score) { + // return static_cast(score * std::numeric_limits::max() + // / max_partial_score); + //}; } } diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index 8d62026f8..a96dcc423 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -27,3 +27,9 @@ target_link_libraries(intersection pisa CLI11) add_executable(select-pairs select_pairs.cpp) target_link_libraries(select-pairs pisa CLI11) + +add_executable(count-postings count_postings.cpp) +target_link_libraries(count-postings pisa CLI11) + +add_executable(stats stats.cpp) +target_link_libraries(stats pisa CLI11) diff --git a/v1/bigram_index.cpp b/v1/bigram_index.cpp index e48d952c4..0355d8f35 100644 --- a/v1/bigram_index.cpp +++ b/v1/bigram_index.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "app.hpp" #include "v1/index_builder.hpp" @@ -10,18 +11,20 @@ #include "v1/types.hpp" using pisa::App; -using pisa::v1::build_bigram_index; +using pisa::v1::build_pair_index; using pisa::v1::collect_unique_bigrams; -using pisa::v1::DefaultProgress; +using pisa::v1::DefaultProgressCallback; using pisa::v1::ProgressStatus; namespace arg = pisa::arg; int main(int argc, char** argv) { - std::optional terms_file{}; + tl::optional clone_path{}; - App> app{"Creates a v1 bigram index."}; + App, arg::Threads> app{ + "Creates a v1 bigram index."}; + app.add_option("--clone", clone_path, "Instead", false); CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); @@ -31,9 +34,11 @@ int main(int argc, char** argv) spdlog::info("Collected {} queries", queries.size()); spdlog::info("Collecting bigrams..."); - ProgressStatus status(queries.size(), DefaultProgress{}, std::chrono::milliseconds(1000)); + ProgressStatus status( + queries.size(), DefaultProgressCallback{}, std::chrono::milliseconds(1000)); auto bigrams = collect_unique_bigrams(queries, [&]() { status += 1; }); + status.close(); spdlog::info("Collected {} bigrams", bigrams.size()); - build_bigram_index(meta, bigrams); + build_pair_index(meta, bigrams, clone_path, app.threads()); return 0; } diff --git a/v1/bmscore.cpp b/v1/bmscore.cpp index 76ddc6a00..9f5c868f7 100644 --- a/v1/bmscore.cpp +++ b/v1/bmscore.cpp @@ -8,18 +8,30 @@ #include "v1/score_index.hpp" using pisa::App; +using pisa::v1::BlockType; +using pisa::v1::FixedBlock; +using pisa::v1::VariableBlock; namespace arg = pisa::arg; int main(int argc, char** argv) { std::optional yml{}; - std::size_t block_size; + std::size_t block_size = 128; std::size_t threads = std::thread::hardware_concurrency(); + std::optional lambda{}; App app{"Constructs block-max score lists for v1 index."}; - app.add_option("--block-size", block_size, "The size of a block for max scores", false) - ->required(); + app.add_option("--block-size", block_size, "The size of a block for max scores", true); + app.add_option("--variable-blocks", lambda, "The size of a block for max scores", false); CLI11_PARSE(app, argc, argv); - pisa::v1::bm_score_index(app.index_metadata(), block_size, app.threads()); + + auto block_type = [&]() -> BlockType { + if (lambda) { + return VariableBlock{*lambda}; + } + return FixedBlock{block_size}; + }(); + + pisa::v1::bm_score_index(app.index_metadata(), block_type, app.threads()); return 0; } diff --git a/v1/count_postings.cpp b/v1/count_postings.cpp new file mode 100644 index 000000000..6ebf286af --- /dev/null +++ b/v1/count_postings.cpp @@ -0,0 +1,51 @@ +#include + +#include +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" +#include "v1/progress_status.hpp" + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + bool pair_index = false; + + pisa::App app("Simply counts all postings in the index"); + app.add_flag("--pairs", pair_index, "Count postings in the pair index instead"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + std::size_t count{0}; + pisa::v1::index_runner(meta)([&](auto&& index) { + if (pair_index) { + pisa::v1::ProgressStatus status( + index.pairs()->size(), + pisa::v1::DefaultProgressCallback("Counting pair postings"), + std::chrono::milliseconds(500)); + for (auto term_pair : index.pairs().value()) { + count += + index.bigram_cursor(std::get<0>(term_pair), std::get<1>(term_pair))->size(); + status += 1; + } + } else { + pisa::v1::ProgressStatus status( + index.num_terms(), + pisa::v1::DefaultProgressCallback("Counting term postings"), + std::chrono::milliseconds(500)); + for (pisa::v1::TermId id = 0; id < index.num_terms(); id += 1) { + count += index.term_posting_count(id); + status += 1; + } + } + }); + std::cout << count << '\n'; + return 0; +} diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index 627209208..4b00dad6f 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -14,13 +14,16 @@ int main(int argc, char** argv) spdlog::drop(""); spdlog::set_default_logger(spdlog::stderr_color_mt("")); + std::size_t min_query_len = 1; + pisa::App> app( "Filters out empty queries against a v1 index."); + app.add_option("--min", min_query_len, "Minimum query legth to consider"); CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); for (auto&& query : app.query_range(meta)) { - if (not query.get_term_ids().empty()) { + if (query.get_term_ids().size() >= min_query_len) { std::cout << *query.to_json() << '\n'; } } diff --git a/v1/intersection.cpp b/v1/intersection.cpp index 7f9ce0287..7834000d5 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -38,6 +38,10 @@ auto compute_intersection(Index const& index, { auto const term_ids = term_selection ? query.filtered_terms(*term_selection) : query.get_term_ids(); + if (term_ids.size() == 1) { + auto cursor = index.max_scored_cursor(term_ids[0], make_bm25(index)); + return Intersection{cursor.size(), cursor.max_score()}; + } auto cursors = index.scored_cursors(gsl::make_span(term_ids), make_bm25(index)); auto intersection = intersect(cursors, 0.0F, [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { diff --git a/v1/stats.cpp b/v1/stats.cpp new file mode 100644 index 000000000..3376997bc --- /dev/null +++ b/v1/stats.cpp @@ -0,0 +1,29 @@ +#include + +#include +#include +#include + +#include "app.hpp" +#include "v1/default_index_runner.hpp" + +namespace arg = pisa::arg; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Simply counts all postings in the index"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + std::size_t count{0}; + pisa::v1::index_runner(meta)([&](auto&& index) { + std::cout << fmt::format("#terms: {}\n", index.num_terms()); + std::cout << fmt::format("#documents: {}\n", index.num_documents()); + std::cout << fmt::format("#pairs: {}\n", index.num_pairs()); + std::cout << fmt::format("avg. document length: {}\n", index.avg_document_length()); + }); + return 0; +} From 7f5eb8bfa42e7472caf95d32578f4f33f722e834 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 15 Jan 2020 18:40:53 +0000 Subject: [PATCH 49/56] Fix queries test after merge --- test/v1/index_fixture.hpp | 5 +++++ test/v1/test_v1_queries.cpp | 35 ++++++++++++++++++++++------------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 7430492a1..17519bdc2 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -83,6 +83,11 @@ struct IndexFixture { } } + void rebuild_bm_scores(pisa::v1::BlockType block_type) + { + v1::bm_score_index(meta(), block_type, 1); + } + [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } [[nodiscard]] auto document_reader() const { return m_document_reader; } [[nodiscard]] auto frequency_reader() const { return m_frequency_reader; } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 260f9e863..e3aa1eff5 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -147,26 +147,33 @@ TEMPLATE_TEST_CASE("Query", Index, RawCursor>, Index, RawCursor>>::get(); TestType fixture; - auto input_data = GENERATE(table({ - {"daat_or", false}, - {"maxscore", false}, - {"maxscore", true}, - {"wand", false}, - {"wand", true}, - {"bmw", false}, - {"bmw", true}, - {"maxscore_union_lookup", true}, - {"unigram_union_lookup", true}, - {"union_lookup", true}, - {"lookup_union", true}, + auto input_data = GENERATE(table({ + {"daat_or", false, false}, + {"maxscore", false, false}, + {"maxscore", true, false}, + {"wand", false, false}, + {"wand", true, false}, + {"bmw", false, false}, + {"bmw", true, false}, + {"bmw", false, true}, + {"bmw", true, true}, + {"maxscore_union_lookup", true, false}, + {"unigram_union_lookup", true, false}, + {"union_lookup", true, false}, + {"lookup_union", true, false}, })); std::string algorithm = std::get<0>(input_data); bool with_threshold = std::get<1>(input_data); + bool rebuild_with_variable_blocks = std::get<2>(input_data); + if (rebuild_with_variable_blocks) { + fixture.rebuild_bm_scores(pisa::v1::VariableBlock{12.0}); + } CAPTURE(algorithm); CAPTURE(with_threshold); auto index_basename = (fixture.tmpdir().path() / "inv").string(); auto meta = IndexMetadata::from_file(fmt::format("{}.yml", index_basename)); - ::pisa::ranked_or_query or_q(10); + auto heap = ::pisa::topk_queue(10); + ::pisa::ranked_or_query or_q(heap); auto run_query = [](std::string const& name, auto query, auto&& index, auto scorer) { if (name == "daat_or") { return daat_or(query, index, ::pisa::topk_queue(10), scorer); @@ -202,6 +209,7 @@ TEMPLATE_TEST_CASE("Query", auto const intersections = pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); for (auto& query : test_queries()) { + heap.clear(); if (algorithm == "union_lookup" || algorithm == "lookup_union") { query.selections(gsl::make_span(intersections[idx])); } @@ -215,6 +223,7 @@ TEMPLATE_TEST_CASE("Query", ::pisa::bm25<::pisa::wand_data<::pisa::wand_data_raw>>(data->wdata), ::pisa::Query{{}, query.get_term_ids(), {}}), data->v0_index.num_docs()); + heap.finalize(); auto expected = or_q.topk(); if (with_threshold) { query.threshold(expected.back().first - 1.0F); From 72e7ef43665c78dc10eae6c0cadc88870a02f70e Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 21 Jan 2020 20:43:45 +0000 Subject: [PATCH 50/56] Improved UL --- include/pisa/topk_queue.hpp | 2 +- include/pisa/v1/cursor/labeled_cursor.hpp | 63 ++++++ include/pisa/v1/cursor/scoring_cursor.hpp | 16 ++ include/pisa/v1/document_payload_cursor.hpp | 7 + include/pisa/v1/index.hpp | 8 + include/pisa/v1/index_metadata.hpp | 3 +- include/pisa/v1/score_index.hpp | 5 +- include/pisa/v1/union_lookup.hpp | 234 +++++++++++++++++++- script/cw09b-bp.sh | 21 ++ script/cw09b-bpq.sh | 21 ++ script/cw09b-est.sh | 2 +- script/cw09b-url.sh | 15 ++ script/cw09b.sh | 72 +++--- script/cw12-url.sh | 15 ++ src/v1/index_metadata.cpp | 14 +- src/v1/score_index.cpp | 15 +- test/v1/index_fixture.hpp | 4 +- test/v1/test_v1_queries.cpp | 34 +-- v1/app.hpp | 7 +- v1/bmscore.cpp | 8 +- v1/filter_queries.cpp | 4 +- 21 files changed, 498 insertions(+), 72 deletions(-) create mode 100644 include/pisa/v1/cursor/labeled_cursor.hpp create mode 100644 script/cw09b-bp.sh create mode 100644 script/cw09b-bpq.sh create mode 100644 script/cw09b-url.sh create mode 100644 script/cw12-url.sh diff --git a/include/pisa/topk_queue.hpp b/include/pisa/topk_queue.hpp index 3a840ca1d..96644f46a 100644 --- a/include/pisa/topk_queue.hpp +++ b/include/pisa/topk_queue.hpp @@ -44,7 +44,7 @@ struct topk_queue { return true; } - [[nodiscard]] bool would_enter(float score) const { return score > m_threshold; } + [[nodiscard]] bool would_enter(float score) const { return score >= m_threshold; } void finalize() { diff --git a/include/pisa/v1/cursor/labeled_cursor.hpp b/include/pisa/v1/cursor/labeled_cursor.hpp new file mode 100644 index 000000000..3f44e0fff --- /dev/null +++ b/include/pisa/v1/cursor/labeled_cursor.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include "v1/cursor_traits.hpp" + +namespace pisa::v1 { + +template +struct LabeledCursor { + using Value = typename CursorTraits::Value; + + explicit constexpr LabeledCursor(Cursor cursor, Label label) + : m_cursor(std::move(cursor)), m_label(std::move(label)) + { + } + + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_cursor.value(); } + [[nodiscard]] constexpr auto payload() noexcept { return m_cursor.payload(); } + constexpr void advance() { m_cursor.advance(); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.advance_to_position(pos); } + constexpr void advance_to_geq(Value value) { m_cursor.advance_to_geq(value); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.position(); + } + // TODO: Support not sized + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.size(); } + [[nodiscard]] constexpr auto sentinel() const -> Value { return m_cursor.sentinel(); } + [[nodiscard]] constexpr auto label() const -> Label const& { return m_label; } + [[nodiscard]] constexpr auto max_score() const { return m_cursor.max_score(); } + + private: + Cursor m_cursor; + Label m_label; +}; + +template +auto label(Cursor cursor, Label label) +{ + return LabeledCursor(std::move(cursor), std::move(label)); +} + +template +auto label(std::vector cursors, LabelFn&& label_fn) +{ + using label_type = std::decay_t()))>; + std::vector> labeled; + std::transform(cursors.begin(), + cursors.end(), + std::back_inserter(labeled), + [&](Cursor&& cursor) { return label(cursor, label_fn(cursor)); }); + return labeled; +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/scoring_cursor.hpp b/include/pisa/v1/cursor/scoring_cursor.hpp index 8f7604d83..1fe0d929e 100644 --- a/include/pisa/v1/cursor/scoring_cursor.hpp +++ b/include/pisa/v1/cursor/scoring_cursor.hpp @@ -4,6 +4,7 @@ #include +#include "v1/cursor_traits.hpp" #include "v1/types.hpp" namespace pisa::v1 { @@ -156,4 +157,19 @@ template std::move(base_cursor), std::move(block_max_cursor), max_score); } +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/document_payload_cursor.hpp b/include/pisa/v1/document_payload_cursor.hpp index d00a675fc..a137873fb 100644 --- a/include/pisa/v1/document_payload_cursor.hpp +++ b/include/pisa/v1/document_payload_cursor.hpp @@ -4,6 +4,8 @@ #include +#include "v1/cursor_traits.hpp" + namespace pisa::v1 { template @@ -53,4 +55,9 @@ template std::move(payload_cursor)); } +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + } // namespace pisa::v1 diff --git a/include/pisa/v1/index.hpp b/include/pisa/v1/index.hpp index 88b5c3c82..41c20472d 100644 --- a/include/pisa/v1/index.hpp +++ b/include/pisa/v1/index.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include "v1/base_index.hpp" @@ -137,6 +138,13 @@ struct Index : public BaseIndex { } } + template + [[nodiscard]] auto cursors(gsl::span terms, Fn&& fn) const + { + return ranges::views::transform(terms, [this, &fn](auto term) { return fn(*this, term); }) + | ranges::to_vector; + } + template [[nodiscard]] auto scored_cursors(gsl::span terms, Scorer&& scorer) const { diff --git a/include/pisa/v1/index_metadata.hpp b/include/pisa/v1/index_metadata.hpp index cd281f371..6f8651cce 100644 --- a/include/pisa/v1/index_metadata.hpp +++ b/include/pisa/v1/index_metadata.hpp @@ -74,7 +74,8 @@ struct IndexMetadata final { void write(std::string const& file) const; void update() const; - [[nodiscard]] auto query_parser() const -> std::function; + [[nodiscard]] auto query_parser(tl::optional const& stop_words = tl::nullopt) const + -> std::function; [[nodiscard]] auto get_basename() const -> std::string const&; [[nodiscard]] static auto from_file(std::string const& file) -> IndexMetadata; }; diff --git a/include/pisa/v1/score_index.hpp b/include/pisa/v1/score_index.hpp index ea1ac580c..e19aa2b85 100644 --- a/include/pisa/v1/score_index.hpp +++ b/include/pisa/v1/score_index.hpp @@ -62,6 +62,9 @@ auto score_index(Index const& index, std::basic_ostream& os, Writer write } auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata; -auto bm_score_index(IndexMetadata meta, BlockType block_type, std::size_t threads) -> IndexMetadata; +auto bm_score_index(IndexMetadata meta, + BlockType block_type, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata; } // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 6cdbf414e..55a93a19c 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -6,19 +6,14 @@ #include #include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" #include "v1/cursor/transform.hpp" #include "v1/cursor_accumulator.hpp" #include "v1/query.hpp" +#include "v1/runtime_assert.hpp" namespace pisa::v1 { -inline void ensure(bool condition, char const* message) -{ - if (condition) { - throw std::invalid_argument(message); - } -} - /// This cursor operator takes a number of essential cursors (in an arbitrary order) /// and a list of lookup cursors. The documents traversed will be in the DaaT order, /// and the following documents will be skipped: @@ -247,7 +242,7 @@ auto unigram_union_lookup(Query const& query, } auto const& selections = query.get_selections(); - ensure(not selections.bigrams.empty(), "This algorithm only supports unigrams"); + runtime_assert(selections.bigrams.empty()).or_exit("This algorithm only supports unigrams"); topk.set_threshold(query.get_threshold()); @@ -529,8 +524,7 @@ auto union_lookup(Query const& query, topk.set_threshold(threshold); - std::array initial_payload{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + std::array initial_payload{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; if constexpr (not std::is_void_v) { inspect->essential(essential_unigrams.size() + essential_bigrams.size()); @@ -542,6 +536,7 @@ auto union_lookup(Query const& query, std::back_inserter(essential_unigram_cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); + /// TODO: remap according to max score instead of term sorted order std::vector unigram_query_positions(essential_unigrams.size()); for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); unigram_position += 1) { @@ -658,6 +653,225 @@ auto union_lookup(Query const& query, return topk; } +inline auto precompute_next_lookup(std::size_t essential_count, + std::size_t non_essential_count, + std::vector> essential_bigrams) +{ + runtime_assert(essential_count + non_essential_count <= 8).or_throw("Must be shorter than 9"); + std::uint32_t term_count = essential_count + non_essential_count; + std::vector next_lookup((term_count + 1) * (1U << term_count), -1); + auto unnecessary = [&](auto p, auto state) { + if (((1U << p) & state) > 0) { + return true; + } + for (auto k : essential_bigrams[p]) { + if (((1U << k) & state) > 0) { + return true; + } + } + return false; + }; + for (auto term_idx = essential_count; term_idx < term_count; term_idx += 1) { + for (std::uint32_t state = 0; state < (1U << term_count); state += 1) { + auto p = term_idx; + while (p < term_count && unnecessary(p, state)) { + ++p; + } + if (p == term_count) { + next_lookup[(term_idx << term_count) + state] = -1; + } else { + next_lookup[(term_idx << term_count) + state] = p; + } + } + } + return next_lookup; +} + +template +auto union_lookup_plus(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) +{ + using bigram_cursor_type = + LabeledCursor, + std::pair>; + + auto term_ids = gsl::make_span(query.get_term_ids()); + std::size_t term_count = term_ids.size(); + if (term_ids.empty()) { + return topk; + } + runtime_assert(term_ids.size() <= 8) + .or_throw("Generic version of union-Lookup supported only for queries of length <= 8"); + topk.set_threshold(query.get_threshold()); + auto const& selections = query.get_selections(); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; + + auto non_essential_terms = + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + std::array initial_payload{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + if constexpr (not std::is_void_v) { + inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + } + + auto essential_unigram_cursors = index.cursors(term_ids, [&](auto&& index, auto term) { + return label(index.scored_cursor(term, scorer), term); + }); + + auto lookup_cursors = + index.cursors(gsl::make_span(non_essential_terms), [&](auto&& index, auto term) { + return label(index.max_scored_cursor(term, scorer), term); + }); + std::sort(lookup_cursors.begin(), lookup_cursors.end(), [](auto&& lhs, auto&& rhs) { + return lhs.max_score() > rhs.max_score(); + }); + + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + acc[idx] = cursor.payload(); + return acc; + }); + + auto term_to_position = [&] { + std::unordered_map term_to_position; + std::uint32_t position = 0; + for (auto&& cursor : essential_unigram_cursors) { + term_to_position[cursor.label()] = position++; + } + for (auto&& cursor : lookup_cursors) { + term_to_position[cursor.label()] = position++; + } + return term_to_position; + }(); + + std::vector essential_bigram_cursors; + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + essential_bigram_cursors.push_back( + label(cursor.take().value(), + std::make_pair(term_to_position[left], term_to_position[right]))); + } + + auto merged_bigrams = v1::union_merge(std::move(essential_bigram_cursors), + initial_payload, + [&](auto& acc, auto& cursor, auto /* bigram_idx */) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + auto payload = cursor.payload(); + acc[cursor.label().first] = std::get<0>(payload); + acc[cursor.label().second] = std::get<1>(payload); + return acc; + }); + + auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { + auto payload = cursor.payload(); + for (auto idx = 0; idx < acc.size(); idx += 1) { + if (acc[idx] == 0.0F) { + acc[idx] = payload[idx]; + } + } + return acc; + }; + auto merged = v1::variadic_union_merge( + initial_payload, + std::make_tuple(std::move(merged_unigrams), std::move(merged_bigrams)), + std::make_tuple(accumulate, accumulate)); + + auto lookup_cursors_upper_bound = std::accumulate( + lookup_cursors.begin(), lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + + auto next_lookup = + precompute_next_lookup(essential_unigrams.size(), lookup_cursors.size(), [&] { + std::unordered_map> bigram_map; + for (auto [left, right] : essential_bigrams) { + bigram_map[left].push_back(right); + bigram_map[right].push_back(left); + } + std::uint32_t position = essential_unigrams.size(); + std::vector> mapping(term_ids.size()); + for (auto&& cursor : lookup_cursors) { + for (auto b : bigram_map[cursor.label()]) { + mapping[position].push_back(term_to_position[b]); + } + position++; + } + return mapping; + }()); + auto mus = [&] { + std::size_t term_count = term_ids.size(); + std::vector mus((term_count + 1) * (1U << term_count), 0.0); + for (auto term_idx = term_count; term_idx + 1 >= 1; term_idx -= 1) { + for (std::uint32_t state = (1U << term_count) - 1; state + 1 >= 1; state -= 1) { + auto nt = next_lookup[(term_idx << term_count) + state]; + if (nt == -1) { + mus[(term_idx << term_count) + state] = 0.0F; + } else { + mus[(term_idx << term_count) + state] = + std::max(lookup_cursors[term_idx - essential_unigrams.size()].max_score() + + mus[((term_idx + 1) << term_count) + (state | (1 << nt))], + mus[((term_idx + 1) << term_count) + state]); + } + } + } + return mus; + }(); + + v1::for_each(merged, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + } + auto docid = cursor.value(); + auto scores = cursor.payload(); + auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + + std::uint32_t state = essential_unigrams.size() << term_count; + for (auto pos = 0U; pos < scores.size(); pos += 1) { + if (score > 0) { + state |= 1U << pos; + } + } + + auto next_idx = next_lookup[state]; + while (next_idx >= 0 && topk.would_enter(score + mus[state])) { + auto lookup_idx = next_idx - essential_unigrams.size(); + assert(lookup_idx >= 0 && lookup_idx < lookup_cursors.size()); + auto&& lookup_cursor = lookup_cursors[lookup_idx]; + lookup_cursor.advance_to_geq(docid); + if constexpr (not std::is_void_v) { + inspect->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + score += lookup_cursor.payload(); + state |= (1U << next_idx); + } + state = (state & ((1U << term_count) - 1)) + ((next_idx + 1) << term_count); + next_idx = next_lookup[state]; + } + if constexpr (not std::is_void_v) { + if (topk.insert(score, docid)) { + inspect->insert(); + } + } else { + topk.insert(score, docid); + } + }); + return topk; +} + template struct BaseUnionLookupInspect { BaseUnionLookupInspect(Index const& index, Scorer scorer) diff --git a/script/cw09b-bp.sh b/script/cw09b-bp.sh new file mode 100644 index 000000000..03c9793ab --- /dev/null +++ b/script/cw09b-bp.sh @@ -0,0 +1,21 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv.bp" +FWD="/home/amallia/cw09b/CW09B.fwd.bp" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-bp/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-bp" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#PAIRS="/home/michal/real.aol.top50k.jl" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 + +# 134384 149026 1648 128376 4 # UL +# 149026 149026 1648 109040 4 # 15542 LU +# 2261867 2308897 1648 94280 3 # 44935 MS +# T = 12.985799789428713 diff --git a/script/cw09b-bpq.sh b/script/cw09b-bpq.sh new file mode 100644 index 000000000..03c9793ab --- /dev/null +++ b/script/cw09b-bpq.sh @@ -0,0 +1,21 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/home/amallia/cw09b/CW09B.inv.bp" +FWD="/home/amallia/cw09b/CW09B.fwd.bp" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-bp/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-bp" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +#PAIRS="/home/michal/real.aol.top50k.jl" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 + +# 134384 149026 1648 128376 4 # UL +# 149026 149026 1648 109040 4 # 15542 LU +# 2261867 2308897 1648 94280 3 # 44935 MS +# T = 12.985799789428713 diff --git a/script/cw09b-est.sh b/script/cw09b-est.sh index bae1ea935..0bffb773d 100644 --- a/script/cw09b-est.sh +++ b/script/cw09b-est.sh @@ -9,7 +9,7 @@ QUERIES="/home/michal/05.clean.shuf.test" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-est-lm" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" -PAIRS="/home/michal/real.aol.top6m.jl" +PAIRS="/home/michal/real.aol.top50k.jl" PAIR_INDEX_BASENAME="${BASENAME}-pair" THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" QUERY_LIMIT=100 diff --git a/script/cw09b-url.sh b/script/cw09b-url.sh new file mode 100644 index 000000000..4742252ce --- /dev/null +++ b/script/cw09b-url.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh index 07253225b..94b047efe 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -25,19 +25,20 @@ set -x mkdir -p ${OUTPUT_DIR} ## Compress an inverted index in `binary_freq_collection` format. -#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} +${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} # This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} +${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} # This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 +${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 +${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var # Filter out queries witout existing terms. -#paste -d: ${QUERIES} ${THRESHOLDS} \ -# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ -# | head -${QUERY_LIMIT} \ -# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml > ${FILTERED_QUERIES} +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | head -${QUERY_LIMIT} \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). ${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 @@ -50,49 +51,56 @@ ${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} # Select unigrams ${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 # Run benchmarks -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ -# > ${OUTPUT_DIR}/bench.bmw-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw +${PISA_BIN}/query -i "${BASENAME}-var.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.vbmw +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ + > ${OUTPUT_DIR}/bench.bmw-threshold +${PISA_BIN}/query -i "${BASENAME}-var.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ + > ${OUTPUT_DIR}/bench.vbmw-threshold ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ > ${OUTPUT_DIR}/bench.maxscore-threshold -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ -# > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ > ${OUTPUT_DIR}/bench.unigram-union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ -# > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ + > ${OUTPUT_DIR}/bench.union-lookup ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ > ${OUTPUT_DIR}/bench.lookup-union -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ -# --benchmark --algorithm lookup-union \ -# > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ -# --benchmark --algorithm lookup-union \ -# > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 # Analyze -# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ -# > ${OUTPUT_DIR}/stats.maxscore-threshold -# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ -# > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ > ${OUTPUT_DIR}/stats.unigram-union-lookup -# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ -# > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ > ${OUTPUT_DIR}/stats.lookup-union -# ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ -# --inspect --algorithm lookup-union \ -# > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-2 # Evaluate #${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" diff --git a/script/cw12-url.sh b/script/cw12-url.sh new file mode 100644 index 000000000..03767dd8d --- /dev/null +++ b/script/cw12-url.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/src/v1/index_metadata.cpp b/src/v1/index_metadata.cpp index 08b6863a8..a4c07982e 100644 --- a/src/v1/index_metadata.cpp +++ b/src/v1/index_metadata.cpp @@ -186,11 +186,19 @@ void IndexMetadata::write(std::string const& file) const fout << root; } -[[nodiscard]] auto IndexMetadata::query_parser() const -> std::function +[[nodiscard]] auto IndexMetadata::query_parser(tl::optional const& stop_words) const + -> std::function { if (term_lexicon) { - auto term_processor = - ::pisa::TermProcessor(*term_lexicon, {}, [&]() -> std::optional { + auto term_processor = ::pisa::TermProcessor( + *term_lexicon, + [&]() -> std::optional { + if (stop_words) { + return *stop_words; + } + return std::nullopt; + }(), + [&]() -> std::optional { if (stemmer) { return *stemmer; } diff --git a/src/v1/score_index.cpp b/src/v1/score_index.cpp index fd5ac039f..161d2d26e 100644 --- a/src/v1/score_index.cpp +++ b/src/v1/score_index.cpp @@ -93,11 +93,14 @@ auto score_index(IndexMetadata meta, std::size_t threads) -> IndexMetadata } // TODO: Use multiple threads -auto bm_score_index(IndexMetadata meta, BlockType block_type, std::size_t threads) -> IndexMetadata +auto bm_score_index(IndexMetadata meta, + BlockType block_type, + tl::optional const& clone_path, + std::size_t threads) -> IndexMetadata { auto run = index_runner(meta); - auto const& index_basename = meta.get_basename(); - auto prefix = fmt::format("{}.bm25_block_max", meta.get_basename()); + std::string index_basename = clone_path.value_or(std::string(meta.get_basename())); + auto prefix = fmt::format("{}.bm25_block_max", index_basename); UnigramFilePaths paths{ .documents = PostingFilePaths{.postings = fmt::format("{}_documents", prefix), .offsets = fmt::format("{}_document_offsets", prefix)}, @@ -152,7 +155,11 @@ auto bm_score_index(IndexMetadata meta, BlockType block_type, std::size_t thread write_span(gsl::make_span(score_builder.offsets()), paths.payloads.offsets); }); meta.block_max_scores["bm25"] = paths; - meta.update(); + if (clone_path) { + meta.write(append_extension(clone_path.value())); + } else { + meta.update(); + } return meta; } diff --git a/test/v1/index_fixture.hpp b/test/v1/index_fixture.hpp index 17519bdc2..2e22e15a5 100644 --- a/test/v1/index_fixture.hpp +++ b/test/v1/index_fixture.hpp @@ -75,7 +75,7 @@ struct IndexFixture { meta = v1::score_index(v1::IndexMetadata::from_file(yml), 1); } if (bm_score) { - meta = v1::bm_score_index(meta, pisa::v1::FixedBlock{5}, 1); + meta = v1::bm_score_index(meta, pisa::v1::FixedBlock{5}, tl::nullopt, 1); } if (build_bigrams) { v1::build_pair_index( @@ -85,7 +85,7 @@ struct IndexFixture { void rebuild_bm_scores(pisa::v1::BlockType block_type) { - v1::bm_score_index(meta(), block_type, 1); + v1::bm_score_index(meta(), block_type, tl::nullopt, 1); } [[nodiscard]] auto const& tmpdir() const { return *m_tmpdir; } diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index e3aa1eff5..007863454 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -148,18 +148,19 @@ TEMPLATE_TEST_CASE("Query", Index, RawCursor>>::get(); TestType fixture; auto input_data = GENERATE(table({ - {"daat_or", false, false}, - {"maxscore", false, false}, - {"maxscore", true, false}, - {"wand", false, false}, - {"wand", true, false}, - {"bmw", false, false}, - {"bmw", true, false}, - {"bmw", false, true}, - {"bmw", true, true}, - {"maxscore_union_lookup", true, false}, - {"unigram_union_lookup", true, false}, - {"union_lookup", true, false}, + //{"daat_or", false, false}, + //{"maxscore", false, false}, + //{"maxscore", true, false}, + //{"wand", false, false}, + //{"wand", true, false}, + //{"bmw", false, false}, + //{"bmw", true, false}, + //{"bmw", false, true}, + //{"bmw", true, true}, + //{"maxscore_union_lookup", true, false}, + //{"unigram_union_lookup", true, false}, + //{"union_lookup", true, false}, + {"union_lookup_plus", true, false}, {"lookup_union", true, false}, })); std::string algorithm = std::get<0>(input_data); @@ -200,6 +201,12 @@ TEMPLATE_TEST_CASE("Query", } return union_lookup(query, index, ::pisa::topk_queue(10), scorer); } + if (name == "union_lookup_plus") { + if (query.get_term_ids().size() > 8) { + return maxscore_union_lookup(query, index, ::pisa::topk_queue(10), scorer); + } + return union_lookup_plus(query, index, ::pisa::topk_queue(10), scorer); + } if (name == "lookup_union") { return lookup_union(query, index, ::pisa::topk_queue(10), scorer); } @@ -210,7 +217,8 @@ TEMPLATE_TEST_CASE("Query", pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); for (auto& query : test_queries()) { heap.clear(); - if (algorithm == "union_lookup" || algorithm == "lookup_union") { + if (algorithm == "union_lookup" || algorithm == "union_lookup_plus" + || algorithm == "lookup_union") { query.selections(gsl::make_span(intersections[idx])); } diff --git a/v1/app.hpp b/v1/app.hpp index 958ea2b36..0824fa76c 100644 --- a/v1/app.hpp +++ b/v1/app.hpp @@ -50,6 +50,8 @@ namespace arg { app->add_flag("--force-parse", m_force_parse, "Force parsing of query string even ifterm IDs already available"); + app->add_option( + "--stopwords", m_stop_words, "List of blacklisted stop words to filter out"); } [[nodiscard]] auto query_file() -> tl::optional @@ -63,7 +65,7 @@ namespace arg { [[nodiscard]] auto queries(v1::IndexMetadata const& meta) const -> std::vector { std::vector queries; - auto parser = meta.query_parser(); + auto parser = meta.query_parser(m_stop_words); auto parse_line = [&](auto&& line) { auto query = [&line, this]() { if (m_query_input_format == "jl") { @@ -100,7 +102,7 @@ namespace arg { return ranges::views::transform(lines, [force_parse = m_force_parse, k = m_k, - parser = meta.query_parser(), + parser = meta.query_parser(m_stop_words), qfmt = m_query_input_format](auto&& line) { auto query = [&]() { if (qfmt == "jl") { @@ -131,6 +133,7 @@ namespace arg { std::string m_query_input_format = "jl"; int m_k = DefaultK; bool m_force_parse{false}; + tl::optional m_stop_words{tl::nullopt}; }; struct Benchmark { diff --git a/v1/bmscore.cpp b/v1/bmscore.cpp index 9f5c868f7..78e184969 100644 --- a/v1/bmscore.cpp +++ b/v1/bmscore.cpp @@ -19,10 +19,16 @@ int main(int argc, char** argv) std::size_t block_size = 128; std::size_t threads = std::thread::hardware_concurrency(); std::optional lambda{}; + tl::optional clone_path{}; App app{"Constructs block-max score lists for v1 index."}; app.add_option("--block-size", block_size, "The size of a block for max scores", true); app.add_option("--variable-blocks", lambda, "The size of a block for max scores", false); + app.add_option( + "--clone", + clone_path, + "Clone .yml metadata to another path, and then score (won't affect the initial index)", + false); CLI11_PARSE(app, argc, argv); auto block_type = [&]() -> BlockType { @@ -32,6 +38,6 @@ int main(int argc, char** argv) return FixedBlock{block_size}; }(); - pisa::v1::bm_score_index(app.index_metadata(), block_type, app.threads()); + pisa::v1::bm_score_index(app.index_metadata(), block_type, clone_path, app.threads()); return 0; } diff --git a/v1/filter_queries.cpp b/v1/filter_queries.cpp index 4b00dad6f..444ee43ad 100644 --- a/v1/filter_queries.cpp +++ b/v1/filter_queries.cpp @@ -15,15 +15,17 @@ int main(int argc, char** argv) spdlog::set_default_logger(spdlog::stderr_color_mt("")); std::size_t min_query_len = 1; + std::size_t max_query_len = std::numeric_limits::max(); pisa::App> app( "Filters out empty queries against a v1 index."); app.add_option("--min", min_query_len, "Minimum query legth to consider"); + app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); for (auto&& query : app.query_range(meta)) { - if (query.get_term_ids().size() >= min_query_len) { + if (auto len = query.get_term_ids().size(); len >= min_query_len && len <= max_query_len) { std::cout << *query.to_json() << '\n'; } } From fde4d9843363444b55ba496554fe5c26aeaa4820 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sat, 25 Jan 2020 16:00:52 +0000 Subject: [PATCH 51/56] Scripts and tweaks --- include/pisa/v1/union_lookup.hpp | 82 ++++++++++++--------- script/cw09b-url-10.sh | 15 ++++ script/cw09b-url-100.sh | 15 ++++ script/cw09b-url-10000.sh | 15 ++++ script/cw09b-url-bi-trec06.sh | 15 ++++ script/cw09b-url-bi.sh | 15 ++++ script/cw09b-url-trec06-2.sh | 15 ++++ script/cw09b-url-trec06.sh | 15 ++++ script/cw09b-url.sh | 2 +- script/cw09b.sh | 109 +++++++++++++++------------ script/cw12-url-bi-trec06.sh | 15 ++++ script/cw12-url-bi.sh | 15 ++++ script/cw12-url-trec06.sh | 15 ++++ script/cw12-url.sh | 4 +- test/v1/test_v1_queries.cpp | 31 ++++---- v1/CMakeLists.txt | 3 + v1/id_to_term.cpp | 44 +++++++++++ v1/intersection.cpp | 123 +++++++++++++++++++++++++------ v1/query.cpp | 26 +++++++ v1/term_to_id.cpp | 44 +++++++++++ 20 files changed, 499 insertions(+), 119 deletions(-) create mode 100644 script/cw09b-url-10.sh create mode 100644 script/cw09b-url-100.sh create mode 100644 script/cw09b-url-10000.sh create mode 100644 script/cw09b-url-bi-trec06.sh create mode 100644 script/cw09b-url-bi.sh create mode 100644 script/cw09b-url-trec06-2.sh create mode 100644 script/cw09b-url-trec06.sh create mode 100644 script/cw12-url-bi-trec06.sh create mode 100644 script/cw12-url-bi.sh create mode 100644 script/cw12-url-trec06.sh create mode 100644 v1/id_to_term.cpp create mode 100644 v1/term_to_id.cpp diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 55a93a19c..b3f399aa5 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -536,7 +536,6 @@ auto union_lookup(Query const& query, std::back_inserter(essential_unigram_cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); - /// TODO: remap according to max score instead of term sorted order std::vector unigram_query_positions(essential_unigrams.size()); for (std::size_t unigram_position = 0; unigram_position < essential_unigrams.size(); unigram_position += 1) { @@ -655,7 +654,7 @@ auto union_lookup(Query const& query, inline auto precompute_next_lookup(std::size_t essential_count, std::size_t non_essential_count, - std::vector> essential_bigrams) + std::vector> const& essential_bigrams) { runtime_assert(essential_count + non_essential_count <= 8).or_throw("Must be shorter than 9"); std::uint32_t term_count = essential_count + non_essential_count; @@ -719,9 +718,10 @@ auto union_lookup_plus(Query const& query, inspect->essential(essential_unigrams.size() + essential_bigrams.size()); } - auto essential_unigram_cursors = index.cursors(term_ids, [&](auto&& index, auto term) { - return label(index.scored_cursor(term, scorer), term); - }); + auto essential_unigram_cursors = + index.cursors(essential_unigrams, [&](auto&& index, auto term) { + return label(index.scored_cursor(term, scorer), term); + }); auto lookup_cursors = index.cursors(gsl::make_span(non_essential_terms), [&](auto&& index, auto term) { @@ -731,15 +731,6 @@ auto union_lookup_plus(Query const& query, return lhs.max_score() > rhs.max_score(); }); - auto merged_unigrams = v1::union_merge( - essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto idx) { - if constexpr (not std::is_void_v) { - inspect->posting(); - } - acc[idx] = cursor.payload(); - return acc; - }); - auto term_to_position = [&] { std::unordered_map term_to_position; std::uint32_t position = 0; @@ -752,6 +743,15 @@ auto union_lookup_plus(Query const& query, return term_to_position; }(); + auto merged_unigrams = v1::union_merge( + essential_unigram_cursors, initial_payload, [&](auto& acc, auto& cursor, auto idx) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } + acc[idx] = cursor.payload(); + return acc; + }); + std::vector essential_bigram_cursors; for (auto [left, right] : essential_bigrams) { auto cursor = index.scored_bigram_cursor(left, right, scorer); @@ -796,34 +796,27 @@ auto union_lookup_plus(Query const& query, auto next_lookup = precompute_next_lookup(essential_unigrams.size(), lookup_cursors.size(), [&] { - std::unordered_map> bigram_map; - for (auto [left, right] : essential_bigrams) { - bigram_map[left].push_back(right); - bigram_map[right].push_back(left); - } - std::uint32_t position = essential_unigrams.size(); std::vector> mapping(term_ids.size()); - for (auto&& cursor : lookup_cursors) { - for (auto b : bigram_map[cursor.label()]) { - mapping[position].push_back(term_to_position[b]); - } - position++; + for (auto&& cursor : essential_bigram_cursors) { + auto [left, right] = cursor.label(); + mapping[left].push_back(right); + mapping[right].push_back(left); } return mapping; }()); auto mus = [&] { - std::size_t term_count = term_ids.size(); std::vector mus((term_count + 1) * (1U << term_count), 0.0); for (auto term_idx = term_count; term_idx + 1 >= 1; term_idx -= 1) { - for (std::uint32_t state = (1U << term_count) - 1; state + 1 >= 1; state -= 1) { - auto nt = next_lookup[(term_idx << term_count) + state]; + for (std::uint32_t j = (1U << term_count) - 1; j + 1 >= 1; j -= 1) { + auto state = (term_idx << term_count) + j; + auto nt = next_lookup[state]; if (nt == -1) { - mus[(term_idx << term_count) + state] = 0.0F; + mus[state] = 0.0F; } else { - mus[(term_idx << term_count) + state] = - std::max(lookup_cursors[term_idx - essential_unigrams.size()].max_score() - + mus[((term_idx + 1) << term_count) + (state | (1 << nt))], - mus[((term_idx + 1) << term_count) + state]); + auto a = lookup_cursors[nt - essential_unigrams.size()].max_score() + + mus[((nt + 1) << term_count) + (j | (1 << nt))]; + auto b = mus[((term_idx + 1) << term_count) + j]; + mus[state] = std::max(a, b); } } } @@ -839,12 +832,13 @@ auto union_lookup_plus(Query const& query, auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); std::uint32_t state = essential_unigrams.size() << term_count; - for (auto pos = 0U; pos < scores.size(); pos += 1) { - if (score > 0) { + for (auto pos = 0U; pos < term_count; pos += 1) { + if (scores[pos] > 0) { state |= 1U << pos; } } + assert(state >= 0 && state < next_lookup.size()); auto next_idx = next_lookup[state]; while (next_idx >= 0 && topk.would_enter(score + mus[state])) { auto lookup_idx = next_idx - essential_unigrams.size(); @@ -993,6 +987,24 @@ struct UnionLookupInspect : public BaseUnionLookupInspect { } }; +template +struct UnionLookupPlusInspect : public BaseUnionLookupInspect { + UnionLookupPlusInspect(Index const& index, Scorer scorer) + : BaseUnionLookupInspect(index, std::move(scorer)) + { + } + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } else if (query.get_term_ids().size() > 8) { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } else { + union_lookup_plus(query, index, std::move(topk), scorer, this); + } + } +}; + template struct LookupUnionInspector : public BaseUnionLookupInspect { LookupUnionInspector(Index const& index, Scorer scorer) diff --git a/script/cw09b-url-10.sh b/script/cw09b-url-10.sh new file mode 100644 index 000000000..bfda430ba --- /dev/null +++ b/script/cw09b-url-10.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=10 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-10-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top2.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-100.sh b/script/cw09b-url-100.sh new file mode 100644 index 000000000..43c4d7d94 --- /dev/null +++ b/script/cw09b-url-100.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=100 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-100-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top5.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-10000.sh b/script/cw09b-url-10000.sh new file mode 100644 index 000000000..c9f2613d8 --- /dev/null +++ b/script/cw09b-url-10000.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=10000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-10k-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top130.bm25.05.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-bi-trec06.sh b/script/cw09b-url-bi-trec06.sh new file mode 100644 index 000000000..ba686627d --- /dev/null +++ b/script/cw09b-url-bi-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-bi-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06-2" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-bi.sh b/script/cw09b-url-bi.sh new file mode 100644 index 000000000..39dab34c5 --- /dev/null +++ b/script/cw09b-url-bi.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS="/home/michal/real.aol.top100k.jl" +PAIR_INDEX_BASENAME="/data/michal/work/v1/cw09b-url/cw09b-simdbp-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw09b-url-trec06-2.sh b/script/cw09b-url-trec06-2.sh new file mode 100644 index 000000000..e9f39d84c --- /dev/null +++ b/script/cw09b-url-trec06-2.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-trec06-2" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06-2" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw09b-url-trec06.sh b/script/cw09b-url-trec06.sh new file mode 100644 index 000000000..e3a0ed67c --- /dev/null +++ b/script/cw09b-url-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW09B/CW09B.url.inv" +FWD="/data/CW09B/CW09B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.val" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw09b-url-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.val" +QUERY_LIMIT=5000 diff --git a/script/cw09b-url.sh b/script/cw09b-url.sh index 4742252ce..7dbc48ea3 100644 --- a/script/cw09b-url.sh +++ b/script/cw09b-url.sh @@ -10,6 +10,6 @@ K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-url" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" PAIRS=${FILTERED_QUERIES} -PAIR_INDEX_BASENAME="${BASENAME}-pair" +PAIR_INDEX_BASENAME="${BASENAME}-pair-2" THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh index 94b047efe..b2f55f9bf 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -25,14 +25,14 @@ set -x mkdir -p ${OUTPUT_DIR} ## Compress an inverted index in `binary_freq_collection` format. -${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} +#${PISA_BIN}/compress -i ${BINARY_FREQ_COLL} --fwd ${FWD} -o ${BASENAME} -j ${THREADS} -e ${ENCODING} # This will produce both quantized scores and max scores (both quantized and not). -${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} +#${PISA_BIN}/score -i "${BASENAME}.yml" -j ${THREADS} # This will produce both quantized scores and max scores (both quantized and not). -${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 -${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --block-size 128 +#${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var # Filter out queries witout existing terms. paste -d: ${QUERIES} ${THRESHOLDS} \ @@ -41,64 +41,79 @@ paste -d: ${QUERIES} ${THRESHOLDS} \ | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). -${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 +#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 # Extract intersections -${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ +#${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl +${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ + --in /home/michal/real.aol.top1m.tsv \ | grep -v "\[warning\]" \ > ${OUTPUT_DIR}/intersections.jl # Select unigrams ${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 ${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ - --scale 1.5 > ${OUTPUT_DIR}/selections.2.scaled-1.5 -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl \ - --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time +${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact # Run benchmarks -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw -${PISA_BIN}/query -i "${BASENAME}-var.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.vbmw -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe > ${OUTPUT_DIR}/bench.maxscore -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ - > ${OUTPUT_DIR}/bench.bmw-threshold -${PISA_BIN}/query -i "${BASENAME}-var.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ - > ${OUTPUT_DIR}/bench.vbmw-threshold -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe \ - > ${OUTPUT_DIR}/bench.maxscore-threshold -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ - > ${OUTPUT_DIR}/bench.maxscore-union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe \ - > ${OUTPUT_DIR}/bench.unigram-union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ - > ${OUTPUT_DIR}/bench.union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ - > ${OUTPUT_DIR}/bench.lookup-union -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ - --benchmark --algorithm lookup-union \ - > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.bmw +#${PISA_BIN}/query -i "${BASENAME}-var.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm bmw --safe > ${OUTPUT_DIR}/bench.vbmw +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +# > ${OUTPUT_DIR}/bench.bmw-threshold +#${PISA_BIN}/query -i "${BASENAME}-var.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ +# > ${OUTPUT_DIR}/bench.vbmw-threshold + +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe -k ${K} > ${OUTPUT_DIR}/bench.maxscore +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe -k ${K} \ +# > ${OUTPUT_DIR}/bench.maxscore-threshold +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ +# > ${OUTPUT_DIR}/bench.maxscore-union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe -k ${K} \ +# > ${OUTPUT_DIR}/bench.unigram-union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ +# > ${OUTPUT_DIR}/bench.union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ +# > ${OUTPUT_DIR}/bench.lookup-union +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup-plus --safe \ +# > ${OUTPUT_DIR}/bench.plus +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ +# --benchmark --algorithm lookup-union \ +# > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ +# --benchmark --algorithm union-lookup-plus -k ${K} \ +# > ${OUTPUT_DIR}/bench.plus.scaled-2 ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ - --benchmark --algorithm lookup-union \ + --benchmark --algorithm lookup-union -k ${K} \ > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 # Analyze -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore > ${OUTPUT_DIR}/stats.maxscore -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore \ - > ${OUTPUT_DIR}/stats.maxscore-threshold -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ - > ${OUTPUT_DIR}/stats.maxscore-union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup \ - > ${OUTPUT_DIR}/stats.unigram-union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ - > ${OUTPUT_DIR}/stats.union-lookup -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ - > ${OUTPUT_DIR}/stats.lookup-union -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ - --inspect --algorithm lookup-union \ - > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 \ +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore -k ${K} > ${OUTPUT_DIR}/stats.maxscore +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore -k ${K} \ +# > ${OUTPUT_DIR}/stats.maxscore-threshold +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ +# > ${OUTPUT_DIR}/stats.maxscore-union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup -k ${K} \ +# > ${OUTPUT_DIR}/stats.unigram-union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ +# > ${OUTPUT_DIR}/stats.union-lookup +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup-plus \ +# > ${OUTPUT_DIR}/stats.plus +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ +# > ${OUTPUT_DIR}/stats.lookup-union +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ +# --inspect --algorithm lookup-union \ +# > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ +# --inspect --algorithm union-lookup-plus \ +# > ${OUTPUT_DIR}/stats.plus.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ --inspect --algorithm lookup-union \ > ${OUTPUT_DIR}/stats.lookup-union.scaled-2 diff --git a/script/cw12-url-bi-trec06.sh b/script/cw12-url-bi-trec06.sh new file mode 100644 index 000000000..23489dacd --- /dev/null +++ b/script/cw12-url-bi-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-bi-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw12-url-bi.sh b/script/cw12-url-bi.sh new file mode 100644 index 000000000..7604b025f --- /dev/null +++ b/script/cw12-url-bi.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/05.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-bi" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS="/home/michal/real.aol.top100k.jl" +PAIR_INDEX_BASENAME="/data/michal/work/v1/cw12b-url/cw09b-simdbp-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.05.clean.shuf.test" +QUERY_LIMIT=5000 diff --git a/script/cw12-url-trec06.sh b/script/cw12-url-trec06.sh new file mode 100644 index 000000000..04eee7b25 --- /dev/null +++ b/script/cw12-url-trec06.sh @@ -0,0 +1,15 @@ +PISA_BIN="/home/michal/pisa/build/bin" +INTERSECT_BIN="/home/michal/intersect/target/release/intersect" +BINARY_FREQ_COLL="/data/CW12B/CW12B.url.inv" +FWD="/data/CW12B/CW12B.url.fwd" +ENCODING="simdbp" # v1 +BASENAME="/data/michal/work/v1/cw12b-url/cw09b-${ENCODING}" +THREADS=4 +QUERIES="/home/michal/06.clean.shuf.test" +K=1000 +OUTPUT_DIR="/data/michal/intersect/cw12b-url-trec06" +FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" +PAIRS=${FILTERED_QUERIES} +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.06.clean.shuf.test" +QUERY_LIMIT=1000 diff --git a/script/cw12-url.sh b/script/cw12-url.sh index 03767dd8d..e0a5e20fe 100644 --- a/script/cw12-url.sh +++ b/script/cw12-url.sh @@ -10,6 +10,6 @@ K=1000 OUTPUT_DIR="/data/michal/intersect/cw12b-url" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" PAIRS=${FILTERED_QUERIES} -PAIR_INDEX_BASENAME="${BASENAME}-pair" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +PAIR_INDEX_BASENAME="${BASENAME}-pair-trec05" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw12b/thresholds.cw12b.0_01.top20.bm25.05.clean.shuf.test" QUERY_LIMIT=5000 diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index 007863454..b68246485 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -148,18 +148,18 @@ TEMPLATE_TEST_CASE("Query", Index, RawCursor>>::get(); TestType fixture; auto input_data = GENERATE(table({ - //{"daat_or", false, false}, - //{"maxscore", false, false}, - //{"maxscore", true, false}, - //{"wand", false, false}, - //{"wand", true, false}, - //{"bmw", false, false}, - //{"bmw", true, false}, - //{"bmw", false, true}, - //{"bmw", true, true}, - //{"maxscore_union_lookup", true, false}, - //{"unigram_union_lookup", true, false}, - //{"union_lookup", true, false}, + {"daat_or", false, false}, + {"maxscore", false, false}, + {"maxscore", true, false}, + {"wand", false, false}, + {"wand", true, false}, + {"bmw", false, false}, + {"bmw", true, false}, + {"bmw", false, true}, + {"bmw", true, true}, + {"maxscore_union_lookup", true, false}, + {"unigram_union_lookup", true, false}, + {"union_lookup", true, false}, {"union_lookup_plus", true, false}, {"lookup_union", true, false}, })); @@ -226,6 +226,11 @@ TEMPLATE_TEST_CASE("Query", CAPTURE(idx); CAPTURE(intersections[idx]); + // if (query.get_selections().unigrams.empty() or idx < 8) { + // idx += 1; + // continue; + //} + or_q( make_scored_cursors(data->v0_index, ::pisa::bm25<::pisa::wand_data<::pisa::wand_data_raw>>(data->wdata), @@ -261,7 +266,7 @@ TEMPLATE_TEST_CASE("Query", expected.resize(on_the_fly.size()); std::sort(expected.begin(), expected.end(), approximate_order); - // if (algorithm == "bmw") { + // if (algorithm == "union_lookup_plus") { // for (size_t i = 0; i < on_the_fly.size(); ++i) { // std::cerr << fmt::format("{} {} -- {} {}\n", // on_the_fly[i].second, diff --git a/v1/CMakeLists.txt b/v1/CMakeLists.txt index a96dcc423..6bb640447 100644 --- a/v1/CMakeLists.txt +++ b/v1/CMakeLists.txt @@ -33,3 +33,6 @@ target_link_libraries(count-postings pisa CLI11) add_executable(stats stats.cpp) target_link_libraries(stats pisa CLI11) + +add_executable(id-to-term id_to_term.cpp) +target_link_libraries(id-to-term pisa CLI11) diff --git a/v1/id_to_term.cpp b/v1/id_to_term.cpp new file mode 100644 index 000000000..a13bf36e5 --- /dev/null +++ b/v1/id_to_term.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "payload_vector.hpp" + +namespace arg = pisa::arg; +using pisa::Payload_Vector; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Each ID from input translated to term"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + if (not meta.term_lexicon) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(meta.term_lexicon.value().c_str()); + auto lex = pisa::Payload_Vector<>::from(*source); + + pisa::io::for_each_line(std::cin, [&](auto&& line) { + std::istringstream is(line); + std::string next; + if (is >> next) { + std::cout << lex[std::stoi(next)]; + } + while (is >> next) { + std::cout << fmt::format(" {}", lex[std::stoi(next)]); + } + std::cout << '\n'; + }); + + return 0; +} diff --git a/v1/intersection.cpp b/v1/intersection.cpp index 7834000d5..27a80e89d 100644 --- a/v1/intersection.cpp +++ b/v1/intersection.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -27,6 +28,7 @@ using pisa::intersection::Mask; using pisa::v1::intersect; using pisa::v1::make_bm25; using pisa::v1::RawReader; +using pisa::v1::TermId; namespace arg = pisa::arg; namespace v1 = pisa::v1; @@ -78,7 +80,8 @@ void compute_intersections(Index const& index, QRng queries, IntersectionType intersection_type, tl::optional max_term_count, - bool existing) + bool existing, + tl::optional>> const& in_set) { for (auto const& query : queries) { auto intersections = nlohmann::json::array(); @@ -92,7 +95,7 @@ void compute_intersections(Index const& index, } }; if (intersection_type == IntersectionType::Combinations) { - if (existing) { + if (in_set) { std::uint64_t left_mask = 1; auto term_ids = query.get_term_ids(); for (auto left = 0; left < term_ids.size(); left += 1) { @@ -102,23 +105,87 @@ void compute_intersections(Index const& index, {"max_score", cursor.max_score()}}); std::uint64_t right_mask = left_mask << 1U; for (auto right = left + 1; right < term_ids.size(); right += 1) { - index - .scored_bigram_cursor(term_ids[left], term_ids[right], make_bm25(index)) - .map([&](auto&& cursor) { - // TODO(michal): Do not traverse once max scores for bigrams are - // implemented - auto cost = cursor.size(); - auto max_score = - accumulate(cursor, 0.0F, [](float acc, auto&& cursor) { - auto score = std::get<0>(cursor.payload()) - + std::get<1>(cursor.payload()); - return std::max(acc, score); - }); - intersections.push_back( - nlohmann::json{{"intersection", left_mask | right_mask}, - {"cost", cost}, - {"max_score", max_score}}); + if (auto bid = index.bigram_id(term_ids[left], term_ids[right]); bid) { + if (in_set->find(std::make_pair(term_ids[left], term_ids[right])) + == in_set->end()) { + continue; + } + std::vector const terms{term_ids[left], term_ids[right]}; + auto cursors = + index.scored_cursors(gsl::make_span(terms), make_bm25(index)); + auto intersection = + intersect(cursors, + 0.0F, + [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } }); + intersections.push_back( + nlohmann::json{{"intersection", left_mask | right_mask}, + {"cost", postings}, + {"max_score", max_score}}); + } + right_mask <<= 1U; + } + left_mask <<= 1U; + } + } else if (existing) { + std::uint64_t left_mask = 1; + auto term_ids = query.get_term_ids(); + for (auto left = 0; left < term_ids.size(); left += 1) { + auto cursor = index.max_scored_cursor(term_ids[left], make_bm25(index)); + intersections.push_back(nlohmann::json{{"intersection", left_mask}, + {"cost", cursor.size()}, + {"max_score", cursor.max_score()}}); + std::uint64_t right_mask = left_mask << 1U; + for (auto right = left + 1; right < term_ids.size(); right += 1) { + if (auto bid = index.bigram_id(term_ids[left], term_ids[right]); bid) { + std::vector const terms{term_ids[left], term_ids[right]}; + auto cursors = + index.scored_cursors(gsl::make_span(terms), make_bm25(index)); + auto intersection = + intersect(cursors, + 0.0F, + [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }); + std::size_t postings = 0; + float max_score = 0.0F; + v1::for_each(intersection, [&](auto& cursor) { + postings += 1; + if (auto score = cursor.payload(); score > max_score) { + max_score = score; + } + }); + intersections.push_back( + nlohmann::json{{"intersection", left_mask | right_mask}, + {"cost", postings}, + {"max_score", max_score}}); + } + // index + // .scored_bigram_cursor(term_ids[left], term_ids[right], + // make_bm25(index)) .map([&](auto&& cursor) { + // // TODO(michal): Do not traverse once max scores for bigrams are + // // implemented + // auto cost = cursor.size(); + // auto max_score = + // accumulate(cursor, 0.0F, [](float acc, auto&& cursor) { + // auto score = std::get<0>(cursor.payload()) + // + std::get<1>(cursor.payload()); + // return std::max(acc, score); + // }); + // intersections.push_back( + // nlohmann::json{{"intersection", left_mask | right_mask}, + // {"cost", cost}, + // {"max_score", max_score}}); + // }); right_mask <<= 1U; } left_mask <<= 1U; @@ -143,6 +210,7 @@ int main(int argc, const char** argv) bool combinations = false; bool existing = false; std::optional max_term_count; + std::optional in_set_path; pisa::App> app( "Calculates intersections for a v1 index."); @@ -152,11 +220,24 @@ int main(int argc, const char** argv) max_term_count, "Max number of terms when computing combinations"); mtc_flag->needs(combinations_flag); - app.add_flag("--existing", existing, "Use only existing bigrams") + auto* existing_flag = app.add_flag("--existing", existing, "Use only existing bigrams") + ->needs(combinations_flag) + ->excludes(mtc_flag); + app.add_option("--in", in_set_path, "Use only bigrams from this list") ->needs(combinations_flag) - ->excludes(mtc_flag); + ->excludes(mtc_flag) + ->excludes(existing_flag); CLI11_PARSE(app, argc, argv); auto mtc = max_term_count ? tl::make_optional(*max_term_count) : tl::optional{}; + tl::optional>> in_set{}; + if (in_set_path) { + in_set = std::set>{}; + std::ifstream is(*in_set_path); + std::string left, right; + while (is >> left >> right) { + in_set->emplace(std::stoi(left), std::stoi(right)); + } + } IntersectionType intersection_type = combinations ? IntersectionType::Combinations : IntersectionType::Query; @@ -167,7 +248,7 @@ int main(int argc, const char** argv) auto run = index_runner(meta); run([&](auto&& index) { - compute_intersections(index, queries, intersection_type, mtc, existing); + compute_intersections(index, queries, intersection_type, mtc, existing, in_set); }); } catch (std::exception const& error) { spdlog::error("{}", error.what()); diff --git a/v1/query.cpp b/v1/query.cpp index 9ff210a2e..8eb198220 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -41,7 +41,9 @@ using pisa::v1::RawReader; using pisa::v1::unigram_union_lookup; using pisa::v1::UnigramUnionLookupInspect; using pisa::v1::union_lookup; +using pisa::v1::union_lookup_plus; using pisa::v1::UnionLookupInspect; +using pisa::v1::UnionLookupPlusInspect; using pisa::v1::VoidScorer; using pisa::v1::wand; @@ -155,10 +157,31 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco fallback, safe); } + if (name == "union-lookup-plus") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + if (query.get_term_ids().size() >= 8) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::union_lookup_plus( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } if (name == "lookup-union") { return RetrievalAlgorithm( [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { if (query.selections()->bigrams.empty()) { + if (query.selections()->unigrams.empty()) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } return pisa::v1::unigram_union_lookup( query, index, std::move(topk), std::forward(scorer)); } @@ -195,6 +218,9 @@ auto resolve_inspect(std::string const& name, Index const& index, Scorer&& score if (name == "lookup-union") { return QueryInspector(LookupUnionInspector>(index, scorer)); } + if (name == "union-lookup-plus") { + return QueryInspector(UnionLookupPlusInspect>(index, scorer)); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } diff --git a/v1/term_to_id.cpp b/v1/term_to_id.cpp new file mode 100644 index 000000000..adb96924f --- /dev/null +++ b/v1/term_to_id.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include +#include +#include + +#include "app.hpp" +#include "io.hpp" +#include "payload_vector.hpp" + +namespace arg = pisa::arg; +using pisa::Payload_Vector; + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + pisa::App app("Each term from input translated to ID"); + CLI11_PARSE(app, argc, argv); + + auto meta = app.index_metadata(); + if (not meta.term_lexicon) { + spdlog::error("Term lexicon not defined"); + std::exit(1); + } + auto source = std::make_shared(meta.term_lexicon.value().c_str()); + auto lex = pisa::Payload_Vector<>::from(*source); + + pisa::io::for_each_line(std::cin, [&](auto&& line) { + std::istringstream is(line); + std::string next; + if (is >> next) { + std::cout << lex[std::stoi(next)]; + } + while (is >> next) { + std::cout << fmt::format(" {}", lex[std::stoi(next)]); + } + std::cout << '\n'; + }); + + return 0; +} From f1470463fb153e2dc26cd0410d2b9e945689f439 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 31 Jan 2020 20:02:59 +0000 Subject: [PATCH 52/56] Script update --- script/cw09b-url-trec06.sh | 4 ++-- script/cw09b.sh | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/script/cw09b-url-trec06.sh b/script/cw09b-url-trec06.sh index e3a0ed67c..aa1491e25 100644 --- a/script/cw09b-url-trec06.sh +++ b/script/cw09b-url-trec06.sh @@ -5,11 +5,11 @@ FWD="/data/CW09B/CW09B.url.fwd" ENCODING="simdbp" # v1 BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" THREADS=4 -QUERIES="/home/michal/06.clean.shuf.val" +QUERIES="/home/michal/06.clean.shuf.test" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-url-trec06" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" PAIRS=${FILTERED_QUERIES} PAIR_INDEX_BASENAME="${BASENAME}-pair-trec06" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.val" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.06.clean.shuf.test" QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh index b2f55f9bf..90a4cf0f4 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -35,10 +35,10 @@ mkdir -p ${OUTPUT_DIR} #${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var # Filter out queries witout existing terms. -paste -d: ${QUERIES} ${THRESHOLDS} \ - | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ - | head -${QUERY_LIMIT} \ - | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} +#paste -d: ${QUERIES} ${THRESHOLDS} \ +# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ +# | head -${QUERY_LIMIT} \ +# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). #${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 @@ -47,19 +47,19 @@ paste -d: ${QUERIES} ${THRESHOLDS} \ #${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ # | grep -v "\[warning\]" \ # > ${OUTPUT_DIR}/intersections.jl -${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ - --in /home/michal/real.aol.top1m.tsv \ - | grep -v "\[warning\]" \ - > ${OUTPUT_DIR}/intersections.jl +#${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ +# --in /home/michal/real.aol.top1m.tsv \ +# | grep -v "\[warning\]" \ +# > ${OUTPUT_DIR}/intersections.jl # Select unigrams ${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 ${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -${INTERSECT_BIN} -m bigram ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time -${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time +#${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact # Run benchmarks #${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm wand --safe > ${OUTPUT_DIR}/bench.wand From 43acf123a2b1a1f712005936cb9ab5075167dd99 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 3 Feb 2020 22:33:58 +0000 Subject: [PATCH 53/56] Expand LookupUnion stats --- CMakeLists.txt | 4 +- include/pisa/v1/union_lookup.hpp | 163 +++++++++++++++++++++++++++++-- v1/query.cpp | 4 +- 3 files changed, 158 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b72fa3fb3..89ed0244e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,8 +65,8 @@ if (UNIX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ggdb -gdwarf") # Add debug info anyway - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") - #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -Wfatal-errors") if (${FORCE_COLORED_OUTPUT}) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index b3f399aa5..20bb7889f 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -396,12 +396,16 @@ struct LookupTransform { }; /// This algorithm... -template +template auto lookup_union(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer, - Inspect* inspect = nullptr) + InspectUnigram* inspect_unigram = nullptr, + InspectBigram* inspect_bigram = nullptr) { using bigram_cursor_type = std::decay_t; using lookup_cursor_type = std::decay_t; @@ -419,8 +423,11 @@ auto lookup_union(Query const& query, auto& essential_unigrams = selections.unigrams; auto& essential_bigrams = selections.bigrams; - if constexpr (not std::is_void_v) { - inspect->essential(essential_unigrams.size() + essential_bigrams.size()); + if constexpr (not std::is_void_v) { + inspect_unigram->essential(essential_unigrams.size()); + } + if constexpr (not std::is_void_v) { + inspect_bigram->essential(essential_bigrams.size()); } auto non_essential_terms = @@ -438,13 +445,13 @@ auto lookup_union(Query const& query, 0.0F, accumulators::Add{}, is_above_threshold, - inspect); + inspect_unigram); }(); using lookup_transform_type = LookupTransform; + InspectBigram>; using transform_payload_cursor_type = TransformPayloadCursor; @@ -471,7 +478,7 @@ auto lookup_union(Query const& query, lookup_transform_type(std::move(lookup_cursors), lookup_cursors_upper_bound, is_above_threshold, - inspect)); + inspect_bigram)); } auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { @@ -484,9 +491,9 @@ auto lookup_union(Query const& query, std::make_tuple(accumulate, accumulate)); v1::for_each(merged, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { + if constexpr (not std::is_void_v) { if (topk.insert(cursor.payload(), cursor.value())) { - inspect->insert(); + inspect_unigram->insert(); } } else { topk.insert(cursor.payload(), cursor.value()); @@ -1005,6 +1012,56 @@ struct UnionLookupPlusInspect : public BaseUnionLookupInspect { } }; +struct ComponentInspect { + void reset_current() + { + m_documents += m_current_documents; + m_postings += m_current_postings; + m_lookups += m_current_lookups; + m_inserts += m_current_inserts; + m_essential_lists += m_current_essential_lists; + m_count += 1; + + m_current_documents = 0; + m_current_postings = 0; + m_current_lookups = 0; + m_current_inserts = 0; + m_essential_lists = 0; + } + + void document() { m_current_documents += 1; } + void posting() { m_current_postings += 1; } + void lookup() { m_current_lookups += 1; } + void insert() { m_current_inserts += 1; } + void essential(std::size_t n) { m_current_essential_lists = n; } + + [[nodiscard]] auto current_documents() const { return m_current_documents; } + [[nodiscard]] auto current_postings() const { return m_current_postings; } + [[nodiscard]] auto current_lookups() const { return m_current_lookups; } + [[nodiscard]] auto current_inserts() const { return m_current_inserts; } + [[nodiscard]] auto current_essential_lists() const { return m_current_essential_lists; } + + [[nodiscard]] auto documents() const { return m_documents; } + [[nodiscard]] auto postings() const { return m_postings; } + [[nodiscard]] auto lookups() const { return m_lookups; } + [[nodiscard]] auto inserts() const { return m_inserts; } + [[nodiscard]] auto essential_lists() const { return m_essential_lists; } + + private: + std::size_t m_current_documents = 0; + std::size_t m_current_postings = 0; + std::size_t m_current_lookups = 0; + std::size_t m_current_inserts = 0; + std::size_t m_current_essential_lists = 0; + + std::size_t m_documents = 0; + std::size_t m_postings = 0; + std::size_t m_lookups = 0; + std::size_t m_inserts = 0; + std::size_t m_count = 0; + std::size_t m_essential_lists = 0; +}; + template struct LookupUnionInspector : public BaseUnionLookupInspect { LookupUnionInspector(Index const& index, Scorer scorer) @@ -1021,4 +1078,92 @@ struct LookupUnionInspector : public BaseUnionLookupInspect { } }; +template +struct LookupUnionInspect { + LookupUnionInspect(Index const& index, Scorer scorer) + : m_index(index), m_scorer(std::move(scorer)) + { + std::cout << fmt::format( + "documents\tpostings\tinserts\tlookups\tessential_lists\t" + "postings-uni\tlookups-uni\tpostings-bi\tlookups-bi\n"); + } + + void reset_current() + { + m_unigram_inspect.reset_current(); + m_bigram_inspect.reset_current(); + m_count += 1; + } + + void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) + { + if (query.selections()->bigrams.empty()) { + unigram_union_lookup(query, index, std::move(topk), scorer, &m_unigram_inspect); + } else { + lookup_union( + query, index, std::move(topk), scorer, &m_unigram_inspect, &m_bigram_inspect); + } + } + + void operator()(Query const& query) + { + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return; + } + using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); + using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); + + run(query, m_index, m_scorer, topk_queue(query.k())); + std::cout << fmt::format( + "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n", + m_unigram_inspect.current_documents() + m_bigram_inspect.current_documents(), + m_unigram_inspect.current_postings() + m_bigram_inspect.current_postings(), + m_unigram_inspect.current_inserts() + m_bigram_inspect.current_inserts(), + m_unigram_inspect.current_lookups() + m_bigram_inspect.current_lookups(), + m_unigram_inspect.current_essential_lists() + + m_bigram_inspect.current_essential_lists(), + m_unigram_inspect.current_postings(), + m_unigram_inspect.current_lookups(), + m_bigram_inspect.current_postings(), + m_bigram_inspect.current_lookups()); + reset_current(); + } + + void summarize() && + { + auto documents = m_unigram_inspect.documents() + m_bigram_inspect.documents(); + auto postings = m_unigram_inspect.postings() + m_bigram_inspect.postings(); + auto inserts = m_unigram_inspect.inserts() + m_bigram_inspect.inserts(); + auto lookups = m_unigram_inspect.lookups() + m_bigram_inspect.lookups(); + auto essential_lists = + m_unigram_inspect.essential_lists() + m_bigram_inspect.essential_lists(); + auto uni_postings = m_unigram_inspect.postings(); + auto uni_lookups = m_unigram_inspect.lookups(); + auto bi_postings = m_bigram_inspect.postings(); + auto bi_lookups = m_bigram_inspect.lookups(); + std::cerr << fmt::format( + "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" + "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n- essential lists:\t{}\n" + "- uni-postings:\t{}\n- uni-lookups:\t{}\n" + "- bi-postings:\t{}\n- bi-lookups:\t{}\n", + static_cast(documents) / m_count, + static_cast(postings) / m_count, + static_cast(inserts) / m_count, + static_cast(lookups) / m_count, + static_cast(essential_lists) / m_count, + static_cast(uni_postings) / m_count, + static_cast(uni_lookups) / m_count, + static_cast(bi_postings) / m_count, + static_cast(bi_lookups) / m_count); + } + + private: + ComponentInspect m_unigram_inspect; + ComponentInspect m_bigram_inspect; + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + } // namespace pisa::v1 diff --git a/v1/query.cpp b/v1/query.cpp index 8eb198220..46e8e407d 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -30,7 +30,7 @@ using pisa::v1::DaatOrInspector; using pisa::v1::DocumentBlockedReader; using pisa::v1::index_runner; using pisa::v1::lookup_union; -using pisa::v1::LookupUnionInspector; +using pisa::v1::LookupUnionInspect; using pisa::v1::maxscore_union_lookup; using pisa::v1::MaxscoreInspector; using pisa::v1::MaxscoreUnionLookupInspect; @@ -216,7 +216,7 @@ auto resolve_inspect(std::string const& name, Index const& index, Scorer&& score return QueryInspector(UnionLookupInspect>(index, scorer)); } if (name == "lookup-union") { - return QueryInspector(LookupUnionInspector>(index, scorer)); + return QueryInspector(LookupUnionInspect>(index, scorer)); } if (name == "union-lookup-plus") { return QueryInspector(UnionLookupPlusInspect>(index, scorer)); From 5e853de2d3719b2c8fa3fb3287e93abc64a52ed6 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 17 Feb 2020 16:24:07 +0000 Subject: [PATCH 54/56] Refactor and test query inspection --- include/pisa/v1/cursor/lookup_transform.hpp | 61 ++ include/pisa/v1/cursor/reference.hpp | 45 + include/pisa/v1/cursor/transform.hpp | 10 + include/pisa/v1/daat_or.hpp | 76 +- include/pisa/v1/inspect_query.hpp | 362 +++++++- include/pisa/v1/maxscore.hpp | 119 +-- include/pisa/v1/maxscore_union_lookup.hpp | 97 +++ include/pisa/v1/query.hpp | 1 - include/pisa/v1/unigram_union_lookup.hpp | 101 +++ include/pisa/v1/union_lookup.hpp | 864 ++++++-------------- include/pisa/v1/union_lookup_join.hpp | 277 +++++++ script/cw09b-url.sh | 7 +- script/cw09b.sh | 99 ++- test/v1/test_union_lookup_join.cpp | 143 ++++ test/v1/test_v1_queries.cpp | 35 +- test/v1/test_v1_union_lookup.cpp | 155 ++++ v1/query.cpp | 69 +- 17 files changed, 1675 insertions(+), 846 deletions(-) create mode 100644 include/pisa/v1/cursor/lookup_transform.hpp create mode 100644 include/pisa/v1/cursor/reference.hpp create mode 100644 include/pisa/v1/maxscore_union_lookup.hpp create mode 100644 include/pisa/v1/unigram_union_lookup.hpp create mode 100644 include/pisa/v1/union_lookup_join.hpp create mode 100644 test/v1/test_union_lookup_join.cpp create mode 100644 test/v1/test_v1_union_lookup.cpp diff --git a/include/pisa/v1/cursor/lookup_transform.hpp b/include/pisa/v1/cursor/lookup_transform.hpp new file mode 100644 index 000000000..33ff24cbb --- /dev/null +++ b/include/pisa/v1/cursor/lookup_transform.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include + +namespace pisa::v1 { + +/// **Note**: This currently works only for pair cursor with single-term lookup cursors. +/// This callable transforms a cursor by performing lookups to the current document +/// in the given lookup cursors, and then adding the scores that were found. +/// It uses the same short-circuiting rules before each lookup as `UnionLookupJoin`. +template +struct LookupTransform { + + LookupTransform(std::vector lookup_cursors, + float lookup_cursors_upper_bound, + AboveThresholdFn above_threshold, + Inspector* inspect = nullptr) + : m_lookup_cursors(std::move(lookup_cursors)), + m_lookup_cursors_upper_bound(lookup_cursors_upper_bound), + m_above_threshold(std::move(above_threshold)), + m_inspect(inspect) + { + } + + template + auto operator()(Cursor& cursor) + { + auto docid = cursor.value(); + auto scores = cursor.payload(); + float score = std::get<0>(scores) + std::get<1>(scores); + auto upper_bound = score + m_lookup_cursors_upper_bound; + for (auto&& lookup_cursor : m_lookup_cursors) { + //if (docid == 2288) { + // std::cout << fmt::format("[checking] doc: {}\tbound: {}\n", docid, upper_bound); + //} + if (not m_above_threshold(upper_bound)) { + return score; + } + lookup_cursor.advance_to_geq(docid); + //std::cout << fmt::format("[b] doc: {}\tbound: {}\n", docid, upper_bound); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + auto partial_score = lookup_cursor.payload(); + score += partial_score; + upper_bound += partial_score; + } + upper_bound -= lookup_cursor.max_score(); + } + return score; + } + + private: + std::vector m_lookup_cursors; + float m_lookup_cursors_upper_bound; + AboveThresholdFn m_above_threshold; + Inspector* m_inspect; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/reference.hpp b/include/pisa/v1/cursor/reference.hpp new file mode 100644 index 000000000..398e8e387 --- /dev/null +++ b/include/pisa/v1/cursor/reference.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include + +#include "v1/cursor_traits.hpp" + +namespace pisa::v1 { + +template +struct CursorRef { + using Value = typename CursorTraits>::Value; + + constexpr CursorRef(Cursor&& cursor) : m_cursor(std::ref(cursor)) {} + [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } + [[nodiscard]] constexpr auto value() const noexcept -> Value { return m_cursor.get().value(); } + [[nodiscard]] constexpr auto payload() { return m_cursor.get().payload(); } + constexpr void advance() { m_cursor.get().advance(); } + constexpr void advance_to_geq(Value value) { m_cursor.get().advance_to_geq(value); } + constexpr void advance_to_position(std::size_t pos) { m_cursor.get().advance_to_position(pos); } + [[nodiscard]] constexpr auto empty() const noexcept -> bool { return m_cursor.get().empty(); } + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t + { + return m_cursor.get().position(); + } + [[nodiscard]] constexpr auto size() const -> std::size_t { return m_cursor.get().size(); } + [[nodiscard]] constexpr auto max_score() const { return m_cursor.get().max_score(); } + + private: + std::reference_wrapper> m_cursor; +}; + +template +auto ref(Cursor&& cursor) +{ + return CursorRef(std::forward(cursor)); +} + +template +struct CursorTraits> { + using Value = typename CursorTraits::Value; +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/cursor/transform.hpp b/include/pisa/v1/cursor/transform.hpp index b53deae8e..f19688981 100644 --- a/include/pisa/v1/cursor/transform.hpp +++ b/include/pisa/v1/cursor/transform.hpp @@ -15,6 +15,11 @@ struct TransformCursor { : m_cursor(std::move(cursor)), m_transform(std::move(transform)) { } + TransformCursor(TransformCursor&&) noexcept = default; + TransformCursor(TransformCursor const&) = default; + TransformCursor& operator=(TransformCursor&&) noexcept = default; + TransformCursor& operator=(TransformCursor const&) = default; + ~TransformCursor() = default; [[nodiscard]] constexpr auto operator*() const -> Value { return value(); } [[nodiscard]] constexpr auto value() const noexcept -> Value @@ -47,6 +52,11 @@ struct TransformPayloadCursor { : m_cursor(std::move(cursor)), m_transform(std::move(transform)) { } + TransformPayloadCursor(TransformPayloadCursor&&) noexcept = default; + TransformPayloadCursor(TransformPayloadCursor const&) = default; + TransformPayloadCursor& operator=(TransformPayloadCursor&&) noexcept = default; + TransformPayloadCursor& operator=(TransformPayloadCursor const&) = default; + ~TransformPayloadCursor() = default; [[nodiscard]] constexpr auto operator*() const { return value(); } [[nodiscard]] constexpr auto value() const noexcept { return m_cursor.value(); } diff --git a/include/pisa/v1/daat_or.hpp b/include/pisa/v1/daat_or.hpp index 9c394044c..e9e919b16 100644 --- a/include/pisa/v1/daat_or.hpp +++ b/include/pisa/v1/daat_or.hpp @@ -4,12 +4,17 @@ #include "topk_queue.hpp" #include "v1/cursor/for_each.hpp" +#include "v1/inspect_query.hpp" #include "v1/query.hpp" namespace pisa::v1 { -template -auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +template +auto daat_or(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) { std::vector cursors; std::transform(query.get_term_ids().begin(), @@ -17,67 +22,38 @@ auto daat_or(Query const& query, Index const& index, topk_queue topk, Scorer&& s std::back_inserter(cursors), [&](auto term) { return index.scored_cursor(term, scorer); }); auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [](auto& score, auto& cursor, auto /* term_idx */) { + std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { + if constexpr (not std::is_void_v) { + inspect->posting(); + } score += cursor.payload(); return score; }); - v1::for_each(cunion, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); + v1::for_each(cunion, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + inspect->document(); + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); return topk; } template -struct DaatOrInspector { - DaatOrInspector(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) - { - std::cout << fmt::format("documents\tpostings\n"); - } +struct InspectDaatOr : Inspect { - void operator()(Query const& query) + InspectDaatOr(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) { - std::vector cursors; - std::transform(query.get_term_ids().begin(), - query.get_term_ids().end(), - std::back_inserter(cursors), - [&](auto term) { return m_index.scored_cursor(term, m_scorer); }); - std::size_t postings = 0; - auto cunion = v1::union_merge( - std::move(cursors), 0.0F, [&](auto& score, auto& cursor, auto /* term_idx */) { - postings += 1; - score += cursor.payload(); - return score; - }); - std::size_t documents = 0; - std::size_t inserts = 0; - topk_queue topk(query.k()); - v1::for_each(cunion, [&](auto& cursor) { - if (topk.insert(cursor.payload(), cursor.value())) { - inserts += 1; - }; - documents += 1; - }); - std::cout << fmt::format("{}\t{}\t{}\n", documents, postings, inserts); - m_documents += documents; - m_postings += postings; - m_inserts += inserts; - m_count += 1; } - void summarize() && + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n- postings:\t{}\n- inserts:\t{}\n", - static_cast(m_documents) / m_count, - static_cast(m_postings) / m_count, - static_cast(m_inserts) / m_count); + daat_or(query, index, std::move(topk), scorer, this); } - - private: - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - Index const& m_index; - Scorer m_scorer; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/inspect_query.hpp b/include/pisa/v1/inspect_query.hpp index bdf7b236f..31e3a4c5c 100644 --- a/include/pisa/v1/inspect_query.hpp +++ b/include/pisa/v1/inspect_query.hpp @@ -1,33 +1,360 @@ #pragma once #include +#include #include #include +#include "topk_queue.hpp" +#include "v1/query.hpp" + namespace pisa::v1 { -struct Query; +template +auto write_delimited(std::ostream& os, + std::string_view sep, + FirstValue first_value, + Values&&... values) -> std::ostream&; + +template +auto write_delimited(std::ostream& os, [[maybe_unused]] std::string_view sep, Value value) + -> std::ostream& +{ + return os << value; +} + +template +auto write_delimited(std::ostream& os, std::string_view sep, std::tuple t) -> std::ostream& +{ + std::apply([&](auto... vals) { write_delimited(os, sep, vals...); }, t); + return os; +} + +template +auto write_delimited(std::ostream& os, std::string_view sep, std::pair p) -> std::ostream& +{ + write_delimited(os, sep, p.first); + os << sep; + return write_delimited(os, sep, p.second); +} + +/// Writes a list of values into a stream separated by `sep`. +template +auto write_delimited(std::ostream& os, + std::string_view sep, + FirstValue first_value, + Values&&... values) -> std::ostream& +{ + write_delimited(os, sep, first_value); + ( + [&] { + os << sep; + write_delimited(os, sep, values); + }(), + ...); + return os; +} + +struct InspectCount { + using value_type = std::size_t; + void reset() { m_current_count = 0; } + void inc(std::size_t n = 1) + { + m_current_count += n; + m_total_count += n; + } + [[nodiscard]] auto get() const -> std::size_t { return m_current_count; } + [[nodiscard]] auto mean(std::size_t n) const -> float + { + return static_cast(m_total_count) / n; + } + + struct Result { + explicit Result(std::size_t value) : m_value(value) {} + [[nodiscard]] auto get() const { return m_value; } + [[nodiscard]] auto operator+(Result const& other) const { return m_value + other.m_value; } + + private: + std::size_t m_value; + }; + + private: + std::size_t m_current_count = 0; + std::size_t m_total_count = 0; +}; + +struct InspectPostings : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto postings() const { return get(); } + }; + void posting() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("postings{}", suffix); } +}; + +struct InspectDocuments : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto documents() const { return get(); } + }; + void document() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("documents{}", suffix); } +}; + +struct InspectLookups : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto lookups() const { return get(); } + }; + void lookup() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("lookups{}", suffix); } +}; + +struct InspectInserts : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto inserts() const { return get(); } + }; + void insert() { inc(); } + static auto header(std::string_view suffix) { return fmt::format("inserts{}", suffix); } +}; + +struct InspectEssential : InspectCount { + struct Result : InspectCount::Result { + explicit Result(std::size_t value) : InspectCount::Result(value) {} + [[nodiscard]] auto essentials() const { return get(); } + }; + void essential(std::size_t n) { inc(n); } + static auto header(std::string_view suffix) { return fmt::format("essential-terms{}", suffix); } +}; + +template +struct InspectPartitioned { + struct Result { + explicit Result(typename Inspect::Result first, typename Inspect::Result second) + : first(first), second(second), sum(first + second) + { + } + [[nodiscard]] auto get() const + { + return std::make_tuple(sum.get(), first.get(), second.get()); + } + + typename Inspect::Result first; + typename Inspect::Result second; + typename Inspect::Result sum; + }; + using value_type = Result; + + void reset() + { + m_components.first.reset(); + m_components.second.reset(); + } + void inc(std::size_t n = 1) + { + m_components.first.inc(n); + m_components.second.inc(n); + } + [[nodiscard]] auto get() const + { + return Result(m_components.first.get(), m_components.second.get()); + } + [[nodiscard]] auto mean(std::size_t n) const + { + return Result(m_components.first.mean(n), m_components.second.mean(n)); + } + [[nodiscard]] auto first() -> Inspect* { return &m_components.first; } + [[nodiscard]] auto second() -> Inspect* { return &m_components.second; } + static auto header(std::string_view suffix) + { + return std::make_tuple(Inspect::header(fmt::format("{}", suffix)), + Inspect::header(fmt::format("{}_1", suffix)), + Inspect::header(fmt::format("{}_2", suffix))); + } + + private: + std::pair m_components; +}; + +template +struct InspectPair { + struct Result { + explicit Result(typename First::Result first, typename Second::Result second) + : first(first), second(second) + { + } + [[nodiscard]] auto get() const { return std::make_pair(first.get(), second.get()); } + + typename First::Result first; + typename Second::Result second; + }; + using value_type = Result; + + void reset() + { + m_components.first.reset(); + m_components.second.reset(); + } + void inc(std::size_t n = 1) + { + m_components.first.inc(n); + m_components.second.inc(n); + } + [[nodiscard]] auto get() const + { + return Result(m_components.first.get(), m_components.second.get()); + } + [[nodiscard]] auto mean(std::size_t n) const + { + return Result(m_components.first.mean(n), m_components.second.mean(n)); + } + [[nodiscard]] auto first() -> First* { return &m_components.first; } + [[nodiscard]] auto second() -> Second* { return &m_components.second; } + static auto header(std::string_view suffix) + { + return std::make_pair(First::header(fmt::format("{}_1", suffix)), + Second::header(fmt::format("{}_2", suffix))); + } + + private: + std::pair m_components; +}; + +template +struct InspectMany : Stat... { + struct Result : Stat::Result... { + explicit Result(typename Stat::value_type... values) : Stat::Result(values)... {} + Result(Result const& result) = default; + Result(Result&& result) noexcept = default; + Result& operator=(Result const& result) = default; + Result& operator=(Result&& result) noexcept = default; + ~Result() = default; + [[nodiscard]] auto get() const { return std::make_tuple(Stat::Result::get()...); } + [[nodiscard]] auto operator+(Result const& other) const + { + return Result((Stat::Result::get() + other.Stat::Result::get())...); + } + }; + using value_type = Result; + void reset() { (Stat::reset(), ...); } + void inc(std::size_t n = 1) { (Stat::inc(n), ...); } + [[nodiscard]] auto get() const { return Result(Stat::get()...); } + [[nodiscard]] auto mean(std::size_t n) const { return Result(Stat::mean(n)...); } + static auto header(std::string_view suffix) { return std::make_tuple(Stat::header(suffix)...); } +}; + +template +struct Inspect : Stat... { + Inspect(Index const& index, Scorer scorer) : m_index(index), m_scorer(std::move(scorer)) {} + + virtual void run(Query const& query, + Index const& index, + Scorer const& scorer, + topk_queue topk) = 0; + + [[nodiscard]] auto mean() const { return InspectResult(Stat::mean(m_count)...); } + + struct InspectResult : Stat::Result... { + explicit InspectResult(typename Stat::value_type... values) : Stat::Result(values)... {} + std::ostream& write(std::ostream& os, std::string_view sep = "\t") + { + return write_delimited(os, sep, Stat::Result::get()...); + } + }; + + [[nodiscard]] auto operator()(Query const& query) -> InspectResult + { + (Stat::reset(), ...); + run(query, m_index, m_scorer, topk_queue(query.k())); + m_count += 1; + return InspectResult(Stat::get()...); + } + static std::ostream& header(std::ostream& os, std::string_view sep = "\t") + { + return write_delimited(os, sep, Stat::header("")...); + } + + private: + std::size_t m_count = 0; + Index const& m_index; + Scorer m_scorer; +}; + +struct InspectResult { + + template + explicit constexpr InspectResult(R result) : m_inner(std::make_unique>(result)) + { + } + InspectResult() = default; + InspectResult(InspectResult const& other) : m_inner(other.m_inner->clone()) {} + InspectResult(InspectResult&& other) noexcept = default; + InspectResult& operator=(InspectResult const& other) = delete; + InspectResult& operator=(InspectResult&& other) noexcept = default; + ~InspectResult() = default; + + std::ostream& write(std::ostream& os, std::string_view sep = "\t") + { + return m_inner->write(os, sep); + } + + struct ResultInterface { + ResultInterface() = default; + ResultInterface(ResultInterface const&) = default; + ResultInterface(ResultInterface&&) noexcept = default; + ResultInterface& operator=(ResultInterface const&) = default; + ResultInterface& operator=(ResultInterface&&) noexcept = default; + virtual ~ResultInterface() = default; + virtual std::ostream& write(std::ostream& os, std::string_view sep) = 0; + [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; + }; + + template + struct ResultImpl : ResultInterface { + explicit ResultImpl(R result) : m_result(std::move(result)) {} + ResultImpl() = default; + ResultImpl(ResultImpl const&) = default; + ResultImpl(ResultImpl&&) noexcept = default; + ResultImpl& operator=(ResultImpl const&) = default; + ResultImpl& operator=(ResultImpl&&) noexcept = default; + ~ResultImpl() override = default; + std::ostream& write(std::ostream& os, std::string_view sep) override + { + return m_result.write(os, sep); + } + [[nodiscard]] auto clone() const -> std::unique_ptr override + { + auto copy = *this; + return std::make_unique>(std::move(copy)); + } + + private: + R m_result; + }; + + private: + std::unique_ptr m_inner; +}; struct QueryInspector { template explicit constexpr QueryInspector(R writer) - : m_internal_analyzer(std::make_unique>(writer)) + : m_inner(std::make_unique>(writer)) { } QueryInspector() = default; - QueryInspector(QueryInspector const& other) - : m_internal_analyzer(other.m_internal_analyzer->clone()) - { - } + QueryInspector(QueryInspector const& other) : m_inner(other.m_inner->clone()) {} QueryInspector(QueryInspector&& other) noexcept = default; QueryInspector& operator=(QueryInspector const& other) = delete; QueryInspector& operator=(QueryInspector&& other) noexcept = default; ~QueryInspector() = default; - void operator()(Query const& query) { m_internal_analyzer->operator()(query); } - void summarize() && { std::move(*m_internal_analyzer).summarize(); } + InspectResult operator()(Query const& query) { return m_inner->operator()(query); } + InspectResult mean() { return m_inner->mean(); } + std::ostream& header(std::ostream& os) { return m_inner->header(os); } struct InspectorInterface { InspectorInterface() = default; @@ -36,34 +363,39 @@ struct QueryInspector { InspectorInterface& operator=(InspectorInterface const&) = default; InspectorInterface& operator=(InspectorInterface&&) noexcept = default; virtual ~InspectorInterface() = default; - virtual void operator()(Query const& query) = 0; - virtual void summarize() && = 0; + virtual InspectResult operator()(Query const& query) = 0; + virtual InspectResult mean() = 0; + virtual std::ostream& header(std::ostream& os) = 0; [[nodiscard]] virtual auto clone() const -> std::unique_ptr = 0; }; template struct InspectorImpl : InspectorInterface { - explicit InspectorImpl(R analyzer) : m_analyzer(std::move(analyzer)) {} + explicit InspectorImpl(R inspect) : m_inspect(std::move(inspect)) {} InspectorImpl() = default; InspectorImpl(InspectorImpl const&) = default; InspectorImpl(InspectorImpl&&) noexcept = default; InspectorImpl& operator=(InspectorImpl const&) = default; InspectorImpl& operator=(InspectorImpl&&) noexcept = default; ~InspectorImpl() override = default; - void operator()(Query const& query) override { m_analyzer(query); } - void summarize() && override { std::move(m_analyzer).summarize(); } + InspectResult operator()(Query const& query) override + { + return InspectResult(m_inspect(query)); + } + InspectResult mean() override { return InspectResult(m_inspect.mean()); } [[nodiscard]] auto clone() const -> std::unique_ptr override { auto copy = *this; return std::make_unique>(std::move(copy)); } + std::ostream& header(std::ostream& os) override { return m_inspect.header(os); } private: - R m_analyzer; + R m_inspect; }; private: - std::unique_ptr m_internal_analyzer; + std::unique_ptr m_inner; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore.hpp b/include/pisa/v1/maxscore.hpp index dc1fecf77..5902832bd 100644 --- a/include/pisa/v1/maxscore.hpp +++ b/include/pisa/v1/maxscore.hpp @@ -9,6 +9,7 @@ #include "v1/algorithm.hpp" #include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" #include "v1/query.hpp" namespace pisa::v1 { @@ -194,8 +195,12 @@ auto join_maxscore(CursorContainer cursors, std::move(cursors), std::move(init), std::move(accumulate), std::move(threshold), inspect); } -template -auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& scorer) +template +auto maxscore(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + Inspect* inspect = nullptr) { auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { @@ -208,94 +213,48 @@ auto maxscore(Query const& query, Index const& index, topk_queue topk, Scorer&& if (query.threshold()) { topk.set_threshold(*query.threshold()); } - auto joined = join_maxscore(std::move(cursors), 0.0F, accumulators::Add{}, [&](auto score) { - return topk.would_enter(score); + auto joined = join_maxscore( + std::move(cursors), + 0.0F, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } }); - v1::for_each(joined, [&](auto& cursor) { topk.insert(cursor.payload(), cursor.value()); }); return topk; } template -struct MaxscoreInspector { - MaxscoreInspector(Index const& index, Scorer scorer) - : m_index(index), m_scorer(std::move(scorer)) +struct InspectMaxScore : Inspect { + + InspectMaxScore(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) { - std::cout << fmt::format("documents\tpostings\tinserts\tlookups\n"); } - void operator()(Query const& query) + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return; - } - using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); - using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); - - m_current_documents = 0; - m_current_postings = 0; - m_current_lookups = 0; - - std::vector cursors; - std::transform(term_ids.begin(), - term_ids.end(), - std::back_inserter(cursors), - [&](auto term) { return m_index.max_scored_cursor(term, m_scorer); }); - - std::size_t inserts = 0; - topk_queue topk(query.k()); - auto initial_threshold = query.threshold().value_or(-1.0); - topk.set_threshold(initial_threshold); - auto joined = join_maxscore( - std::move(cursors), - 0.0F, - accumulators::Add{}, - [&](auto score) { return topk.would_enter(score); }, - this); - v1::for_each(joined, [&](auto& cursor) { - if (topk.insert(cursor.payload(), cursor.value())) { - inserts += 1; - }; - }); - std::cout << fmt::format("{}\t{}\t{}\t{}\n", - m_current_documents, - m_current_postings, - inserts, - m_current_lookups); - m_documents += m_current_documents; - m_postings += m_current_postings; - m_lookups += m_current_lookups; - m_inserts += inserts; - m_count += 1; - } - - void summarize() && - { - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" - "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n", - static_cast(m_documents) / m_count, - static_cast(m_postings) / m_count, - static_cast(m_inserts) / m_count, - static_cast(m_lookups) / m_count); + maxscore(query, index, std::move(topk), scorer, this); } - - void document() { m_current_documents += 1; } - void posting() { m_current_postings += 1; } - void lookup() { m_current_lookups += 1; } - - private: - std::size_t m_current_documents = 0; - std::size_t m_current_postings = 0; - std::size_t m_current_lookups = 0; - - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_lookups = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - Index const& m_index; - Scorer m_scorer; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/maxscore_union_lookup.hpp b/include/pisa/v1/maxscore_union_lookup.hpp new file mode 100644 index 000000000..a4539ba99 --- /dev/null +++ b/include/pisa/v1/maxscore_union_lookup.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/reference.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/union_lookup_join.hpp" + +namespace pisa::v1 { + +/// This is a special case of Union-Lookup algorithm that does not use user-defined selections, +/// but rather uses the same way of determining essential list as Maxscore does. +/// The difference is that this algorithm will never update the threshold whereas Maxscore will +/// try to improve the estimate after each accumulated document. +template +auto maxscore_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + [[maybe_unused]] Inspect* inspect = nullptr) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); + + auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); + auto [non_essential, essential] = maxscore_partition(gsl::make_span(cursors), threshold); + + std::vector essential_cursors; + std::move(essential.begin(), essential.end(), std::back_inserter(essential_cursors)); + std::vector lookup_cursors; + std::move(non_essential.begin(), non_essential.end(), std::back_inserter(lookup_cursors)); + std::reverse(lookup_cursors.begin(), lookup_cursors.end()); + + auto joined = join_union_lookup( + std::move(essential_cursors), + std::move(lookup_cursors), + payload_type{}, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectMaxScoreUnionLookup : Inspect { + + InspectMaxScoreUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + maxscore_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/query.hpp b/include/pisa/v1/query.hpp index 27f5bb2e2..c1db38fd7 100644 --- a/include/pisa/v1/query.hpp +++ b/include/pisa/v1/query.hpp @@ -14,7 +14,6 @@ #include "v1/cursor/for_each.hpp" #include "v1/cursor_intersection.hpp" #include "v1/cursor_union.hpp" -#include "v1/inspect_query.hpp" #include "v1/intersection.hpp" #include "v1/types.hpp" diff --git a/include/pisa/v1/unigram_union_lookup.hpp b/include/pisa/v1/unigram_union_lookup.hpp new file mode 100644 index 000000000..98cdfcada --- /dev/null +++ b/include/pisa/v1/unigram_union_lookup.hpp @@ -0,0 +1,101 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "v1/algorithm.hpp" +#include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/reference.hpp" +#include "v1/cursor/transform.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/inspect_query.hpp" +#include "v1/query.hpp" +#include "v1/runtime_assert.hpp" +#include "v1/union_lookup_join.hpp" + +namespace pisa::v1 { + +/// Processes documents with the Union-Lookup method. +/// This is an optimized version that works **only on single-term posting lists**. +/// It will throw an exception if bigram selections are passed to it. +template +auto unigram_union_lookup(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + [[maybe_unused]] Inspect* inspect = nullptr) +{ + using cursor_type = decltype(index.max_scored_cursor(0, scorer)); + using payload_type = decltype(std::declval().payload()); + + auto const& term_ids = query.get_term_ids(); + if (term_ids.empty()) { + return topk; + } + + auto const& selections = query.get_selections(); + runtime_assert(selections.bigrams.empty()).or_throw("This algorithm only supports unigrams"); + + topk.set_threshold(query.get_threshold()); + + auto non_essential_terms = + ranges::views::set_difference(term_ids, selections.unigrams) | ranges::to_vector; + + auto essential_cursors = index.max_scored_cursors(selections.unigrams, scorer); + auto lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); + ranges::sort(lookup_cursors, [](auto&& l, auto&& r) { return l.max_score() > r.max_score(); }); + + if constexpr (not std::is_void_v) { + inspect->essential(essential_cursors.size()); + } + + auto joined = join_union_lookup( + std::move(essential_cursors), + std::move(lookup_cursors), + payload_type{}, + accumulators::Add{}, + [&](auto score) { return topk.would_enter(score); }, + inspect); + v1::for_each(joined, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + if (topk.insert(cursor.payload(), cursor.value())) { + inspect->insert(); + } + } else { + topk.insert(cursor.payload(), cursor.value()); + } + }); + return topk; +} + +template +struct InspectUnigramUnionLookup : Inspect { + + InspectUnigramUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) + { + } + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override + { + unigram_union_lookup(query, index, std::move(topk), scorer, this); + } +}; + +} // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup.hpp b/include/pisa/v1/union_lookup.hpp index 20bb7889f..147248619 100644 --- a/include/pisa/v1/union_lookup.hpp +++ b/include/pisa/v1/union_lookup.hpp @@ -3,272 +3,134 @@ #include #include +#include #include +#include #include "v1/algorithm.hpp" #include "v1/cursor/labeled_cursor.hpp" +#include "v1/cursor/lookup_transform.hpp" +#include "v1/cursor/reference.hpp" #include "v1/cursor/transform.hpp" #include "v1/cursor_accumulator.hpp" +#include "v1/maxscore_union_lookup.hpp" #include "v1/query.hpp" #include "v1/runtime_assert.hpp" +#include "v1/unigram_union_lookup.hpp" +#include "v1/union_lookup_join.hpp" namespace pisa::v1 { -/// This cursor operator takes a number of essential cursors (in an arbitrary order) -/// and a list of lookup cursors. The documents traversed will be in the DaaT order, -/// and the following documents will be skipped: -/// - documents that do not appear in any of the essential cursors, -/// - documents that at the moment of their traversal are irrelevant (see below). -/// -/// # Threshold -/// -/// This operator takes a callable object that returns `true` only if a given score -/// has a chance to be in the final result set. It is used to decide whether or not -/// to perform further lookups for the given document. The score passed to the function -/// is such that when it returns `false`, we know that it will return `false` for the -/// rest of the lookup cursors, and therefore we can skip that document. -/// Note that such document will never be returned by this cursor. Instead, we will -/// proceed to the next document to see if it can land in the final result set, and so on. -/// -/// # Accumulating Scores -/// -/// Another parameter taken by this operator is a callable that accumulates payloads -/// for one document ID. The function is very similar to what you would pass to -/// `std::accumulate`: it takes the accumulator (either by reference or value), -/// and a reference to the cursor. It must return an updated accumulator. -/// For example, a simple accumulator that simply sums all payloads for each document, -/// can be: `[](float score, auto&& cursor) { return score + cursor.payload(); }`. -/// Note that you can accumulate "heavier" objects by taking and returning a reference: -/// ``` -/// [](auto& acc, auto&& cursor) { -/// // Do something with acc -/// return acc; -/// } -/// ``` -/// Before the first call to the accumulating function, the accumulated payload will be -/// initialized to the value `init` passed in the constructor. This will also be the -/// type of the payload returned by this cursor. -/// -/// # Passing Cursors -/// -/// Both essential and lookup cursors are passed by value and moved into a member. -/// It is thus important to pass either a temporary, a view, or a moved object to the constructor. -/// It is recommended to pass the ownership through an rvalue, as the cursors will be consumed -/// either way. However, in rare cases when the cursors need to be read after use -/// (for example to get their size or max score) or if essential and lookup cursors are in one -/// container and you want to avoid moving them, you may pass a view such as `gsl::span`. -/// However, it is discouraged in general case due to potential lifetime issues and dangling -/// references. -template -struct UnionLookupJoin { - - using essential_cursor_type = typename EssentialCursors::value_type; - using lookup_cursor_type = typename LookupCursors::value_type; - - using payload_type = Payload; - using value_type = std::decay_t())>; - - using essential_iterator_category = - typename std::iterator_traits::iterator_category; - static_assert(std::is_base_of(), - "cursors must be stored in a random access container"); - - UnionLookupJoin(EssentialCursors essential_cursors, - LookupCursors lookup_cursors, - Payload init, - AccumulateFn accumulate, - ThresholdFn above_threshold, - Inspect* inspect = nullptr) - : m_essential_cursors(std::move(essential_cursors)), - m_lookup_cursors(std::move(lookup_cursors)), - m_init(std::move(init)), - m_accumulate(std::move(accumulate)), - m_above_threshold(std::move(above_threshold)), - m_inspect(inspect) - { - if (m_essential_cursors.empty()) { - if (m_lookup_cursors.empty()) { - m_sentinel = std::numeric_limits::max(); - } else { - m_sentinel = min_sentinel(m_lookup_cursors); - } - m_current_value = m_sentinel; - m_current_payload = m_init; - return; - } - m_lookup_cumulative_upper_bound = std::accumulate( - m_lookup_cursors.begin(), m_lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { - return acc + cursor.max_score(); - }); - m_next_docid = min_value(m_essential_cursors); - m_sentinel = min_sentinel(m_essential_cursors); - advance(); - } - - [[nodiscard]] constexpr auto operator*() const noexcept -> value_type - { - return m_current_value; - } - [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } - [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& - { - return m_current_payload; - } - [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } - - constexpr void advance() - { - bool exit = false; - while (not exit) { - if (PISA_UNLIKELY(m_next_docid >= sentinel())) { - m_current_value = sentinel(); - m_current_payload = m_init; - return; - } - m_current_payload = m_init; - m_current_value = std::exchange(m_next_docid, sentinel()); - - if constexpr (not std::is_void_v) { - m_inspect->document(); - } - - for (auto&& cursor : m_essential_cursors) { - if (cursor.value() == m_current_value) { - if constexpr (not std::is_void_v) { - m_inspect->posting(); - } - m_current_payload = m_accumulate(m_current_payload, cursor); - cursor.advance(); - } - if (auto docid = cursor.value(); docid < m_next_docid) { - m_next_docid = docid; - } - } - - exit = true; - auto lookup_bound = m_lookup_cumulative_upper_bound; - for (auto&& cursor : m_lookup_cursors) { - if (not m_above_threshold(m_current_payload + lookup_bound)) { - exit = false; - break; - } - cursor.advance_to_geq(m_current_value); - if constexpr (not std::is_void_v) { - m_inspect->lookup(); - } - if (cursor.value() == m_current_value) { - m_current_payload = m_accumulate(m_current_payload, cursor); - } - lookup_bound -= cursor.max_score(); - } - } - m_position += 1; - } - - [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } - [[nodiscard]] constexpr auto empty() const noexcept -> bool - { - return m_current_value >= sentinel(); - } - - private: - EssentialCursors m_essential_cursors; - LookupCursors m_lookup_cursors; - payload_type m_init; - AccumulateFn m_accumulate; - ThresholdFn m_above_threshold; - - value_type m_current_value{}; - value_type m_sentinel{}; - payload_type m_current_payload{}; - std::uint32_t m_next_docid{}; - payload_type m_previous_threshold{}; - payload_type m_lookup_cumulative_upper_bound{}; - std::size_t m_position = 0; - - Inspect* m_inspect; -}; - -/// Convenience function to construct a `UnionLookupJoin` cursor operator. -/// See the struct documentation for more information. -template -auto join_union_lookup(EssentialCursors essential_cursors, - LookupCursors lookup_cursors, - Payload init, - AccumulateFn accumulate, - ThresholdFn threshold, - Inspect* inspect) +template +auto filter_bigram_lookup_cursors( + Index const& index, Scorer&& scorer, LookupCursors&& lookup_cursors, TermId left, TermId right) { - return UnionLookupJoin(std::move(essential_cursors), - std::move(lookup_cursors), - std::move(init), - std::move(accumulate), - std::move(threshold), - inspect); + return ranges::views::filter( + lookup_cursors, + [&](auto&& cursor) { return cursor.label() != left && cursor.label() != right; }) + //| ranges::views::transform([](auto&& cursor) { return ref(cursor); }) + | ranges::views::transform( + [&](auto&& cursor) { return index.max_scored_cursor(cursor.label(), scorer); }) + | ranges::to_vector; } -/// Processes documents with the Union-Lookup method. -/// -/// This is an optimized version that works **only on single-term posting lists**. -/// It will throw an exception if bigram selections are passed to it. -template -auto unigram_union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - [[maybe_unused]] Inspect* inspect = nullptr) +/// This algorithm... +template +auto lookup_union(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + InspectUnigram* inspect_unigram = nullptr, + InspectBigram* inspect_bigram = nullptr) { - using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using payload_type = decltype(std::declval().payload()); + using bigram_cursor_type = std::decay_t; auto const& term_ids = query.get_term_ids(); if (term_ids.empty()) { return topk; } + auto threshold = query.get_threshold(); + topk.set_threshold(threshold); + auto is_above_threshold = [&](auto score) { return topk.would_enter(score); }; + auto const& selections = query.get_selections(); - runtime_assert(selections.bigrams.empty()).or_exit("This algorithm only supports unigrams"); + auto& essential_unigrams = selections.unigrams; + auto& essential_bigrams = selections.bigrams; - topk.set_threshold(query.get_threshold()); + if constexpr (not std::is_void_v) { + inspect_unigram->essential(essential_unigrams.size()); + } + if constexpr (not std::is_void_v) { + inspect_bigram->essential(essential_bigrams.size()); + } auto non_essential_terms = - ranges::views::set_difference(term_ids, selections.unigrams) | ranges::to_vector; + ranges::views::set_difference(term_ids, essential_unigrams) | ranges::to_vector; + + auto lookup_cursors = index.cursors(non_essential_terms, [&](auto&& index, auto term) { + return label(index.max_scored_cursor(term, scorer), term); + }); + ranges::sort(lookup_cursors, std::greater{}, func::max_score{}); + auto unigram_cursor = + join_union_lookup(index.max_scored_cursors(gsl::make_span(essential_unigrams), scorer), + gsl::make_span(lookup_cursors), + 0.0F, + accumulators::Add{}, + is_above_threshold, + inspect_unigram); + + using lookup_transform_type = + LookupTransform; + using transform_payload_cursor_type = + TransformPayloadCursor; - std::vector lookup_cursors = index.max_scored_cursors(non_essential_terms, scorer); - ranges::sort(lookup_cursors, [](auto&& l, auto&& r) { return l.max_score() > r.max_score(); }); - std::vector essential_cursors = - index.max_scored_cursors(selections.unigrams, scorer); + std::vector bigram_cursors; - if constexpr (not std::is_void_v) { - inspect->essential(essential_cursors.size()); + for (auto [left, right] : essential_bigrams) { + auto cursor = index.scored_bigram_cursor(left, right, scorer); + if (not cursor) { + throw std::runtime_error(fmt::format("Bigram not found: <{}, {}>", left, right)); + } + auto bigram_lookup_cursors = + filter_bigram_lookup_cursors(index, scorer, lookup_cursors, left, right); + auto lookup_cursors_upper_bound = + std::accumulate(bigram_lookup_cursors.begin(), + bigram_lookup_cursors.end(), + 0.0F, + [](auto acc, auto&& cursor) { return acc + cursor.max_score(); }); + bigram_cursors.emplace_back(std::move(*cursor.take()), + LookupTransform(std::move(bigram_lookup_cursors), + lookup_cursors_upper_bound, + is_above_threshold, + inspect_bigram)); } - auto joined = join_union_lookup( - std::move(essential_cursors), - std::move(lookup_cursors), - payload_type{}, - accumulators::Add{}, - [&](auto score) { return topk.would_enter(score); }, - inspect); - v1::for_each(joined, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { + auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { + return acc == 0 ? cursor.payload() : acc; + }; + auto bigram_cursor = union_merge( + std::move(bigram_cursors), 0.0F, [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { + if constexpr (not std::is_void_v) { + inspect_bigram->posting(); + } + return acc == 0 ? cursor.payload() : acc; + }); + auto merged = v1::variadic_union_merge( + 0.0F, + std::make_tuple(std::move(unigram_cursor), std::move(bigram_cursor)), + std::make_tuple(accumulate, accumulate)); + + v1::for_each(merged, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { if (topk.insert(cursor.payload(), cursor.value())) { - inspect->insert(); + inspect_unigram->insert(); } } else { topk.insert(cursor.payload(), cursor.value()); @@ -277,135 +139,41 @@ auto unigram_union_lookup(Query const& query, return topk; } -/// This is a special case of Union-Lookup algorithm that does not use user-defined selections, -/// but rather uses the same way of determining essential list as Maxscore does. -/// The difference is that this algorithm will never update the threshold whereas Maxscore will -/// try to improve the estimate after each accumulated document. -template -auto maxscore_union_lookup(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - [[maybe_unused]] Inspect* inspect = nullptr) +template +auto accumulate_cursor_to_heap(Cursor&& cursor, + std::size_t k, + float threshold = 0.0, + InspectInserts* inspect_inserts = nullptr, + InspectPostings* inspect_postings = nullptr) { - using cursor_type = decltype(index.max_scored_cursor(0, scorer)); - using payload_type = decltype(std::declval().payload()); - - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return topk; - } - auto threshold = query.get_threshold(); - topk.set_threshold(threshold); - - auto cursors = index.max_scored_cursors(gsl::make_span(term_ids), scorer); - ranges::sort(cursors, [](auto&& lhs, auto&& rhs) { return lhs.max_score() < rhs.max_score(); }); - - std::vector upper_bounds(cursors.size()); - upper_bounds[0] = cursors[0].max_score(); - for (size_t i = 1; i < cursors.size(); ++i) { - upper_bounds[i] = upper_bounds[i - 1] + cursors[i].max_score(); - } - std::size_t non_essential_count = 0; - while (non_essential_count < cursors.size() && upper_bounds[non_essential_count] <= threshold) { - non_essential_count += 1; - } - if constexpr (not std::is_void_v) { - inspect->essential(cursors.size() - non_essential_count); - } - - std::vector essential_cursors; - std::move(std::next(cursors.begin(), non_essential_count), - cursors.end(), - std::back_inserter(essential_cursors)); - cursors.erase(std::next(cursors.begin(), non_essential_count), cursors.end()); - std::reverse(cursors.begin(), cursors.end()); - - auto joined = join_union_lookup( - std::move(essential_cursors), - std::move(cursors), - payload_type{}, - accumulators::Add{}, - [&](auto score) { return topk.would_enter(score); }, - inspect); - v1::for_each(joined, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { - if (topk.insert(cursor.payload(), cursor.value())) { - inspect->insert(); + topk_queue heap(k); + heap.set_threshold(threshold); + v1::for_each(cursor, [&](auto&& cursor) { + if constexpr (not std::is_void_v) { + inspect_postings->posting(); + } + if constexpr (not std::is_void_v) { + if (heap.insert(cursor.payload(), cursor.value())) { + inspect_inserts->insert(); } } else { - topk.insert(cursor.payload(), cursor.value()); + heap.insert(cursor.payload(), cursor.value()); } }); - return topk; + return heap; } -/// This callable transforms a cursor by performing lookups to the current document -/// in the given lookup cursors, and then adding the scores that were found. -/// It uses the same short-circuiting rules before each lookup as `UnionLookupJoin`. -template -struct LookupTransform { - - LookupTransform(std::vector lookup_cursors, - float lookup_cursors_upper_bound, - AboveThresholdFn above_threshold, - Inspector* inspect = nullptr) - : m_lookup_cursors(std::move(lookup_cursors)), - m_lookup_cursors_upper_bound(lookup_cursors_upper_bound), - m_above_threshold(std::move(above_threshold)), - m_inspect(inspect) - { - } - - auto operator()(Cursor& cursor) - { - if constexpr (not std::is_void_v) { - m_inspect->document(); - m_inspect->posting(); - } - auto docid = cursor.value(); - auto scores = cursor.payload(); - float score = std::get<0>(scores) + std::get<1>(scores); - auto upper_bound = score + m_lookup_cursors_upper_bound; - for (auto& lookup_cursor : m_lookup_cursors) { - if (not m_above_threshold(upper_bound)) { - return score; - } - lookup_cursor.advance_to_geq(docid); - if constexpr (not std::is_void_v) { - m_inspect->lookup(); - } - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { - auto partial_score = lookup_cursor.payload(); - score += partial_score; - upper_bound += partial_score; - } - upper_bound -= lookup_cursor.max_score(); - } - return score; - } - - private: - std::vector m_lookup_cursors; - float m_lookup_cursors_upper_bound; - AboveThresholdFn m_above_threshold; - Inspector* m_inspect; -}; - /// This algorithm... template -auto lookup_union(Query const& query, - Index const& index, - topk_queue topk, - Scorer&& scorer, - InspectUnigram* inspect_unigram = nullptr, - InspectBigram* inspect_bigram = nullptr) +auto lookup_union_eaat(Query const& query, + Index const& index, + topk_queue topk, + Scorer&& scorer, + InspectUnigram* inspect_unigram = nullptr, + InspectBigram* inspect_bigram = nullptr) { using bigram_cursor_type = std::decay_t; using lookup_cursor_type = std::decay_t; @@ -448,14 +216,17 @@ auto lookup_union(Query const& query, inspect_unigram); }(); - using lookup_transform_type = LookupTransform; + auto unigram_heap = + accumulate_cursor_to_heap(unigram_cursor, topk.size(), threshold, inspect_unigram); + + using lookup_transform_type = + LookupTransform; using transform_payload_cursor_type = TransformPayloadCursor; - std::vector bigram_cursors; + std::vector entries(unigram_heap.topk().begin(), + unigram_heap.topk().end()); + for (auto [left, right] : essential_bigrams) { auto cursor = index.scored_bigram_cursor(left, right, scorer); if (not cursor) { @@ -474,31 +245,39 @@ auto lookup_union(Query const& query, return acc + cursor.max_score(); }); - bigram_cursors.emplace_back(std::move(*cursor.take()), - lookup_transform_type(std::move(lookup_cursors), - lookup_cursors_upper_bound, - is_above_threshold, - inspect_bigram)); + auto heap = accumulate_cursor_to_heap( + transform_payload_cursor_type(std::move(*cursor.take()), + lookup_transform_type(std::move(lookup_cursors), + lookup_cursors_upper_bound, + is_above_threshold, + inspect_bigram)), + topk.size(), + threshold, + inspect_bigram, + inspect_bigram); + std::copy(heap.topk().begin(), heap.topk().end(), std::back_inserter(entries)); + } + std::sort(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + if (lhs.second == rhs.second) { + return lhs.first > rhs.first; + } + return lhs.second < rhs.second; + }); + auto end = std::unique(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + return lhs.second == rhs.second; + }); + entries.erase(end, entries.end()); + std::sort(entries.begin(), entries.end(), [](auto&& lhs, auto&& rhs) { + return lhs.first > rhs.first; + }); + if (entries.size() > topk.size()) { + entries.erase(std::next(entries.begin(), topk.size()), entries.end()); } - auto accumulate = [&](float acc, auto& cursor, [[maybe_unused]] auto idx) { - return std::max(acc, cursor.payload()); - }; - auto bigram_cursor = union_merge(std::move(bigram_cursors), 0.0F, accumulate); - auto merged = v1::variadic_union_merge( - 0.0F, - std::make_tuple(std::move(unigram_cursor), std::move(bigram_cursor)), - std::make_tuple(accumulate, accumulate)); + for (auto entry : entries) { + topk.insert(entry.first, entry.second); + } - v1::for_each(merged, [&](auto&& cursor) { - if constexpr (not std::is_void_v) { - if (topk.insert(cursor.payload(), cursor.value())) { - inspect_unigram->insert(); - } - } else { - topk.insert(cursor.payload(), cursor.value()); - } - }); return topk; } @@ -592,7 +371,7 @@ auto union_lookup(Query const& query, auto accumulate = [&](auto& acc, auto& cursor, auto /* union_idx */) { auto payload = cursor.payload(); for (auto idx = 0; idx < acc.size(); idx += 1) { - if (acc[idx] == 0.0F) { + if (acc[idx] == 0) { acc[idx] = payload[idx]; } } @@ -830,17 +609,21 @@ auto union_lookup_plus(Query const& query, return mus; }(); + auto const state_mask = (1U << term_count) - 1; + v1::for_each(merged, [&](auto& cursor) { if constexpr (not std::is_void_v) { inspect->document(); } auto docid = cursor.value(); auto scores = cursor.payload(); - auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + // auto score = std::accumulate(scores.begin(), scores.end(), 0.0F, std::plus{}); + float score = 0.0F; std::uint32_t state = essential_unigrams.size() << term_count; for (auto pos = 0U; pos < term_count; pos += 1) { if (scores[pos] > 0) { + score += scores[pos]; state |= 1U << pos; } } @@ -855,11 +638,11 @@ auto union_lookup_plus(Query const& query, if constexpr (not std::is_void_v) { inspect->lookup(); } - if (PISA_UNLIKELY(lookup_cursor.value() == docid)) { + if (lookup_cursor.value() == docid) { score += lookup_cursor.payload(); state |= (1U << next_idx); } - state = (state & ((1U << term_count) - 1)) + ((next_idx + 1) << term_count); + state = (state & state_mask) + ((next_idx + 1) << term_count); next_idx = next_lookup[state]; } if constexpr (not std::is_void_v) { @@ -874,115 +657,26 @@ auto union_lookup_plus(Query const& query, } template -struct BaseUnionLookupInspect { - BaseUnionLookupInspect(Index const& index, Scorer scorer) - : m_index(index), m_scorer(std::move(scorer)) - { - std::cout << fmt::format("documents\tpostings\tinserts\tlookups\tessential_lists\n"); - } - - void reset_current() - { - m_current_documents = 0; - m_current_postings = 0; - m_current_lookups = 0; - m_current_inserts = 0; - m_essential_lists = 0; - } - - virtual void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) = 0; - - void operator()(Query const& query) - { - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return; - } - using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); - using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); - - reset_current(); - run(query, m_index, m_scorer, topk_queue(query.k())); - std::cout << fmt::format("{}\t{}\t{}\t{}\t{}\n", - m_current_documents, - m_current_postings, - m_current_inserts, - m_current_lookups, - m_current_essential_lists); - m_documents += m_current_documents; - m_postings += m_current_postings; - m_lookups += m_current_lookups; - m_inserts += m_current_inserts; - m_essential_lists += m_current_essential_lists; - m_count += 1; - } - - void summarize() && - { - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" - "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n- essential lists:\t{}\n", - static_cast(m_documents) / m_count, - static_cast(m_postings) / m_count, - static_cast(m_inserts) / m_count, - static_cast(m_lookups) / m_count, - static_cast(m_essential_lists) / m_count); - } - - void document() { m_current_documents += 1; } - void posting() { m_current_postings += 1; } - void lookup() { m_current_lookups += 1; } - void insert() { m_current_inserts += 1; } - void essential(std::size_t n) { m_current_essential_lists = n; } - - private: - std::size_t m_current_documents = 0; - std::size_t m_current_postings = 0; - std::size_t m_current_lookups = 0; - std::size_t m_current_inserts = 0; - std::size_t m_current_essential_lists = 0; - - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_lookups = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - std::size_t m_essential_lists = 0; - Index const& m_index; - Scorer m_scorer; -}; - -template -struct MaxscoreUnionLookupInspect : public BaseUnionLookupInspect { - MaxscoreUnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override - { - maxscore_union_lookup(query, index, std::move(topk), scorer, this); - } -}; - -template -struct UnigramUnionLookupInspect : public BaseUnionLookupInspect { - UnigramUnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override +struct InspectUnionLookup : Inspect { + + InspectUnionLookup(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) { - unigram_union_lookup(query, index, std::move(topk), scorer, this); } -}; -template -struct UnionLookupInspect : public BaseUnionLookupInspect { - UnionLookupInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) - { - } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { if (query.selections()->bigrams.empty()) { unigram_union_lookup(query, index, std::move(topk), scorer, this); @@ -995,12 +689,26 @@ struct UnionLookupInspect : public BaseUnionLookupInspect { }; template -struct UnionLookupPlusInspect : public BaseUnionLookupInspect { - UnionLookupPlusInspect(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) +struct InspectUnionLookupPlus : Inspect { + + InspectUnionLookupPlus(Index const& index, Scorer const& scorer) + : Inspect(index, scorer) { } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { if (query.selections()->bigrams.empty()) { unigram_union_lookup(query, index, std::move(topk), scorer, this); @@ -1012,158 +720,64 @@ struct UnionLookupPlusInspect : public BaseUnionLookupInspect { } }; -struct ComponentInspect { - void reset_current() - { - m_documents += m_current_documents; - m_postings += m_current_postings; - m_lookups += m_current_lookups; - m_inserts += m_current_inserts; - m_essential_lists += m_current_essential_lists; - m_count += 1; - - m_current_documents = 0; - m_current_postings = 0; - m_current_lookups = 0; - m_current_inserts = 0; - m_essential_lists = 0; - } - - void document() { m_current_documents += 1; } - void posting() { m_current_postings += 1; } - void lookup() { m_current_lookups += 1; } - void insert() { m_current_inserts += 1; } - void essential(std::size_t n) { m_current_essential_lists = n; } - - [[nodiscard]] auto current_documents() const { return m_current_documents; } - [[nodiscard]] auto current_postings() const { return m_current_postings; } - [[nodiscard]] auto current_lookups() const { return m_current_lookups; } - [[nodiscard]] auto current_inserts() const { return m_current_inserts; } - [[nodiscard]] auto current_essential_lists() const { return m_current_essential_lists; } - - [[nodiscard]] auto documents() const { return m_documents; } - [[nodiscard]] auto postings() const { return m_postings; } - [[nodiscard]] auto lookups() const { return m_lookups; } - [[nodiscard]] auto inserts() const { return m_inserts; } - [[nodiscard]] auto essential_lists() const { return m_essential_lists; } - - private: - std::size_t m_current_documents = 0; - std::size_t m_current_postings = 0; - std::size_t m_current_lookups = 0; - std::size_t m_current_inserts = 0; - std::size_t m_current_essential_lists = 0; - - std::size_t m_documents = 0; - std::size_t m_postings = 0; - std::size_t m_lookups = 0; - std::size_t m_inserts = 0; - std::size_t m_count = 0; - std::size_t m_essential_lists = 0; -}; +using LookupUnionComponent = InspectMany; template -struct LookupUnionInspector : public BaseUnionLookupInspect { - LookupUnionInspector(Index const& index, Scorer scorer) - : BaseUnionLookupInspect(index, std::move(scorer)) +struct InspectLookupUnion : Inspect> { + + InspectLookupUnion(Index const& index, Scorer scorer) + : Inspect>(index, scorer) { } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) override + + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { if (query.selections()->bigrams.empty()) { - unigram_union_lookup(query, index, std::move(topk), scorer, this); + unigram_union_lookup(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first()); } else { - lookup_union(query, index, std::move(topk), scorer, this); + lookup_union(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first(), + InspectPartitioned::second()); } } }; template -struct LookupUnionInspect { - LookupUnionInspect(Index const& index, Scorer scorer) - : m_index(index), m_scorer(std::move(scorer)) - { - std::cout << fmt::format( - "documents\tpostings\tinserts\tlookups\tessential_lists\t" - "postings-uni\tlookups-uni\tpostings-bi\tlookups-bi\n"); - } +struct InspectLookupUnionEaat : Inspect> { - void reset_current() + InspectLookupUnionEaat(Index const& index, Scorer scorer) + : Inspect>(index, scorer) { - m_unigram_inspect.reset_current(); - m_bigram_inspect.reset_current(); - m_count += 1; } - void run(Query const& query, Index const& index, Scorer& scorer, topk_queue topk) + void run(Query const& query, Index const& index, Scorer const& scorer, topk_queue topk) override { if (query.selections()->bigrams.empty()) { - unigram_union_lookup(query, index, std::move(topk), scorer, &m_unigram_inspect); + unigram_union_lookup(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first()); } else { - lookup_union( - query, index, std::move(topk), scorer, &m_unigram_inspect, &m_bigram_inspect); + lookup_union_eaat(query, + index, + std::move(topk), + scorer, + InspectPartitioned::first(), + InspectPartitioned::second()); } } - - void operator()(Query const& query) - { - auto const& term_ids = query.get_term_ids(); - if (term_ids.empty()) { - return; - } - using cursor_type = decltype(m_index.max_scored_cursor(0, m_scorer)); - using value_type = decltype(m_index.max_scored_cursor(0, m_scorer).value()); - - run(query, m_index, m_scorer, topk_queue(query.k())); - std::cout << fmt::format( - "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n", - m_unigram_inspect.current_documents() + m_bigram_inspect.current_documents(), - m_unigram_inspect.current_postings() + m_bigram_inspect.current_postings(), - m_unigram_inspect.current_inserts() + m_bigram_inspect.current_inserts(), - m_unigram_inspect.current_lookups() + m_bigram_inspect.current_lookups(), - m_unigram_inspect.current_essential_lists() - + m_bigram_inspect.current_essential_lists(), - m_unigram_inspect.current_postings(), - m_unigram_inspect.current_lookups(), - m_bigram_inspect.current_postings(), - m_bigram_inspect.current_lookups()); - reset_current(); - } - - void summarize() && - { - auto documents = m_unigram_inspect.documents() + m_bigram_inspect.documents(); - auto postings = m_unigram_inspect.postings() + m_bigram_inspect.postings(); - auto inserts = m_unigram_inspect.inserts() + m_bigram_inspect.inserts(); - auto lookups = m_unigram_inspect.lookups() + m_bigram_inspect.lookups(); - auto essential_lists = - m_unigram_inspect.essential_lists() + m_bigram_inspect.essential_lists(); - auto uni_postings = m_unigram_inspect.postings(); - auto uni_lookups = m_unigram_inspect.lookups(); - auto bi_postings = m_bigram_inspect.postings(); - auto bi_lookups = m_bigram_inspect.lookups(); - std::cerr << fmt::format( - "=== SUMMARY ===\nAverage:\n- documents:\t{}\n" - "- postings:\t{}\n- inserts:\t{}\n- lookups:\t{}\n- essential lists:\t{}\n" - "- uni-postings:\t{}\n- uni-lookups:\t{}\n" - "- bi-postings:\t{}\n- bi-lookups:\t{}\n", - static_cast(documents) / m_count, - static_cast(postings) / m_count, - static_cast(inserts) / m_count, - static_cast(lookups) / m_count, - static_cast(essential_lists) / m_count, - static_cast(uni_postings) / m_count, - static_cast(uni_lookups) / m_count, - static_cast(bi_postings) / m_count, - static_cast(bi_lookups) / m_count); - } - - private: - ComponentInspect m_unigram_inspect; - ComponentInspect m_bigram_inspect; - std::size_t m_count = 0; - Index const& m_index; - Scorer m_scorer; }; } // namespace pisa::v1 diff --git a/include/pisa/v1/union_lookup_join.hpp b/include/pisa/v1/union_lookup_join.hpp new file mode 100644 index 000000000..416df35a2 --- /dev/null +++ b/include/pisa/v1/union_lookup_join.hpp @@ -0,0 +1,277 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace pisa::v1 { + +namespace func { + + /// Calls `max_score()` method on any passed object. The default projection for + /// `maxscore_partition`. + struct max_score { + template + auto operator()(Cursor&& cursor) -> float + { + return cursor.max_score(); + } + }; + +} // namespace func + +/// Partitions a list of cursors into essential and non-essential parts as in MaxScore algorithm +/// first proposed by [Turtle and +/// Flood](https://www.sciencedirect.com/science/article/pii/030645739500020H). +/// +/// # Details +/// +/// This function takes a span of (max-score) cursors that participate in a query, and the current +/// threshold. By default, it retrieves the max scores from cursors by calling `max_score()` method. +/// However, a different can be used by passing a projection. For example, if instead of +/// partitioning actual posting list cursors, you want to partition a vector of pairs `(term, +/// max-score)`, then you may pass +/// `[](auto&& c) { return c.second; }` as the projection. +/// +/// # Complexity +/// +/// Note that this function **will** sort the cursors by their max scores to ensure correct +/// partitioning, and therefore it may not be suitable to update an existing partition. +template +auto maxscore_partition(gsl::span cursors, float threshold, P projection = func::max_score{}) + -> std::pair, gsl::span> +{ + ranges::sort(cursors.begin(), cursors.end(), std::less{}, projection); + float bound = 0; + auto mid = ranges::find_if_not(cursors, [&](auto&& cursor) { + bound += projection(cursor); + return bound < threshold; + }); + auto non_essential_count = std::distance(cursors.begin(), mid); + return std::make_pair(cursors.first(non_essential_count), cursors.subspan(non_essential_count)); +} + +/// This cursor operator takes a number of essential cursors (in an arbitrary order) +/// and a list of lookup cursors. The documents traversed will be in the DaaT order, +/// and the following documents will be skipped: +/// - documents that do not appear in any of the essential cursors, +/// - documents that at the moment of their traversal are irrelevant (see below). +/// +/// # Threshold +/// +/// This operator takes a callable object that returns `true` only if a given score +/// has a chance to be in the final result set. It is used to decide whether or not +/// to perform further lookups for the given document. The score passed to the function +/// is such that when it returns `false`, we know that it will return `false` for the +/// rest of the lookup cursors, and therefore we can skip that document. +/// Note that such document will never be returned by this cursor. Instead, we will +/// proceed to the next document to see if it can land in the final result set, and so on. +/// +/// # Accumulating Scores +/// +/// Another parameter taken by this operator is a callable that accumulates payloads +/// for one document ID. The function is very similar to what you would pass to +/// `std::accumulate`: it takes the accumulator (either by reference or value), +/// and a reference to the cursor. It must return an updated accumulator. +/// For example, a simple accumulator that simply sums all payloads for each document, +/// can be: `[](float score, auto&& cursor) { return score + cursor.payload(); }`. +/// Note that you can accumulate "heavier" objects by taking and returning a reference: +/// ``` +/// [](auto& acc, auto&& cursor) { +/// // Do something with acc +/// return acc; +/// } +/// ``` +/// Before the first call to the accumulating function, the accumulated payload will be +/// initialized to the value `init` passed in the constructor. This will also be the +/// type of the payload returned by this cursor. +/// +/// # Passing Cursors +/// +/// Both essential and lookup cursors are passed by value and moved into a member. +/// It is thus important to pass either a temporary, a view, or a moved object to the constructor. +/// It is recommended to pass the ownership through an rvalue, as the cursors will be consumed +/// either way. However, in rare cases when the cursors need to be read after use +/// (for example to get their size or max score) or if essential and lookup cursors are in one +/// container and you want to avoid moving them, you may pass a view such as `gsl::span`. +/// However, it is discouraged in general case due to potential lifetime issues and dangling +/// references. +template +struct UnionLookupJoin { + + using essential_cursor_type = typename EssentialCursors::value_type; + using lookup_cursor_type = typename LookupCursors::value_type; + + using payload_type = Payload; + using value_type = std::decay_t())>; + + using essential_iterator_category = + typename std::iterator_traits::iterator_category; + + static_assert(std::is_base_of(), + "cursors must be stored in a random access container"); + + UnionLookupJoin(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn above_threshold, + Inspect* inspect = nullptr) + : m_essential_cursors(std::move(essential_cursors)), + m_lookup_cursors(std::move(lookup_cursors)), + m_init(std::move(init)), + m_accumulate(std::move(accumulate)), + m_above_threshold(std::move(above_threshold)), + m_inspect(inspect) + { + if (m_essential_cursors.empty()) { + if (m_lookup_cursors.empty()) { + m_sentinel = std::numeric_limits::max(); + } else { + m_sentinel = min_sentinel(m_lookup_cursors); + } + m_current_value = m_sentinel; + m_current_payload = m_init; + return; + } + m_lookup_cumulative_upper_bound = std::accumulate( + m_lookup_cursors.begin(), m_lookup_cursors.end(), 0.0F, [](auto acc, auto&& cursor) { + return acc + cursor.max_score(); + }); + m_next_docid = min_value(m_essential_cursors); + m_sentinel = min_sentinel(m_essential_cursors); + advance(); + } + + [[nodiscard]] constexpr auto operator*() const noexcept -> value_type + { + return m_current_value; + } + [[nodiscard]] constexpr auto value() const noexcept -> value_type { return m_current_value; } + [[nodiscard]] constexpr auto payload() const noexcept -> Payload const& + { + return m_current_payload; + } + [[nodiscard]] constexpr auto sentinel() const noexcept -> std::uint32_t { return m_sentinel; } + + constexpr void advance() + { + bool exit = false; + while (not exit) { + if (PISA_UNLIKELY(m_next_docid >= sentinel())) { + m_current_value = sentinel(); + m_current_payload = m_init; + return; + } + m_current_payload = m_init; + m_current_value = std::exchange(m_next_docid, sentinel()); + + if constexpr (not std::is_void_v) { + m_inspect->document(); + } + + for (auto&& cursor : m_essential_cursors) { + if (cursor.value() == m_current_value) { + if constexpr (not std::is_void_v) { + m_inspect->posting(); + } + m_current_payload = m_accumulate(m_current_payload, cursor); + cursor.advance(); + } + if (auto docid = cursor.value(); docid < m_next_docid) { + m_next_docid = docid; + } + } + + exit = true; + auto lookup_bound = m_lookup_cumulative_upper_bound; + for (auto&& cursor : m_lookup_cursors) { + //if (m_current_value == 2288) { + // std::cout << fmt::format( + // "[checking] doc: {}\tscore: {}\tbound: {} (is above = {})\tms = {}\n", + // m_current_value, + // m_current_payload, + // m_current_payload + lookup_bound, + // m_above_threshold(m_current_payload + lookup_bound), + // cursor.max_score()); + //} + if (not m_above_threshold(m_current_payload + lookup_bound)) { + exit = false; + break; + } + cursor.advance_to_geq(m_current_value); + //std::cout << fmt::format( + // "doc: {}\tbound: {}\n", m_current_value, m_current_payload + lookup_bound); + if constexpr (not std::is_void_v) { + m_inspect->lookup(); + } + if (cursor.value() == m_current_value) { + m_current_payload = m_accumulate(m_current_payload, cursor); + } + lookup_bound -= cursor.max_score(); + } + } + m_position += 1; + } + + [[nodiscard]] constexpr auto position() const noexcept -> std::size_t { return m_position; } + [[nodiscard]] constexpr auto empty() const noexcept -> bool + { + return m_current_value >= sentinel(); + } + + private: + EssentialCursors m_essential_cursors; + LookupCursors m_lookup_cursors; + payload_type m_init; + AccumulateFn m_accumulate; + ThresholdFn m_above_threshold; + + value_type m_current_value{}; + value_type m_sentinel{}; + payload_type m_current_payload{}; + std::uint32_t m_next_docid{}; + payload_type m_previous_threshold{}; + payload_type m_lookup_cumulative_upper_bound{}; + std::size_t m_position = 0; + + Inspect* m_inspect; +}; + +/// Convenience function to construct a `UnionLookupJoin` cursor operator. +/// See the struct documentation for more information. +template +auto join_union_lookup(EssentialCursors essential_cursors, + LookupCursors lookup_cursors, + Payload init, + AccumulateFn accumulate, + ThresholdFn threshold, + Inspect* inspect = nullptr) +{ + return UnionLookupJoin(std::move(essential_cursors), + std::move(lookup_cursors), + std::move(init), + std::move(accumulate), + std::move(threshold), + inspect); +} + +} // namespace pisa::v1 diff --git a/script/cw09b-url.sh b/script/cw09b-url.sh index 7dbc48ea3..225f47e36 100644 --- a/script/cw09b-url.sh +++ b/script/cw09b-url.sh @@ -5,11 +5,12 @@ FWD="/data/CW09B/CW09B.url.fwd" ENCODING="simdbp" # v1 BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" THREADS=4 -QUERIES="/home/michal/05.clean.shuf.test" +#QUERIES="/home/michal/05.clean.shuf.test" +QUERIES="/home/michal/biscorer/data/queries/topics.web.51-200" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-url" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" PAIRS=${FILTERED_QUERIES} -PAIR_INDEX_BASENAME="${BASENAME}-pair-2" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" +PAIR_INDEX_BASENAME="${BASENAME}-pair" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.topics.web.51-200" QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh index 90a4cf0f4..029d6113a 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -6,19 +6,21 @@ source $1 echo "Running experiment with the following environment:" echo "" -echo " PISA_BIN = ${PISA_BIN}" -echo " INTERSECT_BIN = ${INTERSECT_BIN}" -echo " BINARY_FREQ_COLL = ${BINARY_FREQ_COLL}" -echo " FWD = ${FWD}" -echo " BASENAME = ${BASENAME}" -echo " THREADS = ${THREADS}" -echo " ENCODING = ${ENCODING}" -echo " OUTPUT_DIR = ${OUTPUT_DIR}" -echo " QUERIES = ${QUERIES}" -echo " FILTERED_QUERIES = ${FILTERED_QUERIES}" -echo " K = ${K}" -echo " THRESHOLDS = ${THRESHOLDS}" -echo " QUERY_LIMIT = ${QUERY_LIMIT}" +echo " PISA_BIN = ${PISA_BIN}" +echo " INTERSECT_BIN = ${INTERSECT_BIN}" +echo " BINARY_FREQ_COLL = ${BINARY_FREQ_COLL}" +echo " FWD = ${FWD}" +echo " BASENAME = ${BASENAME}" +echo " THREADS = ${THREADS}" +echo " ENCODING = ${ENCODING}" +echo " OUTPUT_DIR = ${OUTPUT_DIR}" +echo " QUERIES = ${QUERIES}" +echo " FILTERED_QUERIES = ${FILTERED_QUERIES}" +echo " K = ${K}" +echo " THRESHOLDS = ${THRESHOLDS}" +echo " QUERY_LIMIT = ${QUERY_LIMIT}" +echo " PAIRS = ${PAIRS}" +echo " PAIR_INDEX_BASENAME = ${PAIR_INDEX_BASENAME}" echo "" set -x @@ -53,12 +55,12 @@ mkdir -p ${OUTPUT_DIR} # > ${OUTPUT_DIR}/intersections.jl # Select unigrams -${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 -${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time -${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 -${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time -${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time +#${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 +#${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time +#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time +#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time #${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact # Run benchmarks @@ -89,33 +91,39 @@ ${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${O #${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ # --benchmark --algorithm union-lookup-plus -k ${K} \ # > ${OUTPUT_DIR}/bench.plus.scaled-2 -${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ - --benchmark --algorithm lookup-union -k ${K} \ - > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ +# --benchmark --algorithm lookup-union -k ${K} \ +# > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ +# --benchmark --algorithm lookup-union-eaat -k ${K} \ +# > ${OUTPUT_DIR}/bench.lookup-union-eaat.scaled-2 # Analyze -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore -k ${K} > ${OUTPUT_DIR}/stats.maxscore -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore -k ${K} \ -# > ${OUTPUT_DIR}/stats.maxscore-threshold -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ -# > ${OUTPUT_DIR}/stats.maxscore-union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup -k ${K} \ -# > ${OUTPUT_DIR}/stats.unigram-union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ -# > ${OUTPUT_DIR}/stats.union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup-plus \ -# > ${OUTPUT_DIR}/stats.plus -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ -# > ${OUTPUT_DIR}/stats.lookup-union -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ -# --inspect --algorithm lookup-union \ -# > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ -# --inspect --algorithm union-lookup-plus \ -# > ${OUTPUT_DIR}/stats.plus.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore -k ${K} > ${OUTPUT_DIR}/stats.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore -k ${K} \ + > ${OUTPUT_DIR}/stats.maxscore-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --inspect --algorithm maxscore-union-lookup \ + > ${OUTPUT_DIR}/stats.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --inspect --algorithm unigram-union-lookup -k ${K} \ + > ${OUTPUT_DIR}/stats.unigram-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup \ + > ${OUTPUT_DIR}/stats.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm union-lookup-plus \ + > ${OUTPUT_DIR}/stats.plus +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 \ + --inspect --algorithm lookup-union \ + > ${OUTPUT_DIR}/stats.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ + --inspect --algorithm union-lookup-plus \ + > ${OUTPUT_DIR}/stats.plus.scaled-2 ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ --inspect --algorithm lookup-union \ > ${OUTPUT_DIR}/stats.lookup-union.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ + --inspect --algorithm lookup-union-eaat \ + > ${OUTPUT_DIR}/stats.lookup-union-eaat.scaled-2 # Evaluate #${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --algorithm maxscore > "${OUTPUT_DIR}/eval.maxscore" @@ -131,3 +139,12 @@ ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2. # > "${OUTPUT_DIR}/eval.lookup-union" #${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --algorithm lookup-union \ # > "${OUTPUT_DIR}/eval.lookup-union.scaled-1.5" +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} \ +# --algorithm union-lookup-plus \ +# > ${OUTPUT_DIR}/eval.plus.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} --safe \ +# --algorithm lookup-union \ +# > ${OUTPUT_DIR}/eval.lookup-union.scaled-2 +#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 -k ${K} --safe \ +# --algorithm lookup-union-eaat \ +# > ${OUTPUT_DIR}/eval.lookup-union-eaat.scaled-2 diff --git a/test/v1/test_union_lookup_join.cpp b/test/v1/test_union_lookup_join.cpp new file mode 100644 index 000000000..9cd5c97b7 --- /dev/null +++ b/test/v1/test_union_lookup_join.cpp @@ -0,0 +1,143 @@ +#include +#include + +template +std::ostream& operator<<(std::ostream& os, std::pair const& p) +{ + os << fmt::format("({}, {})", p.first, p.second); + return os; +} + +#define CATCH_CONFIG_MAIN +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "index_fixture.hpp" +#include "topk_queue.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index_metadata.hpp" +#include "v1/maxscore.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/union_lookup_join.hpp" + +using pisa::v1::collect_payloads; +using pisa::v1::collect_with_payload; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::join_union_lookup; +using pisa::v1::maxscore_partition; +using pisa::v1::RawCursor; +using pisa::v1::accumulators::Add; + +TEST_CASE("Maxscore partition", "[maxscore][v1][unit]") +{ + rc::check("paritition vector of max scores according to maxscore", []() { + auto max_scores = *rc::gen::nonEmpty>(); + ranges::sort(max_scores); + float total_sum = ranges::accumulate(max_scores, 0.0F); + auto threshold = + *rc::gen::suchThat([=](float x) { return x >= 0.0 && x < total_sum; }); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(max_scores), threshold, [](auto&& x) { return x; }); + auto non_essential_sum = ranges::accumulate(non_essential, 0.0F); + auto first_essential = essential.empty() ? 0.0 : essential[0]; + REQUIRE(non_essential_sum <= Approx(threshold)); + REQUIRE(non_essential_sum + first_essential >= threshold); + }); +} + +struct InspectMock { + std::size_t documents = 0; + std::size_t postings = 0; + std::size_t lookups = 0; + + void document() { documents += 1; } + void posting() { postings += 1; } + void lookup() { lookups += 1; } +}; + +TEST_CASE("UnionLookupJoin v Union", "[union-lookup][v1][unit]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture; + + auto result_order = [](auto&& lhs, auto&& rhs) { + if (lhs.second == rhs.second) { + return lhs.first > rhs.first; + } + return lhs.second > rhs.second; + }; + + auto add = [](auto score, auto&& cursor, [[maybe_unused]] auto idx) { + return score + cursor.payload(); + }; + + index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader()))([&](auto&& index) { + auto queries = test_queries(); + for (auto&& [idx, q] : ranges::views::enumerate(queries)) { + CAPTURE(q.get_term_ids()); + CAPTURE(idx); + + auto term_ids = gsl::make_span(q.get_term_ids()); + + auto union_results = collect_with_payload( + v1::union_merge(index.scored_cursors(term_ids, make_bm25(index)), 0.0F, add)); + std::sort(union_results.begin(), union_results.end(), result_order); + std::size_t num_results = std::min(union_results.size(), 10UL); + if (num_results == 0) { + continue; + } + float threshold = std::next(union_results.begin(), num_results - 1)->second; + + auto cursors = index.max_scored_cursors(term_ids, make_bm25(index)); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(cursors), threshold); + CAPTURE(non_essential.size()); + CAPTURE(essential.size()); + + InspectMock inspect; + auto ul_results = collect_with_payload(join_union_lookup( + essential, + non_essential, + 0.0F, + Add{}, + [=](auto score) { return score >= threshold; }, + &inspect)); + std::sort(ul_results.begin(), ul_results.end(), result_order); + REQUIRE(ul_results.size() >= num_results); + union_results.erase(std::next(union_results.begin(), num_results), union_results.end()); + ul_results.erase(std::next(ul_results.begin(), num_results), ul_results.end()); + for (auto pos = 0; pos < num_results; pos++) { + CAPTURE(pos); + REQUIRE(union_results[pos].first == ul_results[pos].first); + REQUIRE(union_results[pos].second == Approx(ul_results[pos].second)); + } + + auto essential_counts = [&] { + auto cursors = index.max_scored_cursors(term_ids, make_bm25(index)); + auto [non_essential, essential] = + maxscore_partition(gsl::make_span(cursors), threshold); + return collect_payloads(v1::union_merge( + essential, + 0, + [](auto count, [[maybe_unused]] auto&& cursor, [[maybe_unused]] auto idx) { + return count + 1; + })); + }(); + REQUIRE(essential_counts.size() == inspect.documents); + REQUIRE(ranges::accumulate(essential_counts, 0) == inspect.postings); + } + }); +} diff --git a/test/v1/test_v1_queries.cpp b/test/v1/test_v1_queries.cpp index b68246485..5f62f5698 100644 --- a/test/v1/test_v1_queries.cpp +++ b/test/v1/test_v1_queries.cpp @@ -31,6 +31,7 @@ #include "v1/index.hpp" #include "v1/index_builder.hpp" #include "v1/maxscore.hpp" +#include "v1/maxscore_union_lookup.hpp" #include "v1/posting_builder.hpp" #include "v1/posting_format_header.hpp" #include "v1/query.hpp" @@ -40,6 +41,7 @@ #include "v1/sequence/positive_sequence.hpp" #include "v1/taat_or.hpp" #include "v1/types.hpp" +#include "v1/unigram_union_lookup.hpp" #include "v1/union_lookup.hpp" #include "v1/wand.hpp" @@ -134,10 +136,11 @@ std::unique_ptr> TEMPLATE_TEST_CASE("Query", "[v1][integration]", - (IndexFixture, RawCursor, RawCursor>), - (IndexFixture, - PayloadBlockedCursor<::pisa::simdbp_block>, - RawCursor>), + //(IndexFixture, RawCursor, + // RawCursor>), + //(IndexFixture, + // PayloadBlockedCursor<::pisa::simdbp_block>, + // RawCursor>), (IndexFixture>, PayloadBitSequenceCursor>, RawCursor>)) @@ -148,20 +151,21 @@ TEMPLATE_TEST_CASE("Query", Index, RawCursor>>::get(); TestType fixture; auto input_data = GENERATE(table({ - {"daat_or", false, false}, + //{"daat_or", false, false}, {"maxscore", false, false}, {"maxscore", true, false}, - {"wand", false, false}, - {"wand", true, false}, - {"bmw", false, false}, - {"bmw", true, false}, - {"bmw", false, true}, - {"bmw", true, true}, + //{"wand", false, false}, + //{"wand", true, false}, + //{"bmw", false, false}, + //{"bmw", true, false}, + //{"bmw", false, true}, + //{"bmw", true, true}, {"maxscore_union_lookup", true, false}, {"unigram_union_lookup", true, false}, - {"union_lookup", true, false}, - {"union_lookup_plus", true, false}, + //{"union_lookup", true, false}, + //{"union_lookup_plus", true, false}, {"lookup_union", true, false}, + {"lookup_union_eaat", true, false}, })); std::string algorithm = std::get<0>(input_data); bool with_threshold = std::get<1>(input_data); @@ -210,6 +214,9 @@ TEMPLATE_TEST_CASE("Query", if (name == "lookup_union") { return lookup_union(query, index, ::pisa::topk_queue(10), scorer); } + if (name == "lookup_union_eaat") { + return lookup_union_eaat(query, index, ::pisa::topk_queue(10), scorer); + } std::abort(); }; int idx = 0; @@ -218,7 +225,7 @@ TEMPLATE_TEST_CASE("Query", for (auto& query : test_queries()) { heap.clear(); if (algorithm == "union_lookup" || algorithm == "union_lookup_plus" - || algorithm == "lookup_union") { + || algorithm == "lookup_union" || algorithm == "lookup_union_eaat") { query.selections(gsl::make_span(intersections[idx])); } diff --git a/test/v1/test_v1_union_lookup.cpp b/test/v1/test_v1_union_lookup.cpp new file mode 100644 index 000000000..47b85089c --- /dev/null +++ b/test/v1/test_v1_union_lookup.cpp @@ -0,0 +1,155 @@ +#define CATCH_CONFIG_MAIN +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "index_fixture.hpp" +#include "topk_queue.hpp" +#include "v1/cursor/collect.hpp" +#include "v1/cursor_accumulator.hpp" +#include "v1/index_metadata.hpp" +#include "v1/maxscore.hpp" +#include "v1/raw_cursor.hpp" +#include "v1/scorer/bm25.hpp" +#include "v1/union_lookup.hpp" + +using pisa::v1::BM25; +using pisa::v1::collect_payloads; +using pisa::v1::collect_with_payload; +using pisa::v1::DocId; +using pisa::v1::Frequency; +using pisa::v1::index_runner; +using pisa::v1::InspectLookupUnion; +using pisa::v1::InspectLookupUnionEaat; +using pisa::v1::InspectMaxScore; +using pisa::v1::InspectUnionLookup; +using pisa::v1::InspectUnionLookupPlus; +using pisa::v1::join_union_lookup; +using pisa::v1::lookup_union; +using pisa::v1::lookup_union_eaat; +using pisa::v1::maxscore_partition; +using pisa::v1::RawCursor; + +template +void test_write(T&& result) +{ + std::ostringstream os; + result.write(os); + REQUIRE(fmt::format("{}\t{}\t{}\t{}\t{}", + result.postings(), + result.documents(), + result.lookups(), + result.inserts(), + result.essentials()) + == os.str()); +} + +template +void test_write_partitioned(T&& result) +{ + std::ostringstream os; + result.write(os); + REQUIRE(fmt::format("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", + result.sum.postings(), + result.sum.documents(), + result.sum.lookups(), + result.sum.inserts(), + result.sum.essentials(), + result.first.postings(), + result.first.documents(), + result.first.lookups(), + result.first.inserts(), + result.first.essentials(), + result.second.postings(), + result.second.documents(), + result.second.lookups(), + result.second.inserts(), + result.second.essentials()) + == os.str()); +} + +TEST_CASE("UnionLookup statistics", "[union-lookup][v1][unit]") +{ + tbb::task_scheduler_init init; + IndexFixture, RawCursor, RawCursor> fixture; + index_runner(fixture.meta(), + std::make_tuple(fixture.document_reader()), + std::make_tuple(fixture.frequency_reader()))([&](auto&& index) { + auto union_lookup_inspect = InspectUnionLookup(index, make_bm25(index)); + auto union_lookup_plus_inspect = InspectUnionLookupPlus(index, make_bm25(index)); + auto lookup_union_inspect = InspectLookupUnion(index, make_bm25(index)); + auto lookup_union_eaat_inspect = InspectLookupUnionEaat(index, make_bm25(index)); + auto queries = test_queries(); + auto const intersections = + pisa::v1::read_intersections(PISA_SOURCE_DIR "/test/test_data/top10_selections"); + for (auto&& [idx, q] : ranges::views::enumerate(queries)) { + // if (idx == 230 || idx == 272 || idx == 380) { + // // Skipping these because of the false positives caused by floating point + // precision. continue; + //} + if (q.get_term_ids().size() > 8) { + continue; + } + auto heap = maxscore(q, index, ::pisa::topk_queue(10), make_bm25(index)); + q.selections(intersections[idx]); + q.threshold(heap.topk().back().first); + CAPTURE(q.get_term_ids()); + CAPTURE(intersections[idx]); + CAPTURE(q.get_threshold()); + CAPTURE(idx); + + auto ul = union_lookup_inspect(q); + auto ulp = union_lookup_plus_inspect(q); + auto lu = lookup_union_inspect(q); + auto lue = lookup_union_eaat_inspect(q); + test_write(ul); + test_write(ulp); + test_write_partitioned(lu); + test_write_partitioned(lue); + + CHECK(ul.documents() == ulp.documents()); + CHECK(ul.postings() == ulp.postings()); + + // +2 because of the false positives caused by floating point + CHECK(ul.lookups() + 2 >= ulp.lookups()); + + CAPTURE(ulp.lookups()); + CAPTURE(ul.lookups()); + CAPTURE(lu.first.lookups()); + CAPTURE(lu.second.lookups()); + + CAPTURE(ul.essentials()); + CAPTURE(lu.first.essentials()); + CAPTURE(lu.second.essentials()); + + CHECK(lu.first.lookups() + lu.second.lookups() == lu.sum.lookups()); + CHECK(lue.first.lookups() + lue.second.lookups() == lue.sum.lookups()); + CHECK(ul.postings() == lu.sum.postings()); + CHECK(ul.postings() == lue.sum.postings()); + + // +3 because of the false positives caused by floating point + CHECK(ulp.lookups() <= lu.sum.lookups() + 3); + CHECK(ulp.lookups() <= lue.sum.lookups() + 3); + } + auto ul = union_lookup_inspect.mean(); + auto ulp = union_lookup_plus_inspect.mean(); + auto lu = lookup_union_inspect.mean(); + auto lue = lookup_union_inspect.mean(); + CHECK(ul.documents() == ulp.documents()); + CHECK(ul.postings() == ulp.postings()); + CHECK(ul.lookups() >= ulp.lookups()); + CHECK(ul.postings() == lu.first.postings() + lu.second.postings()); + CHECK(ul.postings() == lue.first.postings() + lue.second.postings()); + CHECK(ulp.lookups() <= lu.first.lookups() + lu.second.lookups()); + CHECK(ulp.lookups() <= lue.first.lookups() + lue.second.lookups()); + CHECK(lu.first.lookups() + lu.second.lookups() == lu.sum.lookups()); + CHECK(lue.first.lookups() + lue.second.lookups() == lue.sum.lookups()); + }); +} diff --git a/v1/query.cpp b/v1/query.cpp index 46e8e407d..90839f49a 100644 --- a/v1/query.cpp +++ b/v1/query.cpp @@ -17,33 +17,36 @@ #include "v1/index_metadata.hpp" #include "v1/inspect_query.hpp" #include "v1/maxscore.hpp" +#include "v1/maxscore_union_lookup.hpp" #include "v1/query.hpp" #include "v1/raw_cursor.hpp" #include "v1/scorer/bm25.hpp" #include "v1/scorer/runner.hpp" #include "v1/types.hpp" +#include "v1/unigram_union_lookup.hpp" #include "v1/union_lookup.hpp" #include "v1/wand.hpp" using pisa::v1::daat_or; -using pisa::v1::DaatOrInspector; using pisa::v1::DocumentBlockedReader; using pisa::v1::index_runner; +using pisa::v1::InspectDaatOr; +using pisa::v1::InspectLookupUnion; +using pisa::v1::InspectLookupUnionEaat; +using pisa::v1::InspectMaxScore; +using pisa::v1::InspectMaxScoreUnionLookup; +using pisa::v1::InspectUnigramUnionLookup; +using pisa::v1::InspectUnionLookup; +using pisa::v1::InspectUnionLookupPlus; using pisa::v1::lookup_union; -using pisa::v1::LookupUnionInspect; using pisa::v1::maxscore_union_lookup; -using pisa::v1::MaxscoreInspector; -using pisa::v1::MaxscoreUnionLookupInspect; using pisa::v1::PayloadBlockedReader; using pisa::v1::Query; using pisa::v1::QueryInspector; using pisa::v1::RawReader; using pisa::v1::unigram_union_lookup; -using pisa::v1::UnigramUnionLookupInspect; using pisa::v1::union_lookup; using pisa::v1::union_lookup_plus; -using pisa::v1::UnionLookupInspect; -using pisa::v1::UnionLookupPlusInspect; using pisa::v1::VoidScorer; using pisa::v1::wand; @@ -59,6 +62,9 @@ struct RetrievalAlgorithm { { topk = m_retrieve(query, topk); if (m_safe && not topk.full()) { + spdlog::debug("Retrieved {} out of {} documents. Rerunning without threshold.", + topk.topk().size(), + topk.size()); topk.clear(); topk = m_fallback(query, topk); } @@ -164,7 +170,7 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco return pisa::v1::unigram_union_lookup( query, index, std::move(topk), std::forward(scorer)); } - if (query.get_term_ids().size() >= 8) { + if (query.get_term_ids().size() > 8) { return pisa::v1::maxscore( query, index, std::move(topk), std::forward(scorer)); } @@ -191,6 +197,23 @@ auto resolve_algorithm(std::string const& name, Index const& index, Scorer&& sco fallback, safe); } + if (name == "lookup-union-eaat") { + return RetrievalAlgorithm( + [&](pisa::v1::Query const& query, ::pisa::topk_queue topk) { + if (query.selections()->bigrams.empty()) { + if (query.selections()->unigrams.empty()) { + return pisa::v1::maxscore( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::unigram_union_lookup( + query, index, std::move(topk), std::forward(scorer)); + } + return pisa::v1::lookup_union_eaat( + query, index, std::move(topk), std::forward(scorer)); + }, + fallback, + safe); + } spdlog::error("Unknown algorithm: {}", name); std::exit(1); } @@ -199,27 +222,30 @@ template auto resolve_inspect(std::string const& name, Index const& index, Scorer&& scorer) -> QueryInspector { if (name == "daat_or") { - return QueryInspector(DaatOrInspector(index, std::forward(scorer))); + return QueryInspector(InspectDaatOr(index, std::forward(scorer))); } if (name == "maxscore") { - return QueryInspector(MaxscoreInspector(index, std::forward(scorer))); + return QueryInspector(InspectMaxScore(index, std::forward(scorer))); } if (name == "maxscore-union-lookup") { return QueryInspector( - MaxscoreUnionLookupInspect>(index, scorer)); + InspectMaxScoreUnionLookup>(index, scorer)); } if (name == "unigram-union-lookup") { return QueryInspector( - UnigramUnionLookupInspect>(index, scorer)); + InspectUnigramUnionLookup>(index, scorer)); } if (name == "union-lookup") { - return QueryInspector(UnionLookupInspect>(index, scorer)); + return QueryInspector(InspectUnionLookup>(index, scorer)); } if (name == "lookup-union") { - return QueryInspector(LookupUnionInspect>(index, scorer)); + return QueryInspector(InspectLookupUnion>(index, scorer)); + } + if (name == "lookup-union-eaat") { + return QueryInspector(InspectLookupUnionEaat>(index, scorer)); } if (name == "union-lookup-plus") { - return QueryInspector(UnionLookupPlusInspect>(index, scorer)); + return QueryInspector(InspectUnionLookupPlus>(index, scorer)); } spdlog::error("Unknown algorithm: {}", name); std::exit(1); @@ -278,10 +304,19 @@ void benchmark(std::vector const& queries, RetrievalAlgorithm retrieve) void inspect_queries(std::vector const& queries, QueryInspector inspect) { + inspect.header(std::cout); + std::cout << '\n'; for (auto query = 0; query < queries.size(); query += 1) { - inspect(queries[query]); + inspect(queries[query]).write(std::cout); + std::cout << '\n'; } - std::move(inspect).summarize(); + + std::cerr << "========== Avg ==========\n"; + inspect.header(std::cerr); + std::cerr << '\n'; + inspect.mean().write(std::cerr); + std::cerr << '\n'; + std::cerr << "=========================\n"; } int main(int argc, char** argv) From 8480dc970e8490ac5647c4149dc67f1ae075f66f Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 21 Feb 2020 13:14:27 +0000 Subject: [PATCH 55/56] cmake --- script/cw09b-url.sh | 7 ++-- script/cw09b.sh | 79 ++++++++++++++++++++++---------------------- tools/CMakeLists.txt | 28 ++++++++-------- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/script/cw09b-url.sh b/script/cw09b-url.sh index 225f47e36..eac47f6e9 100644 --- a/script/cw09b-url.sh +++ b/script/cw09b-url.sh @@ -5,12 +5,13 @@ FWD="/data/CW09B/CW09B.url.fwd" ENCODING="simdbp" # v1 BASENAME="/data/michal/work/v1/cw09b-url/cw09b-${ENCODING}" THREADS=4 -#QUERIES="/home/michal/05.clean.shuf.test" -QUERIES="/home/michal/biscorer/data/queries/topics.web.51-200" +QUERIES="/home/michal/05.clean.shuf.test" +#QUERIES="/home/michal/biscorer/data/queries/topics.web.51-200" K=1000 OUTPUT_DIR="/data/michal/intersect/cw09b-url" FILTERED_QUERIES="${OUTPUT_DIR}/$(basename ${QUERIES}).filtered" PAIRS=${FILTERED_QUERIES} PAIR_INDEX_BASENAME="${BASENAME}-pair" -THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.topics.web.51-200" +#THRESHOLDS="/home/michal/topics.web.51-200.thresholds" +THRESHOLDS="/home/michal/biscorer/data/thresholds/cw09b/thresholds.cw09b.0_01.top20.bm25.05.clean.shuf.test" QUERY_LIMIT=5000 diff --git a/script/cw09b.sh b/script/cw09b.sh index 029d6113a..49a7cbcc5 100644 --- a/script/cw09b.sh +++ b/script/cw09b.sh @@ -37,30 +37,31 @@ mkdir -p ${OUTPUT_DIR} #${PISA_BIN}/bmscore -i "${BASENAME}.yml" -j ${THREADS} --variable-blocks 22.5 --clone ${BASENAME}-var # Filter out queries witout existing terms. -#paste -d: ${QUERIES} ${THRESHOLDS} \ -# | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ -# | head -${QUERY_LIMIT} \ -# | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} +paste -d: ${QUERIES} ${THRESHOLDS} \ + | jq '{"id": split(":")[0], "query": split(":")[1], "threshold": split(":")[2] | tonumber}' -R -c \ + | head -${QUERY_LIMIT} \ + | ${PISA_BIN}/filter-queries -i ${BASENAME}.yml --min 2 --max 8 > ${FILTERED_QUERIES} # This will produce both quantized scores and max scores (both quantized and not). -#${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 +${PISA_BIN}/bigram-index -i "${BASENAME}.yml" -q ${PAIRS} --clone ${PAIR_INDEX_BASENAME} -j 4 # Extract intersections -#${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ -# | grep -v "\[warning\]" \ -# > ${OUTPUT_DIR}/intersections.jl +${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations --existing \ + | grep -v "\[warning\]" \ + > ${OUTPUT_DIR}/intersections.jl #${PISA_BIN}/intersection -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --combinations \ # --in /home/michal/real.aol.top1m.tsv \ # | grep -v "\[warning\]" \ # > ${OUTPUT_DIR}/intersections.jl # Select unigrams -#${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 -#${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time -#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 -#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time -#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 -#${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1 +${INTERSECT_BIN} -m unigram ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.1.time +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.1.5 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl > ${OUTPUT_DIR}/selections.2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --time > ${OUTPUT_DIR}/selections.2.time +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.scaled-2 +${INTERSECT_BIN} -m greedy ${OUTPUT_DIR}/intersections.jl --scale 2 --time > ${OUTPUT_DIR}/selections.2.scaled-2.time #${INTERSECT_BIN} -m exact ${OUTPUT_DIR}/intersections.jl --scale 2 > ${OUTPUT_DIR}/selections.2.exact # Run benchmarks @@ -72,31 +73,31 @@ mkdir -p ${OUTPUT_DIR} #${PISA_BIN}/query -i "${BASENAME}-var.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm bmw --safe \ # > ${OUTPUT_DIR}/bench.vbmw-threshold -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe -k ${K} > ${OUTPUT_DIR}/bench.maxscore -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe -k ${K} \ -# > ${OUTPUT_DIR}/bench.maxscore-threshold -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ -# > ${OUTPUT_DIR}/bench.maxscore-union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe -k ${K} \ -# > ${OUTPUT_DIR}/bench.unigram-union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ -# > ${OUTPUT_DIR}/bench.union-lookup -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ -# > ${OUTPUT_DIR}/bench.lookup-union -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup-plus --safe \ -# > ${OUTPUT_DIR}/bench.plus -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ -# --benchmark --algorithm lookup-union \ -# > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ -# --benchmark --algorithm union-lookup-plus -k ${K} \ -# > ${OUTPUT_DIR}/bench.plus.scaled-2 -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ -# --benchmark --algorithm lookup-union -k ${K} \ -# > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 -#${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ -# --benchmark --algorithm lookup-union-eaat -k ${K} \ -# > ${OUTPUT_DIR}/bench.lookup-union-eaat.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --benchmark --algorithm maxscore --safe -k ${K} > ${OUTPUT_DIR}/bench.maxscore +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore --safe -k ${K} \ + > ${OUTPUT_DIR}/bench.maxscore-threshold +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${FILTERED_QUERIES} --benchmark --algorithm maxscore-union-lookup --safe \ + > ${OUTPUT_DIR}/bench.maxscore-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.1 --benchmark --algorithm unigram-union-lookup --safe -k ${K} \ + > ${OUTPUT_DIR}/bench.unigram-union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup --safe \ + > ${OUTPUT_DIR}/bench.union-lookup +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm lookup-union --safe \ + > ${OUTPUT_DIR}/bench.lookup-union +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2 --benchmark --algorithm union-lookup-plus --safe \ + > ${OUTPUT_DIR}/bench.plus +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-1.5 --safe \ + --benchmark --algorithm lookup-union \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-1.5 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm union-lookup-plus -k ${K} \ + > ${OUTPUT_DIR}/bench.plus.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union -k ${K} \ + > ${OUTPUT_DIR}/bench.lookup-union.scaled-2 +${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q ${OUTPUT_DIR}/selections.2.scaled-2 --safe \ + --benchmark --algorithm lookup-union-eaat -k ${K} \ + > ${OUTPUT_DIR}/bench.lookup-union-eaat.scaled-2 # Analyze ${PISA_BIN}/query -i "${PAIR_INDEX_BASENAME}.yml" -q <(jq 'del(.threshold)' ${FILTERED_QUERIES} -c) --inspect --algorithm maxscore -k ${K} > ${OUTPUT_DIR}/stats.maxscore diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 61a21e951..2d8b6aaca 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -63,20 +63,20 @@ #target_link_libraries(evaluate_collection_ordering # pisa # ) -# -#add_executable(parse_collection parse_collection.cpp) -#target_link_libraries(parse_collection -# pisa -# CLI11 -# wapopp -#) -# -#add_executable(invert invert.cpp) -#target_link_libraries(invert -# CLI11 -# pisa -#) -# + +add_executable(parse_collection parse_collection.cpp) +target_link_libraries(parse_collection + pisa + CLI11 + wapopp +) + +add_executable(invert invert.cpp) +target_link_libraries(invert + CLI11 + pisa +) + #add_executable(read_collection read_collection.cpp) #target_link_libraries(read_collection # pisa From c931089c02612ac47f4e89eb8cde4e17222df73b Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 21 Feb 2020 18:32:07 +0000 Subject: [PATCH 56/56] Add counting individual term postings --- v1/count_postings.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/v1/count_postings.cpp b/v1/count_postings.cpp index 6ebf286af..b52462ecb 100644 --- a/v1/count_postings.cpp +++ b/v1/count_postings.cpp @@ -17,9 +17,13 @@ int main(int argc, char** argv) spdlog::set_default_logger(spdlog::stderr_color_mt("")); bool pair_index = false; + bool term_by_term = false; pisa::App app("Simply counts all postings in the index"); - app.add_flag("--pairs", pair_index, "Count postings in the pair index instead"); + auto* pairs_opt = + app.add_flag("--pairs", pair_index, "Count postings in the pair index instead"); + app.add_flag("-t,--terms", term_by_term, "Print posting counts for each term in the index") + ->excludes(pairs_opt); CLI11_PARSE(app, argc, argv); auto meta = app.index_metadata(); @@ -35,6 +39,10 @@ int main(int argc, char** argv) index.bigram_cursor(std::get<0>(term_pair), std::get<1>(term_pair))->size(); status += 1; } + } else if (term_by_term) { + for (pisa::v1::TermId id = 0; id < index.num_terms(); id += 1) { + std::cout << index.term_posting_count(id) << '\n'; + } } else { pisa::v1::ProgressStatus status( index.num_terms(),