In [1]:
"""
g2p_nettalk_zero_delay.py

Nettalk-style grapheme-to-phoneme (G2P) training with:
- Zero-delay alignment (pad phoneme seq with φ so output length == grapheme length)
- ARPAbet tokens (optionally strip stress digits)
- Fixed context window size = 7 (3 left, center, 3 right)
- PyTorch feed-forward neural network (embedding -> hidden -> softmax)

Usage:
    python g2p_nettalk_zero_delay.py

Requires:
    - Python 3.8+
    - PyTorch
    - tqdm (optional)
"""

import re
import random
import argparse
from collections import Counter
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ---------------------------
# Config / hyperparameters
# ---------------------------
WINDOW = 7               # Nettalk context window (odd)
LEFT_CTX = RIGHT_CTX = (WINDOW - 1) // 2
EMBED_DIM = 32
HIDDEN_DIM = 256
BATCH_SIZE = 256
LR = 1e-3
EPOCHS = 12
MAX_WORDS = None         # optionally limit number of words loaded for quick debug (None -> all)
STRIP_STRESS = True      # strip numeric stress markers from ARPAbet phones (AH0 -> AH)
PAD_GRAP = '<pad_g>'     # grapheme pad for context window
PAD_PHON = 'φ'           # zero-delay padding symbol for phoneme sequence (phi)
START_SYM = '<s>'
END_SYM = '</s>'

CMU_DICT_PATH = 'cmudict.dict.txt'

# ---------------------------
# Utilities: load CMUDict
# ---------------------------
def load_cmudict(path: str, max_words=None) -> List[Tuple[str, List[str]]]:
    """
    Returns list of tuples (word, phoneme_list)
    Strips CMUDict comment lines beginning with ';;;'
    Strips variant numbers from words like 'WORD(1)' -> 'WORD'
    """
    entries = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith(';;;'):
                continue
            parts = line.split()
            word_raw = parts[0]
            # remove variant numbering (e.g., WORD(1) -> WORD)
            word = re.sub(r'\(\d+\)$', '', word_raw).lower()
            phones = parts[1:]
            if STRIP_STRESS:
                phones = [re.sub(r'\d+$', '', p) for p in phones]
            entries.append((word, phones))
            if max_words and len(entries) >= max_words:
                break
    return entries

# ---------------------------
# Zero-delay alignment
# ---------------------------
def zero_delay_align(graphemes: List[str], phonemes: List[str], pad_phi: str = PAD_PHON) -> List[str]:
    """
    Implements zero-delay alignment: produce an output phoneme sequence
    whose length equals number of graphemes by padding phoneme sequence
    at the end with pad_phi. If there are more phonemes than graphemes,
    we truncate (rare).
    Example:
        graphemes = ['g','o','o','g','l','e']
        phonemes  = ['g','u','g','@','l']  -> pad with φ -> 6 outputs
        result    = ['g','u','g','@','l','φ']
    """
    if len(phonemes) <= len(graphemes):
        out = phonemes + [pad_phi] * (len(graphemes) - len(phonemes))
    else:
        # truncate extra phonemes (rare); alternative: collapse multiple phonemes per grapheme with more complex alignment
        out = phonemes[:len(graphemes)]
    return out

# ---------------------------
# Dataset creation (per-glyph training examples)
# ---------------------------
class G2PDataset(Dataset):
    def __init__(self, word_phone_pairs, char2idx, phone2idx, window=WINDOW):
        self.window = window
        self.left = (window - 1) // 2
        self.char2idx = char2idx
        self.phone2idx = phone2idx
        self.examples = []  # list of (context_indices_tensor, target_phone_idx, word_idx, position_in_word)
        for w_idx, (word, phones) in enumerate(word_phone_pairs):
            grap = list(word)  # characters
            # optionally add start/end markers to grapheme seq so net can see boundaries
            grap_padded = [START_SYM] * self.left + grap + [END_SYM] * self.left
            # zero-delay align phonemes to graphemes (no start/end) -> length equals len(grap)
            aligned_phones = zero_delay_align(graphemes=grap, phonemes=phones, pad_phi=PAD_PHON)
            # for each position in the grapheme sequence, create a window and target
            for pos in range(len(grap)):
                # window centered at pos within original grap (not including padded START/END used above)
                left_idx = pos
                window_chars = grap_padded[left_idx:left_idx + self.window]
                # convert to indices
                context_idxs = [char2idx.get(ch, char2idx[PAD_GRAP]) for ch in window_chars]
                target_phone = aligned_phones[pos]
                target_idx = phone2idx[target_phone]
                # store (list of ints) and target
                self.examples.append((context_idxs, target_idx, w_idx, pos))
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        ctx, tgt, widx, pos = self.examples[idx]
        return torch.tensor(ctx, dtype=torch.long), torch.tensor(tgt, dtype=torch.long), widx, pos

