Skip to content

Commit

Permalink
feat(contextual_translation): weight and re-order phrases by context
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Apr 16, 2019
1 parent 6ae34de commit 2390da3
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 7 deletions.
60 changes: 60 additions & 0 deletions src/rime/gear/contextual_translation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <algorithm>
#include <iterator>
#include <rime/gear/contextual_translation.h>
#include <rime/gear/translator_commons.h>

namespace rime {

const int kContextualSearchLimit = 32;

bool ContextualTranslation::Replenish() {
vector<of<Phrase>> queue;
size_t end_pos = 0;
while (!translation_->exhausted() &&
cache_.size() + queue.size() < kContextualSearchLimit) {
auto cand = translation_->Peek();
DLOG(INFO) << cand->text() << " cache/queue: "
<< cache_.size() << "/" << queue.size();
if (cand->type() == "phrase" || cand->type() == "table") {
if (end_pos != cand->end()) {
end_pos = cand->end();
AppendToCache(queue);
}
queue.push_back(Evaluate(As<Phrase>(cand)));
} else {
AppendToCache(queue);
cache_.push_back(cand);
}
if (!translation_->Next()) {
break;
}
}
AppendToCache(queue);
return !cache_.empty();
}

an<Phrase> ContextualTranslation::Evaluate(an<Phrase> phrase) {
auto sentence = New<Sentence>(phrase->language());
sentence->Offset(phrase->start());
bool is_rear = phrase->end() == input_.length();
sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_,
grammar_);
phrase->set_weight(sentence->weight());
DLOG(INFO) << "contextual suggestion: " << phrase->text()
<< " weight: " << phrase->weight();
return phrase;
}

static bool compare_by_weight_desc(const an<Phrase>& a, const an<Phrase>& b) {
return a->weight() > b->weight();
}

void ContextualTranslation::AppendToCache(vector<of<Phrase>>& queue) {
if (queue.empty()) return;
DLOG(INFO) << "appending to cache " << queue.size() << " candidates.";
std::sort(queue.begin(), queue.end(), compare_by_weight_desc);
std::copy(queue.begin(), queue.end(), std::back_inserter(cache_));
queue.clear();
}

} // namespace rime
38 changes: 38 additions & 0 deletions src/rime/gear/contextual_translation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//
// Copyright RIME Developers
// Distributed under the BSD License
//

#include <rime/common.h>
#include <rime/translation.h>

namespace rime {

class Candidate;
class Grammar;
class Phrase;

class ContextualTranslation : public PrefetchTranslation {
public:
ContextualTranslation(an<Translation> translation,
string input,
string preceding_text,
Grammar* grammar)
: PrefetchTranslation(translation),
input_(input),
preceding_text_(preceding_text),
grammar_(grammar) {}

protected:
bool Replenish() override;

private:
an<Phrase> Evaluate(an<Phrase> phrase);
void AppendToCache(vector<of<Phrase>>& queue);

string input_;
string preceding_text_;
Grammar* grammar_;
};

} // namespace rime
15 changes: 15 additions & 0 deletions src/rime/gear/poet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#define RIME_POET_H_

#include <rime/common.h>
#include <rime/translation.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/translator_commons.h>
#include <rime/gear/contextual_translation.h>

namespace rime {

Expand All @@ -39,6 +41,19 @@ class Poet {
size_t total_length,
const string& preceding_text);

template <class TranslatorT>
an<Translation> ContextualWeighted(an<Translation> translation,
const string& input,
TranslatorT* translator) {
if (!translator->contextual_suggestions() || !grammar_) {
return translation;
}
return New<ContextualTranslation>(translation,
input,
translator->GetPrecedingText(),
grammar_.get());
}

private:
const Language* language_;
the<Grammar> grammar_;
Expand Down
6 changes: 5 additions & 1 deletion src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ an<Translation> ScriptTranslator::Query(const string& input,
enable_user_dict ? user_dict_.get() : NULL)) {
return nullptr;
}
return New<DistinctTranslation>(result);
auto deduped = New<DistinctTranslation>(result);
if (contextual_suggestions_) {
return poet_->ContextualWeighted(deduped, input, this);
}
return deduped;
}

string ScriptTranslator::FormatPreedit(const string& preedit) {
Expand Down
12 changes: 7 additions & 5 deletions src/rime/gear/table_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ TableTranslator::TableTranslator(const Ticket& ticket)
&max_phrase_length_);
config->GetInt(name_space_ + "/max_homographs",
&max_homographs_);
if (enable_sentence_ || sentence_over_completion_) {
if (enable_sentence_ || sentence_over_completion_ ||
contextual_suggestions_) {
poet_.reset(new Poet(language(), config, Poet::LeftAssociateCompare));
}
}
Expand Down Expand Up @@ -306,11 +307,12 @@ an<Translation> TableTranslator::Query(const string& input,
translation = sentence + translation;
}
}
if (translation) {
translation = New<DistinctTranslation>(translation);
}
if (translation && translation->exhausted()) {
translation.reset(); // discard futile translation
return nullptr;
}
translation = New<DistinctTranslation>(translation);
if (contextual_suggestions_) {
return poet_->ContextualWeighted(translation, input, this);
}
return translation;
}
Expand Down
2 changes: 1 addition & 1 deletion src/rime/gear/translator_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class Phrase : public Candidate {
void set_syllabifier(an<PhraseSyllabifier> syllabifier) {
syllabifier_ = syllabifier;
}

double weight() const { return entry_->weight; }
void set_weight(double weight) { entry_->weight = weight; }
Code& code() const { return entry_->code; }
const DictEntry& entry() const { return *entry_; }
const Language* language() const { return language_; }
Expand Down

0 comments on commit 2390da3

Please sign in to comment.