# WAV2VEC2 Inference Benchmarking

Based on https://pytorch.org/tutorials/intermediate/speech_recognition_pipeline_tutorial.html


## Pytorch Inference

In [1]:
# was recommended for an error with mps mode (didn't seem to help).  
# %env PYTORCH_ENABLE_MPS_FALLBACK=1 

import os
import time
import torch
import torchaudio
import numpy as np

torch.random.manual_seed(0)
print(f"PyTorch Version: {torch.__version__}, Pytorchaudio Version: {torchaudio.__version__}")

#SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
#SPEECH_FILE = "_assets/speech.wav"
basename = "data/mary_had_a_little_lamb_spoken"
SPEECH_FILE = basename + ".wav"
waveform, sample_rate = torchaudio.load(SPEECH_FILE)

PyTorch Version: 2.2.1, Pytorchaudio Version: 2.2.1


In [2]:
def truncate_waveform(data: torch.Tensor,sr: float, seconds: float) ->torch.Tensor:
    """
    Summary: truncates tensor to length that cooresponds to sr * seconds
    
    Input Arguments: 
    data (torch.Tensor) - audio data
    sr (torch.Float) - sampling rate of audio data
    seconds - seconds of data to return. 
            if the seconds is specified to be larger than the len(tensor) / sr, then an error is raised.  
            if seconds is 0, then error is raised.
    
    Return: torch.Tensor that is truncated from input, or unmodified.
    
    """
    # calculate the truncation length of tensor.
    truncation_len = int(np.floor(sr * seconds))
    
    # pull out dimensions.
    dim, max_len = data.shape
    
    if max_len < truncation_len: 
        raise Exception(f"Expected seconds to be less than: {max_len / sr:.2f}.  Seconds is: {seconds}.")
    
    # check for errors in the input.
    if seconds == 0.0: 
        raise Exception("seconds cannot be 0.0.")

    #if max_len > truncation_len:
    #    print("normal truncation occuring.") 
    
    return data[:,:truncation_len]
    
# tests in this assume the waveform is 5s long.
testing_truncate_waveform = True
wavey, sample_rate = torchaudio.load(SPEECH_FILE)
if testing_truncate_waveform:
    print('Test 1')
    try: 
        w1 = truncate_waveform(wavey, sr=sample_rate, seconds=0.0)
    except Exception as e:
        print(e)

    print('Test 2')
    try: 
        w2 = truncate_waveform(wavey, sr=sample_rate, seconds=30.0)
    except Exception as e:
        print(e)

    print('Test 3')
    try: 
        w3 = truncate_waveform(wavey, sr=sample_rate, seconds=1.0)
        print(w3.size())
    except Exception as e:
        print(e)


Test 1
seconds cannot be 0.0.
Test 2
Expected seconds to be less than: 22.24.  Seconds is: 30.0.
Test 3
torch.Size([1, 44100])


## Pytorch Inference (CPU) Benchmark
**Objective**: Measure single inference lookback to understand if I can keep up real time.  <br>

In [21]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
device = torch.device('mps')
model = bundle.get_model().to(device)

In [22]:
if os.path.exists(SPEECH_FILE):
    max_length_secs = 1.0
    waveform, sample_rate = torchaudio.load(SPEECH_FILE)
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
    sample_rate = bundle.sample_rate
    waveform_trunc = truncate_waveform(waveform, sample_rate, seconds=max_length_secs)
    waveform = waveform_trunc.to(device)
    print(f"Target Sample Rate: {sample_rate}")
else:
    print('NO FILE HERE!')

chans, samples = waveform.size()
print(f"wavform is: {samples / bundle.sample_rate}s long ({samples} samples).")

Target Sample Rate: 16000
wavform is: 1.0s long (16000 samples).


In [23]:
torch.randn(1,16000).shape

torch.Size([1, 16000])

In [24]:
class GreedyCTCDecoder(torch.nn.Module):
    """
    Summary: simple decoder using argmax to determine best character, 
             then remove duplicates per CTC's algorithm.
    
    Note: would be better to use CTC using maximumizing liklihood of sequence 
         (i.e. using adjacent logits to guess characters.) 
    """
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor):
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()

tokens = [label.lower() for label in bundle.get_labels()]
greedy_decoder = GreedyCTCDecoder(tokens)

In [25]:
# benchmark for 1.0s
start = time.time()
with torch.inference_mode():
  emission, _ = model(waveform)
  transcript = greedy_decoder(emission[0])
finish = time.time()
print(f"Time to perform inference (with decoding): {finish-start}")
print(f"Transcript: {transcript}")