# ---------------------------
# Model
# ---------------------------
class NettalkModel(nn.Module):
    def __init__(self, n_chars, embed_dim, window, hidden_dim, n_phones, dropout=0.1):
        super().__init__()
        self.window = window
        self.embed = nn.Embedding(n_chars, embed_dim, padding_idx=None)
        self.fc1 = nn.Linear(embed_dim * window, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, n_phones)
    def forward(self, x):  # x: (B, window) long
        emb = self.embed(x)              # (B, window, embed_dim)
        flat = emb.view(emb.size(0), -1) # (B, window*embed_dim)
        h = self.dropout(self.relu(self.fc1(flat)))
        out = self.fc2(h)                # logits (B, n_phones)
        return out

# ---------------------------
# Helpers: vocab build, evaluation
# ---------------------------
def build_vocabs(entries):
    # chars: lowercase letters a-z, apostrophe, and special PAD/START/END
    char_counter = Counter()
    phone_counter = Counter()
    for word, phones in entries:
        for ch in word:
            char_counter[ch] += 1
        for p in phones:
            phone_counter[p] += 1
    # char->idx (reserve small indices for PAD, start/end)
    chars = [PAD_GRAP, START_SYM, END_SYM] + sorted([c for c in char_counter])
    char2idx = {c: i for i, c in enumerate(chars)}
    idx2char = {i: c for c, i in char2idx.items()}
    # phones: include PAD_PHON (phi) if not present
    phones = sorted(set(list(phone_counter.keys()) + [PAD_PHON]))
    phone2idx = {p: i for i, p in enumerate(phones)}
    idx2phone = {i: p for p, i in phone2idx.items()}
    return char2idx, idx2char, phone2idx, idx2phone

def compute_word_level_errors(model, dataset, idx2phone, entries, device):
    """
    Reconstruct predicted phoneme sequence per word (by collecting per-position predictions),
    then compute simple PER and WER (word considered error if any phoneme differs or a phone missing).
    This treats the phi pad as normal symbol: when the true phone is phi it means no real phone at that grapheme.
    Returns: (phoneme_error_rate, word_error_rate)
    """
    model.eval()
    # prepare containers
    word_preds = dict()  # widx -> list of preds per position (length = word length)
    word_trues = dict()
    # initialize
    for i, (w, phones) in enumerate(entries):
        word_trues[i] = zero_delay_align(list(w), phones, pad_phi=PAD_PHON)
        word_preds[i] = [''] * len(w)
    loader = DataLoader(dataset, batch_size=1024, shuffle=False)
    with torch.no_grad():
        for batch in loader:
            ctx, tgt, widxs, poss = batch
            ctx = ctx.to(device)
            logits = model(ctx)
            preds = logits.argmax(dim=1).cpu().tolist()
            for p_idx, widx, pos in zip(preds, widxs.tolist(), poss.tolist()):
                phone_pred = idx2phone[p_idx]
                word_preds[widx][pos] = phone_pred
    # compute PER and WER
    total_phonemes = 0
    phone_errors = 0
    total_words = len(entries)
    word_errors = 0
    for i in range(len(entries)):
        true_seq = word_trues[i]
        pred_seq = word_preds[i]
        # phoneme-level edit distance (simple Levenshtein)
        # compute substitutions/insertions/deletions between true_seq and pred_seq
        # we'll compute exact sequence edit distance:
        a = true_seq
        b = pred_seq
        n, m = len(a), len(b)
        # dp
        dp = [[0]*(m+1) for _ in range(n+1)]
        for ii in range(n+1):
            dp[ii][0] = ii
        for jj in range(m+1):
            dp[0][jj] = jj
        for ii in range(1, n+1):
            for jj in range(1, m+1):
                cost = 0 if a[ii-1] == b[jj-1] else 1
                dp[ii][jj] = min(dp[ii-1][jj] + 1,     # deletion
                                 dp[ii][jj-1] + 1,     # insertion
                                 dp[ii-1][jj-1] + cost) # substitution
        ed = dp[n][m]
        phone_errors += ed
        total_phonemes += n
        if ed > 0:
            word_errors += 1
    per = phone_errors / total_phonemes if total_phonemes > 0 else 0.0
    wer = word_errors / total_words if total_words > 0 else 0.0
    return per, wer

