Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a normalization term to ctc_beam_search_decoder #21187

Merged
merged 4 commits into from
Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 15 additions & 3 deletions tensorflow/core/util/ctc/ctc_beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}

// Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
float logsumexp = 0.0;
for (int j = 0; j < raw_input.size(); ++j) {
logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
}
logsumexp = Eigen::numext::log(logsumexp);
// Final normalization offset to get correct log probabilities.
float norm_offset = max_coeff + logsumexp;

const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
Expand Down Expand Up @@ -290,10 +300,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
b->newp.label += raw_input(b->label) - max_coeff;
b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);

Expand Down Expand Up @@ -328,6 +338,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
// We may compare logits instead of log probabilities,
// since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
Expand All @@ -341,7 +353,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
c.newp.label = logit - max_coeff +
c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/keras/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,8 +1384,8 @@ def test_ctc_decode(self):
np.array([seq_len_0], dtype=np.int32))
# batch_size length vector of negative log probabilities
log_prob_truth = np.array([
0.584855, # output beam 0
0.389139 # output beam 1
-3.5821197, # output beam 0
-3.777835 # output beam 1
], np.float32)[np.newaxis, :]

decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,23 +188,23 @@ def testCTCDecoderBeamSearch(self):
],
dtype=np.float32)
# Add arbitrary offset - this is fine
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0
input_prob_matrix_0 = input_prob_matrix_0 + 2.0

# len max_time_steps array of batch_size x depth matrices
inputs = ([
input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
] # Pad to max_time_steps = 8
+ 2 * [np.zeros(
(1, depth), dtype=np.float32)])

# batch_size length vector of sequence_lengths
seq_lens = np.array([seq_len_0], dtype=np.int32)

# batch_size length vector of negative log probabilities
# batch_size length vector of log probabilities
log_prob_truth = np.array(
[
0.584855, # output beam 0
0.389139 # output beam 1
-5.811451, # output beam 0
-6.63339 # output beam 1
],
np.float32)[np.newaxis, :]

Expand All @@ -215,11 +215,11 @@ def testCTCDecoderBeamSearch(self):
[[0, 0], [0, 1]], dtype=np.int64), np.array(
[1, 0], dtype=np.int64), np.array(
[1, 2], dtype=np.int64)),
# beam 1, batch 0, three outputs decoded
# beam 1, batch 0, one output decoded
(np.array(
[[0, 0], [0, 1], [0, 2]], dtype=np.int64), np.array(
[0, 1, 0], dtype=np.int64), np.array(
[1, 3], dtype=np.int64)),
[[0, 0]], dtype=np.int64), np.array(
[1], dtype=np.int64), np.array(
[1, 1], dtype=np.int64)),
]

# Test correct decoding.
Expand Down