Skip to content

Commit

Permalink
Merge pull request #40 from ryanleary/11-oov-weight
Browse files Browse the repository at this point in the history
Make OOV weight configurable. Closes #11
  • Loading branch information
ryanleary committed Nov 7, 2017
2 parents a776b66 + 322e442 commit 1cd9178
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 8 deletions.
8 changes: 8 additions & 0 deletions ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def __init__(self, labels, trie_path, blank_index=0, space_index=28):
self._scorer_type = 1
self._scorer = ctc_decode.get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode())

def set_min_unigram_weight(self, weight):
if weight is not None:
ctc_decode.set_dict_min_unigram_weight(self._scorer, weight)


class KenLMScorer(BaseScorer):
def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28):
Expand All @@ -103,6 +107,10 @@ def set_word_weight(self, weight):
if weight is not None:
ctc_decode.set_kenlm_scorer_wc_weight(self._scorer, weight)

def set_min_unigram_weight(self, weight):
if weight is not None:
ctc_decode.set_kenlm_min_unigram_weight(self._scorer, weight)


class CTCBeamDecoder(BaseCTCBeamDecoder):
def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, space_index=28):
Expand Down
12 changes: 12 additions & 0 deletions ctcdecode/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ namespace pytorch {
#endif
}

void set_kenlm_min_unigram_weight(void *scorer, float weight) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer *beam_scorer = static_cast<ctc::KenLMBeamScorer *>(scorer);
beam_scorer->SetMinimumUnigramProbability(weight);
#endif
}

void set_dict_min_unigram_weight(void *scorer, float weight) {
ctc::DictBeamScorer *beam_scorer = static_cast<ctc::DictBeamScorer *>(scorer);
beam_scorer->SetMinimumUnigramProbability(weight);
}

