Skip to content

Commit

Permalink
Use id in Model
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Apr 12, 2015
1 parent 67ad312 commit 600c06a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 25 deletions.
47 changes: 26 additions & 21 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,14 @@ void make_left_candidates(
}
}

void make_sentence(
const std::vector<Word>& left,
const std::vector<Word>& right,
Sentence& sentence) {
std::vector<Word> ws(left.rbegin(), left.rend());
if (!right.empty()) {
ws.insert(ws.end(), right.begin() + 1, right.end());
}
ws.swap(sentence.words);
}

} // namespace

bool Model::sample_next(
bool go_right,
SearchState& state,
bool ignore_eos) const {
const vector<Word>& h = go_right ? state.right : state.left;
const string& history = h[h.size() - 1].str;
id_t history_id = dictionary.id_of_string(history);
// LOG() << "kaibun::sample_next: history: " << history << " read: " << read;
const vector<WordId>& h = go_right ? state.right : state.left;
id_t history_id = h[h.size() - 1].str;

const unsigned len = state.read.size();
const Ngram& ngram = go_right ? forward : backward;
Expand All @@ -105,12 +92,11 @@ bool Model::sample_next(
state.read.erase(state.read.begin(), state.read.begin() + res_len);
}

std::string rword = dictionary.string_of_id(r);
if (go_right) {
state.right.push_back(Word(rword, r_read));
state.right.push_back(WordId(r, r_read));
} else {
read_t rev(r_read.rbegin(), r_read.rend());
state.left.push_back(Word(rword, rev));
state.left.push_back(WordId(r, rev));
}

return true;
Expand All @@ -135,12 +121,11 @@ bool Model::sample_center(SearchState& state) const {
return false;
}

std::string center = dictionary.string_of_id(center_id);
unsigned id = random_int(cands.size());
state.state = cands[id].first;
state.read = cands[id].second;
state.right.push_back(Word(center, center_read));
state.left.push_back(Word(center, center_read));
state.right.push_back(WordId(center_id, center_read));
state.left.push_back(WordId(center_id, center_read));

return true;
}
Expand Down Expand Up @@ -201,6 +186,26 @@ bool Model::try_make(Sentence& ret) const {
}
}

Word Model::make_word(const WordId& w) const {
return Word(dictionary.string_of_id(w.str), w.read);
}

void Model::make_sentence(
const std::vector<WordId>& left,
const std::vector<WordId>& right,
Sentence& sentence) const {
std::vector<Word> ws;
for (int i = left.size() - 1; i >= 0; --i) {
ws.push_back(make_word(left[i]));
}
if (!right.empty()) {
for (size_t i = 1; i < right.size(); ++i) {
ws.push_back(make_word(right[i]));
}
}
ws.swap(sentence.words);
}

void Model::swap_dictionary(Dictionary& dictionary) {
Model::dictionary = dictionary;
}
Expand Down
14 changes: 10 additions & 4 deletions src/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
#include "dictionary.hpp"
#include "fwd.hpp"
#include "ngram.hpp"
#include "sentence.hpp"
#include "unigram.hpp"
#include "util.hpp"

namespace ppg {

class Ngram;
struct Word;
struct Sentence;

enum State {
BALANCE, RIGHT, LEFT
Expand All @@ -36,8 +35,8 @@ class Model {
private:
struct SearchState {
read_t read;
std::vector<Word> right;
std::vector<Word> left;
std::vector<WordId> right;
std::vector<WordId> left;
State state;
};

Expand All @@ -59,6 +58,13 @@ class Model {

bool sample_center(SearchState& state) const;

Word make_word(const WordId& w) const;

void make_sentence(
const std::vector<WordId>& left,
const std::vector<WordId>& right,
Sentence& sentence) const;

Dictionary dictionary;
Ngram forward;
Ngram backward;
Expand Down
8 changes: 8 additions & 0 deletions src/sentence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ struct Word {
: str(s), read(r) {}
};

struct WordId {
id_t str;
read_t read;

WordId(id_t s, const read_t& r)
: str(s), read(r) {}
};

struct Sentence {
std::vector<Word> words;
};
Expand Down

0 comments on commit 600c06a

Please sign in to comment.