-
-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
563 additions
and
338 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,72 @@ | ||
#pragma once | ||
|
||
#include <cstdlib> | ||
#include <cstdint> | ||
#include <cstdlib> | ||
#include <memory> | ||
#include <thread> | ||
|
||
#include "boost/lexical_cast.hpp" | ||
|
||
namespace pisa { | ||
|
||
class configuration { | ||
public: | ||
static configuration const& get() { | ||
static configuration instance; | ||
return instance; | ||
} | ||
|
||
double eps1; | ||
double eps2; | ||
double eps3; | ||
|
||
double eps1_wand; | ||
double eps2_wand; | ||
class configuration { | ||
public: | ||
static configuration const &get() | ||
{ | ||
static configuration instance; | ||
return instance; | ||
} | ||
|
||
double fixed_cost_wand_partition; | ||
uint64_t fix_cost; | ||
uint64_t k; | ||
uint64_t block_size; | ||
double eps1; | ||
double eps2; | ||
double eps3; | ||
|
||
size_t log_partition_size; | ||
size_t worker_threads; | ||
size_t threshold_wand_list; | ||
size_t reference_size; | ||
double eps1_wand; | ||
double eps2_wand; | ||
|
||
double fixed_cost_wand_partition; | ||
uint64_t fix_cost; | ||
uint64_t k; | ||
uint64_t block_size; | ||
|
||
size_t log_partition_size; | ||
size_t worker_threads; | ||
size_t threshold_wand_list; | ||
size_t reference_size; | ||
size_t quantization_bits; | ||
|
||
bool heuristic_greedy; | ||
bool heuristic_greedy; | ||
|
||
private: | ||
configuration() | ||
{ | ||
fillvar("PISA_K", k, 10); | ||
fillvar("PISA_BLOCK_SIZE", block_size, 5); | ||
fillvar("PISA_EPS1", eps1, 0.03); | ||
fillvar("PISA_EPS2", eps2, 0.3); | ||
fillvar("PISA_EPS3", eps3, 0.01); | ||
fillvar("PISA_FIXCOST", fix_cost, 64); | ||
fillvar("PISA_LOG_PART", log_partition_size, 7); | ||
fillvar("PISA_THRESHOLD_WAND_LIST", threshold_wand_list, 0); | ||
fillvar("PISA_THREADS", worker_threads, std::thread::hardware_concurrency()); | ||
fillvar("PISA_HEURISTIC_GREEDY", heuristic_greedy, false); | ||
fillvar("PISA_FIXED_COST_WAND_PARTITION", fixed_cost_wand_partition, 12.0); | ||
fillvar("PISA_EPS1_WAND", eps1_wand, 0.01); | ||
fillvar("PISA_EPS2_WAND", eps2_wand, 0.4); | ||
fillvar("PISA_SCORE_REFERENCES_SIZE", reference_size, 128); | ||
} | ||
private: | ||
configuration() | ||
{ | ||
fillvar("PISA_K", k, 10); | ||
fillvar("PISA_BLOCK_SIZE", block_size, 5); | ||
fillvar("PISA_EPS1", eps1, 0.03); | ||
fillvar("PISA_EPS2", eps2, 0.3); | ||
fillvar("PISA_EPS3", eps3, 0.01); | ||
fillvar("PISA_FIXCOST", fix_cost, 64); | ||
fillvar("PISA_LOG_PART", log_partition_size, 7); | ||
fillvar("PISA_THRESHOLD_WAND_LIST", threshold_wand_list, 0); | ||
fillvar("PISA_THREADS", worker_threads, std::thread::hardware_concurrency()); | ||
fillvar("PISA_HEURISTIC_GREEDY", heuristic_greedy, false); | ||
fillvar("PISA_FIXED_COST_WAND_PARTITION", fixed_cost_wand_partition, 12.0); | ||
fillvar("PISA_EPS1_WAND", eps1_wand, 0.01); | ||
fillvar("PISA_EPS2_WAND", eps2_wand, 0.4); | ||
fillvar("PISA_SCORE_REFERENCES_SIZE", reference_size, 128); | ||
fillvar("PISA_QUANTIZTION_BITS", quantization_bits, 8); | ||
} | ||
|
||
template <typename T, typename T2> | ||
void fillvar(const char* envvar, T& var, T2 def) | ||
{ | ||
const char* val = std::getenv(envvar); | ||
if (!val || !strlen(val)) { | ||
var = def; | ||
} else { | ||
var = boost::lexical_cast<T>(val); | ||
} | ||
template <typename T, typename T2> | ||
void fillvar(const char *envvar, T &var, T2 def) | ||
{ | ||
const char *val = std::getenv(envvar); | ||
if (!val || !strlen(val)) { | ||
var = def; | ||
} else { | ||
var = boost::lexical_cast<T>(val); | ||
} | ||
}; | ||
} | ||
}; | ||
|
||
} | ||
} // namespace pisa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#pragma once | ||
#include <cmath> | ||
#include <gsl/gsl_assert> | ||
#include "spdlog/spdlog.h" | ||
|
||
namespace pisa { | ||
|
||
struct LinearQuantizer { | ||
explicit LinearQuantizer(float max, uint8_t bits) | ||
: m_max(max), m_scale(static_cast<float>(1u << (bits)) / max) | ||
{ | ||
if (bits > 32 or bits == 0) { | ||
throw std::runtime_error(fmt::format( | ||
"Linear quantizer must take a number of bits between 1 and 32 but {} passed", | ||
bits)); | ||
} | ||
} | ||
[[nodiscard]] auto operator()(float value) const -> std::uint32_t | ||
{ | ||
Expects(value <= m_max); | ||
return std::ceil(value * m_scale); | ||
} | ||
|
||
private: | ||
float m_max; | ||
float m_scale; | ||
}; | ||
|
||
} // namespace pisa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <cstdint> | ||
#include <utility> | ||
|
||
#include "index_scorer.hpp" | ||
namespace pisa { | ||
|
||
template <typename Wand> | ||
struct quantized : public index_scorer<Wand> { | ||
|
||
using index_scorer<Wand>::index_scorer; | ||
term_scorer_t term_scorer(uint64_t term_id) const | ||
{ | ||
auto s = [](uint32_t doc, uint32_t freq) { | ||
return freq; | ||
}; | ||
return s; | ||
} | ||
}; | ||
|
||
} // namespace pisa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,35 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <type_traits> | ||
|
||
#include "spdlog/spdlog.h" | ||
#include "index_scorer.hpp" | ||
#include "qld.hpp" | ||
#include "pl2.hpp" | ||
#include "bm25.hpp" | ||
#include "dph.hpp" | ||
#include "index_scorer.hpp" | ||
#include "pl2.hpp" | ||
#include "qld.hpp" | ||
#include "quantized.hpp" | ||
#include "spdlog/spdlog.h" | ||
|
||
namespace pisa { | ||
namespace scorer{ | ||
auto from_name = [](std::string const &scorer_name, auto const &wdata) -> std::unique_ptr<index_scorer<decltype(wdata)>> { | ||
if (scorer_name == "bm25") { | ||
return std::make_unique<bm25<decltype(wdata)>>(wdata); | ||
} else if (scorer_name == "qld") { | ||
return std::make_unique<qld<decltype(wdata)>>(wdata); | ||
} else if (scorer_name == "pl2") { | ||
return std::make_unique<pl2<decltype(wdata)>>(wdata); | ||
} else if (scorer_name == "dph") { | ||
return std::make_unique<dph<decltype(wdata)>>(wdata); | ||
} else { | ||
spdlog::error("Unknown scorer {}", scorer_name); | ||
std::abort(); | ||
|
||
} | ||
}; | ||
namespace scorer { | ||
auto from_name = | ||
[](std::string const &scorer_name, | ||
auto const &wdata) -> std::unique_ptr<index_scorer<std::decay_t<decltype(wdata)>>> { | ||
if (scorer_name == "bm25") { | ||
return std::make_unique<bm25<std::decay_t<decltype(wdata)>>>(wdata); | ||
} else if (scorer_name == "qld") { | ||
return std::make_unique<qld<std::decay_t<decltype(wdata)>>>(wdata); | ||
} else if (scorer_name == "pl2") { | ||
return std::make_unique<pl2<std::decay_t<decltype(wdata)>>>(wdata); | ||
} else if (scorer_name == "dph") { | ||
return std::make_unique<dph<std::decay_t<decltype(wdata)>>>(wdata); | ||
} else if (scorer_name == "quantized") { | ||
return std::make_unique<quantized<std::decay_t<decltype(wdata)>>>(wdata); | ||
} else { | ||
spdlog::error("Unknown scorer {}", scorer_name); | ||
std::abort(); | ||
} | ||
}; | ||
} | ||
} // namespace pisa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.