Skip to content

Commit

Permalink
CTCPrefixBeamSearcher timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
Adel-Moumen committed Sep 25, 2023
1 parent e67c761 commit c3d3cf7
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion speechbrain/decoders/ctc.py
Expand Up @@ -1568,6 +1568,8 @@ def get_lm_beams(
partial_word=beam.partial_word,
last_token=beam.last_token,
last_token_index=beam.last_token_index,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
p=beam.p,
p_b=beam.p_b,
p_nb=beam.p_nb,
Expand Down Expand Up @@ -1618,6 +1620,8 @@ def get_lm_beams(
partial_word=beam.partial_word,
last_token=beam.last_token,
last_token_index=beam.last_token_index,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
p=beam.p,
p_b=beam.p_b,
p_nb=beam.p_nb,
Expand All @@ -1632,6 +1636,7 @@ def get_lm_beams(

def _get_new_beam(
self,
frame_index: int,
new_prefix: str,
new_token: str,
new_token_index: int,
Expand All @@ -1643,6 +1648,8 @@ def _get_new_beam(
Arguments
---------
frame_index : int
The index of the current frame.
new_prefix : str
The new prefix.
new_token : str
Expand All @@ -1668,6 +1675,12 @@ def _get_new_beam(
return beam

if not self.is_spm and new_token_index == self.space_index:
new_frame_list = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)

# if we extend the beam with a space, we need to reset the partial word
# and move it to the next word
new_beam = CTCBeam(
Expand All @@ -1677,6 +1690,8 @@ def _get_new_beam(
partial_word="",
last_token=new_token,
last_token_index=new_token_index,
text_frames=new_frame_list,
partial_frames=(-1, -1),
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
Expand All @@ -1685,6 +1700,12 @@ def _get_new_beam(
# remove the spm token at the beginning of the token
clean_token = new_token[1:]

new_frame_list = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)

# If the beginning of the token is the spm_token
# then it means that we are extending the beam with a new word.
# We need to change the new_word with the partial_word
Expand All @@ -1697,11 +1718,19 @@ def _get_new_beam(
partial_word=clean_token,
last_token=new_token,
last_token_index=new_token_index,
text_frames=new_frame_list,
partial_frames=(frame_index, frame_index + 1),
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
elif new_token_index == previous_beam.last_token_index:
new_end_frame = frame_index + 1

new_part_frames = (
beam.partial_frames if new_token_index == self.blank_index else (beam.partial_frames[0], new_end_frame)
)

# if repeated token, we only change the score
new_beam = CTCBeam(
text=new_prefix,
Expand All @@ -1710,11 +1739,19 @@ def _get_new_beam(
partial_word=previous_beam.partial_word,
last_token=new_token,
last_token_index=new_token_index,
text_frames=beam.text_frames,
partial_frames=new_part_frames,
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
else:
new_part_frames = (
(frame_index, frame_index + 1)
if beam.partial_frames[0] < 0
else (beam.partial_frames[0], frame_index + 1)
)

# last case, we are extending the partial_word with a new token
new_beam = CTCBeam(
text=new_prefix,
Expand All @@ -1723,6 +1760,8 @@ def _get_new_beam(
partial_word=previous_beam.partial_word + new_token,
last_token=new_token,
last_token_index=new_token_index,
text_frames=beam.text_frames,
partial_frames=new_part_frames,
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
Expand Down Expand Up @@ -1769,7 +1808,7 @@ def partial_decoding(
# select only the valid frames, i.e., the frames that are not padded
log_probs = log_probs[:wav_len]

for _, logit_col in enumerate(log_probs, start=processed_frames):
for frame_index, logit_col in enumerate(log_probs, start=processed_frames):
# skip the frame if the blank probability is higher than the threshold
if (
self.blank_skip_threshold is not None
Expand Down Expand Up @@ -1811,6 +1850,7 @@ def partial_decoding(
new_text = beam.text + token

new_beam = self._get_new_beam(
frame_index,
new_text,
token,
token_index,
Expand Down

0 comments on commit c3d3cf7

Please sign in to comment.