-
Notifications
You must be signed in to change notification settings - Fork 75.3k
What arguments should I be using for tf.nn.ctc_loss? #53105
Description
I'm having a lot of trouble converting tf.compat.v1.nn.ctc_loss to tf.nn.ctc_loss (or even tf.v1.ctc_loss_v2 for that matter, which seems to be based on the same thing.) The lack of example usages on https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss makes this rather unclear. Is there any end-to-end example that I can use? Specifically I'm trying to convert
input_tensor=tf.compat.v1.nn.ctc_loss(
labels=self.gt_texts,
inputs=self.ctc_in_3d_tbc,
sequence_length=self.seq_len,
ctc_merge_repeated=True,
)
The above works fine. My understanding is this is equivalent to
input_tensor=tf.nn.ctc_loss(
labels=self.gt_texts,
logits=self.ctc_in_3d_tbc,
label_length=None,
logit_length=self.seq_len,
blank_index=-1,
)
but it goesn't seem to be working as well. Can someone confirm that I'm using my arguments properly (especially logit_length - is it supposed to be the same as the previously specified sequence length argument)?