Skip to content

Commit

Permalink
Merge pull request #39324 from pvarouktsis:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 373627910
Change-Id: Ic7defb871fec22ff189b2e0b6c39258efd349184
  • Loading branch information
tensorflower-gardener committed May 13, 2021
2 parents 4fe82d0 + 5f30c1d commit 053a191
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 30 deletions.
10 changes: 8 additions & 2 deletions tensorflow/core/kernels/ctc_decoder_ops.cc
Expand Up @@ -187,6 +187,7 @@ class CTCGreedyDecoderOp : public OpKernel {
public:
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index", &blank_index_));
}

void Compute(OpKernelContext* ctx) override {
Expand Down Expand Up @@ -223,8 +224,12 @@ class CTCGreedyDecoderOp : public OpKernel {

log_prob_t.setZero();

// Assumption: the blank index is num_classes - 1
int blank_index = num_classes - 1;
int blank_index =
(blank_index_ < 0) ? num_classes + blank_index_ : blank_index_;
OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes),
errors::InvalidArgument("blank_index expected to be between ",
-num_classes, " and ", num_classes - 1,
" but was ", blank_index_));

// Perform best path decoding
std::vector<std::vector<std::vector<int> > > sequences(batch_size);
Expand Down Expand Up @@ -263,6 +268,7 @@ class CTCGreedyDecoderOp : public OpKernel {
private:
CTCDecodeHelper decode_helper_;
bool merge_repeated_;
int blank_index_;

TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/ops/ctc_ops.cc
Expand Up @@ -103,6 +103,7 @@ REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("merge_repeated: bool = false")
.Attr("blank_index: int = -1")
.Output("decoded_indices: int64")
.Output("decoded_values: int64")
.Output("decoded_shape: int64")
Expand Down
61 changes: 43 additions & 18 deletions tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ctc_ops.ctc_loss_op."""
"""Tests for tensorflow.ctc_ops.ctc_decoder_ops."""

from __future__ import absolute_import
from __future__ import division
Expand Down Expand Up @@ -100,39 +100,39 @@ def testCTCGreedyDecoder(self):
"""Test two batch entries - best path decoder."""
max_time_steps = 6
# depth == 4

seq_len_0 = 4
input_prob_matrix_0 = np.asarray(
[[1.0, 0.0, 0.0, 0.0], # t=0
[0.0, 0.0, 0.4, 0.6], # t=1
[0.0, 0.0, 0.4, 0.6], # t=2
[0.0, 0.9, 0.1, 0.0], # t=3
[0.0, 0.0, 0.0, 0.0], # t=4 (ignored)
[0.0, 0.0, 0.0, 0.0]], # t=5 (ignored)
[
[1.0, 0.0, 0.0, 0.0], # t=0
[0.0, 0.0, 0.4, 0.6], # t=1
[0.0, 0.0, 0.4, 0.6], # t=2
[0.0, 0.9, 0.1, 0.0], # t=3
[0.0, 0.0, 0.0, 0.0], # t=4 (ignored)
[0.0, 0.0, 0.0, 0.0]
], # t=5 (ignored)
dtype=np.float32)
input_log_prob_matrix_0 = np.log(input_prob_matrix_0)

seq_len_1 = 5
# dimensions are time x depth

input_prob_matrix_1 = np.asarray(
[
[0.1, 0.9, 0.0, 0.0], # t=0
[0.0, 0.9, 0.1, 0.0], # t=1
[0.0, 0.0, 0.1, 0.9], # t=2
[0.0, 0.9, 0.1, 0.1], # t=3
[0.9, 0.1, 0.0, 0.0], # t=4
[0.0, 0.0, 0.0, 0.0]
], # t=5 (ignored)
[0.0, 0.0, 0.0, 0.0] # t=5 (ignored)
],
dtype=np.float32)
input_log_prob_matrix_1 = np.log(input_prob_matrix_1)

# len max_time_steps array of batch_size x depth matrices
inputs = [
inputs = np.array([
np.vstack(
[input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]])
for t in range(max_time_steps)
]
])

# batch_size length vector of sequence_lengths
seq_lens = np.array([seq_len_0, seq_len_1], dtype=np.int32)
Expand All @@ -157,21 +157,46 @@ def testCTCGreedyDecoder(self):
dtype=np.int64),
np.array(
[
0,
1, # batch 0
0, # batch 0, 2 values
1,
1, # batch 1, 3 values
1,
0
], # batch 1
],
dtype=np.int64),
# shape is batch x max_decoded_length
np.array(
[2, 3], dtype=np.int64)),
np.array([2, 3], dtype=np.int64)),
]

# Test without defining blank_index
self._testCTCDecoder(ctc_ops.ctc_greedy_decoder, inputs, seq_lens,
log_prob_truth, decode_truth)

# Shift blank_index to be somewhere in the middle of inputs
blank_index = 2
inputs = np.concatenate(
(inputs[:, :, :blank_index], inputs[:, :, -1:], inputs[:, :,
blank_index:-1]),
axis=2)

# Test positive value in blank_index
self._testCTCDecoder(
ctc_ops.ctc_greedy_decoder,
inputs,
seq_lens,
log_prob_truth,
decode_truth,
blank_index=2)

# Test negative value in blank_index
self._testCTCDecoder(
ctc_ops.ctc_greedy_decoder,
inputs,
seq_lens,
log_prob_truth,
decode_truth,
blank_index=-2)

@test_util.run_deprecated_v1
def testCTCDecoderBeamSearch(self):
"""Test one batch, two beams - hibernating beam search."""
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/ctc_loss_op_test.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ctc_ops.ctc_decoder_ops."""
"""Tests for tensorflow.ctc_ops.ctc_loss_op."""

from __future__ import absolute_import
from __future__ import division
Expand Down
44 changes: 39 additions & 5 deletions tensorflow/python/ops/ctc_ops.py
Expand Up @@ -287,12 +287,38 @@ def _CTCLossV2Grad(op, grad_loss, _):

@tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
def ctc_greedy_decoder(inputs,
sequence_length,
merge_repeated=True,
blank_index=None):
"""Performs greedy decoding on the logits given in input (best path).
Note: Regardless of the value of merge_repeated, if the maximum index of a
given time and batch corresponds to the blank index `(num_classes - 1)`, no
new element is emitted.
Given a tensor as `inputs`, the `blank_index` parameter defines the class
index of the blank symbol.
For example:
If `blank_index` is equal to 1:
>>> inf = float("inf")
>>> logits = tf.constant([[[ 0., -inf, -inf],
... [ -2.3, -inf, -0.1]],
... [[ -inf, -0.5, -inf],
... [ -inf, -inf, -0.1]],
... [[ -inf, -inf, -inf],
... [ -0.1, -inf, -2.3]]])
>>> seq_lens = tf.constant([2, 3])
>>> outputs = tf.nn.ctc_greedy_decoder(
... logits,
... seq_lens,
... blank_index=1)
Notes:
- Regardless of the value of `merge_repeated`, if an index of a
given time and batch corresponds to the `blank_index`, no new
element is emitted.
- Default `blank_index` is `(num_classes - 1)`, unless overriden.
If `merge_repeated` is `True`, merge repeated classes in output.
This means that if consecutive logits' maximum indices are the same,
Expand All @@ -308,6 +334,10 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
sequence_length: 1-D `int32` vector containing sequence lengths, having size
`[batch_size]`.
merge_repeated: Boolean. Default: True.
blank_index: (Optional). Default: `num_classes - 1`. Define the class index
to use for the blank label. Negative values will start from num_classes,
ie, -1 will reproduce the ctc_greedy_decoder behavior of using
num_classes - 1 for the blank symbol, which corresponds to the default.
Returns:
A tuple `(decoded, neg_sum_logits)` where
Expand All @@ -328,8 +358,12 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
sequence found, the negative of the sum of the greatest logit at each
timeframe.
"""

outputs = gen_ctc_ops.ctc_greedy_decoder(
inputs, sequence_length, merge_repeated=merge_repeated)
inputs,
sequence_length,
merge_repeated=merge_repeated,
blank_index=blank_index)
(decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
decoded_shape)], log_probabilities)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
Expand Up @@ -126,7 +126,7 @@ tf_module {
}
member_method {
name: "ctc_greedy_decoder"
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'True\'], "
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'blank_index\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ctc_loss"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
Expand Up @@ -690,7 +690,7 @@ tf_module {
}
member_method {
name: "CTCGreedyDecoder"
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'blank_index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], "
}
member_method {
name: "CTCLoss"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
Expand Up @@ -106,7 +106,7 @@ tf_module {
}
member_method {
name: "ctc_greedy_decoder"
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'True\'], "
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'blank_index\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ctc_loss"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
Expand Up @@ -690,7 +690,7 @@ tf_module {
}
member_method {
name: "CTCGreedyDecoder"
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\', \'blank_index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], "
}
member_method {
name: "CTCLoss"
Expand Down

0 comments on commit 053a191

Please sign in to comment.