Time to perform inference (with decoding): 0.08699202537536621
Transcript: ['mary', 'had', 'a']


In [26]:
emission.shape

torch.Size([1, 49, 29])

**Conclusion**: CPU based inference is not good enough a realtime lookback.  because the delays are going to make predicting words practically, not useful.  Need a response time near 100ms max (speculation).  

Question: I believe the model emissions / logits will perform better if there is more history.  With this in mind, can I actually just concatenate prior emissions with a good CTC decoder?

COA 1: concatenate emissions, perform decoding every pass on longer array of emissions.<br>
COA 2: concatenate emissions, perform decoding every pass on longer array of emissions AND add a LM that can correct for issues in text.<br>
COA 3: make a guess at what will be needed 1 second into the future (bad idea).<br>



# Pytorch Inference (M1 GPU) Benchmark 

In [27]:
device   = torch.device('mps')
waveform = waveform_trunc.to(device)
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

In [28]:
# benchmark for 1.0s
start = time.time()
transcript_gpu = ''
try: 
    with torch.inference_mode():
        emission, _ = model(waveform[0,:].reshape(1,16000))
        transcript_gpu = greedy_decoder(emission[0])

except NotImplementedError as e:
    print("NotImplementedError:",e)

finish = time.time()
print(f"Time to perform inference (with decoding): {finish-start}")
print(f"Transcript: {transcript_gpu}")

Time to perform inference (with decoding): 0.23835325241088867
Transcript: ['mary', 'had', 'a']


**Conclusion**: GPU support for weight norm is cuasing some kind of issue.  I used the FALLBACK (see above) environmental variable to see if that helped, but the response is blank, so Im not sure if the weight norm is doing a pass through thing, and then corrupting the results, or if the data is not making it to the output for some reason.  This issue may be similar to the CoreML issue below.

Issue here: 
https://github.com/pytorch/pytorch/issues/77764

# CoreML Benchmark

Pytorch directly to CoreML:<br>
https://coremltools.readme.io/docs/pytorch-conversion#generate-a-torchscript-version

**Conclusion**: there is some issue with weight norm that I think there is a work around for, but I haven't have the time to work through it. <br>
https://github.com/pytorch/pytorch/issues/57289

Apparently, this may work:<br> 
`for layer in layers_with_weight_norm:`<br>
`   torch.nn.utils.remove_weight_norm(layer)`

# ONNX Benchmark
**Objective**: Attempt to see if ONNX gives you some imporovement in performance on CPU
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html


In [35]:
device   = torch.device('cpu')
waveform = waveform_trunc.to(device)
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

In [36]:
import onnx, onnxruntime

# Exporting Pytorch Model to ONNX file...

# Putting into eval mode: i.e. disabling calculation of gradients.
model.eval()

# Adding example inputs to perform something called tracing...
example_input = torch.randn(1,16000)
# returns example emissions (garbage output) from the pytorch model
torch_out = model(example_input)

# take the original pytorch model, and the example_input and export to file...
torch.onnx.export(model,
                  example_input,
                  'wav2vec2.onnx')

print("ONNX file saved.")

ONNX file saved.


In [37]:
# Loading ONNX file and preparing for Inference.

# loading the model into memory, checking for issues, c
omodel = onnx.load('wav2vec2.onnx')
onnx.checker.check_model(omodel)

# InferenceSession is the main class of ONNX Runtime. 
# It is used to load and run an ONNX model, 
# as well as specify environment and application configuration options.
# as far as I can tell GPU support is spotty for M1
# https://onnxruntime.ai/docs/execution-providers/
ort_session = onnxruntime.InferenceSession('wav2vec2.onnx')
print("Model Ready for Inference")

Model Ready for Inference


In [38]:
start= time.time()
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(waveform[0,:].reshape(1,16000))}
ort_outs = ort_session.run(None, ort_inputs)
transcript = greedy_decoder(torch.Tensor(ort_outs[0][0]))
finish = time.time()
print(f"Time to perform inference: {finish-start}")
print(transcript)


Time to perform inference: 0.09635710716247559
['mary', 'had', 'a']


**Conclusion**: ONNX used to work, but with the latest update there are some holes that haven't been filled, so I'm going to ignore ONNX until I absolutely need it.

In [None]:
import time
from IPython.display import clear_output

i = 0
while True: 
    print(f"Hello: {i}")
    i += 1
    time.sleep(0.2)
    clear_output(wait=True)

# Next Steps
Attempt to profile the ONNX model.<br>
Is this running on the M1 chip? <br>

Here is an example for profiling:<br>
https://machinelearning.apple.com/research/neural-engine-transformers<br>