In [None]:
import torch
from Levenshtein import distance as levenshtein_distance

def greedy_ctc_decode(log_probs, phoneme_vocab, prob=True, blank_id=0):
    """
    Converts frame-level log_probs (T x P) to phoneme sequence using greedy CTC decoding.

    Args:
        log_probs (Tensor): [T, P] log-probs over phoneme classes
        phoneme_vocab (List[str]): ID-to-phoneme mapping
        blank_id (int): Index of the CTC blank token

    Returns:
        List[str]: Decoded phoneme sequence (collapsed)
    """
    if prob:
        pred_ids = torch.argmax(log_probs, dim=-1).tolist()
    else:
        pred_ids = log_probs.tolist()
    seq = []
    prev = None
    for idx in pred_ids:
        if idx != blank_id and idx != prev:
            seq.append(phoneme_vocab[idx])
        prev = idx
    return seq


def phoneme_error_rate(log_probs_ref, log_probs_hyp, phoneme_vocab, blank_id=0):
    ref_seq = greedy_ctc_decode(log_probs_ref, phoneme_vocab, prob=False, blank_id=blank_id)
    hyp_seq = greedy_ctc_decode(log_probs_hyp, phoneme_vocab, prob=True, blank_id=blank_id)

    ref_str = ' '.join(ref_seq)
    hyp_str = ' '.join(hyp_seq)
    dist = levenshtein_distance(ref_str, hyp_str)
    N = len(ref_seq)
    per = dist / (N if N > 0 else 1)
    return per, dist, N, ref_seq, hyp_seq


In [None]:
phoneme_vocab = ['_', 'ah', 'b', 'd', 'eh', 'f', 'g', 'k', 's', 't']  # example
log_probs_ref = torch.randint(0, len(phoneme_vocab), (50,))  # [T, P]
log_probs_hyp = torch.randn(50, len(phoneme_vocab))

per, dist, N, r_seq, h_seq = phoneme_error_rate(log_probs_ref, log_probs_hyp, phoneme_vocab)
print(f"PER: {per:.3f} ({dist} edits over {N} phonemes)")


PER: 1.111 (40 edits over 36 phonemes)


In [29]:
from Levenshtein import distance as levenshtein_distance

class PhonemeErrorRate:
    def __init__(self, phoneme_vocab, blank_id=0):
        """
        Args:
            phoneme_vocab (List[str]): ID-to-phoneme mapping
            blank_id (int): ID of the CTC blank token
        """
        self.phoneme_vocab = phoneme_vocab
        self.blank_id = blank_id
        self.total_edits = 0
        self.total_ref_phonemes = 0
        self.decoded_refs = []
        self.decoded_hyps = []

    def greedy_ctc_decode(self, log_probs, prob=True):
        """
        Greedy CTC decoding from log_probs to phoneme sequence.
        
        Args:
            log_probs (Tensor): shape [T, V]
            prob (bool): If True, assumes input is log_probs from model output (hyp);
                         If False, assumes manual or clean target (ref)
        
        Returns:
            List[str]: Decoded phoneme sequence
        """
        import torch
        if prob:
            pred_ids = torch.argmax(log_probs, dim=-1).tolist()
        else:
            pred_ids = log_probs.tolist()
        
        seq = []
        prev = None
        for idx in pred_ids:
            if idx == self.blank_id:
                continue
            if idx != prev:
                seq.append(self.phoneme_vocab[idx])
            prev = idx
        return seq

    def add_batch(self, log_probs_ref_batch, log_probs_hyp_batch):
        """
        Add a batch of phoneme predictions and references.
        
        Args:
            log_probs_ref_batch: List[Tensor] — reference log-probs per sample [T_i, V]
            log_probs_hyp_batch: List[Tensor] — hypothesis log-probs per sample [T_i, V]
        """
        for log_probs_ref, log_probs_hyp in zip(log_probs_ref_batch, log_probs_hyp_batch):
            ref_seq = self.greedy_ctc_decode(log_probs_ref, prob=False)
            hyp_seq = self.greedy_ctc_decode(log_probs_hyp, prob=True)

            self.decoded_refs.append(ref_seq)
            self.decoded_hyps.append(hyp_seq)

            ref_str = ' '.join(ref_seq)
            hyp_str = ' '.join(hyp_seq)
            dist = levenshtein_distance(ref_str, hyp_str)

            self.total_edits += dist
            self.total_ref_phonemes += len(ref_seq)

    def compute(self):
        """
        Returns:
            per (float): Micro-average PER over all batches
            total_edits (int): Total Levenshtein distance
            total_ref_phonemes (int): Total number of reference phonemes
        """
        per = self.total_edits / max(self.total_ref_phonemes, 1)
        return per, self.total_edits, self.total_ref_phonemes


In [33]:
log_probs_ref_batch = torch.randint(0, len(phoneme_vocab), (4, 50,))  # [T, P]
log_probs_hyp_batch = torch.randn(4, 50, len(phoneme_vocab))

# Assume phoneme_vocab is your ID-to-symbol list, blank_id is index of blank token
per_metric = PhonemeErrorRate(phoneme_vocab, blank_id=0)

# For each batch during eval:
per_metric.add_batch(log_probs_ref_batch, log_probs_hyp_batch)

# At the end:
per, edits, total = per_metric.compute()
print(f"Micro PER: {per:.3f} ({edits} edits over {total} phonemes)")


Micro PER: 1.112 (179 edits over 161 phonemes)
