In [1]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
from scipy.io import wavfile
import os
import torch
from jiwer import wer, cer
from scipy.io import wavfile
from scipy.signal import resample_poly

In [2]:
model = Wav2Vec2ForCTC.from_pretrained("../savedModel")
processor = Wav2Vec2Processor.from_pretrained("../savedModel")

In [12]:
# Directory paths
audio_dir = "../dataset_experiment/audio"
text_dir = "../dataset_experiment/text"

# Function to load text labels
def load_text_labels(text_dir):
    labels = {}
    for filename in os.listdir(text_dir):
        if filename.endswith(".txt"):
            with open(os.path.join(text_dir, filename), "r") as f:
                labels[filename.replace(".txt", "")] = f.read().strip()
    return labels

# Load ground truth text labels
text_labels = load_text_labels(text_dir)

# Lists to store predictions and ground truth
predictions = []
ground_truths = []

# Process each audio file in the directory
for audio_file in os.listdir(audio_dir):
    if audio_file.endswith(".wav"):
        file_id = audio_file.replace(".wav", "")
        
        # Read audio file
        sampling_rate, audio = wavfile.read(os.path.join(audio_dir, audio_file))
        
        # Preprocess audio
        input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
        input_values = input_values.float()
        
        # Move to GPU if available
        if torch.cuda.is_available():
            input_values = input_values.to("cuda")
            model.to("cuda")
        
        # Perform inference
        with torch.no_grad():
            logits = model(input_values).logits
        
        # Decode predictions
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        predictions.append(transcription)
        
        # Get ground truth
        if file_id in text_labels:
            ground_truths.append(text_labels[file_id])

            # Display results
            print(f"File ID: {file_id} (CER: {cer(text_labels[file_id], transcription)}, WER: {wer(text_labels[file_id], transcription)})")
            print(f"    Predicted: {transcription}")
            print(f"    Ground Truth: {text_labels[file_id]}\n")
        else:
            print(f"Warning: No label found for {file_id}. Skipping.")

# Calculate CER and WER
cer_score = cer(ground_truths, predictions)
wer_score = wer(ground_truths, predictions)

# Display results
print("\nEvaluation Results:")
print(f"Overall Character Error Rate (CER): {cer_score}")
print(f"Overall Word Error Rate (WER): {wer_score}")

File ID: Faris_1 (CER: 0.16666666666666666, WER: 0.6666666666666666)
    Predicted: saiya orang pantung
    Ground Truth: saya orang bandung

File ID: Faris_2 (CER: 0.16666666666666666, WER: 0.6666666666666666)
    Predicted: iya sudah ngantu
    Ground Truth: saya sudah ngantuk

File ID: Faris_3 (CER: 0.1111111111111111, WER: 0.3333333333333333)
    Predicted: pari sedang lapar
    Ground Truth: faris sedang lapar

File ID: Faris_4 (CER: 0.23809523809523808, WER: 1.0)
    Predicted: sa iya sedang perjaran jalu andi kutapandung
    Ground Truth: saya sedang berjalan jalan di kota bandung

File ID: Jason_1 (CER: 0.5, WER: 2.0)
    Predicted: a ku si ya
    Ground Truth: aku siap

File ID: Jason_2 (CER: 0.125, WER: 0.6666666666666666)
    Predicted: aku u sing nubes
    Ground Truth: aku pusing nubes

File ID: Jason_3 (CER: 0.16666666666666666, WER: 0.5)
    Predicted: didu ena kali ya
    Ground Truth: tidur enak kali ya

File ID: Louis_1 (CER: 0.16666666666666666, WER: 1.0)
    Predict