## final code

In [55]:
import numpy as np
from tqdm import tqdm
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC, Wav2Vec2ForCTC
from datasets import load_dataset
import jiwer
import torch

# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-large-960h")
mdl_data2vec = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-large-960h")
mdl_wav2vec = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

sample_test = sample(librispeech_eval['audio'], 20)

X = []
Y = []
for ind in tqdm(np.random.randint(0, len(librispeech_eval['audio']), 20)):
    X.append(librispeech_eval['audio'][ind]['array'])
    Y.append(librispeech_eval['text'][ind])
    
input_values = processor(X, return_tensors="pt", sampling_rate=16000, padding="longest").input_values
with torch.no_grad():
    logits_data2vec = mdl_data2vec(input_values).logits
    logits_wav2vec = mdl_wav2vec(input_values).logits
    
Y_pred = processor.batch_decode(torch.argmax(logits_data2vec, dim=-1))
Y_base = processor.batch_decode(torch.argmax(logits_wav2vec, dim=-1))


jiwer.wer(Y, Y_pred)
jiwer.cer(Y, Y_pred)

Downloading: 100%|██████████| 843/843 [00:00<00:00, 211kB/s]
Downloading: 100%|██████████| 1.18G/1.18G [01:05<00:00, 19.1MB/s]
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Reusing dataset librispeech_asr (C:\Users\Jayma\.cache\huggingface\datasets\librispeech_asr\clean\2.1.0\14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)
100%|██████████| 20/20 [02:57<00:00,  8.87s/it]


0.0021333333333333334

In [57]:
print("WER baseline : {:.2}, CER baseline {:.2}".format(jiwer.wer(Y, Y_base), jiwer.cer(Y, Y_base)))
print("WER data2vec : {:.2}, CER data2vec {:.2}".format(jiwer.wer(Y, Y_pred), jiwer.cer(Y, Y_pred)))

WER baseline : 0.031, CER baseline 0.0096
WER data2vec : 0.011, CER data2vec 0.0021
