/
ctc_ops.py
1467 lines (1204 loc) · 56 KB
/
ctc_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CTC (Connectionist Temporal Classification) Operations."""
import uuid
from tensorflow.python.eager import context
from tensorflow.python.eager import function as function_eager
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_ctc_ops
from tensorflow.python.ops import inplace_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
_CPU_DEVICE_NAME = "CPU"
_GPU_DEVICE_NAME = "GPU"
def _get_context_device_type():
"""Parse the current context and return the device type, eg CPU/GPU."""
current_device = context.context().device_name
if current_device is None:
return None
return device.DeviceSpec.from_string(current_device).device_type
def _generate_defun_backend(unique_api_name, preferred_device, func):
function_attributes = {
_DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
}
return function_eager.defun_with_attributes(
func=func, attributes=function_attributes, autograph=False)
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss"])
@dispatch.add_dispatch_support
def ctc_loss(labels,
inputs=None,
sequence_length=None,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True,
logits=None):
"""Computes the CTC (Connectionist Temporal Classification) Loss.
This op implements the CTC loss as presented in (Graves et al., 2006).
Input requirements:
```
sequence_length(b) <= time for all b
max(labels.indices(labels.indices[:, 1] == b, 2))
<= sequence_length(b) for all b.
```
Notes:
This class performs the softmax operation for you, so inputs should
be e.g. linear projections of outputs by an LSTM.
The `inputs` Tensor's innermost dimension size, `num_classes`, represents
`num_labels + 1` classes, where num_labels is the number of true labels, and
the largest value `(num_classes - 1)` is reserved for the blank label.
For example, for a vocabulary containing 3 labels `[a, b, c]`,
`num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
Regarding the arguments `preprocess_collapse_repeated` and
`ctc_merge_repeated`:
If `preprocess_collapse_repeated` is True, then a preprocessing step runs
before loss calculation, wherein repeated labels passed to the loss
are merged into single labels. This is useful if the training labels come
from, e.g., forced alignments and therefore have unnecessary repetitions.
If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
repeated non-blank labels will not be merged and are interpreted
as individual labels. This is a simplified (non-standard) version of CTC.
Here is a table of the (roughly) expected first order behavior:
* `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
Classical CTC behavior: Outputs true repeated classes with blanks in
between, and can also output repeated classes with no blanks in
between that need to be collapsed by the decoder.
* `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
Never learns to output repeated classes, as they are collapsed
in the input labels before training.
* `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
Outputs repeated classes with blanks in between, but generally does not
require the decoder to collapse/merge repeated classes.
* `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
Untested. Very likely will not learn to output repeated classes.
The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
of the CTCLoss when dealing with sequences that have longer outputs than
inputs. If true, the CTCLoss will simply return zero gradient for those
items, otherwise an InvalidArgument error is returned, stopping training.
Args:
labels: An `int32` `SparseTensor`.
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id
for (batch b, time t). `labels.values[i]` must take on values in `[0,
num_labels)`. See `core/ops/ctc_ops.cc` for more details.
inputs: 3-D `float` `Tensor`.
If time_major == False, this will be a `Tensor` shaped: `[batch_size,
max_time, num_classes]`.
If time_major == True (default), this will be a `Tensor` shaped:
`[max_time, batch_size, num_classes]`. The logits.
sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence
lengths.
preprocess_collapse_repeated: Boolean. Default: False. If True, repeated
labels are collapsed prior to the CTC calculation.
ctc_merge_repeated: Boolean. Default: True.
ignore_longer_outputs_than_inputs: Boolean. Default: False. If True,
sequences with longer outputs than inputs will be ignored.
time_major: The shape format of the `inputs` Tensors. If True, these
`Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False,
these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
Using `time_major = True` (default) is a bit more efficient because it
avoids transposes at the beginning of the ctc_loss calculation. However,
most TensorFlow data is batch-major, so by this function also accepts
inputs in batch-major form.
logits: Alias for inputs.
Returns:
A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
probabilities.
Raises:
TypeError: if labels is not a `SparseTensor`.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
return _ctc_loss_impl(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated,
ctc_merge_repeated,
ignore_longer_outputs_than_inputs,
time_major,
logits,
use_cudnn=False)
def _ctc_loss_impl(labels,
inputs=None,
sequence_length=None,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True,
logits=None,
use_cudnn=False):
# Helper function of ctc_loss with one additional param:
# use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
# index has to be 0.
# The second, third, etc output tensors contain the gradients. We use it in
# _CTCLossGrad() below.
if not isinstance(labels, sparse_tensor.SparseTensor):
raise TypeError("Expected argument `labels` to be a SparseTensor. "
f"Received labels={labels} of type: "
f"{type(labels).__name__}")
# For internal calculations, we transpose to [time, batch, num_classes]
inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
inputs)
if not time_major:
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
# blank index to be 0, but v1 views it as the last index.
if use_cudnn:
ctc_loss_func = gen_ctc_ops.ctc_loss_v2
else:
ctc_loss_func = gen_ctc_ops.ctc_loss
loss, _ = ctc_loss_func(
inputs,
labels.indices,
labels.values,
sequence_length,
preprocess_collapse_repeated=preprocess_collapse_repeated,
ctc_merge_repeated=ctc_merge_repeated,
ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
return loss
# pylint: disable=unused-argument
def _CTCLossGradImpl(op, grad_loss, _):
# Outputs are: loss, grad
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
grad_without_gradient = array_ops.prevent_gradient(
op.outputs[1],
message="Currently there is no way to take the second "
" derivative of ctc_loss due to the fused implementation's interaction "
" with tf.gradients()")
# Return gradient for inputs and None for
# labels_indices, labels_values and sequence_length
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLoss")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss.
Args:
op: the CTCLoss op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLossV2")
def _CTCLossV2Grad(op, grad_loss, _):
"""The derivative provided by CTC Loss V2.
Args:
op: the CTCLossV2 op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss V2 gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
@tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs,
sequence_length,
merge_repeated=True,
blank_index=None):
"""Performs greedy decoding on the logits given in input (best path).
Given a tensor as `inputs`, the `blank_index` parameter defines the class
index of the blank symbol.
For example:
If `blank_index` is equal to 1:
>>> inf = float("inf")
>>> logits = tf.constant([[[ 0., -inf, -inf],
... [ -2.3, -inf, -0.1]],
... [[ -inf, -0.5, -inf],
... [ -inf, -inf, -0.1]],
... [[ -inf, -inf, -inf],
... [ -0.1, -inf, -2.3]]])
>>> seq_lens = tf.constant([2, 3])
>>> outputs = tf.nn.ctc_greedy_decoder(
... logits,
... seq_lens,
... blank_index=1)
Notes:
- Regardless of the value of `merge_repeated`, if an index of a
given time and batch corresponds to the `blank_index`, no new
element is emitted.
- Default `blank_index` is `(num_classes - 1)`, unless overriden.
If `merge_repeated` is `True`, merge repeated classes in output.
This means that if consecutive logits' maximum indices are the same,
only the first of these is emitted. The sequence `A B B * B * B` (where '*'
is the blank label) becomes
* `A B B B` if `merge_repeated=True`.
* `A B B B B` if `merge_repeated=False`.
Args:
inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
The logits.
sequence_length: 1-D `int32` vector containing sequence lengths, having size
`[batch_size]`.
merge_repeated: Boolean. Default: True.
blank_index: (Optional). Default: `num_classes - 1`. Define the class index
to use for the blank label. Negative values will start from num_classes,
ie, -1 will reproduce the ctc_greedy_decoder behavior of using
num_classes - 1 for the blank symbol, which corresponds to the default.
Returns:
A tuple `(decoded, neg_sum_logits)` where
decoded: A single-element list. `decoded[0]`
is an `SparseTensor` containing the decoded outputs s.t.:
`decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
The rows store: `[batch, time]`.
`decoded.values`: Values vector, size `(total_decoded_outputs)`.
The vector stores the decoded classes.
`decoded.dense_shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length]`
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
sequence found, the negative of the sum of the greatest logit at each
timeframe.
"""
outputs = gen_ctc_ops.ctc_greedy_decoder(
inputs,
sequence_length,
merge_repeated=merge_repeated,
blank_index=blank_index)
(decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
decoded_shape)], log_probabilities)
@tf_export(v1=["nn.ctc_beam_search_decoder"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder(inputs,
sequence_length,
beam_width=100,
top_paths=1,
merge_repeated=True):
"""Performs beam search decoding on the logits given in input.
**Note** The `ctc_greedy_decoder` is a special case of the
`ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
that decoder is faster for this special case).
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
only the first of these is emitted. That is, when the sequence is
`A B B * B * B` (where '*' is the blank label), the return value is:
* `A B` if `merge_repeated = True`.
* `A B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`.
The logits.
sequence_length: 1-D `int32` vector containing sequence lengths, having size
`[batch_size]`.
beam_width: An int scalar >= 0 (beam search beam width).
top_paths: An int scalar >= 0, <= beam_width (controls output size).
merge_repeated: Boolean. Default: True.
Returns:
A tuple `(decoded, log_probabilities)` where
decoded: A list of length top_paths, where `decoded[j]`
is a `SparseTensor` containing the decoded outputs:
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
The rows store: [batch, time].
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
The vector stores the decoded classes for beam j.
`decoded[j].dense_shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length[j]]`.
log_probability: A `float` matrix `(batch_size x top_paths)` containing
sequence log-probabilities.
"""
decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
gen_ctc_ops.ctc_beam_search_decoder(
inputs,
sequence_length,
beam_width=beam_width,
top_paths=top_paths,
merge_repeated=merge_repeated))
return ([
sparse_tensor.SparseTensor(ix, val, shape)
for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)
], log_probabilities)
@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder_v2(inputs,
sequence_length,
beam_width=100,
top_paths=1):
"""Performs beam search decoding on the logits given in input.
**Note** The `ctc_greedy_decoder` is a special case of the
`ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
that decoder is faster for this special case).
Args:
inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`.
The logits.
sequence_length: 1-D `int32` vector containing sequence lengths, having size
`[batch_size]`.
beam_width: An int scalar >= 0 (beam search beam width).
top_paths: An int scalar >= 0, <= beam_width (controls output size).
Returns:
A tuple `(decoded, log_probabilities)` where
decoded: A list of length top_paths, where `decoded[j]`
is a `SparseTensor` containing the decoded outputs:
`decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
The rows store: `[batch, time]`.
`decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
The vector stores the decoded classes for beam `j`.
`decoded[j].dense_shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length[j]]`.
log_probability: A `float` matrix `[batch_size, top_paths]` containing
sequence log-probabilities.
"""
# Note, merge_repeated is an invalid optimization that is removed from the
# public API: it returns low probability paths.
return ctc_beam_search_decoder(
inputs,
sequence_length=sequence_length,
beam_width=beam_width,
top_paths=top_paths,
merge_repeated=False)
ops.NotDifferentiable("CTCGreedyDecoder")
ops.NotDifferentiable("CTCBeamSearchDecoder")
def _ctc_state_trans(label_seq):
"""Compute CTC alignment model transition matrix.
Args:
label_seq: tensor of shape [batch_size, max_seq_length]
Returns:
tensor of shape [batch_size, states, states] with a state transition matrix
computed for each sequence of the batch.
"""
with ops.name_scope("ctc_state_trans"):
label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
batch_size = _get_dim(label_seq, 0)
num_labels = _get_dim(label_seq, 1)
num_label_states = num_labels + 1
num_states = 2 * num_label_states
label_states = math_ops.range(num_label_states)
blank_states = label_states + num_label_states
# Start state to first label.
start_to_label = [[1, 0]]
# Blank to label transitions.
blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
# Label to blank transitions.
label_to_blank = array_ops.stack([blank_states, label_states], 1)
# Scatter transitions that don't depend on sequence.
indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank],
0)
values = array_ops.ones([_get_dim(indices, 0)])
trans = array_ops.scatter_nd(
indices, values, shape=[num_states, num_states])
trans += linalg_ops.eye(num_states) # Self-loops.
# Label to label transitions. Disallow transitions between repeated labels
# with no blank state in between.
batch_idx = array_ops.zeros_like(label_states[2:])
indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]],
1)
indices = array_ops.tile(
array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
indices += array_ops.expand_dims(batch_idx, 1)
repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
values = 1.0 - math_ops.cast(repeats, dtypes.float32)
batched_shape = [batch_size, num_states, num_states]
label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
return array_ops.expand_dims(trans, 0) + label_to_label
def ctc_state_log_probs(seq_lengths, max_seq_length):
"""Computes CTC alignment initial and final state log probabilities.
Create the initial/final state values directly as log values to avoid
having to take a float64 log on tpu (which does not exist).
Args:
seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
max_seq_length: int, max sequence length possible.
Returns:
initial_state_log_probs, final_state_log_probs
"""
batch_size = _get_dim(seq_lengths, 0)
num_label_states = max_seq_length + 1
num_duration_states = 2
num_states = num_duration_states * num_label_states
log_0 = math_ops.cast(
math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32)
initial_state_log_probs = array_ops.one_hot(
indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
depth=num_states,
on_value=0.0,
off_value=log_0,
axis=1)
label_final_state_mask = array_ops.one_hot(
seq_lengths, depth=num_label_states, axis=0)
duration_final_state_mask = array_ops.ones(
[num_duration_states, 1, batch_size])
final_state_mask = duration_final_state_mask * label_final_state_mask
final_state_log_probs = (1.0 - final_state_mask) * log_0
final_state_log_probs = array_ops.reshape(final_state_log_probs,
[num_states, batch_size])
return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
"""Project ilabel log probs to state log probs."""
num_label_states = _get_dim(labels, 1)
blank = ilabel_log_probs[:, :, :1]
blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
one_hot = array_ops.one_hot(labels, depth=num_labels)
one_hot = array_ops.expand_dims(one_hot, axis=0)
ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
return array_ops.pad(
state_log_probs, [[0, 0], [0, 0], [1, 0]],
constant_values=math_ops.log(0.0))
def _state_to_olabel(labels, num_labels, states):
"""Sum state log probs to ilabel log probs."""
num_label_states = _get_dim(labels, 1) + 1
label_states = states[:, :, 1:num_label_states]
blank_states = states[:, :, num_label_states:]
one_hot = array_ops.one_hot(
labels - 1,
depth=(num_labels - 1),
on_value=0.0,
off_value=math_ops.log(0.0))
one_hot = array_ops.expand_dims(one_hot, axis=0)
label_states = array_ops.expand_dims(label_states, axis=3)
label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
return array_ops.concat([blank_olabels, label_olabels], axis=-1)
# pylint: disable=redefined-outer-name
def _state_to_olabel_unique(labels, num_labels, states, unique):
"""Sum state log probs to ilabel log probs using unique label indices."""
num_label_states = _get_dim(labels, 1) + 1
label_states = states[:, :, 1:num_label_states]
blank_states = states[:, :, num_label_states:]
unique_y, unique_idx = unique
mul_reduce = _sum_states(unique_idx, label_states)
num_frames = states.shape[0]
batch_size = states.shape[1]
num_states = num_label_states - 1
batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
batch_state_major = array_ops.reshape(batch_state_major,
[batch_size * num_states, num_frames])
batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
indices = array_ops.reshape(indices, [-1, 1])
scatter = array_ops.scatter_nd(
indices=indices,
updates=batch_state_major,
shape=[batch_size * num_labels, num_frames])
scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
mask = array_ops.scatter_nd(
indices=indices,
updates=mask,
shape=[batch_size * num_labels, num_frames])
mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
scatter = array_ops.where(
mask, scatter,
array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
label_olabels = array_ops.transpose(scatter, [2, 0, 1])
label_olabels = label_olabels[:, :, 1:]
blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
return array_ops.concat([blank_olabels, label_olabels], axis=-1)
def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
"""Computes the CTC loss and gradients.
Most users will want fwd_bwd.ctc_loss
This function returns the computed gradient, it does not have a gradient
of its own defined.
Args:
logits: tensor of shape [frames, batch_size, num_labels]
labels: tensor of shape [batch_size, max_label_seq_length]
label_length: tensor of shape [batch_size] Length of reference label
sequence in labels.
logit_length: tensor of shape [batch_size] Length of input sequence in
logits.
unique: (optional) unique label indices as computed by unique(labels) If
supplied, enables an implementation that is faster and more memory
efficient on TPU.
Returns:
loss: tensor of shape [batch_size]
gradient: tensor of shape [frames, batch_size, num_labels]
"""
num_labels = _get_dim(logits, 2)
max_label_seq_length = _get_dim(labels, 1)
ilabel_log_probs = nn_ops.log_softmax(logits)
state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
state_trans_probs = _ctc_state_trans(labels)
initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
label_length, max_label_seq_length)
fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
state_trans_log_probs=math_ops.log(state_trans_probs),
initial_state_log_probs=initial_state_log_probs,
final_state_log_probs=final_state_log_probs,
observed_log_probs=state_log_probs,
sequence_length=logit_length)
if unique:
olabel_log_probs = _state_to_olabel_unique(labels, num_labels,
fwd_bwd_log_probs, unique)
else:
olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
# Applies the sequence mask for the gradient. It is enough to appply the mask
# only for ilabel_log_probs because olabel_log_probs already consider the
# mask. However, it is just safe and clean to apply it for the gradient.
max_logit_length = _get_dim(logits, 0)
logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
dtypes.float32)
logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
logit_mask = array_ops.expand_dims(logit_mask, axis=2)
grad *= logit_mask
loss = -log_likelihood
return loss, grad
def _ctc_loss_grad(op, grad_loss, _):
grad = op.outputs[1]
grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
grad += [None] * (len(op.inputs) - len(grad))
return grad
def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
blank_index):
part_before = logits[:, :, :blank_index]
part_after = logits[:, :, blank_index + 1:]
part_blank = logits[:, :, blank_index:blank_index + 1]
logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values,
labels.values - 1), labels.dense_shape)
return _ctc_loss_impl(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major,
use_cudnn=False)
def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
blank_index):
part_before = logits[:, :, :blank_index]
part_after = logits[:, :, blank_index + 1:]
part_blank = logits[:, :, blank_index:blank_index + 1]
logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values + 1,
labels.values), labels.dense_shape)
return _ctc_loss_impl(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major,
use_cudnn=True)
def _ctc_loss_shape(op):
return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss_v2"])
@dispatch.add_dispatch_support
def ctc_loss_v2(labels,
logits,
label_length,
logit_length,
logits_time_major=True,
unique=None,
blank_index=None,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in (Graves et al., 2006).
Notes:
- Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
- Labels may be supplied as either a dense, zero-padded tensor with a
vector of label sequence lengths OR as a SparseTensor.
- On TPU and GPU: Only dense padded labels are supported.
- On CPU: Caller may use SparseTensor or dense padded labels but calling with
a SparseTensor will be significantly faster.
- Default blank label is 0 rather num_classes - 1, unless overridden by
blank_index.
Args:
labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
logits: tensor of shape [frames, batch_size, num_labels], if
logits_time_major == False, shape is [batch_size, frames, num_labels].
label_length: tensor of shape [batch_size], None if labels is SparseTensor
Length of reference label sequence in labels.
logit_length: tensor of shape [batch_size] Length of input sequence in
logits.
logits_time_major: (optional) If True (default), logits is shaped [time,
batch, logits]. If False, shape is [batch, time, logits]
unique: (optional) Unique label indices as computed by
ctc_unique_labels(labels). If supplied, enable a faster, memory efficient
implementation on TPU.
blank_index: (optional) Set the class index to use for the blank label.
Negative values will start from num_classes, ie, -1 will reproduce the
ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
some memory/performance overhead to switching from the default of 0 as an
additional shifted copy of the logits may be created.
name: A name for this `Op`. Defaults to "ctc_loss_dense".
Returns:
loss: tensor of shape [batch_size], negative log probabilities.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
if isinstance(labels, sparse_tensor.SparseTensor):
if blank_index is None:
raise ValueError(
"Argument `blank_index` must be provided when labels is a "
"SparseTensor.")
if blank_index < 0:
blank_index += _get_dim(logits, 2)
if blank_index != _get_dim(logits, 2) - 1:
logits = array_ops.concat([
logits[:, :, :blank_index],
logits[:, :, blank_index + 1:],
logits[:, :, blank_index:blank_index + 1],
],
axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values,
labels.values - 1), labels.dense_shape)
return ctc_loss(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major)
if blank_index is None:
blank_index = 0
return ctc_loss_dense(
labels=labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
logits_time_major=logits_time_major,
unique=unique,
blank_index=blank_index,
name=name)
@tf_export("nn.ctc_loss", v1=[])
@dispatch.add_dispatch_support
def ctc_loss_v3(labels,
logits,
label_length,
logit_length,
logits_time_major=True,
unique=None,
blank_index=None,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in (Graves et al., 2006).
Notes:
- Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
- Labels may be supplied as either a dense, zero-padded tensor with a
vector of label sequence lengths OR as a SparseTensor.
- On TPU and GPU: Only dense padded labels are supported.
- On CPU: Caller may use SparseTensor or dense padded labels but calling with
a SparseTensor will be significantly faster.
- Default blank label is 0 rather num_classes - 1, unless overridden by
blank_index.
Args:
labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
logits: tensor of shape [frames, batch_size, num_labels], if
logits_time_major == False, shape is [batch_size, frames, num_labels].
label_length: tensor of shape [batch_size], None if labels is SparseTensor
Length of reference label sequence in labels.
logit_length: tensor of shape [batch_size] Length of input sequence in
logits.
logits_time_major: (optional) If True (default), logits is shaped [time,
batch, logits]. If False, shape is [batch, time, logits]
unique: (optional) Unique label indices as computed by
ctc_unique_labels(labels). If supplied, enable a faster, memory efficient
implementation on TPU.
blank_index: (optional) Set the class index to use for the blank label.
Negative values will start from num_classes, ie, -1 will reproduce the
ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
some memory/performance overhead to switching from the default of 0 as an
additional shifted copy of the logits may be created.
name: A name for this `Op`. Defaults to "ctc_loss_dense".
Returns:
loss: tensor of shape [batch_size], negative log probabilities.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
if isinstance(labels, sparse_tensor.SparseTensor):
if blank_index is None:
raise ValueError(
"Argument `blank_index` must be provided when labels is a "
"SparseTensor.")
if blank_index < 0:
blank_index += _get_dim(logits, 2)
params = {
"labels": labels,
"logits": logits,
"logit_length": logit_length,
"logits_time_major": logits_time_major,
"blank_index": blank_index
}
if context.executing_eagerly():
device_type = _get_context_device_type()
can_use_gpu = (
# Either user specified GPU or unspecified but GPU is available.
(device_type == _GPU_DEVICE_NAME or
(device_type is None and context.num_gpus() > 0)))
# Under eager context, check the device placement and prefer the
if can_use_gpu:
res = _ctc_loss_op_cudnn(**params)
else:
res = _ctc_loss_op_standard(**params)
else:
api_name = "ctc_loss_" + str(uuid.uuid4())
ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
_ctc_loss_op_standard)
ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
_ctc_loss_op_cudnn)
res = ctc_loss_op_standard(**params)
function_eager.register(ctc_loss_op_cudnn, **params)
return res
if blank_index is None:
blank_index = 0
return ctc_loss_dense(
labels=labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
logits_time_major=logits_time_major,
unique=unique,
blank_index=blank_index,
name=name)
def ctc_loss_dense(labels,
logits,
label_length,
logit_length,
logits_time_major=True,
unique=None,
blank_index=0,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in (Graves et al., 2006),
using the batched forward backward algorithm described in (Sim et al., 2017).
Notes:
Significant differences from tf.compat.v1.nn.ctc_loss:
Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only):
For batched operations, GPU and TPU are significantly faster than using
ctc_loss on CPU.
This implementation runs on CPU, but significantly slower than ctc_loss.
Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
Logits and labels are dense arrays with padding rather than SparseTensor.
The only mode supported is the same as:
preprocess_collapse_repeated=False, ctc_merge_repeated=True
To collapse labels, the caller can preprocess label sequence first.
The dense implementation supports both CPU, GPU and TPU. A fast path is
provided that significantly improves memory use for large vocabulary if the
caller preprocesses label sequences to get unique label indices on the CPU