Skip to content

Commit

Permalink
feat(script_translator): word completion from 2nd place (#848)
Browse files Browse the repository at this point in the history
* prefer exact match phrase on top
* set word "completion" type
  • Loading branch information
lotem committed Mar 14, 2024
1 parent 9184ae6 commit 5c7fb64
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 66 deletions.
15 changes: 15 additions & 0 deletions src/rime/dict/user_dictionary.cc
Expand Up @@ -325,6 +325,21 @@ an<UserDictEntryCollector> UserDictionary::Lookup(
for (auto& v : state.query_result) {
v.second.Sort();
}
auto entries_with_word_completion =
state.query_result.find(state.predict_word_from_depth);
if (entries_with_word_completion != state.query_result.end()) {
auto& entries = entries_with_word_completion->second;
// if the top candidate is predictive match,
if (!entries.empty() && entries.front()->IsPredictiveMatch()) {
auto found =
std::find_if(entries.begin(), entries.end(),
[](const auto& e) { return e->IsExactMatch(); });
if (found != entries.end()) {
// move the first exact match candidate to top.
std::rotate(entries.begin(), found, found + 1);
}
}
}
return collect(&state.query_result);
}

Expand Down
4 changes: 0 additions & 4 deletions src/rime/dict/vocabulary.cc
Expand Up @@ -57,10 +57,6 @@ string Code::ToString() const {
return stream.str();
}

inline ShortDictEntry DictEntry::ToShort() const {
return {text, code, weight};
}

bool ShortDictEntry::operator<(const ShortDictEntry& other) const {
// Sort different entries sharing the same code by weight desc.
if (weight != other.weight)
Expand Down
8 changes: 7 additions & 1 deletion src/rime/dict/vocabulary.h
Expand Up @@ -55,7 +55,13 @@ struct DictEntry {
int matching_code_size = 0;

DictEntry() = default;
ShortDictEntry ToShort() const;
ShortDictEntry ToShort() const { return {text, code, weight}; }
bool IsExactMatch() const {
return matching_code_size == 0 || matching_code_size == code.size();
}
bool IsPredictiveMatch() const {
return matching_code_size != 0 && matching_code_size < code.size();
}
bool operator<(const DictEntry& other) const;
};

Expand Down
3 changes: 2 additions & 1 deletion src/rime/gear/contextual_translation.cc
Expand Up @@ -18,7 +18,8 @@ bool ContextualTranslation::Replenish() {
DLOG(INFO) << cand->text() << " cache/queue: " << cache_.size() << "/"
<< queue.size();
if (cand->type() == "phrase" || cand->type() == "user_phrase" ||
cand->type() == "table" || cand->type() == "user_table") {
cand->type() == "table" || cand->type() == "user_table" ||
cand->type() == "completion") {
if (end_pos != cand->end() || last_type != cand->type()) {
end_pos = cand->end();
last_type = cand->type();
Expand Down
153 changes: 93 additions & 60 deletions src/rime/gear/script_translator.cc
Expand Up @@ -130,7 +130,7 @@ class ScriptTranslation : public Translation {
protected:
bool CheckEmpty();
bool IsNormalSpelling() const;
void PrepareCandidate();
bool PrepareCandidate();
template <class QueryResult>
void EnrollEntries(map<int, DictEntryList>& entries_by_end_pos,
const an<QueryResult>& query_result);
Expand All @@ -147,11 +147,19 @@ class ScriptTranslation : public Translation {
an<Sentence> sentence_;

an<Phrase> candidate_ = nullptr;
size_t candidate_index_ = 0;
enum CandidateSource {
kUninitialized,
kUserPhrase,
kSysPhrase,
kSentence,
};
CandidateSource candidate_source_ = kUninitialized;

DictEntryCollector::reverse_iterator phrase_iter_;
UserDictEntryCollector::reverse_iterator user_phrase_iter_;

size_t max_corrections_ = 4;
const size_t max_corrections_ = 4;
size_t correction_count_ = 0;

bool enable_correction_;
Expand Down Expand Up @@ -342,9 +350,10 @@ string ScriptSyllabifier::GetOriginalSpelling(const Phrase& cand) const {
return string();
}

static bool is_exact_match_phrase(const an<DictEntry>& entry) {
return entry && entry->matching_code_size == 0 ||
entry->matching_code_size == entry->code.size();
template <class Ptr, class Iter>
static bool has_exact_match_phrase(Ptr ptr, Iter iter, size_t consumed) {
return ptr && iter->first == consumed && !iter->second.exhausted() &&
iter->second.Peek()->IsExactMatch();
}

// ScriptTranslation implementation
Expand Down Expand Up @@ -372,15 +381,10 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
user_phrase_iter_ = user_phrase_->rbegin();

// make sentences when there is no exact-matching phrase candidate
bool has_exact_match_phrase =
phrase_ && phrase_iter_->first == consumed &&
is_exact_match_phrase(phrase_iter_->second.Peek());
bool has_exact_match_user_phrase =
user_phrase_ && user_phrase_iter_->first == consumed &&
is_exact_match_phrase(user_phrase_iter_->second.Peek());
bool has_at_least_two_syllables = syllable_graph.edges.size() >= 2;
if (!has_exact_match_phrase && !has_exact_match_user_phrase &&
has_at_least_two_syllables) {
if (has_at_least_two_syllables &&
!has_exact_match_phrase(phrase_, phrase_iter_, consumed) &&
!has_exact_match_phrase(user_phrase_, user_phrase_iter_, consumed)) {
sentence_ = MakeSentence(dict, user_dict);
}

Expand All @@ -393,43 +397,42 @@ bool ScriptTranslation::Next() {
is_correction = false;
if (exhausted())
return false;
if (sentence_) {
sentence_.reset();
return !CheckEmpty();
if (candidate_source_ == kUninitialized) {
PrepareCandidate(); // to determine candidate_source_
}
int user_phrase_code_length = 0;
if (user_phrase_ && user_phrase_iter_ != user_phrase_->rend()) {
user_phrase_code_length = user_phrase_iter_->first;
}
int phrase_code_length = 0;
if (phrase_ && phrase_iter_ != phrase_->rend()) {
phrase_code_length = phrase_iter_->first;
}
if (user_phrase_code_length > 0 &&
user_phrase_code_length >= phrase_code_length) {
UserDictEntryIterator& uter(user_phrase_iter_->second);
if (!uter.Next()) {
++user_phrase_iter_;
}
} else if (phrase_code_length > 0) {
DictEntryIterator& iter(phrase_iter_->second);
if (!iter.Next()) {
++phrase_iter_;
}
switch (candidate_source_) {
case kUninitialized:
break;
case kSentence:
sentence_.reset();
break;
case kUserPhrase: {
UserDictEntryIterator& uter(user_phrase_iter_->second);
if (!uter.Next()) {
++user_phrase_iter_;
}
} break;
case kSysPhrase: {
DictEntryIterator& iter(phrase_iter_->second);
if (!iter.Next()) {
++phrase_iter_;
}
} break;
}
candidate_.reset();
candidate_source_ = kUninitialized;
if (enable_correction_) {
PrepareCandidate();
if (!candidate_) {
// populate next candidate and skip it if it's a correction beyond max
// numbers.
if (!PrepareCandidate()) {
break;
}
is_correction = syllabifier_->IsCandidateCorrection(*candidate_);
}
} while ( // limit the number of correction candidates
enable_correction_ && is_correction &&
correction_count_ > max_corrections_);
if (is_correction) {
++correction_count_;
}
} while (enable_correction_ &&
syllabifier_->IsCandidateCorrection(*candidate_) &&
// limit the number of correction candidates
++correction_count_ > max_corrections_);
++candidate_index_;
return !CheckEmpty();
}

Expand All @@ -440,8 +443,7 @@ bool ScriptTranslation::IsNormalSpelling() const {
}

an<Candidate> ScriptTranslation::Peek() {
PrepareCandidate();
if (!candidate_) {
if (candidate_source_ == kUninitialized && !PrepareCandidate()) {
return nullptr;
}
if (candidate_->preedit().empty()) {
Expand All @@ -458,14 +460,29 @@ an<Candidate> ScriptTranslation::Peek() {
return candidate_;
}

void ScriptTranslation::PrepareCandidate() {
static bool always_true() {
return true;
}

template <typename T>
inline static bool prefer_user_phrase(
T user_phrase_weight,
T sys_phrase_weight,
function<bool()> compare_on_tie = always_true) {
return user_phrase_weight > sys_phrase_weight ||
(user_phrase_weight == sys_phrase_weight && compare_on_tie());
}

bool ScriptTranslation::PrepareCandidate() {
if (exhausted()) {
candidate_source_ = kUninitialized;
candidate_ = nullptr;
return;
return false;
}
if (sentence_) {
candidate_source_ = kSentence;
candidate_ = sentence_;
return;
return true;
}
size_t user_phrase_code_length = 0;
if (user_phrase_ && user_phrase_iter_ != user_phrase_->rend()) {
Expand All @@ -475,28 +492,44 @@ void ScriptTranslation::PrepareCandidate() {
if (phrase_ && phrase_iter_ != phrase_->rend()) {
phrase_code_length = phrase_iter_->first;
}
an<Phrase> cand;
if (user_phrase_code_length > 0 &&
user_phrase_code_length >= phrase_code_length) {
prefer_user_phrase(user_phrase_code_length, phrase_code_length, [this]() {
const int kNumExactMatchOnTop = 1;
size_t full_code_length = end_of_input_ - start_;
return candidate_index_ >= kNumExactMatchOnTop ||
prefer_user_phrase(
has_exact_match_phrase(user_phrase_, user_phrase_iter_,
full_code_length),
has_exact_match_phrase(phrase_, phrase_iter_,
full_code_length));
})) {
UserDictEntryIterator& uter = user_phrase_iter_->second;
const auto& entry = uter.Peek();
DLOG(INFO) << "user phrase '" << entry->text
<< "', code length: " << user_phrase_code_length;
cand = New<Phrase>(translator_->language(), "user_phrase", start_,
start_ + user_phrase_code_length, entry);
cand->set_quality(std::exp(entry->weight) + translator_->initial_quality() +
(IsNormalSpelling() ? 0.5 : -0.5));
candidate_source_ = kUserPhrase;
candidate_ =
New<Phrase>(translator_->language(),
entry->IsPredictiveMatch() ? "completion" : "user_phrase",
start_, start_ + user_phrase_code_length, entry);
candidate_->set_quality(std::exp(entry->weight) +
translator_->initial_quality() +
(IsNormalSpelling() ? 0.5 : -0.5));
} else if (phrase_code_length > 0) {
DictEntryIterator& iter = phrase_iter_->second;
const auto& entry = iter.Peek();
DLOG(INFO) << "phrase '" << entry->text
<< "', code length: " << phrase_code_length;
cand = New<Phrase>(translator_->language(), "phrase", start_,
start_ + phrase_code_length, entry);
cand->set_quality(std::exp(entry->weight) + translator_->initial_quality() +
(IsNormalSpelling() ? 0 : -1));
candidate_source_ = kSysPhrase;
candidate_ =
New<Phrase>(translator_->language(),
entry->IsPredictiveMatch() ? "completion" : "phrase",
start_, start_ + phrase_code_length, entry);
candidate_->set_quality(std::exp(entry->weight) +
translator_->initial_quality() +
(IsNormalSpelling() ? 0 : -1));
}
candidate_ = cand;
return true;
}

bool ScriptTranslation::CheckEmpty() {
Expand Down

0 comments on commit 5c7fb64

Please sign in to comment.