/
ctc.py
2225 lines (1964 loc) · 79.9 KB
/
ctc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Decoders and output normalization for CTC.
Authors
* Mirco Ravanelli 2020
* Aku Rouhe 2020
* Sung-Lin Yeh 2020
* Adel Moumen 2023
"""
from itertools import groupby
from speechbrain.dataio.dataio import length_to_mask
import math
import dataclasses
import numpy as np
import heapq
import logging
import torch
from typing import Dict, List, Optional, Union, Any, Tuple
logger = logging.getLogger(__name__)
class CTCPrefixScore:
"""This class implements the CTC prefix score of Algorithm 2 in
reference: https://www.merl.com/publications/docs/TR2017-190.pdf.
Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
Arguments
---------
x : torch.Tensor
The encoder states.
enc_lens : torch.Tensor
The actual length of each enc_states sequence.
batch_size : int
The size of the batch.
beam_size : int
The width of beam.
blank_index : int
The index of the blank token.
eos_index : int
The index of the end-of-sequence (eos) token.
ctc_window_size: int
Compute the ctc scores over the time frames using windowing based on attention peaks.
If 0, no windowing applied.
"""
def __init__(
self, x, enc_lens, blank_index, eos_index, ctc_window_size=0,
):
self.blank_index = blank_index
self.eos_index = eos_index
self.batch_size = x.size(0)
self.max_enc_len = x.size(1)
self.vocab_size = x.size(-1)
self.device = x.device
self.minus_inf = -1e20
self.last_frame_index = enc_lens - 1
self.ctc_window_size = ctc_window_size
self.prefix_length = 0
# mask frames > enc_lens
mask = 1 - length_to_mask(enc_lens)
mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1)
x.masked_fill_(mask, self.minus_inf)
x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0)
# dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors
xnb = x.transpose(0, 1)
xb = (
xnb[:, :, self.blank_index]
.unsqueeze(2)
.expand(-1, -1, self.vocab_size)
)
# (2, L, batch_size * beam_size, vocab_size)
self.x = torch.stack([xnb, xb])
# indices of batch.
self.batch_index = torch.arange(self.batch_size, device=self.device)
@torch.no_grad()
def forward_step(self, inp_tokens, states, candidates=None, attn=None):
"""This method if one step of forwarding operation
for the prefix ctc scorer.
Arguments
---------
inp_tokens : torch.Tensor
The last chars of prefix label sequences g, where h = g + c.
states : tuple
Previous ctc states.
candidates : torch.Tensor
(batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring.
If given, performing partial ctc scoring.
attn : torch.Tensor
(batch_size * beam_size, max_enc_len), The attention weights.
"""
n_bh = inp_tokens.size(0)
beam_size = n_bh // self.batch_size
last_char = inp_tokens
self.prefix_length += 1
self.num_candidates = (
self.vocab_size if candidates is None else candidates.size(-1)
)
if states is None:
# r_prev: (L, 2, batch_size * beam_size)
r_prev = torch.full(
(self.max_enc_len, 2, self.batch_size, beam_size),
self.minus_inf,
device=self.device,
)
# Accumulate blank posteriors at each step
r_prev[:, 1] = torch.cumsum(
self.x[0, :, :, self.blank_index], 0
).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh)
psi_prev = torch.full(
(n_bh, self.vocab_size), 0.0, device=self.device,
)
else:
r_prev, psi_prev = states
# for partial search
if candidates is not None:
# The first index of each candidate.
cand_offset = self.batch_index * self.vocab_size
scoring_table = torch.full(
(n_bh, self.vocab_size),
-1,
dtype=torch.long,
device=self.device,
)
# Assign indices of candidates to their positions in the table
col_index = torch.arange(n_bh, device=self.device).unsqueeze(1)
scoring_table[col_index, candidates] = torch.arange(
self.num_candidates, device=self.device
)
# Select candidates indices for scoring
scoring_index = (
candidates
+ cand_offset.unsqueeze(1).repeat(1, beam_size).view(-1, 1)
).view(-1)
x_inflate = torch.index_select(
self.x.view(2, -1, self.batch_size * self.vocab_size),
2,
scoring_index,
).view(2, -1, n_bh, self.num_candidates)
# for full search
else:
scoring_table = None
# Inflate x to (2, -1, batch_size * beam_size, num_candidates)
# It is used to compute forward probs in a batched way
x_inflate = (
self.x.unsqueeze(3)
.repeat(1, 1, 1, beam_size, 1)
.view(2, -1, n_bh, self.num_candidates)
)
# Prepare forward probs
r = torch.full(
(self.max_enc_len, 2, n_bh, self.num_candidates,),
self.minus_inf,
device=self.device,
)
r.fill_(self.minus_inf)
# (Alg.2-6)
if self.prefix_length == 0:
r[0, 0] = x_inflate[0, 0]
# (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
r_sum = torch.logsumexp(r_prev, 1)
phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates)
# (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
if candidates is not None:
for i in range(n_bh):
pos = scoring_table[i, last_char[i]]
if pos != -1:
phi[:, i, pos] = r_prev[:, 1, i]
else:
for i in range(n_bh):
phi[:, i, last_char[i]] = r_prev[:, 1, i]
# Start, end frames for scoring (|g| < |h|).
# Scoring based on attn peak if ctc_window_size > 0
if self.ctc_window_size == 0 or attn is None:
start = max(1, self.prefix_length)
end = self.max_enc_len
else:
_, attn_peak = torch.max(attn, dim=1)
max_frame = torch.max(attn_peak).item() + self.ctc_window_size
min_frame = torch.min(attn_peak).item() - self.ctc_window_size
start = max(max(1, self.prefix_length), int(min_frame))
end = min(self.max_enc_len, int(max_frame))
# Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
for t in range(start, end):
# (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c)
rnb_prev = r[t - 1, 0]
# (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
rb_prev = r[t - 1, 1]
r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
2, 2, n_bh, self.num_candidates
)
r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t]
# Compute the predix prob, psi
psi_init = r[start - 1, 0].unsqueeze(0)
# phi is prob at t-1 step, shift one frame and add it to the current prob p(c)
phix = torch.cat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0]
# (Alg.2-13): psi = psi + phi * p(c)
if candidates is not None:
psi = torch.full(
(n_bh, self.vocab_size), self.minus_inf, device=self.device,
)
psi_ = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# only assign prob to candidates
for i in range(n_bh):
psi[i, candidates[i]] = psi_[i]
else:
psi = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
for i in range(n_bh):
psi[i, self.eos_index] = r_sum[
self.last_frame_index[i // beam_size], i
]
# Exclude blank probs for joint scoring
psi[:, self.blank_index] = self.minus_inf
return psi - psi_prev, (r, psi, scoring_table)
def permute_mem(self, memory, index):
"""This method permutes the CTC model memory
to synchronize the memory index with the current output.
Arguments
---------
memory : No limit
The memory variable to be permuted.
index : torch.Tensor
The index of the previous path.
Return
------
The variable of the memory being permuted.
"""
r, psi, scoring_table = memory
beam_size = index.size(1)
n_bh = self.batch_size * beam_size
# The first index of each batch.
beam_offset = self.batch_index * beam_size
# The index of top-K vocab came from in (t-1) timesteps at batch * beam * vocab dimension.
cand_index = (
index + beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size
).view(n_bh)
# synchronize forward prob
psi = torch.index_select(psi.view(-1), dim=0, index=cand_index)
psi = (
psi.view(-1, 1)
.repeat(1, self.vocab_size)
.view(n_bh, self.vocab_size)
)
# The index of top-K vocab came from in (t-1) timesteps at batch * beam dimension.
hyp_index = (
torch.div(index, self.vocab_size, rounding_mode="floor")
+ beam_offset.unsqueeze(1).expand_as(index)
).view(n_bh)
# synchronize ctc states
if scoring_table is not None:
selected_vocab = (index % self.vocab_size).view(-1)
score_index = scoring_table[hyp_index, selected_vocab]
score_index[score_index == -1] = 0
cand_index = score_index + hyp_index * self.num_candidates
r = torch.index_select(
r.view(-1, 2, n_bh * self.num_candidates), dim=-1, index=cand_index,
)
r = r.view(-1, 2, n_bh)
return r, psi
def filter_ctc_output(string_pred, blank_id=-1):
"""Apply CTC output merge and filter rules.
Removes the blank symbol and output repetitions.
Arguments
---------
string_pred : list
A list containing the output strings/ints predicted by the CTC system.
blank_id : int, string
The id of the blank.
Returns
-------
list
The output predicted by CTC without the blank symbol and
the repetitions.
Example
-------
>>> string_pred = ['a','a','blank','b','b','blank','c']
>>> string_out = filter_ctc_output(string_pred, blank_id='blank')
>>> print(string_out)
['a', 'b', 'c']
"""
if isinstance(string_pred, list):
# Filter the repetitions
string_out = [i[0] for i in groupby(string_pred)]
# Filter the blank symbol
string_out = list(filter(lambda elem: elem != blank_id, string_out))
else:
raise ValueError("filter_ctc_out can only filter python lists")
return string_out
def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
"""Greedy decode a batch of probabilities and apply CTC rules.
Arguments
---------
probabilities : torch.tensor
Output probabilities (or log-probabilities) from the network with shape
[batch, probabilities, time]
seq_lens : torch.tensor
Relative true sequence lengths (to deal with padded inputs),
the longest sequence has length 1.0, others a value between zero and one
shape [batch, lengths].
blank_id : int, string
The blank symbol/index. Default: -1. If a negative number is given,
it is assumed to mean counting down from the maximum possible index,
so that -1 refers to the maximum possible index.
Returns
-------
list
Outputs as Python list of lists, with "ragged" dimensions; padding
has been removed.
Example
-------
>>> import torch
>>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]],
... [[0.2, 0.8], [0.9, 0.1]]])
>>> lens = torch.tensor([0.51, 1.0])
>>> blank_id = 0
>>> ctc_greedy_decode(probs, lens, blank_id)
[[1], [1]]
"""
if isinstance(blank_id, int) and blank_id < 0:
blank_id = probabilities.shape[-1] + blank_id
batch_max_len = probabilities.shape[1]
batch_outputs = []
for seq, seq_len in zip(probabilities, seq_lens):
actual_size = int(torch.round(seq_len * batch_max_len))
scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1)
out = filter_ctc_output(predictions.tolist(), blank_id=blank_id)
batch_outputs.append(out)
return batch_outputs
@dataclasses.dataclass
class CTCBeam:
"""This class handle the CTC beam informations during decoding.
Arguments
---------
text : str
The current text of the beam.
full_text : str
The full text of the beam.
next_word : str
The next word to be added to the beam.
partial_word : str
The partial word being added to the beam.
last_token : str, optional
The last token of the beam.
last_token_index : int, optional
The index of the last token of the beam.
text_frames : List[Tuple[int, int]]
The start and end frame of the text.
partial_frames : Tuple[int, int]
The start and end frame of the partial word.
p : float
The probability of the beam.
p_b : float
The probability of the beam ending in a blank.
p_nb : float
The probability of the beam not ending in a blank.
n_p_b : float
The previous probability of the beam ending in a blank.
n_p_nb : float
The previous probability of the beam not ending in a blank.
score : float
The score of the beam (LM + CTC)
score_ctc : float
The CTC score computed.
Example
-------
>>> beam = CTCBeam(
... text="",
... full_text="",
... next_word="",
... partial_word="",
... last_token=None,
... last_token_index=None,
... text_frames=[(0, 0)],
... partial_frames=(0, 0),
... p=-math.inf,
... p_b=-math.inf,
... p_nb=-math.inf,
... n_p_b=-math.inf,
... n_p_nb=-math.inf,
... score=-math.inf,
... score_ctc=-math.inf,
... )
"""
text: str
full_text: str
next_word: str
partial_word: str
last_token: Optional[str]
last_token_index: Optional[int]
text_frames: List[Tuple[int, int]]
partial_frames: Tuple[int, int]
p: float = -math.inf
p_b: float = -math.inf
p_nb: float = -math.inf
n_p_b: float = -math.inf
n_p_nb: float = -math.inf
score: float = -math.inf
score_ctc: float = -math.inf
@classmethod
def from_lm_beam(self, lm_beam: "LMCTCBeam") -> "CTCBeam":
"""Create a CTCBeam from a LMCTCBeam
Arguments
---------
lm_beam : LMCTCBeam
The LMCTCBeam to convert.
Returns
-------
CTCBeam
The CTCBeam converted.
"""
return CTCBeam(
text=lm_beam.text,
full_text=lm_beam.full_text,
next_word=lm_beam.next_word,
partial_word=lm_beam.partial_word,
last_token=lm_beam.last_token,
last_token_index=lm_beam.last_token_index,
text_frames=lm_beam.text_frames,
partial_frames=lm_beam.partial_frames,
p=lm_beam.p,
p_b=lm_beam.p_b,
p_nb=lm_beam.p_nb,
n_p_b=lm_beam.n_p_b,
n_p_nb=lm_beam.n_p_nb,
score=lm_beam.score,
score_ctc=lm_beam.score_ctc,
)
def step(self) -> None:
"""Update the beam probabilities."""
self.p_b, self.p_nb = self.n_p_b, self.n_p_nb
self.n_p_b = self.n_p_nb = -math.inf
self.score_ctc = np.logaddexp(self.p_b, self.p_nb)
self.score = self.score_ctc
@dataclasses.dataclass
class LMCTCBeam(CTCBeam):
"""This class handle the LM scores during decoding.
Arguments
---------
lm_score: float
The LM score of the beam.
**kwargs
See CTCBeam for the other arguments.
"""
lm_score: float = -math.inf
@dataclasses.dataclass
class CTCHypothesis:
"""This class is a data handler over the generated hypotheses.
This class is the default output of the CTC beam searchers.
It can be re-used for other decoders if using
the beam searchers in an online fashion.
Arguments
---------
text : str
The text of the hypothesis.
last_lm_state : None
The last LM state of the hypothesis.
score : float
The score of the hypothesis.
lm_score : float
The LM score of the hypothesis.
text_frames : List[Tuple[str, Tuple[int, int]]], optional
The list of the text and the corresponding frames.
"""
text: str
last_lm_state: None
score: float
lm_score: float
text_frames: list = None
class CTCBaseSearcher(torch.nn.Module):
"""CTCBaseSearcher class to be inherited by other
CTC beam searchers.
This class provides the basic functionalities for
CTC beam search decoding.
The space_token is required with a non-sentencepiece vocabulary list
if your transcription is expecting to contain spaces.
Arguments
---------
blank_index : int
The index of the blank token.
vocab_list : list
The list of the vocabulary tokens.
space_token : int, optional
The index of the space token. (default: -1)
kenlm_model_path : str, optional
The path to the kenlm model. Use .bin for a faster loading.
If None, no language model will be used. (default: None)
unigrams : list, optional
The list of known word unigrams. (default: None)
alpha : float
Weight for language model during shallow fusion. (default: 0.5)
beta : float
Weight for length score adjustment of during scoring. (default: 1.5)
unk_score_offset : float
Amount of log score offset for unknown tokens. (default: -10.0)
score_boundary : bool
Whether to have kenlm respect boundaries when scoring. (default: True)
beam_size : int, optional
The width of the beam. (default: 100)
beam_prune_logp : float, optional
The pruning threshold for the beam. (default: -10.0)
token_prune_min_logp : float, optional
The pruning threshold for the tokens. (default: -5.0)
prune_history : bool, optional
Whether to prune the history. (default: True)
Note: when using topk > 1, this should be set to False as
it is pruning a lot of beams.
blank_skip_threshold : float, optional
Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk. (default: 1.0)
topk : int, optional
The number of top hypotheses to return. (default: 1)
spm_token: str, optional
The sentencepiece token. (default: "▁")
Example
-------
>>> blank_index = 0
>>> vocab_list = ['blank', 'a', 'b', 'c', ' ']
>>> space_token = ' '
>>> kenlm_model_path = None
>>> unigrams = None
>>> beam_size = 100
>>> beam_prune_logp = -10.0
>>> token_prune_min_logp = -5.0
>>> prune_history = True
>>> blank_skip_threshold = 1.0
>>> topk = 1
>>> searcher = CTCBaseSearcher(
... blank_index=blank_index,
... vocab_list=vocab_list,
... space_token=space_token,
... kenlm_model_path=kenlm_model_path,
... unigrams=unigrams,
... beam_size=beam_size,
... beam_prune_logp=beam_prune_logp,
... token_prune_min_logp=token_prune_min_logp,
... prune_history=prune_history,
... blank_skip_threshold=blank_skip_threshold,
... topk=topk,
... )
"""
def __init__(
self,
blank_index: int,
vocab_list: List[str],
space_token: str = " ",
kenlm_model_path: Union[None, str] = None,
unigrams: Union[None, List[str]] = None,
alpha: float = 0.5,
beta: float = 1.5,
unk_score_offset: float = -10.0,
score_boundary: bool = True,
beam_size: int = 100,
beam_prune_logp: int = -10.0,
token_prune_min_logp: int = -5.0,
prune_history: bool = True,
blank_skip_threshold: Union[None, int] = 1.0,
topk: int = 1,
spm_token: str = "▁",
):
super().__init__()
self.blank_index = blank_index
self.vocab_list = vocab_list
self.space_token = space_token
self.kenlm_model_path = kenlm_model_path
self.unigrams = unigrams
self.alpha = alpha
self.beta = beta
self.unk_score_offset = unk_score_offset
self.score_boundary = score_boundary
self.beam_size = beam_size
self.beam_prune_logp = beam_prune_logp
self.token_prune_min_logp = token_prune_min_logp
self.prune_history = prune_history
self.blank_skip_threshold = math.log(blank_skip_threshold)
self.topk = topk
self.spm_token = spm_token
# check if the vocab is coming from SentencePiece
self.is_spm = any([s.startswith(self.spm_token) for s in vocab_list])
# fetch the index of space_token
if not self.is_spm:
try:
self.space_index = vocab_list.index(space_token)
except ValueError:
logger.warning(
f"space_token `{space_token}` not found in the vocabulary."
"Using value -1 as `space_index`."
"Note: If your transcription is not expected to contain spaces, "
"you can ignore this warning."
)
self.space_index = -1
logger.info(f"Found `space_token` at index {self.space_index}.")
self.kenlm_model = None
if kenlm_model_path is not None:
try:
import kenlm # type: ignore
from speechbrain.decoders.language_model import (
LanguageModel,
load_unigram_set_from_arpa,
)
except ImportError:
raise ImportError(
"kenlm python bindings are not installed. To install it use: "
"pip install https://github.com/kpu/kenlm/archive/master.zip"
)
self.kenlm_model = kenlm.Model(kenlm_model_path)
if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"):
logger.info(
"Using arpa instead of binary LM file, decoder instantiation might be slow."
)
if unigrams is None and kenlm_model_path is not None:
if kenlm_model_path.endswith(".arpa"):
unigrams = load_unigram_set_from_arpa(kenlm_model_path)
else:
logger.warning(
"Unigrams not provided and cannot be automatically determined from LM file (only "
"arpa format). Decoding accuracy might be reduced."
)
if self.kenlm_model is not None:
self.lm = LanguageModel(
kenlm_model=self.kenlm_model,
unigrams=unigrams,
alpha=self.alpha,
beta=self.beta,
unk_score_offset=self.unk_score_offset,
score_boundary=self.score_boundary,
)
else:
self.lm = None
def partial_decoding(
self,
log_probs: torch.Tensor,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
processed_frames: int = 0,
):
"""Perform a single step of decoding.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
beams : list
The list of the beams.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
processed_frames : int, default: 0
The start frame of the current decoding step.
"""
raise NotImplementedError
def normalize_whitespace(self, text: str) -> str:
"""Efficiently normalize whitespace.
Arguments
---------
text : str
The text to normalize.
Returns
-------
str
The normalized text.
"""
return " ".join(text.split())
def merge_tokens(self, token_1: str, token_2: str) -> str:
"""Merge two tokens, and avoid empty ones.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
token_1 : str
The first token.
token_2 : str
The second token.
Returns
-------
str
The merged token.
"""
if len(token_2) == 0:
text = token_1
elif len(token_1) == 0:
text = token_2
else:
text = token_1 + " " + token_2
return text
def merge_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
"""Merge beams with the same text.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
Returns
-------
list
The list of CTCBeam merged.
"""
beam_dict = {}
for beam in beams:
new_text = self.merge_tokens(beam.text, beam.next_word)
hash_idx = (new_text, beam.partial_word, beam.last_token)
if hash_idx not in beam_dict:
beam_dict[hash_idx] = beam
else:
# We've already seen this text - we want to combine the scores
beam_dict[hash_idx] = dataclasses.replace(
beam,
score=np.logaddexp(beam_dict[hash_idx].score, beam.score),
)
return list(beam_dict.values())
def sort_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
"""Sort beams by lm_score.
Arguments
---------
beams : list
The list of CTCBeam.
Returns
-------
list
The list of CTCBeam sorted.
"""
return heapq.nlargest(self.beam_size, beams, key=lambda x: x.lm_score)
def _prune_history(
self, beams: List[CTCBeam], lm_order: int
) -> List[CTCBeam]:
"""Filter out beams that are the same over max_ngram history.
Since n-gram language models have a finite history when scoring a new token, we can use that
fact to prune beams that only differ early on (more than n tokens in the past) and keep only the
higher scoring ones. Note that this helps speed up the decoding process but comes at the cost of
some amount of beam diversity. If more than the top beam is used in the output it should
potentially be disabled.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
lm_order : int
The order of the language model.
Returns
-------
list
The list of CTCBeam.
"""
# let's keep at least 1 word of history
min_n_history = max(1, lm_order - 1)
seen_hashes = set()
filtered_beams = []
# for each beam after this, check if we need to add it
for lm_beam in beams:
# hash based on history that can still affect lm scoring going forward
hash_idx = (
tuple(lm_beam.text.split()[-min_n_history:]),
lm_beam.partial_word,
lm_beam.last_token,
)
if hash_idx not in seen_hashes:
filtered_beams.append(CTCBeam.from_lm_beam(lm_beam))
seen_hashes.add(hash_idx)
return filtered_beams
def finalize_decoding(
self,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
force_next_word=False,
is_end=False,
) -> List[CTCBeam]:
"""Finalize the decoding process by adding and scoring the last partial word.
Arguments
---------
beams : list
The list of CTCBeam.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
force_next_word : bool, default: False
Whether to force the next word.
is_end : bool, default: False
Whether the end of the sequence has been reached.
Returns
-------
list
The list of the CTCBeam.
"""
if force_next_word or is_end:
new_beams = []
for beam in beams:
new_token_times = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.partial_word,
partial_word="",
last_token=None,
last_token_index=None,
text_frames=new_token_times,
partial_frames=(-1, -1),
score=beam.score,
)
)
new_beams = self.merge_beams(new_beams)
else:
new_beams = list(beams)
scored_beams = self.get_lm_beams(
new_beams, cached_lm_scores, cached_p_lm_scores,
)
# remove beam outliers
max_score = max([b.lm_score for b in scored_beams])
scored_beams = [
b
for b in scored_beams
if b.lm_score >= max_score + self.beam_prune_logp
]
sorted_beams = self.sort_beams(scored_beams)
return sorted_beams
def decode_beams(
self,
log_probs: torch.Tensor,
wav_lens: Optional[torch.Tensor] = None,
lm_start_state: Any = None,
) -> List[List[CTCHypothesis]]:
"""Decodes the log probabilities of the CTC output.
It automatically converts the SpeechBrain's relative length of the wav input
to the absolute length.
Each tensors is converted to numpy and CPU as it is faster and consummes less memory.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
The expected shape is [batch_size, seq_length, vocab_size].
wav_lens : torch.Tensor, optional (default: None)
The SpeechBrain's relative length of the wav input.
lm_start_state : Any, optional (default: None)
The start state of the language model.
Returns
-------
list of list
The list of topk list of CTCHypothesis.
"""
# check that the last dimension of log_probs is equal to the vocab size
if log_probs.size(2) != len(self.vocab_list):
logger.warning(
f"Vocab size mismatch: log_probs vocab dim is {log_probs.size(2)} "
f"while vocab_list is {len(self.vocab_list)}. "
"During decoding, going to truncate the log_probs vocab dim to match vocab_list."
)
# compute wav_lens and cast to numpy as it is faster
if wav_lens is not None:
wav_lens = log_probs.size(1) * wav_lens
wav_lens = wav_lens.cpu().numpy().astype(int)
else:
wav_lens = [log_probs.size(1)] * log_probs.size(0)
log_probs = log_probs.cpu().numpy()
hyps = [
self.decode_log_probs(log_prob, wav_len, lm_start_state)
for log_prob, wav_len in zip(log_probs, wav_lens)
]
return hyps
def __call__(
self,
log_probs: torch.Tensor,
wav_lens: Optional[torch.Tensor] = None,
lm_start_state: Any = None,
) -> List[List[CTCHypothesis]]:
"""Decodes the log probabilities of the CTC output.
It automatically converts the SpeechBrain's relative length of the wav input
to the absolute length.
Each tensors is converted to numpy and CPU as it is faster and consummes less memory.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
The expected shape is [batch_size, seq_length, vocab_size].
wav_lens : torch.Tensor, optional (default: None)
The SpeechBrain's relative length of the wav input.
lm_start_state : Any, optional (default: None)