From c3d3cf756b1c099b1919f3f28d42bab819ebfe4b Mon Sep 17 00:00:00 2001 From: Adel Moumen Date: Mon, 25 Sep 2023 11:22:56 +0200 Subject: [PATCH] CTCPrefixBeamSearcher timestamps --- speechbrain/decoders/ctc.py | 42 ++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/speechbrain/decoders/ctc.py b/speechbrain/decoders/ctc.py index 785c092459..aba2aaec7e 100644 --- a/speechbrain/decoders/ctc.py +++ b/speechbrain/decoders/ctc.py @@ -1568,6 +1568,8 @@ def get_lm_beams( partial_word=beam.partial_word, last_token=beam.last_token, last_token_index=beam.last_token_index, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, p=beam.p, p_b=beam.p_b, p_nb=beam.p_nb, @@ -1618,6 +1620,8 @@ def get_lm_beams( partial_word=beam.partial_word, last_token=beam.last_token, last_token_index=beam.last_token_index, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, p=beam.p, p_b=beam.p_b, p_nb=beam.p_nb, @@ -1632,6 +1636,7 @@ def get_lm_beams( def _get_new_beam( self, + frame_index: int, new_prefix: str, new_token: str, new_token_index: int, @@ -1643,6 +1648,8 @@ def _get_new_beam( Arguments --------- + frame_index : int + The index of the current frame. new_prefix : str The new prefix. new_token : str @@ -1668,6 +1675,12 @@ def _get_new_beam( return beam if not self.is_spm and new_token_index == self.space_index: + new_frame_list = ( + beam.text_frames + if beam.partial_word == "" + else beam.text_frames + [beam.partial_frames] + ) + # if we extend the beam with a space, we need to reset the partial word # and move it to the next word new_beam = CTCBeam( @@ -1677,6 +1690,8 @@ def _get_new_beam( partial_word="", last_token=new_token, last_token_index=new_token_index, + text_frames=new_frame_list, + partial_frames=(-1, -1), score=-math.inf, score_ctc=-math.inf, p_b=-math.inf, @@ -1685,6 +1700,12 @@ def _get_new_beam( # remove the spm token at the beginning of the token clean_token = new_token[1:] + new_frame_list = ( + beam.text_frames + if beam.partial_word == "" + else beam.text_frames + [beam.partial_frames] + ) + # If the beginning of the token is the spm_token # then it means that we are extending the beam with a new word. # We need to change the new_word with the partial_word @@ -1697,11 +1718,19 @@ def _get_new_beam( partial_word=clean_token, last_token=new_token, last_token_index=new_token_index, + text_frames=new_frame_list, + partial_frames=(frame_index, frame_index + 1), score=-math.inf, score_ctc=-math.inf, p_b=-math.inf, ) elif new_token_index == previous_beam.last_token_index: + new_end_frame = frame_index + 1 + + new_part_frames = ( + beam.partial_frames if new_token_index == self.blank_index else (beam.partial_frames[0], new_end_frame) + ) + # if repeated token, we only change the score new_beam = CTCBeam( text=new_prefix, @@ -1710,11 +1739,19 @@ def _get_new_beam( partial_word=previous_beam.partial_word, last_token=new_token, last_token_index=new_token_index, + text_frames=beam.text_frames, + partial_frames=new_part_frames, score=-math.inf, score_ctc=-math.inf, p_b=-math.inf, ) else: + new_part_frames = ( + (frame_index, frame_index + 1) + if beam.partial_frames[0] < 0 + else (beam.partial_frames[0], frame_index + 1) + ) + # last case, we are extending the partial_word with a new token new_beam = CTCBeam( text=new_prefix, @@ -1723,6 +1760,8 @@ def _get_new_beam( partial_word=previous_beam.partial_word + new_token, last_token=new_token, last_token_index=new_token_index, + text_frames=beam.text_frames, + partial_frames=new_part_frames, score=-math.inf, score_ctc=-math.inf, p_b=-math.inf, @@ -1769,7 +1808,7 @@ def partial_decoding( # select only the valid frames, i.e., the frames that are not padded log_probs = log_probs[:wav_len] - for _, logit_col in enumerate(log_probs, start=processed_frames): + for frame_index, logit_col in enumerate(log_probs, start=processed_frames): # skip the frame if the blank probability is higher than the threshold if ( self.blank_skip_threshold is not None @@ -1811,6 +1850,7 @@ def partial_decoding( new_text = beam.text + token new_beam = self._get_new_beam( + frame_index, new_text, token, token_index,