Skip to content

Commit

Permalink
Quantized index (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
amallia committed Mar 5, 2020
1 parent 53ebd4a commit f3afb5d
Show file tree
Hide file tree
Showing 18 changed files with 563 additions and 338 deletions.
105 changes: 53 additions & 52 deletions include/pisa/configuration.hpp
Original file line number Diff line number Diff line change
@@ -1,71 +1,72 @@
#pragma once

#include <cstdlib>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <thread>

#include "boost/lexical_cast.hpp"

namespace pisa {

class configuration {
public:
static configuration const& get() {
static configuration instance;
return instance;
}

double eps1;
double eps2;
double eps3;

double eps1_wand;
double eps2_wand;
class configuration {
public:
static configuration const &get()
{
static configuration instance;
return instance;
}

double fixed_cost_wand_partition;
uint64_t fix_cost;
uint64_t k;
uint64_t block_size;
double eps1;
double eps2;
double eps3;

size_t log_partition_size;
size_t worker_threads;
size_t threshold_wand_list;
size_t reference_size;
double eps1_wand;
double eps2_wand;

double fixed_cost_wand_partition;
uint64_t fix_cost;
uint64_t k;
uint64_t block_size;

size_t log_partition_size;
size_t worker_threads;
size_t threshold_wand_list;
size_t reference_size;
size_t quantization_bits;

bool heuristic_greedy;
bool heuristic_greedy;

private:
configuration()
{
fillvar("PISA_K", k, 10);
fillvar("PISA_BLOCK_SIZE", block_size, 5);
fillvar("PISA_EPS1", eps1, 0.03);
fillvar("PISA_EPS2", eps2, 0.3);
fillvar("PISA_EPS3", eps3, 0.01);
fillvar("PISA_FIXCOST", fix_cost, 64);
fillvar("PISA_LOG_PART", log_partition_size, 7);
fillvar("PISA_THRESHOLD_WAND_LIST", threshold_wand_list, 0);
fillvar("PISA_THREADS", worker_threads, std::thread::hardware_concurrency());
fillvar("PISA_HEURISTIC_GREEDY", heuristic_greedy, false);
fillvar("PISA_FIXED_COST_WAND_PARTITION", fixed_cost_wand_partition, 12.0);
fillvar("PISA_EPS1_WAND", eps1_wand, 0.01);
fillvar("PISA_EPS2_WAND", eps2_wand, 0.4);
fillvar("PISA_SCORE_REFERENCES_SIZE", reference_size, 128);
}
private:
configuration()
{
fillvar("PISA_K", k, 10);
fillvar("PISA_BLOCK_SIZE", block_size, 5);
fillvar("PISA_EPS1", eps1, 0.03);
fillvar("PISA_EPS2", eps2, 0.3);
fillvar("PISA_EPS3", eps3, 0.01);
fillvar("PISA_FIXCOST", fix_cost, 64);
fillvar("PISA_LOG_PART", log_partition_size, 7);
fillvar("PISA_THRESHOLD_WAND_LIST", threshold_wand_list, 0);
fillvar("PISA_THREADS", worker_threads, std::thread::hardware_concurrency());
fillvar("PISA_HEURISTIC_GREEDY", heuristic_greedy, false);
fillvar("PISA_FIXED_COST_WAND_PARTITION", fixed_cost_wand_partition, 12.0);
fillvar("PISA_EPS1_WAND", eps1_wand, 0.01);
fillvar("PISA_EPS2_WAND", eps2_wand, 0.4);
fillvar("PISA_SCORE_REFERENCES_SIZE", reference_size, 128);
fillvar("PISA_QUANTIZTION_BITS", quantization_bits, 8);
}

template <typename T, typename T2>
void fillvar(const char* envvar, T& var, T2 def)
{
const char* val = std::getenv(envvar);
if (!val || !strlen(val)) {
var = def;
} else {
var = boost::lexical_cast<T>(val);
}
template <typename T, typename T2>
void fillvar(const char *envvar, T &var, T2 def)
{
const char *val = std::getenv(envvar);
if (!val || !strlen(val)) {
var = def;
} else {
var = boost::lexical_cast<T>(val);
}
};
}
};

}
} // namespace pisa
29 changes: 29 additions & 0 deletions include/pisa/linear_quantizer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
#include <cmath>
#include <gsl/gsl_assert>
#include "spdlog/spdlog.h"

namespace pisa {

struct LinearQuantizer {
explicit LinearQuantizer(float max, uint8_t bits)
: m_max(max), m_scale(static_cast<float>(1u << (bits)) / max)
{
if (bits > 32 or bits == 0) {
throw std::runtime_error(fmt::format(
"Linear quantizer must take a number of bits between 1 and 32 but {} passed",
bits));
}
}
[[nodiscard]] auto operator()(float value) const -> std::uint32_t
{
Expects(value <= m_max);
return std::ceil(value * m_scale);
}

private:
float m_max;
float m_scale;
};

} // namespace pisa
24 changes: 24 additions & 0 deletions include/pisa/scorer/quantized.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <utility>

#include "index_scorer.hpp"
namespace pisa {

template <typename Wand>
struct quantized : public index_scorer<Wand> {

using index_scorer<Wand>::index_scorer;
term_scorer_t term_scorer(uint64_t term_id) const
{
auto s = [](uint32_t doc, uint32_t freq) {
return freq;
};
return s;
}
};

} // namespace pisa
45 changes: 25 additions & 20 deletions include/pisa/scorer/scorer.hpp
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
#pragma once

#include <string>
#include <type_traits>

#include "spdlog/spdlog.h"
#include "index_scorer.hpp"
#include "qld.hpp"
#include "pl2.hpp"
#include "bm25.hpp"
#include "dph.hpp"
#include "index_scorer.hpp"
#include "pl2.hpp"
#include "qld.hpp"
#include "quantized.hpp"
#include "spdlog/spdlog.h"

namespace pisa {
namespace scorer{
auto from_name = [](std::string const &scorer_name, auto const &wdata) -> std::unique_ptr<index_scorer<decltype(wdata)>> {
if (scorer_name == "bm25") {
return std::make_unique<bm25<decltype(wdata)>>(wdata);
} else if (scorer_name == "qld") {
return std::make_unique<qld<decltype(wdata)>>(wdata);
} else if (scorer_name == "pl2") {
return std::make_unique<pl2<decltype(wdata)>>(wdata);
} else if (scorer_name == "dph") {
return std::make_unique<dph<decltype(wdata)>>(wdata);
} else {
spdlog::error("Unknown scorer {}", scorer_name);
std::abort();

}
};
namespace scorer {
auto from_name =
[](std::string const &scorer_name,
auto const &wdata) -> std::unique_ptr<index_scorer<std::decay_t<decltype(wdata)>>> {
if (scorer_name == "bm25") {
return std::make_unique<bm25<std::decay_t<decltype(wdata)>>>(wdata);
} else if (scorer_name == "qld") {
return std::make_unique<qld<std::decay_t<decltype(wdata)>>>(wdata);
} else if (scorer_name == "pl2") {
return std::make_unique<pl2<std::decay_t<decltype(wdata)>>>(wdata);
} else if (scorer_name == "dph") {
return std::make_unique<dph<std::decay_t<decltype(wdata)>>>(wdata);
} else if (scorer_name == "quantized") {
return std::make_unique<quantized<std::decay_t<decltype(wdata)>>>(wdata);
} else {
spdlog::error("Unknown scorer {}", scorer_name);
std::abort();
}
};
}
} // namespace pisa
67 changes: 45 additions & 22 deletions include/pisa/wand_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "util/util.hpp"
#include "wand_data_raw.hpp"

#include "linear_quantizer.hpp"
#include "scorer/scorer.hpp"

class enumerator;
Expand All @@ -31,7 +32,9 @@ class wand_data {
binary_freq_collection const &coll,
std::string const &scorer_name,
BlockSize block_size,
std::unordered_set<size_t> const &terms_to_drop) : m_num_docs(num_docs)
bool is_quantized,
std::unordered_set<size_t> const &terms_to_drop)
: m_num_docs(num_docs)
{
std::vector<uint32_t> doc_lens(num_docs);
std::vector<float> max_term_weight;
Expand All @@ -51,21 +54,22 @@ class wand_data {
typename block_wand_type::builder builder(coll, params);

{
pisa::progress progress("Storing terms statistics", coll.size());
size_t term_id = 0;
for (auto const &seq : coll) {
if(terms_to_drop.find(term_id) != terms_to_drop.end()){
progress.update(1);
term_id += 1;
continue;
}

size_t term_occurrence_count = std::accumulate(seq.freqs.begin(), seq.freqs.end(), 0);
term_occurrence_counts.push_back(term_occurrence_count);
term_posting_counts.push_back(seq.docs.size());
term_id += 1;
progress.update(1);
}
pisa::progress progress("Storing terms statistics", coll.size());
size_t term_id = 0;
for (auto const &seq : coll) {
if (terms_to_drop.find(term_id) != terms_to_drop.end()) {
progress.update(1);
term_id += 1;
continue;
}

size_t term_occurrence_count =
std::accumulate(seq.freqs.begin(), seq.freqs.end(), 0);
term_occurrence_counts.push_back(term_occurrence_count);
term_posting_counts.push_back(seq.docs.size());
term_id += 1;
progress.update(1);
}
}
m_doc_lens.steal(doc_lens);
m_term_occurrence_counts.steal(term_occurrence_counts);
Expand All @@ -77,17 +81,27 @@ class wand_data {
size_t term_id = 0;
size_t new_term_id = 0;
for (auto const &seq : coll) {
if(terms_to_drop.find(term_id) != terms_to_drop.end()){
if (terms_to_drop.find(term_id) != terms_to_drop.end()) {
progress.update(1);
term_id += 1;
continue;
}
auto v = builder.add_sequence(seq, coll, doc_lens, m_avg_len, scorer->term_scorer(new_term_id), block_size);
auto v = builder.add_sequence(
seq, coll, doc_lens, m_avg_len, scorer->term_scorer(new_term_id), block_size);
max_term_weight.push_back(v);
m_index_max_term_weight = std::max(m_index_max_term_weight, v);
term_id += 1;
new_term_id += 1;
progress.update(1);
}
if (is_quantized) {
LinearQuantizer quantizer(m_index_max_term_weight,
configuration::get().quantization_bits);
for (auto &&w : max_term_weight) {
w = quantizer(w);
}
builder.quantize_block_max_term_weitghts(m_index_max_term_weight);
}
}
builder.build(m_block_wand);
m_max_term_weight.steal(max_term_weight);
Expand All @@ -97,10 +111,15 @@ class wand_data {

size_t doc_len(uint64_t doc_id) const { return m_doc_lens[doc_id]; }

size_t term_occurrence_count(uint64_t term_id) const { return m_term_occurrence_counts[term_id]; }
size_t term_occurrence_count(uint64_t term_id) const
{
return m_term_occurrence_counts[term_id];
}

size_t term_posting_count(uint64_t term_id) const { return m_term_posting_counts[term_id]; }

float index_max_term_weight() const { return m_index_max_term_weight; }

size_t num_docs() const { return m_num_docs; }

float avg_len() const { return m_avg_len; }
Expand All @@ -111,7 +130,7 @@ class wand_data {

wand_data_enumerator getenum(size_t i) const
{
return m_block_wand.get_enum(i, max_term_weight(i));
return m_block_wand.get_enum(i, index_max_term_weight());
}

const block_wand_type &get_block_wand() const { return m_block_wand; }
Expand All @@ -121,14 +140,18 @@ class wand_data {
{
visit(m_block_wand, "m_block_wand")(m_doc_lens, "m_doc_lens")(

m_term_occurrence_counts, "m_term_occurrence_counts")(m_term_posting_counts, "m_term_posting_counts")(m_avg_len, "m_avg_len")(
m_collection_len, "m_collection_len")(m_num_docs, "m_num_docs")(m_max_term_weight,"m_max_term_weight");
m_term_occurrence_counts, "m_term_occurrence_counts")(m_term_posting_counts,
"m_term_posting_counts")(
m_avg_len, "m_avg_len")(m_collection_len, "m_collection_len")(m_num_docs, "m_num_docs")(
m_max_term_weight, "m_max_term_weight")(m_index_max_term_weight,
"m_index_max_term_weight");
}

private:
uint64_t m_num_docs = 0;
float m_avg_len = 0;
uint64_t m_collection_len = 0;
float m_index_max_term_weight = 0;
block_wand_type m_block_wand;
mapper::mappable_vector<uint32_t> m_doc_lens;
mapper::mappable_vector<uint32_t> m_term_occurrence_counts;
Expand Down
Loading

0 comments on commit f3afb5d

Please sign in to comment.