diff --git a/ctcdecode/src/decoder_utils.cpp b/ctcdecode/src/decoder_utils.cpp index 2a35fe48..d37493ed 100644 --- a/ctcdecode/src/decoder_utils.cpp +++ b/ctcdecode/src/decoder_utils.cpp @@ -23,12 +23,12 @@ std::vector> get_pruned_log_probs( std::sort( prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); if (log_cutoff_prob < 0.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { + double cum_prob = log_input ? prob_idx[0].second : log(prob_idx[0].second); + cutoff_len = 1; + for (size_t i = 1; i < prob_idx.size(); ++i) { + if (cum_prob >= log_cutoff_prob || cutoff_len >= cutoff_top_n) break; cum_prob = log_sum_exp(cum_prob, log_input ? prob_idx[i].second : log(prob_idx[i].second) ); cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; } }else{ cutoff_len = cutoff_top_n;