Skip to content

What arguments should I be using for tf.nn.ctc_loss? #53105

@thetruejacob

Description

@thetruejacob

I'm having a lot of trouble converting tf.compat.v1.nn.ctc_loss to tf.nn.ctc_loss (or even tf.v1.ctc_loss_v2 for that matter, which seems to be based on the same thing.) The lack of example usages on https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss makes this rather unclear. Is there any end-to-end example that I can use? Specifically I'm trying to convert

input_tensor=tf.compat.v1.nn.ctc_loss(
labels=self.gt_texts,
inputs=self.ctc_in_3d_tbc,
sequence_length=self.seq_len,
ctc_merge_repeated=True,
)

The above works fine. My understanding is this is equivalent to

input_tensor=tf.nn.ctc_loss(
labels=self.gt_texts,
logits=self.ctc_in_3d_tbc,
label_length=None,
logit_length=self.seq_len,
blank_index=-1,
)

but it goesn't seem to be working as well. Can someone confirm that I'm using my arguments properly (especially logit_length - is it supposed to be the same as the previously specified sequence length argument)?

Metadata

Metadata

Assignees

Labels

comp:apisHighlevel API related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:othersissues not falling in bug, perfromance, support, build and install or feature

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions