Skip to content

Commit

Permalink
[Fix] Feed fasttext language model with the pre-tokenized words
Browse files Browse the repository at this point in the history
  • Loading branch information
vstakhov committed May 2, 2023
1 parent 9158852 commit bf00268
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 26 deletions.
3 changes: 1 addition & 2 deletions src/libmime/lang_detection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1836,8 +1836,7 @@ rspamd_language_detector_detect (struct rspamd_task *task,
if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
rspamd_fasttext_predict_result_t fasttext_predict_result =
rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
part->utf_stripped_content->data,
part->utf_stripped_content->len, 4);
part->utf_words, 4);

ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);

Expand Down
74 changes: 51 additions & 23 deletions src/libmime/lang_detection_fasttext.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
#include "libserver/cfg_file.h"
#include "libserver/logger.h"
#include "fmt/core.h"
#include "stat_api.h"
#include <exception>
#include <string>
#include <string_view>
#include <vector>
#include <sstream>
#include <streambuf>
#endif

#ifdef WITH_FASTTEXT
Expand All @@ -37,12 +36,6 @@ class fasttext_langdet {
std::string model_fname;
bool loaded;

struct one_shot_buf : public std::streambuf {
explicit one_shot_buf(const char *in, std::size_t sz) {
auto deconst_in = const_cast<char *>(in);
setg(deconst_in, deconst_in, deconst_in + sz);
}
};
public:
explicit fasttext_langdet(struct rspamd_config *cfg) {
const auto *ucl_obj = cfg->rcl_obj;
Expand Down Expand Up @@ -74,27 +67,51 @@ class fasttext_langdet {
~fasttext_langdet() = default;

auto is_enabled() const -> bool { return loaded; }
auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> *
auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const {
if (!loaded) {
return;
}

std::string tok{in, len};
const auto &dic = ft.getDictionary();
auto h = dic->hash(tok);
auto wid = dic->getId(tok, h);
auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid);

if (type == fasttext::entry_type::word) {
if (wid < 0) {
auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW);
dic->computeSubwords(pipelined_word, word_ngramms);
}
else {
if (ft.getArgs().maxn <= 0) {
word_ngramms.push_back(wid);
}
else {
const auto ngrams = dic->getSubwords(wid);
word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend());
}
}
}
}
auto detect_language(std::vector<std::int32_t> &words, int k)
-> std::vector<std::pair<fasttext::real, std::string>> *
{
if (!loaded) {
return nullptr;
}

/* Hack to deal with streams without copies */
one_shot_buf buf{in, len};
auto stream = std::istream{&buf};
auto predictions = new std::vector<std::pair<fasttext::real, std::string>>;
predictions->reserve(k);
auto res = ft.predictLine(stream, *predictions, k, 0.0f);
fasttext::Predictions line_predictions;
line_predictions.reserve(k);
ft.predict(k, words, line_predictions, 0.0f);
const auto *dict = ft.getDictionary().get();

if (res) {
return predictions;
for (const auto &pred : line_predictions) {
predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second)));
}
else {
delete predictions;
}

return nullptr;
return predictions;
}

auto model_info(void) const -> std::string {
Expand Down Expand Up @@ -150,15 +167,26 @@ bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
}

rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
const char *in, size_t len, int k)
GArray *utf_words,
int k)
{
#ifndef WITH_FASTTEXT
return nullptr;
#else
/* Avoid too long inputs */
static const size_t max_fasttext_input_len = 1024 * 1024 * 1;
static const guint max_fasttext_input_len = 1024 * 1024;
auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
auto *res = real_model->detect_language(in, std::min(max_fasttext_input_len, len), k);
std::vector<std::int32_t> words_vec;
words_vec.reserve(utf_words->len);

for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) {
const auto *w = &g_array_index (utf_words, rspamd_stat_token_t, i);
if (w->original.len > 0) {
real_model->word2vec(w->original.begin, w->original.len, words_vec);
}
}

auto *res = real_model->detect_language(words_vec, k);

return (rspamd_fasttext_predict_result_t)res;
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/libmime/lang_detection_fasttext.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ typedef void * rspamd_fasttext_predict_result_t;
* @return TRUE if language is detected
*/
rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
const char *in, size_t len, int k);
GArray *utf_words, int k);

/**
* Get number of languages detected
Expand Down

0 comments on commit bf00268

Please sign in to comment.