# WAV2VEC2 Inference Benchmarking

Based on https://pytorch.org/audio/stable/tutorials/speech_recognition_pipeline_tutorial.html<br>

Originally, this notebook was meant to help me understand how to run WAV2VEC on a GPU.  However, it appears that MPS support for WAV2VEC is not clear.  <br>

See other notebooks on decoding for more detailed examples of using WAV2VEC for inference, etc.<br>


## Pytorch Setup

In [134]:
# was recommended for an error with mps mode (didn't seem to help).  
# %env PYTORCH_ENABLE_MPS_FALLBACK=1 

import os
import time
import torch
import torchaudio
import numpy as np
import IPython

from transformers import pipeline

torch.random.manual_seed(0)
print(f"PyTorch Version: {torch.__version__}, Pytorchaudio Version: {torchaudio.__version__}")

#SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
#SPEECH_FILE = "_assets/speech.wav"
SPEECH_FILE = "data/mary_had_a_little_lamb_spoken.wav"
waveform, sample_rate = torchaudio.load(SPEECH_FILE)

PyTorch Version: 2.2.1, Pytorchaudio Version: 2.2.1


In [135]:
# Create Pipeline
model_checkpoint="facebook/wav2vec2-large-960h-lv60-self"
asr_pipeline = pipeline("automatic-speech-recognition", model=model_checkpoint)
model = asr_pipeline.model
vocab = [label.lower() for label in asr_pipeline.tokenizer.vocab]
target_sampling_rate = asr_pipeline.feature_extractor.sampling_rate

Some weights of the model checkpoint at facebook/wav2vec2-large-960h-lv60-self were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.maske

In [136]:
def truncate_waveform(data: torch.Tensor,sr: float, seconds: float) ->torch.Tensor:
    """
    Summary: truncates tensor to length that cooresponds to sr * seconds
    
    Input Arguments: 
    data (torch.Tensor) - audio data
    sr (torch.Float) - sampling rate of audio data
    seconds - seconds of data to return. 
            if the seconds is specified to be larger than the len(tensor) / sr, then an error is raised.  
            if seconds is 0, then error is raised.
    
    Return: torch.Tensor that is truncated from input, or unmodified.
    
    """
    # calculate the truncation length of tensor.
    truncation_len = int(np.floor(sr * seconds))
    
    # pull out dimensions.
    dim, max_len = data.shape
    
    if max_len < truncation_len: 
        raise Exception(f"Size of data: {data.shape}.  Size of Requested Window: {seconds} seconds.  Expected window to be less than: {max_len / sr:.2f}.")
    
    # check for errors in the input.
    if seconds == 0.0: 
        raise Exception("seconds cannot be 0.0.")

    #if max_len > truncation_len:
    #    print("normal truncation occuring.") 
    
    return data[:,:truncation_len]
    
testing_truncate_waveform = True
if testing_truncate_waveform:
    print('Test 1')
    try: 
        w1 = truncate_waveform(waveform, sr=sample_rate, seconds=0.0)
    except Exception as e:
        print(e)

    print('Test 2')
    try: 
        w2 = truncate_waveform(waveform, sr=sample_rate, seconds=30.0)
    except Exception as e:
        print(e)

    print('Test 3')
    try: 
        w3 = truncate_waveform(waveform, sr=sample_rate, seconds=1.0)
        print(w3.size())
    except Exception as e:
        print(e)


Test 1
seconds cannot be 0.0.
Test 2
Size of data: torch.Size([1, 980870]).  Size of Requested Window: 30.0 seconds.  Expected window to be less than: 22.24.
Test 3
torch.Size([1, 44100])


In [137]:
class GreedyCTCDecoder(torch.nn.Module):
    """
    Summary: simple decoder using argmax to determine best character, 
             then remove duplicates per CTC's algorithm.
    
    Note: would be better to use CTC using maximumizing liklihood of sequence 
         (i.e. using adjacent logits to guess characters.) 
    """
    def __init__(self, vocab, pad=0):
        super().__init__()
        self.labels = vocab
        self.blank = pad

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [int(i) for i in indices if i != self.blank]
        
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()
        
greedy_decoder = GreedyCTCDecoder(vocab)

In [138]:
# Load Song, truncate to 1.0 second
if os.path.exists(SPEECH_FILE):
    max_length_secs = 1.0
    data_waveform, data_sample_rate = torchaudio.load(SPEECH_FILE)
    print(f"File Sample Rate: {data_sample_rate}, Number of Samples: {data_waveform.shape}")
    data_waveform = torchaudio.functional.resample(data_waveform, data_sample_rate, target_sampling_rate)
    sample_rate = target_sampling_rate
    waveform = truncate_waveform(data_waveform, sample_rate, seconds=max_length_secs)
    print(f"Target Sample Rate: {sample_rate}")
else:
    print('NO FILE HERE!')

chans, samples = waveform.size()
print(f"Number of Samples: {samples}")
print(f"Wavform is: {samples / sample_rate}s long.")
IPython.display.Audio(waveform,rate=sample_rate)

File Sample Rate: 44100, Number of Samples: torch.Size([1, 980870])
Target Sample Rate: 16000
Number of Samples: 16000
Wavform is: 1.0s long.


## Pytorch Inference (CPU) Benchmark
**Objective**: Measure single inference lookback to understand if I can keep up real time.  <br>

In [139]:
# benchmark for 1.0s with CPU
# Only run a single forward pass, batch size=1
waveform = waveform.to('cpu')
model.to('cpu')
print(f"Model Running on: {model.device}, Waveform Running on CPU: {waveform.is_cpu}")
start = time.time()
with torch.inference_mode():
    emission = model(waveform)
    print(emission[0].shape)
    transcript = greedy_decoder(emission[0][0,:,:])
