Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Bring back EncoderASR
  • Loading branch information
TParcollet committed Sep 17, 2021
1 parent 2ec4839 commit 09326e5
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions speechbrain/pretrained/interfaces.py
Expand Up @@ -475,6 +475,120 @@ def transcribe_batch(self, wavs, wav_lens):
for token_seq in predicted_tokens
]
return predicted_words, predicted_tokens

class EncoderASR(Pretrained):
"""A ready-to-use Encoder ASR model
The class can be used either to run only the encoder (encode()) to extract
features or to run the entire encoder + decoder function model
(transcribe()) to transcribe speech. The given YAML must contains the fields
specified in the *_NEEDED[] lists.
Example
-------
>>> from speechbrain.pretrained import EncoderASR
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = EncoderASR.from_hparams(
... source="speechbrain/asr-wav2vec2-commonvoice-fr",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
"""

HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
MODULES_NEEDED = ["encoder"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
self.decoding_function = self.hparams.decoding_function

def transcribe_file(self, path):
"""Transcribes the given audiofile into a sequence of words.
Arguments
---------
path : str
Path to audio file which to transcribe.
Returns
-------
str
The audiofile transcription produced by this ASR system.
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = self.transcribe_batch(
batch, rel_length
)
return str(predicted_words[0])

def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.modules.encoder(wavs, wav_lens)
return encoder_out

def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predictions = self.decoding_function(encoder_out, wav_lens)
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predictions
]
return predicted_words, predictions


class EncoderClassifier(Pretrained):
Expand Down

0 comments on commit 09326e5

Please sign in to comment.