# WAV2VEC2 Inference Notebook 2

Based on https://pytorch.org/tutorials/intermediate/speech_recognition_pipeline_tutorial.html

And 
https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html

Created audio for Mary Had a Little Lamb, and Tested Greedy and Beamsearch Decoding.  

## Pytorch Inference

In [1]:
import os
import time
import copy
import torch
import torchaudio
from torchaudio.models.decoder import ctc_decoder
from torchaudio.models.decoder import download_pretrained_files

import numpy as np
import IPython
from IPython.display import clear_output

torch.random.manual_seed(0)
device = torch.device('cpu')

print(f"PyTorch Version: {torch.__version__}, Pytorchaudio Version: {torchaudio.__version__}, Targeted Device: {device}")

PyTorch Version: 2.2.1, Pytorchaudio Version: 2.2.1, Targeted Device: cpu


In [2]:
# Loading ASR Model
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

## Part 1: Mary Had a Little Lamb Spoken

In [3]:
SPEECH_FILE = "data/mary_had_a_little_lamb_spoken.wav"
actual_transcript = "mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went her lamb was sure to go"
actual_transcript = actual_transcript.split()
if os.path.exists(SPEECH_FILE):
    waveform, sample_rate = torchaudio.load(SPEECH_FILE)
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
    waveform = waveform.to(device)
else:
    print('NO FILE HERE!')
dim, num_samples = waveform.size()
print(f"Song is {num_samples / bundle.sample_rate:.2f} seconds long.  {num_samples} samples long.")
IPython.display.Audio(SPEECH_FILE,rate=bundle.sample_rate)

Song is 22.24 seconds long.  355872 samples long.


## Beam Search Decoding Transcription
**Objective**: Execise the model with all audio data for Mary Had a L. <br>

In [4]:
files = download_pretrained_files("librispeech-4-gram")
LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

In [5]:
start = time.time()
with torch.inference_mode():
    emission, _ = model(waveform)
    beam_search_result = beam_search_decoder(emission)

finish = time.time()
chan, samples = waveform.size()
print(f"Total Time: {samples / bundle.sample_rate:.1f}")
print(f"Time to perform inference: {finish-start:.1f}")

beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
    actual_transcript
)

print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer:.3f}")

Total Time: 22.2
Time to perform inference: 2.4
Transcript: mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow everywhere that mary went mary went mary went everywhere that mary went the lamb was sure to go
WER: 0.026


# Stepping through File

Per the paper: 400 input samples (25ms), with a stride of 20ms (320 samples) therefore 253 input tensors (5.076s @ 16k = 81216 samples and divide this by 320).  <br>

**Objective**: Take the 22 second audio, and step through it and transcribe 1s of data by sliding along every 20ms.

In [6]:
# Creating indices for iteracting with a file 
# (mimic-ing the ability to pull things from an sound card buffer).

# Number of seconds for the QUEUE
SECONDS = 5
# start_index
si = 0 
# ending index
# samples in 1 second (want 1s worth of data)
ei = max_len = int(bundle.sample_rate) * SECONDS

# WAV2VEC is trained to encode 320 samples 
# at a time into embeddings
step = 320 

# how many 400 sample chunks are in a window 
# (i.e. input to wav2vec model will be reshaped to B=2,C=chunk,N=768]) )
chunks = int(np.floor(max_len / step)) 

CHANNEL = 0
dim, samples = waveform.size()
num_file_chunks = int(np.floor(samples / step))
print(f"Total Time in the file: {samples / bundle.sample_rate:.1f}s")
print(f"Total Samples in the file: {samples}")
print(f"Total inference windows: {num_file_chunks}")

print(f"Length of queue for lookback: {max_len} samples")
print(f"Time of lookback: {max_len / bundle.sample_rate}s")

Total Time in the file: 22.2s
Total Samples in the file: 355872
Total inference windows: 1112
Length of queue for lookback: 80000 samples
Time of lookback: 5.0s


In [7]:
# iterate through file with a sliding window.
# implement a queue structure that is effectively max_len long
# 
# stride 
for i in range(0,100,10):
    # ok not really a queue just a list...sigh.
    queue = []
    # remove 1xchunk samples from front of queue, 
    # append first 1xchunk samples from next part of the waveform
    queue = list(waveform[0,i*step:i*step+max_len].detach().cpu().numpy())
    
    #print(f"Now Processing: {i*step} to {i*step+max_len} | {i*step / bundle.sample_rate} to {(i*step+max_len) / bundle.sample_rate} seconds")
    with torch.inference_mode():
        input_data = torch.Tensor(queue).reshape(1,max_len).to(device)
        emission, _ = model(input_data)
        beam_search_result = beam_search_decoder(emission)
    beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
    print(f"Transcript: {beam_search_transcript}")
    
    # Used to make the experiment readable
    #time.sleep(1.0)
    #clear_output(wait=True)

