From 09326e50b4c2ea74237068258431bff7bd7b0d1a Mon Sep 17 00:00:00 2001 From: Parcollet Titouan Date: Fri, 17 Sep 2021 17:47:34 +0200 Subject: [PATCH] Bring back EncoderASR --- speechbrain/pretrained/interfaces.py | 114 +++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/speechbrain/pretrained/interfaces.py b/speechbrain/pretrained/interfaces.py index e53df8c328..cf57f3b55a 100644 --- a/speechbrain/pretrained/interfaces.py +++ b/speechbrain/pretrained/interfaces.py @@ -475,6 +475,120 @@ def transcribe_batch(self, wavs, wav_lens): for token_seq in predicted_tokens ] return predicted_words, predicted_tokens + +class EncoderASR(Pretrained): + """A ready-to-use Encoder ASR model + + The class can be used either to run only the encoder (encode()) to extract + features or to run the entire encoder + decoder function model + (transcribe()) to transcribe speech. The given YAML must contains the fields + specified in the *_NEEDED[] lists. + + Example + ------- + >>> from speechbrain.pretrained import EncoderASR + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = EncoderASR.from_hparams( + ... source="speechbrain/asr-wav2vec2-commonvoice-fr", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["tokenizer", "decoding_function"] + MODULES_NEEDED = ["encoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.hparams.tokenizer + self.decoding_function = self.hparams.decoding_function + + def transcribe_file(self, path): + """Transcribes the given audiofile into a sequence of words. + + Arguments + --------- + path : str + Path to audio file which to transcribe. + + Returns + ------- + str + The audiofile transcription produced by this ASR system. + """ + waveform = self.load_audio(path) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.transcribe_batch( + batch, rel_length + ) + return str(predicted_words[0]) + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.modules.encoder(wavs, wav_lens) + return encoder_out + + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch transcribed. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predictions = self.decoding_function(encoder_out, wav_lens) + predicted_words = [ + self.tokenizer.decode_ids(token_seq) + for token_seq in predictions + ] + return predicted_words, predictions class EncoderClassifier(Pretrained):