From da6088a94a70d6c89a42798f4b3f98362025029e Mon Sep 17 00:00:00 2001 From: JIJIN CHEN <380717149@qq.com> Date: Wed, 8 Jun 2022 19:56:12 +0800 Subject: [PATCH] Update decoder_utils.cpp Modified the calculation method of cum_prob. --- ctcdecode/src/decoder_utils.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ctcdecode/src/decoder_utils.cpp b/ctcdecode/src/decoder_utils.cpp index 2a35fe48..d37493ed 100644 --- a/ctcdecode/src/decoder_utils.cpp +++ b/ctcdecode/src/decoder_utils.cpp @@ -23,12 +23,12 @@ std::vector> get_pruned_log_probs( std::sort( prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); if (log_cutoff_prob < 0.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { + double cum_prob = log_input ? prob_idx[0].second : log(prob_idx[0].second); + cutoff_len = 1; + for (size_t i = 1; i < prob_idx.size(); ++i) { + if (cum_prob >= log_cutoff_prob || cutoff_len >= cutoff_top_n) break; cum_prob = log_sum_exp(cum_prob, log_input ? prob_idx[i].second : log(prob_idx[i].second) ); cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; } }else{ cutoff_len = cutoff_top_n;