In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [None]:
!pip install datasets transformers evaluate soundfile librosa --quiet jiwer

import os
import glob
import json
import torch
import evaluate
import soundfile as sf
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from tqdm import tqdm
import librosa
import pandas as pd


In [None]:

FOLDER = "/content/drive/My Drive/VoiceBridge_SAP_sample"
JSON_FILE = os.path.join(FOLDER, "digital_assistant_metadata.json")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# Load metadata
with open(JSON_FILE, 'r') as f:
    json_data = json.load(f)


files_with_ratings = [entry['Filename'] for entry in json_data if entry.get('Ratings')]
print(f"{len(files_with_ratings)} files have non-empty ratings.")
files_without_ratings = [entry['Filename'] for entry in json_data if not entry.get('Ratings')]
print(f"{len(files_without_ratings)} files have empty ratings.")


# Map filename -> transcript
reference_dict = {
    entry["Filename"]: entry["Prompt"]["Transcript"].strip().lower()
    for entry in json_data
}

# Map filename -> ratings
ratings_dict = {entry["Filename"]: entry.get("Ratings", []) for entry in json_data}

wav_files = sorted(glob.glob(os.path.join(FOLDER, "*.wav")))

# wave2vec2 model
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to("cuda")

In [None]:
def load_audio(path, target_sr=16000):
    """Load and resample audio to 16kHz if needed."""
    speech_array, sr = sf.read(path)
    if sr != target_sr:
        speech_array = librosa.resample(speech_array, orig_sr=sr, target_sr=target_sr)
    return speech_array



In [None]:

@torch.inference_mode()
def batch_transcribe(wav_paths, batch_size=4):
    transcriptions = []
    for i in tqdm(range(0, len(wav_paths), batch_size), desc="Transcribing"):
        batch = [load_audio(p) for p in wav_paths[i:i+batch_size]]
        inputs = processor(batch, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(DEVICE)
        with torch.no_grad():
            logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        decoded = processor.batch_decode(predicted_ids)
        transcriptions.extend([t.strip().lower() for t in decoded])
    return transcriptions


def extract_disease(ratings):
    for r in ratings:
        d = r.get("Dimension Category Description", "").strip()
        if d:
            return d
    return None

ratings_dict = {entry["Filename"]: entry.get("Ratings", []) for entry in json_data}
reference_dict = {entry["Filename"]: entry["Prompt"]["Transcript"].strip().lower() for entry in json_data}


preds = batch_transcribe(wav_files)
metric = evaluate.load("wer")

results = []
for wav_path, pred in zip(wav_files, preds):
    filename = os.path.basename(wav_path)
    ref = reference_dict.get(filename, "")
    wer = metric.compute(predictions=[pred], references=[ref])
    disease = extract_disease(ratings_dict.get(filename, []))
    results.append({
        "Filename": filename,
        "Prediction": pred,
        "Reference": ref,
        "WER": wer,
        "Disease": disease
    })


results_df = pd.DataFrame(results)
print("Finished evaluating all samples.")


top10 = results_df.sort_values("WER", ascending=False).head(10)
print("\nTop 10 highest WER samples:")
print(top10[["Filename", "WER", "Reference", "Prediction"]].to_string(index=False))


overall_wer = results_df["WER"].mean()
print(f"\nOverall average WER: {overall_wer:.3f}")