Transcript: mary had a little lamb little lamb little lamb
Transcript: mary had a little lamb little lamb little lamb
Transcript: he had a little lamb little lamb little lamb
Transcript: had a little lamb little lamb little lamb
Transcript: a little lamb little lamb little lamb mary
Transcript: ittle lamb little lamb little lamb mary
Transcript: lamb little lamb little lamb mary had
Transcript: little lamb little lamb mary had a little
Transcript: little lamb little lamb mary had a little
Transcript: little lamb little lamb mary had a little lamb


**Conclusion**: 1 second is really not a great look back length even if that might be what I'm limited too.  I might have to concatenate with a LM with a longer history for streaming data.

Furthermore, 3 seconds appears more readable / understandable.  But I would think that maybe the window can be longer too if needed.  but for responsiveness sake, it would be good to keep estimates happening every 0.02s.

In [8]:
import pandas as pd
from collections import Counter

class SongTranscript:
    """
    SongTranscript - container for song transcript metadata.
    
    1. list of all unique words
    2. histogram of all unique words
    3. probability of words occuring in song
    """
    def __init__(self,filename):

        fd = open(filename)
            
        # Create transcript
        obj = fd.read()
        obj = obj.replace('\n\n',' | ')
        obj = obj.replace('\n',' ')
        obj = obj.replace(',',' ')
        
        # String sequence of the words inthe song including a | for song segment delimiting
        self._transcript_with_segments = obj.replace('  ',' ').strip(' ')
        
        # String Sequence of the words in the song without the song segment character
        self._transcript = obj.replace('  ',' ').replace(' | ',' ').strip(' ')
        fd.close()
        
        # Create word sequence (list of all the words as they appear in the transcript)
        self._word_sequence = self._transcript.split(' ')
        
        # Create unique word list
        self._unique_words = set(self._word_sequence)
        
        # Create histogram for the song
        self._frequency = dict(Counter(self._word_sequence))
    
        # Create statistic of a word occuring in the song
        self._statistic = {k:(v / len(self._unique_words)) for k,v in self._frequency.items()}
        
        # Create a list of words by segment / slide
        segments = [s.strip(' ') for s in self._transcript_with_segments.split('|')]
        self._segment_frequency = [dict(Counter(s.split(' '))) for s in segments]
        
        # Create segment specific statistic of occurance based on the segment frequency 
        # and total word count for the segment
        sumval = []
        for i, d in enumerate(self._segment_frequency):
            sumval.append(0)
            for k,v in d.items():
                sumval[i] += v
        #print(sumval)
        self._segment_statistic = []
        for j, d in enumerate(self._segment_frequency):
            self._segment_statistic.append({})
            for k,v in d.items():
                self._segment_statistic[j][k] = v / sumval[j]
    
    def get_frequency(self):
        return self._frequency
    
    def get_statistic(self):
        return self._statistic
    
    def get_words(self):
        return self._unique_words
    
    def get_word_sequence(self):
        return self._word_sequence

    def get_segment_frequency(self):
        return self._segment_frequency
    
    def get_segment_statistic(self):
        return self._segment_statistic
    
    def __str__(self):
        return self._transcript_with_segments

st = SongTranscript('data/mary_had_a_little_lamb_song.transcript')
print(st)
print()
print(st.get_words())
print()
print(st.get_word_sequence())
print()
print(len(st.get_words()))
print()
print(st.get_frequency())
print()
print(st.get_statistic())
print()
print(st.get_segment_frequency())
print()
print(st.get_segment_statistic())

mary had a little lamb little lamb little lamb mary had a little lamb her fleece was white as snow | everywhere that mary went mary went mary went everywhere that mary went the lamb was sure to go

{'mary', 'snow', 'her', 'the', 'everywhere', 'a', 'go', 'lamb', 'fleece', 'as', 'that', 'went', 'little', 'was', 'had', 'sure', 'to', 'white'}

['mary', 'had', 'a', 'little', 'lamb', 'little', 'lamb', 'little', 'lamb', 'mary', 'had', 'a', 'little', 'lamb', 'her', 'fleece', 'was', 'white', 'as', 'snow', 'everywhere', 'that', 'mary', 'went', 'mary', 'went', 'mary', 'went', 'everywhere', 'that', 'mary', 'went', 'the', 'lamb', 'was', 'sure', 'to', 'go']

18

{'mary': 6, 'had': 2, 'a': 2, 'little': 4, 'lamb': 5, 'her': 1, 'fleece': 1, 'was': 2, 'white': 1, 'as': 1, 'snow': 1, 'everywhere': 2, 'that': 2, 'went': 4, 'the': 1, 'sure': 1, 'to': 1, 'go': 1}

{'mary': 0.3333333333333333, 'had': 0.1111111111111111, 'a': 0.1111111111111111, 'little': 0.2222222222222222, 'lamb': 0.2777777777777778, 'her':