In [2]:
import torch
from model import Wav2Vec2ForCTC
from utils import Wav2Vec2Config
from safetensors.torch import load_file
from transformers import Wav2Vec2CTCTokenizer
import torchaudio
import jiwer

### Load Weights

We first need to load our weights from our model we trained! The only thing that really matters is we need to let the config know if we are loading pretrained weights from our own model, or if we are loading the huggingface backbone. This is mainly because the pretrained backbone from Huggingface have a slightly different structure than ours (i.e. groupnorm vs layernorm and some other stuff) Once the skeleton of the model is loaded, we can then load in our own weights from our checkpoint and then load that model in! We will also be loading in the Wav2Vec2 Tokenizer from Huggingface that we used for our training.

In [4]:
### Set the backbone type (Huggingface Backbone is slightly different from ours ###
huggingface_backbone = False
path_to_backbone = "work_dir/finetune_my_backbone/model.safetensors"

### Load Config for Model (random backbone is for our own implementation, otherwise we will load huggingface backbone) ###
### We will then replace the weights of the model with our own weights ###
config = Wav2Vec2Config(pretrained_backbone="pretrained_huggingface" if huggingface_backbone else "random")

### Provide Path to Model Weights ###
model_weights = load_file(path_to_backbone)

### Load Weights to Model ###
model = Wav2Vec2ForCTC(config)
model.load_state_dict(model_weights)
model.eval()

### Load Tokenizer with Huggingface Model Name ###
hf_model_name = config.hf_model_name
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(hf_model_name)

### Inference and Audio File

I provide a sample audio file that I just grabbed from the LibriSpeech dataset, and I also provide the corresponding transcript. You can try any audio you want here though! We just need to make sure the model is resampled to 16000Hz and is normalized before passing to the model. 

In [5]:
path_to_audio = "sample_audio/sample_audio.wav"
true_transcript = "CHAPTER SIXTEEN I MIGHT HAVE TOLD YOU OF THE BEGINNING OF THIS LIAISON IN A FEW LINES BUT I WANTED YOU TO SEE EVERY STEP BY WHICH WE CAME I TO AGREE TO WHATEVER MARGUERITE WISHED"

def transcript_audio(path_to_audio, model):
    ### Load Audio ###
    audio, sr = torchaudio.load(path_to_audio)
    
    ### Resample Audio to 16000Hz ###
    resample = torchaudio.transforms.Resample(sr, 16000)
    audio = resample(audio)
    
    ### Normalize Audio ###
    normed_audio = ((audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7))

    ### Add Batch Dimension ###
    normed_audio = normed_audio.unsqueeze(0)

    ### Inference Audio through Model ###
    with torch.inference_mode():
        loss, logits = model(normed_audio)

    ### Grab Predicted Characters at each token ###
    pred_ids = torch.argmax(logits, axis=-1).squeeze().tolist()

    ### Decode Predicted Characters with Tokenizer ###
    pred_transcription = tokenizer.decode(pred_ids)

    return pred_transcription


pred_transcript = transcript_audio(path_to_audio, model)

print("Predicted Transcript")
print(pred_transcript)

Predicted Transcript
CHAPTER SIXTEEN I MIGHT HAVE TOLD YOU OF THE BEGINNING OF THIS LAAISON IN A FEW LINES BUT I WANTED YOU TO SEE EVERY STEP BY WHICH WE CAME I TO AGREE TO WHATEVER MARGUERITE WISHED


### Compute WER

We can use ```jiwer``` to compute the Word Error rate

In [6]:
word_error_rate = jiwer.wer(reference=true_transcript, hypothesis=pred_transcript)
print("Word Error Rate:", round(word_error_rate, 3))

Word Error Rate: 0.028