finish = time.time()
print(f"Time to perform inference (with decoding): {finish-start}")
print(f"Transcript: {transcript}")

Model Running on: cpu, Waveform Running on CPU: True
torch.Size([1, 49, 32])
Time to perform inference (with decoding): 0.3105967044830322
Transcript: ['mary', 'had']


**Conclusion**: CPU based inference is not good enough a realtime lookback.  because the delays are going to make predicting words practically, not useful.  Need a response time near 100ms max (speculation).  

Question: I believe the model emissions / logits will perform better if there is more history.  With this in mind, can I actually just concatenate prior emissions with a good CTC decoder?

COA 1: concatenate emissions, perform decoding every pass on longer array of emissions.<br>
COA 2: concatenate emissions, perform decoding every pass on longer array of emissions AND add a LM that can correct for issues in text.<br>
COA 3: make a guess at what will be needed 1 second into the future (bad idea).<br>



# Pytorch Inference (M1 GPU) Benchmark 

In [140]:
# benchmark for 1.0s with GPU
# Only run a single forward pass, batch size=1
waveform = waveform.to('mps')
model.to('mps') 
print(f"Model Running on: {model.device}, Waveform Running on GPU: {waveform.is_mps}")

start = time.time()
with torch.inference_mode():
    gpu_emission = model(waveform)
    print(gpu_emission[0].shape)
    transcript_gpu = greedy_decoder(gpu_emission[0][0,:,:])
finish = time.time()
print(f"Time to perform inference (with decoding): {finish-start}")
print(f"Transcript: {transcript_gpu}")


print(torch.argmax(gpu_emission[0][0,:,:],dim=1))

Model Running on: mps:0, Waveform Running on GPU: True
torch.Size([1, 49, 32])
Time to perform inference (with decoding): 0.07041597366333008
Transcript: []
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0], device='mps:0')


In [144]:
torch.argmax(gpu_emission[0],dim=-1)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0]], device='mps:0')

In [143]:
torch.argmax(emission[0],dim=-1)

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0, 17,  0,  0,  0,  7,  0,  0,  0, 13,
         13,  0,  0,  0,  0, 22,  0,  0,  0,  0,  0,  0,  4,  0, 11,  0,  0,  7,
          7,  0,  0,  0, 14,  0,  0,  0,  0,  0,  0,  0,  4]])

**Conclusion**: GPU support for weight norm is cuasing some kind of issue.  I used the FALLBACK (see above) environmental variable to see if that helped, but the response is blank, so Im not sure if the weight norm is doing a pass through thing, and then corrupting the results, or if the data is not making it to the output for some reason.  This issue may be similar to the CoreML issue below.

Issue here: 
https://github.com/pytorch/pytorch/issues/77764

**Conclusion**: there is some issue with weight norm that I think there is a work around for, but I haven't have the time to work through it. <br>
https://github.com/pytorch/pytorch/issues/57289

Apparently, this may work:<br> 
`for layer in layers_with_weight_norm:`<br>
`   torch.nn.utils.remove_weight_norm(layer)`

# ONNX Benchmark
**Objective**: Attempt to see if ONNX gives you some imporovement in performance on CPU
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html


In [145]:
import onnx, onnxruntime

# Exporting Pytorch Model to ONNX file...
waveform = waveform_trunc.to('cpu')
model.to('cpu')
print(f"Model Running on: {model.device}, Waveform Running on CPU: {waveform.is_cpu}")

# Putting into eval mode: i.e. disabling calculation of gradients.
model.eval()

# Adding example inputs to perform something called tracing...
example_input = torch.randn(1,16000)
# returns example emissions (garbage output) from the pytorch model
torch_out = model(example_input)

# take the original pytorch model, and the example_input and export to file...
torch.onnx.export(model,
                  example_input,
                  'wav2vec2.onnx')

print("ONNX file saved.")

Model Running on: cpu, Waveform Running on CPU: True


  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):


ONNX file saved.


In [146]:
# Loading ONNX file and preparing for Inference.

# loading the model into memory, checking for issues, c
omodel = onnx.load('wav2vec2.onnx')
onnx.checker.check_model(omodel)

# InferenceSession is the main class of ONNX Runtime. 
# It is used to load and run an ONNX model, 
# as well as specify environment and application configuration options.
# as far as I can tell GPU support is spotty for M1
# https://onnxruntime.ai/docs/execution-providers/
ort_session = onnxruntime.InferenceSession('wav2vec2.onnx')
print("Model Ready for Inference")

Model Ready for Inference


In [147]:
start= time.time()
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(waveform[0,:].reshape(1,16000))}
ort_outs = ort_session.run(None, ort_inputs)
transcript = greedy_decoder(torch.Tensor(ort_outs[0][0]))
finish = time.time()
print(f"Time to perform inference: {finish-start}")
print(transcript)


Time to perform inference: 0.18062806129455566
[]


**Conclusion**: ONNX used to work, but with the latest update there are some holes that haven't been filled, so I'm going to ignore ONNX until I absolutely need it.

In [None]:
import time
from IPython.display import clear_output

i = 0
while True: 
    print(f"Hello: {i}")
    i += 1
    time.sleep(0.2)
    clear_output(wait=True)

# Next Steps
Attempt to profile the ONNX model.<br>
Is this running on the M1 chip? <br>

Here is an example for profiling:<br>
https://machinelearning.apple.com/research/neural-engine-transformers<br>