# Evaluation

This notebook runs ASR on the models.

In [2]:
# Import packages
import pandas as pd
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa

In [3]:
def load_audio(file_path):
    '''
    Load audio file from filepath.
    '''
    audio, _ = librosa.load(file_path, sr=16000)
    return audio


def transcribe_audio(model, proc, file_name):
    '''
    Transcribes audio using defined model.
    '''
    audio = load_audio(file_name)
    
    input_vals = proc(
        audio, 
        return_tensors = 'pt',
        sampling_rate = 16000
    ).input_values 

    with torch.no_grad():
        logits = model(input_vals).logits 
        
    predicted_ids = torch.argmax(logits, dim = -1)
    predicted_word = proc.decode(predicted_ids[0])
    
    return predicted_word.lower()


In [4]:
# Turn into dictionary of models
base = "facebook/wav2vec2-base-960h"
lr0 = './w2v2960h_lr0'
lr1e4 = './w2v2960h_lr1e4'
lr1e6 = './w2v2960h_lr1e6'
lr1e8 = "./w2v2960h_lr1e8"
lr1e16 = "./w2v2960h_lr1e16"

# Turn this into a function
base_model = Wav2Vec2ForCTC.from_pretrained(base)
base_proc = Wav2Vec2Processor.from_pretrained(base)

lr0_model = Wav2Vec2ForCTC.from_pretrained(lr0)
lr0_proc = Wav2Vec2Processor.from_pretrained(lr0)

lr1e4_model = Wav2Vec2ForCTC.from_pretrained(lr1e4)
lr1e4_proc = Wav2Vec2Processor.from_pretrained(lr1e4)

lr1e6_model = Wav2Vec2ForCTC.from_pretrained(lr1e6)
lr1e6_proc = Wav2Vec2Processor.from_pretrained(lr1e6)

lr1e8_model = Wav2Vec2ForCTC.from_pretrained(lr1e8)
lr1e8_proc = Wav2Vec2Processor.from_pretrained(lr1e8)

lr1e16_model = Wav2Vec2ForCTC.from_pretrained(lr1e16)
lr1e16_proc = Wav2Vec2Processor.from_pretrained(lr1e16)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
df = pd.read_csv("./eval.csv")
df =df.head(30)

In [7]:
# Run
# Turn into function
df["base_pred"] = df["file_name"].apply(lambda x: transcribe_audio(base_model, base_proc, x))
df["lr0_pred"] = df["file_name"].apply(lambda x: transcribe_audio(lr0_model, lr0_proc, x))
df["lr1e16_pred"] = df["file_name"].apply(lambda x: transcribe_audio(lr1e16_model, lr1e16_proc, x))
df["lr1e8_pred"] = df["file_name"].apply(lambda x: transcribe_audio(lr1e8_model, lr1e8_proc, x))
df["lr1e6_pred"] = df["file_name"].apply(lambda x: transcribe_audio(lr1e6_model, lr1e6_proc, x))
df["lr1e4_pred"] = df["file_name"].apply(lambda x: transcribe_audio(lr1e4_model, lr1e4_proc, x))
df

Unnamed: 0,file_name,ppt_id,transcription,base_pred,lr0_pred,lr1e16_pred,lr1e8_pred,lr1e6_pred,lr1e4_pred
0,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,pe12,sand,sand,sand,sand,sand,sand,
1,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,pe12,fox,foxs,foxs,foxs,foxs,foks,
2,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,chick,ho,ho,ho,ho,he,
3,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,one,why,why,why,why,why,
4,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so05,rainbow,rar yumow,rar yumow,rar yumow,rar yumow,ramow,
5,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so05,fan,beine,beine,beine,beine,bain,
6,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,rock,roc,roc,roc,roc,rock,
7,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,sun,five,five,five,five,five,
8,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so05,leg,red,red,red,red,red,
9,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so05,sun,fine,fine,fine,fine,fine,
