# WAV2VEC2 Inference Notebook 2

Based on https://pytorch.org/tutorials/intermediate/speech_recognition_pipeline_tutorial.html

And 
https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html

Created audio for Mary Had a Little Lamb, and Tested Greedy and Beamsearch Decoding.  

## Pytorch Inference

In [1]:
import os
import time
import torch
import torchaudio
from torchaudio.models.decoder import ctc_decoder
from torchaudio.models.decoder import download_pretrained_files

import numpy as np
import IPython

torch.random.manual_seed(0)
device = torch.device('cpu')

print(f"PyTorch Version: {torch.__version__}, Pytorchaudio Version: {torchaudio.__version__}, Targeted Device: {device}")

PyTorch Version: 1.12.0, Pytorchaudio Version: 0.12.0, Targeted Device: cpu


In [2]:
# Loading ASR Model
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

## Part 1: Mary Had a Little Lamb Spoken

In [8]:
SPEECH_FILE = "data/mary_had_a_little_lamb_spoken.wav"
actual_transcript = "mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went her lamb was sure to go"
actual_transcript = actual_transcript.split()
if os.path.exists(SPEECH_FILE):
    waveform, sample_rate = torchaudio.load(SPEECH_FILE)
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
    sample_rate = bundle.sample_rate
    waveform = waveform.to(device)
else:
    print('NO FILE HERE!')
dim, num_samples = waveform.size()
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
IPython.display.Audio(SPEECH_FILE,rate=bundle.sample_rate)

Song is 22.24 seconds long.  355872 samples long.


## Greedy Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [4]:
class GreedyCTCDecoder(torch.nn.Module):
    """
    Summary: simple decoder using argmax to determine best character, 
             then remove duplicates per CTC's algorithm.
    
    Note: would be better to use CTC using maximumizing liklihood of sequence 
         (i.e. using adjacent logits to guess characters.) 
    """
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()

tokens = [label.lower() for label in bundle.get_labels()]
greedy_decoder = GreedyCTCDecoder(tokens)

In [9]:
# Performance on Greedy Decoding on 24s spoken song with lots of repitition.
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    greedy_result = greedy_decoder(emission[0])

finish = time.time()

# create transcript and WER
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
    
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"Transcript: {greedy_transcript}\n WER: {greedy_wer:.3f}")

Time to perform inference (with decoding): 9.9 seconds.
Song is 22.24 seconds long.  355872 samples long.
Transcript: mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary want her lamb was sure to go
 WER: 0.026


## Beam Search Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [11]:
files = download_pretrained_files("librispeech-4-gram")
LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

In [12]:
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    beam_search_result = beam_search_decoder(emission)

finish = time.time()
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
    actual_transcript
)

print(f"Transcript: {beam_search_transcript}")
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"WER: {beam_search_wer: .3f}")

Transcript: mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went the lamb was sure to go
Time to perform inference (with decoding): 11.8 seconds.
Song is 22.24 seconds long.  355872 samples long.
WER: 0.02631578947368421


## Part 2: Mary Had a Little Lamb Song

In [14]:
SPEECH_FILE = "data/mary_had_a_little_lamb_song.wav"
actual_transcript = "mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went her lamb was sure to go"
actual_transcript = actual_transcript.split()
if os.path.exists(SPEECH_FILE):
    waveform, sample_rate = torchaudio.load(SPEECH_FILE)
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
    sample_rate = bundle.sample_rate
    waveform = waveform.to(device)
else:
    print('NO FILE HERE!')
dim, num_samples = waveform.size()
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
IPython.display.Audio(SPEECH_FILE,rate=bundle.sample_rate)

Song is 22.95 seconds long.  367272 samples long.


## Greedy Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [15]:
class GreedyCTCDecoder(torch.nn.Module):
    """
    Summary: simple decoder using argmax to determine best character, 
             then remove duplicates per CTC's algorithm.
    
    Note: would be better to use CTC using maximumizing liklihood of sequence 
         (i.e. using adjacent logits to guess characters.) 
    """
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()

tokens = [label.lower() for label in bundle.get_labels()]
greedy_decoder = GreedyCTCDecoder(tokens)

In [16]:
# Performance on Greedy Decoding on 24s spoken song with lots of repitition.
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    greedy_result = greedy_decoder(emission[0])

finish = time.time()

# create transcript and WER
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
    
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"Transcript: {greedy_transcript}\n WER: {greedy_wer:.3f}")

Time to perform inference (with decoding): 10.4 seconds.
Song is 22.95 seconds long.  367272 samples long.
Transcript: mery had a litle lamp litle lamp litle lamp mery had a litle lamb her fleece was white as sno everywhere that mery went mery went mery went everywhere that mery went her lamb was sure to go
 WER: 0.368


## Beam Search Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [17]:
files = download_pretrained_files("librispeech-4-gram")
LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

In [18]:
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    beam_search_result = beam_search_decoder(emission)

finish = time.time()
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
    actual_transcript
)

print(f"Transcript: {beam_search_transcript}")
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"WER: {beam_search_wer: .3f}")

Transcript: mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went the lamb was sure to go
Time to perform inference (with decoding): 12.9 seconds.
Song is 22.95 seconds long.  367272 samples long.
WER:  0.026


## Part 3: Mary Had a Little Lamb Metal Song!
Rock A-Bye Melodies: https://www.youtube.com/watch?v=ZYv8Ro3crP8

In [19]:
SPEECH_FILE = "data/mary_had_a_little_lamb_metal.wav"
actual_transcript = "mary had a little lamb her fleece was white as snow everywhere that mary went her lamb was sure to go"
actual_transcript = actual_transcript.split()
if os.path.exists(SPEECH_FILE):
    waveform, sample_rate = torchaudio.load(SPEECH_FILE)
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
    sample_rate = bundle.sample_rate
    waveform = waveform.to(device)
else:
    print('NO FILE HERE!')
dim, num_samples = waveform.size()
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
IPython.display.Audio(SPEECH_FILE,rate=bundle.sample_rate)

Song is 12.37 seconds long.  197962 samples long.


## Greedy Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [20]:
class GreedyCTCDecoder(torch.nn.Module):
    """
    Summary: simple decoder using argmax to determine best character, 
             then remove duplicates per CTC's algorithm.
    
    Note: would be better to use CTC using maximumizing liklihood of sequence 
         (i.e. using adjacent logits to guess characters.) 
    """
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()

tokens = [label.lower() for label in bundle.get_labels()]
greedy_decoder = GreedyCTCDecoder(tokens)

In [21]:
# Performance on Greedy Decoding on 24s spoken song with lots of repitition.
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    greedy_result = greedy_decoder(emission[0])

finish = time.time()

# create transcript and WER
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
    
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"Transcript: {greedy_transcript}\n WER: {greedy_wer:.3f}")

Time to perform inference (with decoding): 10.1 seconds.
Song is 12.37 seconds long.  197962 samples long.
Transcript: 
 WER: 1.000


## Beam Search Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [22]:
files = download_pretrained_files("librispeech-4-gram")
LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

In [23]:
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    beam_search_result = beam_search_decoder(emission)

finish = time.time()
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
    actual_transcript
)

print(f"Transcript: {beam_search_transcript}")
print(f"Time to perform inference (with decoding): {finish-start:.1f} seconds.")
print(f"Song is {num_samples / sample_rate:.2f} seconds long.  {num_samples} samples long.")
print(f"WER: {beam_search_wer: .3f}")

Transcript: 
Time to perform inference (with decoding): 15.1 seconds.
Song is 12.37 seconds long.  197962 samples long.
WER:  1.000
