In [None]:
%pip install transformers datasets jiwer accelerate torchaudio librosa phonemizer editdistance

In [None]:
import editdistance
from g2p_en import G2p

g2p = G2p()

def phoneme_sequence(text):
    phones = g2p(text)
    return [p for p in phones if p.isalpha()]

def phoneme_error_rate(ref, hyp):
    ref_ph = phoneme_sequence(ref)
    hyp_ph = phoneme_sequence(hyp)

    if len(ref_ph) == 0:
        return None, [], [], None

    dist = editdistance.eval(ref_ph, hyp_ph)
    per = dist / len(ref_ph)

    return per, ref_ph, hyp_ph, dist


In [None]:
import os, glob, json
import torch
import librosa
import soundfile as sf
import pandas as pd
from tqdm import tqdm
from phonemizer import phonemize
from jiwer import cer
import numpy as np


# ------------------------
#  PHONEME ERROR RATE
# ------------------------
def phoneme_sequence(text):
    """Convert text â†’ list of phonemes."""
    try:
        ph = phonemize(
            text,
            language="en-us",
            backend="espeak",
            strip=True,
            preserve_punctuation=False,
            with_stress=False
        )
        ph = ph.replace(" ", "")
        return list(ph)
    except:
        return []


def phoneme_error_rate(ref_text, pred_text):
    """Compute PER and extra phoneme stats."""
    ref_ph = phoneme_sequence(ref_text)
    pred_ph = phoneme_sequence(pred_text)

    if len(ref_ph) == 0:
        return None, ref_ph, pred_ph, None

    ref_str = "".join(ref_ph)
    pred_str = "".join(pred_ph)

    # CER works as Levenshtein distance on strings
    per = cer(ref_str, pred_str)

    # Compute absolute edit distance: PER * length
    edit_distance = int(per * len(ref_str))

    return per, ref_ph, pred_ph, edit_distance


# ------------------------
#  MAIN METRIC FUNCTION
# ------------------------
def transcribe_whisper_and_compute_PER(
    main_folder,
    model_name="openai/whisper-small",
    device=None,
    batch_size=4
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    from transformers import WhisperProcessor, WhisperForConditionalGeneration

    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
    model.eval()

    print("Loading Whisper model:", model_name)

    all_results = []

    subfolders = [
        os.path.join(main_folder, d)
        for d in os.listdir(main_folder)
        if os.path.isdir(os.path.join(main_folder, d))
    ]

    print(f"Found {len(subfolders)} speaker folders.")

    def load_audio(path, target_sr=16000):
        audio, sr = sf.read(path)
        if sr != target_sr:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
        return audio

    @torch.inference_mode()
    def batch_transcribe(wav_paths):
        outputs = []
        for i in tqdm(range(0, len(wav_paths), batch_size)):
            batch = [load_audio(p) for p in wav_paths[i:i+batch_size]]
            feats = processor(batch, sampling_rate=16000, return_tensors="pt").input_features.to(device)
            pred_ids = model.generate(input_features=feats)
            decoded = processor.batch_decode(pred_ids, skip_special_tokens=True)
            outputs.extend([t.strip().lower() for t in decoded])
        return outputs


    # Loop each speaker folder
    for folder in subfolders:
        print(f"\nProcessing folder: {folder}")

        json_files = glob.glob(os.path.join(folder, "*.json"))
        if not json_files:
            print("No JSON file found. Skipping.")
            continue

        with open(json_files[0], "r") as f:
            data = json.load(f)

        reference_dict = {
            item["Filename"]: item["Prompt"]["Transcript"].strip().lower()
            for item in data["Files"]
        }

        wav_files = sorted(glob.glob(os.path.join(folder, "*.wav")))
        if not wav_files:
            print("No WAV files found. Skipping.")
            continue

        preds = batch_transcribe(wav_files)

        # Results
        for wav_path, pred in zip(wav_files, preds):
            fname = os.path.basename(wav_path)
            ref = reference_dict.get(fname, "")

            per, ref_ph, pred_ph, edit_dist = phoneme_error_rate(ref, pred)
            phoneme_acc = None if per is None else (1 - per)

            # Rating label
            def classify(per):
                if per is None: return "No Reference"
                if per < 0.20: return "Excellent"
                elif per < 0.40: return "Good"
                elif per < 0.60: return "Fair"
                else: return "Poor"

            all_results.append({
                "Folder": os.path.basename(folder),
                "Filename": fname,
                "Reference": ref,
                "Prediction": pred,
                "PER": per,
                "PhonemeAccuracy": phoneme_acc,
                "EditDistance": edit_dist,
                "RefLength": len(ref_ph),
                "PredLength": len(pred_ph),
                "Rating": classify(per),
                "RefPhonemes": " ".join(ref_ph),
                "PredPhonemes": " ".join(pred_ph)
            })

    # ------------------------
    # FINAL DATAFRAME
    # ------------------------
    df = pd.DataFrame(all_results)

    # ------------------------
    # GLOBAL STATS
    # ------------------------
    print("\n================ GLOBAL PER STATISTICS ================\n")
    print(df["PER"].describe())
    print(f"\nMedian PER: {df['PER'].median():.3f}")
    print(f"Best PER (min): {df['PER'].min():.3f}")
    print(f"Worst PER (max): {df['PER'].max():.3f}")

    # ------------------------
    # FOLDER SUMMARY
    # ------------------------
    print("\n================ PER-FOLDER SUMMARY =================\n")
    folder_stats = df.groupby("Folder").agg(
        Mean_PER=("PER", "mean"),
        Median_PER=("PER", "median"),
        Std_PER=("PER", "std"),
        Min_PER=("PER", "min"),
        Max_PER=("PER", "max"),
        Samples=("PER", "count")
    )
    print(folder_stats)

    # ------------------------
    # WORST SAMPLES
    # ------------------------
    print("\n================ WORST SAMPLES (HIGH PER) =================\n")
    print(df.sort_values("PER", ascending=False).head(15).to_string(index=False))

    return df


In [None]:
df = transcribe_whisper_and_compute_PER("./test_data")


In [None]:
df.to_csv("full_results.csv", index=False)