void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin) {
ctc::CTCBeamSearchDecoder<> *beam_decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(decoder);
beam_decoder->SetLabelSelectionParameters(label_selection_size, label_selection_margin);
Expand Down
3 changes: 3 additions & 0 deletions ctcdecode/src/cpu_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ void free_kenlm_scorer(void* kenlm_scorer);

void set_kenlm_scorer_lm_weight(void *scorer, float weight);
void set_kenlm_scorer_wc_weight(void *scorer, float weight);
void set_kenlm_min_unigram_weight(void *scorer, float weight);
void set_dict_min_unigram_weight(void *scorer, float weight);

// void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin);
void* get_base_scorer();
Expand Down
14 changes: 10 additions & 4 deletions ctcdecode/src/ctc_beam_scorer_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace ctc_beam_search {
}

DictBeamScorer(Labels *labels, const char *trie_path) :
labels_(labels)
labels_(labels),
default_min_unigram_(kLogZero)
{
std::ifstream in;
in.open(trie_path, std::ios::in);
Expand All @@ -50,11 +51,11 @@ namespace ctc_beam_search {
// check to see if we're on a word boundary
if (labels_->IsSpace(to_label) && from_state.node != trie_root_) {
// check if from_state is valid
to_state->score = StateIsCandidate(from_state, true) ? 0.0 : kLogZero;
to_state->score = StateIsCandidate(from_state, true) ? 0.0 : default_min_unigram_;
to_state->node = trie_root_;
} else {
to_state->node = (from_state.node == nullptr) ? nullptr : from_state.node->GetChildAt(to_label);
to_state->score = StateIsCandidate(*to_state, false) ? 0.0 : kLogZero;
to_state->score = StateIsCandidate(*to_state, false) ? 0.0 : default_min_unigram_;
}
}

Expand All @@ -63,7 +64,7 @@ namespace ctc_beam_search {
// allow a final scoring of the beam in its current state, before resorting
// and retrieving the TopN requested candidates. Called at most once per beam.
void ExpandStateEnd(DictBeamState* state) const {
state->score = StateIsCandidate(*state, true) ? 0.0 : kLogZero;
state->score = StateIsCandidate(*state, true) ? 0.0 : default_min_unigram_;
state->node = trie_root_;
}

Expand All @@ -88,9 +89,14 @@ namespace ctc_beam_search {
return state.score;
}

void SetMinimumUnigramProbability(float min_unigram) {
this->default_min_unigram_ = min_unigram;
}

private:
Labels *labels_;
TrieNode *trie_root_;
float default_min_unigram_;

bool StateIsCandidate(const DictBeamState& state, bool word) const;
};
Expand Down
10 changes: 7 additions & 3 deletions ctcdecode/src/ctc_beam_scorer_klm.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ namespace pytorch {
if (labels_->IsSpace(to_label) && from_state.node != trie_root_) {
// check if from_state is valid
to_state->score = StateIsCandidate(from_state, true) ?
ScoreNewWord(from_state.ngram_state, from_state.word_prefix, &to_state->ngram_state)/kLogE : kLogZero;
ScoreNewWord(from_state.ngram_state, from_state.word_prefix, &to_state->ngram_state)/kLogE : default_min_unigram_;
to_state->num_words = from_state.num_words + 1;
to_state->node = trie_root_;
to_state->word_prefix = L"";
} else {
to_state->node = (from_state.node == nullptr) ? nullptr : from_state.node->GetChildAt(to_label);
to_state->score = StateIsCandidate(*to_state, false) ? 0.0 : kLogZero;
to_state->score = StateIsCandidate(*to_state, false) ? 0.0 : default_min_unigram_;
to_state->word_prefix = from_state.word_prefix + to_char;
to_state->ngram_state = from_state.ngram_state;
to_state->num_words = from_state.num_words;
Expand All @@ -98,7 +98,7 @@ namespace pytorch {
model_->NullContextWrite(&a);
while (token != NULL) {
token = strtok(NULL, " ");
state->score = StateIsCandidate(*state, true) ? model_->BaseScore(&a, model_->BaseVocabulary().Index(token), &b)/kLogE : kLogZero;
state->score = StateIsCandidate(*state, true) ? model_->BaseScore(&a, model_->BaseVocabulary().Index(token), &b)/kLogE : default_min_unigram_;
a = b;
}
free(dup);
Expand Down Expand Up @@ -135,6 +135,10 @@ namespace pytorch {
this->word_insertion_weight_ = word_count_weight;
}

void SetMinimumUnigramProbability(float min_unigram) {
this->default_min_unigram_ = min_unigram;
}

private:
Labels *labels_;
TrieNode *trie_root_;
Expand Down
12 changes: 11 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,17 @@ def test_real_ctc_decode(self):
txt_result = ''.join([labels[x] for x in decode_result[0][0][0:decode_len[0][0]]])
self.assertEqual("the fak friend of the fomcly hae tC", txt_result)

# dictionary-based decoding
# dictionary-based decoding where non-words and words are equiprobable. Equivalent to standard beam decoding
scorer = ctcdecode.DictScorer(labels, "data/ocr.trie", blank_index=labels.index('_'),
space_index=labels.index(' '))
scorer.set_min_unigram_weight(0.0)
decoder = ctcdecode.CTCBeamDecoder(scorer, labels, blank_index=labels.index('_'),
space_index=labels.index(' '), top_paths=1, beam_width=25)
decode_result, scores, decode_len, alignments, char_probs = decoder.decode(th_input, th_seq_len)
txt_result = ''.join([labels[x] for x in decode_result[0][0][0:decode_len[0][0]]])
self.assertEqual("the fak friend of the fomcly hae tC", txt_result)

# dictionary-based decoding - only dictionary words can be emitted
scorer = ctcdecode.DictScorer(labels, "data/ocr.trie", blank_index=labels.index('_'),
space_index=labels.index(' '))
decoder = ctcdecode.CTCBeamDecoder(scorer, labels, blank_index=labels.index('_'),
Expand Down

0 comments on commit 1cd9178

Please sign in to comment.