# ---------------------------
# Training routine
# ---------------------------
def train():
    # load data
    entries = load_cmudict(CMU_DICT_PATH, max_words=MAX_WORDS)
    print(f"Loaded {len(entries)} dictionary entries (words).")

    # build vocabs
    char2idx, idx2char, phone2idx, idx2phone = build_vocabs(entries)
    print(f"Characters: {len(char2idx)}  Phones: {len(phone2idx)}")
    print("Sample phones:", list(phone2idx.keys())[:20])

    # shuffle and split words into train/dev/test (by words)
    random.seed(42)
    indices = list(range(len(entries)))
    random.shuffle(indices)
    n = len(indices)
    n_dev = int(0.02 * n)    # 2% dev
    n_test = int(0.10 * n)   # 10% test
    dev_idx = indices[:n_dev]
    test_idx = indices[n_dev:n_dev + n_test]
    train_idx = indices[n_dev + n_test:]
    train_pairs = [entries[i] for i in train_idx]
    dev_pairs = [entries[i] for i in dev_idx]
    test_pairs = [entries[i] for i in test_idx]
    print(f"Split: train {len(train_pairs)}  dev {len(dev_pairs)}  test {len(test_pairs)}")

    # create datasets
    train_ds = G2PDataset(train_pairs, char2idx, phone2idx, window=WINDOW)
    dev_ds   = G2PDataset(dev_pairs, char2idx, phone2idx, window=WINDOW)
    test_ds  = G2PDataset(test_pairs, char2idx, phone2idx, window=WINDOW)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    dev_loader   = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)

    model = NettalkModel(n_chars=len(char2idx), embed_dim=EMBED_DIM, window=WINDOW,
                         hidden_dim=HIDDEN_DIM, n_phones=len(phone2idx)).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # training loop
    best_dev_per = 1.0
    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_loss = 0.0
        total_examples = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            ctx, tgt, _, _ = batch
            ctx = ctx.to(device)
            tgt = tgt.to(device)
            optimizer.zero_grad()
            logits = model(ctx)   # (B, n_phones)
            loss = criterion(logits, tgt)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * ctx.size(0)
            total_examples += ctx.size(0)
            pbar.set_postfix(loss=total_loss/total_examples)
        avg_loss = total_loss / total_examples
        print(f"Epoch {epoch} training loss: {avg_loss:.4f}")

        # dev evaluation: phoneme accuracy and PER/WER
        dev_per, dev_wer = compute_word_level_errors(model, dev_ds, idx2phone, dev_pairs, device)
        print(f"Dev PER: {dev_per*100:.2f}%  WER: {dev_wer*100:.2f}%")

        if dev_per < best_dev_per:
            best_dev_per = dev_per
            torch.save({
                'model_state_dict': model.state_dict(),
                'char2idx': char2idx,
                'phone2idx': phone2idx,
                'idx2phone': idx2phone,
                'idx2char': idx2char,
            }, 'best_nettalk_zero_delay.pth')
            print("Saved best model.")

    # final test evaluation
    checkpoint = torch.load('best_nettalk_zero_delay.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    test_per, test_wer = compute_word_level_errors(model, test_ds, checkpoint['idx2phone'], test_pairs, device)
    print(f"Test PER: {test_per*100:.2f}%  WER: {test_wer*100:.2f}%")

if __name__ == "__main__":
    train()


Loaded 135166 dictionary entries (words).
Characters: 32  Phones: 40
Sample phones: ['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K']
Split: train 118947  dev 2703  test 13516
Using device: cuda


                                                                        

Epoch 1 training loss: 1.0312
Dev PER: 24.65%  WER: 79.17%
Saved best model.


                                                                         

Epoch 2 training loss: 0.8175
Dev PER: 23.13%  WER: 74.51%
Saved best model.


                                                                         

Epoch 3 training loss: 0.7688
Dev PER: 21.97%  WER: 73.99%
Saved best model.


                                                                         

Epoch 4 training loss: 0.7446
Dev PER: 21.55%  WER: 72.59%
Saved best model.


                                                                         

Epoch 5 training loss: 0.7281
Dev PER: 21.31%  WER: 71.66%
Saved best model.


                                                                         

Epoch 6 training loss: 0.7163
Dev PER: 20.99%  WER: 72.33%
Saved best model.


                                                                         

Epoch 7 training loss: 0.7066
Dev PER: 20.80%  WER: 70.55%
Saved best model.


                                                                         

Epoch 8 training loss: 0.6997
Dev PER: 20.89%  WER: 72.07%


                                                                         

Epoch 9 training loss: 0.6934
Dev PER: 20.81%  WER: 72.36%


                                                                          

Epoch 10 training loss: 0.6886
Dev PER: 20.15%  WER: 68.78%
Saved best model.


                                                                          

Epoch 11 training loss: 0.6841
Dev PER: 20.40%  WER: 70.44%


                                                                          

Epoch 12 training loss: 0.6801
Dev PER: 20.53%  WER: 70.37%
Test PER: 20.62%  WER: 71.46%


In [3]:
# ==============================
# Test trained Nettalk G2P model
# ==============================

import torch

# Load model + mappings
checkpoint = torch.load('best_nettalk_zero_delay.pth', map_location='cpu')
char2idx = checkpoint['char2idx']
idx2char = checkpoint['idx2char']
phone2idx = checkpoint['phone2idx']
idx2phone = checkpoint['idx2phone']

# Rebuild model
model = NettalkModel(
    n_chars=len(char2idx),
    embed_dim=EMBED_DIM,
    window=WINDOW,
    hidden_dim=HIDDEN_DIM,
    n_phones=len(phone2idx)
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# helper functions
def predict_word(word: str):
    """Predict phoneme sequence for a single word (zero-delay context window)."""
    word = word.lower()
    graphemes = list(word)
    # pad with start/end symbols for context
    padded = [START_SYM]*LEFT_CTX + graphemes + [END_SYM]*RIGHT_CTX
    preds = []
    for i in range(len(graphemes)):
        window_chars = padded[i : i + WINDOW]
        ctx_idx = torch.tensor(
            [char2idx.get(ch, char2idx[PAD_GRAP]) for ch in window_chars],
            dtype=torch.long
        ).unsqueeze(0)  # (1, window)
        with torch.no_grad():
            logits = model(ctx_idx)
            pred_idx = logits.argmax(dim=1).item()
        pred_phone = idx2phone[pred_idx]
        preds.append(pred_phone)
    return preds

# test words from CMUdict (examples)
test_words = ["google", "physics", "chatgpt", "algorithm", "data"]

for w in test_words:
    preds = predict_word(w)
    print(f"{w:>12s} → {' '.join(preds)}")


cmu = load_cmudict(CMU_DICT_PATH)
cmu_dict = dict(cmu)
for w in test_words:
    if w in cmu_dict:
        print(f"{w}: {' '.join(cmu_dict[w])}")

      google → G UW G AH L L
     physics → F IH Z IH K K S
     chatgpt → CH AE T G P T φ
   algorithm → AE L G ER TH DH TH M φ
        data → D AA T AH
google: G UW G AH L
physics: F IH Z IH K S
algorithm: AE L G ER IH DH AH M
data: D AE T AH
