# Evaluation

This notebook runs ASR on the models.

In [53]:
# Import packages
import pandas as pd
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa
import os

In [None]:
def load_audio(file_path):
    '''
    Load audio file from filepath.
    '''
    audio, _ = librosa.load(file_path, sr=16000)
    return audio


def transcribe_audio(model, proc, file_name):
    '''
    Transcribes audio using defined model.
    '''
    audio = load_audio(file_name)
    
    input_vals = proc(
        audio, 
        return_tensors = 'pt',
        sampling_rate = 16000
    ).input_values 

    with torch.no_grad():
        logits = model(input_vals).logits 
        
    predicted_ids = torch.argmax(logits, dim = -1)
    predicted_word = proc.decode(predicted_ids[0])
    
    return predicted_word.lower()


def load_model(model_path):
    '''
    Loads model from local model filepath.
    '''
    model = Wav2Vec2ForCTC.from_pretrained(model_path)
    proc = Wav2Vec2Processor.from_pretrained(model_path)

    return model, proc


def transcribe_dataset(df, models):
    '''
    Transcribes audio files and adds a column for predictions of each model.

    Inputs: 
    -------
    df : pandas DataFrame
    models : dictionary of models

    '''
    for i in models.keys():
        model_name = i
        model_path = models[i]

        model, proc = load_model(model_path)

        column_name = model_name + '_pred'
    
        df[column_name] = df["file_name"].apply(lambda x: transcribe_audio(model, proc, x))

In [65]:
kids_df = pd.read_csv("./eval.csv")
adults_df = pd.read_csv("./adults_eval.csv")

In [61]:
models = {
    "base" :"facebook/wav2vec2-base-960h",
    "lr0" : './w2v2960h_lr0',
    "lr1e4" : './w2v2960h_lr1e4',
    "lr1e6" : './w2v2960h_lr1e6',
    "lr1e8" : "./w2v2960h_lr1e8",
    "lr1e16" : "./w2v2960h_lr1e16"
}

In [25]:
transcribe_dataset(kids_df, models)
kids_df

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,file_name,ppt_id,transcription,base_pred,lr0_pred,lr1e4_pred,lr1e6_pred,lr1e8_pred,lr1e16_pred
0,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,pe12,sand,sand,sand,,sand,sand,sand
1,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,pe12,fox,foxs,foxs,,foks,foxs,foxs
2,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,chick,ho,ho,,he,ho,ho
3,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so09,one,why,why,,why,why,why
4,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so05,rainbow,rar yumow,rar yumow,,ramow,rar yumow,rar yumow
...,...,...,...,...,...,...,...,...,...
394,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so10,wing,wing,wing,,wing,wing,wing
395,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,dd08,socks,chok,chok,,cholk,chok,chok
396,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,so10,rock,ra,ra,,ra,ra,ra
397,/home/cogsci-lasrlab/Documents/CSS_Capstone/KT...,lt33,wing,wig,wig,,wig,wig,wig


In [26]:
kids_df.to_csv('kids_eval.csv', index=False)

In [66]:
transcribe_dataset(adults_df, models)
adults_df

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,file_name,transcription,base_pred,lr0_pred,lr1e4_pred,lr1e6_pred,lr1e8_pred,lr1e16_pred
0,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,one,one,one,,one,one,one
1,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,sick,sick,sick,,sick,sick,sick
2,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,fox,fox,fox,,fox,fox,fox
3,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,chip,chip,chip,,chep,chip,chip
4,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,sand,sand,sand,,sand,sand,sand
...,...,...,...,...,...,...,...,...
395,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,fox,fox,fox,,fox,fox,fox
396,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,chip,chip,chip,,,chip,chip
397,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,run,ran,ran,,ran,ran,ran
398,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,zoo,zoo,zoo,,zoo,zoo,zoo


In [67]:
adults_df.to_csv("adults_eval.csv", index=False)

In [None]:
num_correct = 0

for i in adults_df.iterrows():
    target = adults_df['transcription']
    pred = adults_df['lr1e6_pred']
    if target == pred:
        num_correct += 1



KeyError: 'key of type tuple not found and not a MultiIndex'

In [None]:
# dir = "/home/cogsci-lasrlab/Documents/CSS_Capstone/researcher"

# files = []
# words = []

# for file in os.listdir(dir):
#     path = dir+"/"+file
#     files.append(path)

#     if file.endswith('.wav') and "researcher_" in file:
#         base = file[:-4]  # remove '.wav'
#         before, text = base.split("researcher_")

#         words.append(text)

# adults_df = pd.DataFrame({
#     "file_name":files,
#     "transcription":words
# })

# sample_adults_df = adults_df.sample(n=400, random_state=123)
# sample_adults_df

# sample_adults_df.to_csv("adults_eval.csv", index=False)

Unnamed: 0,file_name,transcription
676,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,one
1226,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,sick
437,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,fox
1371,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,chip
1659,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,sand
...,...,...
927,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,fox
1445,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,chip
429,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,run
116,/home/cogsci-lasrlab/Documents/CSS_Capstone/re...,zoo
