Skip to content

Commit

Permalink
feat(grammar): compare homophones/homographs in sentence
Browse files Browse the repository at this point in the history
add inteface to grammar plugin; fall back to naive formula if missing "grammar" module
  • Loading branch information
lotem committed Mar 13, 2019
1 parent fcf36bc commit 9248a6b
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 34 deletions.
2 changes: 2 additions & 0 deletions src/rime/dict/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class DictEntryIterator : public DictEntryFilterBinder {
DictEntryIterator() = default;
DictEntryIterator(DictEntryIterator&& other) = default;
DictEntryIterator& operator= (DictEntryIterator&& other) = default;
DictEntryIterator(const DictEntryIterator& other) = default;
DictEntryIterator& operator= (const DictEntryIterator& other) = default;

void AddChunk(dictionary::Chunk&& chunk, Table* table);
void Sort();
Expand Down
31 changes: 31 additions & 0 deletions src/rime/gear/grammar.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef RIME_GRAMMAR_H_
#define RIME_GRAMMAR_H_

#include <rime/common.h>
#include <rime/component.h>
#include <rime/dict/vocabulary.h>

namespace rime {

class Config;

class Grammar : public Class<Grammar, Config*> {
public:
virtual ~Grammar() {}
virtual double Query(const string& context,
const string& word,
bool is_rear) = 0;

inline static double Evaluate(const string& context,
const DictEntry& entry,
bool is_rear,
Grammar* grammar) {
const double kPenalty = -18.420680743952367; // log(1e-8)
return entry.weight +
(grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty);
}
};

} // namespace rime

#endif // RIME_GRAMMAR_H_
27 changes: 21 additions & 6 deletions src/rime/gear/poet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@
//
// 2011-10-06 GONG Chen <chen.sst@gmail.com>
//
#include <rime/common.h>
#include <rime/candidate.h>
#include <rime/config.h>
#include <rime/dict/vocabulary.h>
#include <rime/gear/grammar.h>
#include <rime/gear/poet.h>

namespace rime {

inline static Grammar* create_grammar(Config* config) {
if (auto* grammar = Grammar::Require("grammar")) {
return grammar->Create(config);
}
return nullptr;
}

Poet::Poet(const Language* language, Config* config)
: language_(language),
grammar_(create_grammar(config)) {}

Poet::~Poet() {}

an<Sentence> Poet::MakeSentence(const WordGraph& graph,
size_t total_length) {
const int kMaxHomophonesInMind = 1;
size_t total_length) {
// TODO: save more intermediate sentence candidates
map<int, an<Sentence>> sentences;
sentences[0] = New<Sentence>(language_);
// dynamic programming
Expand All @@ -30,15 +44,16 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
continue; // exclude single words from the result
DLOG(INFO) << "end pos: " << end_pos;
const DictEntryList& entries(x.second);
for (size_t i = 0; i < kMaxHomophonesInMind && i < entries.size(); ++i) {
for (size_t i = 0; i < entries.size(); ++i) {
const auto& entry(entries[i]);
auto new_sentence = New<Sentence>(*sentences[start_pos]);
new_sentence->Extend(*entry, end_pos);
bool is_rear = end_pos == total_length;
new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get());
if (sentences.find(end_pos) == sentences.end() ||
sentences[end_pos]->weight() < new_sentence->weight()) {
DLOG(INFO) << "updated sentences " << end_pos << ") with '"
<< new_sentence->text() << "', " << new_sentence->weight();
sentences[end_pos] = new_sentence;
sentences[end_pos] = std::move(new_sentence);
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/rime/gear/poet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
#ifndef RIME_POET_H_
#define RIME_POET_H_

#include <rime/common.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/translator_commons.h>

namespace rime {

using WordGraph = map<int, UserDictEntryCollector>;

class Grammar;
class Language;

class Poet {
public:
Poet(const Language* language) : language_(language) {}
Poet(const Language* language, Config* config);
~Poet();

an<Sentence> MakeSentence(const WordGraph& graph, size_t total_length);

protected:
const Language* language_;
the<Grammar> grammar_;
};

} // namespace rime
Expand Down
15 changes: 11 additions & 4 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ class ScriptTranslation : public Translation {
public:
ScriptTranslation(ScriptTranslator* translator,
Corrector* corrector,
Poet* poet,
const string& input,
size_t start)
: translator_(translator),
poet_(poet),
start_(start),
syllabifier_(New<ScriptSyllabifier>(
translator, corrector, input, start)),
Expand All @@ -124,6 +126,7 @@ class ScriptTranslation : public Translation {
void PrepareCandidate();

ScriptTranslator* translator_;
Poet* poet_;
size_t start_;
an<ScriptSyllabifier> syllabifier_;

Expand Down Expand Up @@ -156,6 +159,8 @@ ScriptTranslator::ScriptTranslator(const Ticket& ticket)
config->GetBool(name_space_ + "/always_show_comments",
&always_show_comments_);
config->GetBool(name_space_ + "/enable_correction", &enable_correction_);
config->GetInt(name_space_ + "/max_homophones", &max_homophones_);
poet_.reset(new Poet(language(), config));
}
if (enable_correction_) {
if (auto* corrector = Corrector::Require("corrector")) {
Expand All @@ -181,6 +186,7 @@ an<Translation> ScriptTranslator::Query(const string& input,
// the translator should survive translations it creates
auto result = New<ScriptTranslation>(this,
corrector_.get(),
poet_.get(),
input,
segment.start);
if (!result ||
Expand Down Expand Up @@ -523,15 +529,16 @@ an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
// merge lookup results
for (auto& y : *phrase) {
DictEntryList& entries(dest[y.first]);
if (entries.empty()) {
while (entries.size() < translator_->max_homophones() &&
!y.second.exhausted()) {
entries.push_back(y.second.Peek());
if (!y.second.Next())
break;
}
}
}
}
Poet poet(translator_->language());
auto sentence = poet.MakeSentence(graph,
syllable_graph.interpreted_length);
auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length);
if (sentence) {
sentence->Offset(start_);
sentence->set_syllabifier(syllabifier_);
Expand Down
4 changes: 4 additions & 0 deletions src/rime/gear/script_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Corrector;
struct DictEntry;
struct DictEntryCollector;
class Dictionary;
class Poet;
class UserDictionary;
struct SyllableGraph;

Expand All @@ -38,14 +39,17 @@ class ScriptTranslator : public Translator,
string Spell(const Code& code);

// options
int max_homophones() const { return max_homophones_; }
int spelling_hints() const { return spelling_hints_; }
bool always_show_comments() const { return always_show_comments_; }

protected:
int max_homophones_ = 1;
int spelling_hints_ = 0;
bool always_show_comments_ = false;
bool enable_correction_ = false;
the<Corrector> corrector_;
the<Poet> poet_;
};

} // namespace rime
Expand Down
81 changes: 64 additions & 17 deletions src/rime/gear/table_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <rime/dict/dictionary.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/charset_filter.h>
#include <rime/gear/grammar.h>
#include <rime/gear/table_translator.h>
#include <rime/gear/translator_commons.h>
#include <rime/gear/unity_table_encoder.h>
Expand Down Expand Up @@ -216,6 +217,13 @@ TableTranslator::TableTranslator(const Ticket& ticket)
&encode_commit_history_);
config->GetInt(name_space_ + "/max_phrase_length",
&max_phrase_length_);
config->GetInt(name_space_ + "/max_homographs",
&max_homographs_);
if (enable_sentence_ || sentence_over_completion_) {
if (auto* grammar_component = Grammar::Require("grammar")) {
grammar_.reset(grammar_component->Create(config));
}
}
}
if (enable_encoder_ && user_dict_) {
encoder_.reset(new UnityTableEncoder(user_dict_.get()));
Expand All @@ -231,7 +239,7 @@ static bool starts_with_completion(an<Translation> translation) {
}

an<Translation> TableTranslator::Query(const string& input,
const Segment& segment) {
const Segment& segment) {
if (!segment.HasTag(tag_))
return nullptr;
DLOG(INFO) << "input = '" << input
Expand Down Expand Up @@ -519,7 +527,7 @@ bool SentenceTranslation::PreferUserPhrase() const {
return false;
}

static size_t consume_trailing_delimiters(size_t pos,
inline static size_t consume_trailing_delimiters(size_t pos,
const string& input,
const string& delimiters) {
while (pos < input.length() &&
Expand All @@ -529,11 +537,25 @@ static size_t consume_trailing_delimiters(size_t pos,
return pos;
}

template <class Iter>
inline static void collect_entries(DictEntryList& dest,
Iter& iter,
int max_entries) {
if (dest.size() < max_entries && !iter.exhausted()) {
dest.push_back(iter.Peek());
// alters iter if collecting more than 1 entries
while (dest.size() < max_entries && iter.Next()) {
dest.push_back(iter.Peek());
}
}
}

an<Translation>
TableTranslator::MakeSentence(const string& input, size_t start,
bool include_prefix_phrases) {
bool filter_by_charset = enable_charset_filter_ &&
!engine_->context()->get_option("extended_charset");
const int max_entries = max_homographs_;
DictEntryCollector collector;
UserDictEntryCollector user_phrase_collector;
map<int, an<Sentence>> sentences;
Expand All @@ -543,13 +565,14 @@ TableTranslator::MakeSentence(const string& input, size_t start,
continue;
string active_input = input.substr(start_pos);
string active_key = active_input + ' ';
vector<of<DictEntry>> entries(active_input.length() + 1);
UserDictEntryCollector collected_entries;
// lookup dictionaries
if (user_dict_ && user_dict_->loaded()) {
for (size_t len = 1; len <= active_input.length(); ++len) {
size_t consumed_length =
consume_trailing_delimiters(len, active_input, delimiters_);
if (entries[consumed_length])
auto& dest(collected_entries[consumed_length]);
if (dest.size() >= max_entries)
continue;
DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")";
UserDictEntryIterator uter;
Expand All @@ -560,9 +583,15 @@ TableTranslator::MakeSentence(const string& input, size_t start,
uter.AddFilter(CharsetFilter::FilterDictEntry);
}
if (!uter.exhausted()) {
entries[consumed_length] = uter.Peek();
if (start_pos == 0 && max_entries > 1) {
UserDictEntryIterator uter_copy(uter);
collect_entries(dest, uter_copy, max_entries);
} else {
collect_entries(dest, uter, max_entries);
}
if (start_pos == 0) {
// also provide words for manual composition
// uter must not be consumed
uter.Release(&user_phrase_collector[consumed_length]);
DLOG(INFO) << "user phrase[" << consumed_length << "]: "
<< user_phrase_collector[consumed_length].size();
Expand All @@ -578,7 +607,8 @@ TableTranslator::MakeSentence(const string& input, size_t start,
for (size_t len = 1; len <= active_input.length(); ++len) {
size_t consumed_length =
consume_trailing_delimiters(len, active_input, delimiters_);
if (entries[consumed_length])
auto& dest(collected_entries[consumed_length]);
if (!dest.empty())
continue;
DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")";
UserDictEntryIterator uter;
Expand All @@ -589,9 +619,15 @@ TableTranslator::MakeSentence(const string& input, size_t start,
uter.AddFilter(CharsetFilter::FilterDictEntry);
}
if (!uter.exhausted()) {
entries[consumed_length] = uter.Peek();
if (start_pos == 0 && max_entries > 1) {
UserDictEntryIterator uter_copy(uter);
collect_entries(dest, uter_copy, max_entries);
} else {
collect_entries(dest, uter, max_entries);
}
if (start_pos == 0) {
// also provide words for manual composition
// uter must not be consumed
uter.Release(&user_phrase_collector[consumed_length]);
DLOG(INFO) << "unity phrase[" << consumed_length << "]: "
<< user_phrase_collector[consumed_length].size();
Expand All @@ -612,17 +648,24 @@ TableTranslator::MakeSentence(const string& input, size_t start,
continue;
size_t consumed_length =
consume_trailing_delimiters(m.length, active_input, delimiters_);
if (entries[consumed_length])
auto& dest(collected_entries[consumed_length]);
if (dest.size() >= max_entries)
continue;
DictEntryIterator iter;
dict_->LookupWords(&iter, active_input.substr(0, m.length), false);
if (filter_by_charset) {
iter.AddFilter(CharsetFilter::FilterDictEntry);
}
if (!iter.exhausted()) {
entries[consumed_length] = iter.Peek();
if (start_pos == 0 && max_entries - dest.size() > 1) {
DictEntryIterator iter_copy = iter;
collect_entries(dest, iter_copy, max_entries);
} else {
collect_entries(dest, iter, max_entries);
}
if (start_pos == 0) {
// also provide words for manual composition
// iter must not be consumed
collector[consumed_length] = std::move(iter);
DLOG(INFO) << "table[" << consumed_length << "]: "
<< collector[consumed_length].entry_count();
Expand All @@ -631,16 +674,20 @@ TableTranslator::MakeSentence(const string& input, size_t start,
}
}
for (size_t len = 1; len <= active_input.length(); ++len) {
if (!entries[len])
const auto& entries(collected_entries[len]);
if (entries.empty())
continue;
size_t end_pos = start_pos + len;
// create a new sentence
auto new_sentence = New<Sentence>(*sentences[start_pos]);
new_sentence->Extend(*entries[len], end_pos);
// compare and update sentences
if (sentences.find(end_pos) == sentences.end() ||
sentences[end_pos]->weight() <= new_sentence->weight()) {
sentences[end_pos] = std::move(new_sentence);
bool is_rear = end_pos == input.length();
for (const auto& entry : entries) {
// create a new sentence
auto new_sentence = New<Sentence>(*sentences[start_pos]);
new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get());
// compare and update sentences
if (sentences.find(end_pos) == sentences.end() ||
sentences[end_pos]->weight() <= new_sentence->weight()) {
sentences[end_pos] = std::move(new_sentence);
}
}
}
}
Expand Down
Loading

0 comments on commit 9248a6b

Please sign in to comment.