In [4]:
import math
import nltk
from nltk.corpus import cmudict

# 1. Install and imports
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader, random_split
    from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
    import torchaudio
except ModuleNotFoundError:
    raise ModuleNotFoundError("Please install dependencies: pip install torch torchaudio nltk")

In [5]:
# Data preprocessing
nltk.download('cmudict', quiet=True)
cmu_entries = cmudict.entries()

def build_vocab(tokens):
    vocab = ['<pad>'] + sorted(tokens)
    return {t:i for i,t in enumerate(vocab)}, vocab

# Prepare pairs and simple curriculum: filter by max length
MAX_LEN = 15
chars, phs = set(), set()
pairs = []
for w, pron in cmu_entries:
    w = w.lower()
    if len(w) <= MAX_LEN:
        pairs.append((list(w), pron))
        chars.update(w)
        phs.update(pron)
char2id, id2char = build_vocab(chars)
ph2id, id2ph = build_vocab(phs.union({'<blank>'}))
blank_id = ph2id['<blank>']

pairs_ids = [([char2id[c] for c in w], [ph2id[p] for p in pron]) for w, pron in pairs]

In [18]:
#Checking the number of unique phonemes token
unique_values_count = len(set(ph2id.values()))
print(unique_values_count)

72


In [6]:
# Dataset & DataLoader
class CMUDictDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

def collate_fn(batch):
    x_seq = [torch.tensor(x, dtype=torch.long) for x,_ in batch]
    y_seq = [torch.tensor(y, dtype=torch.long) for _,y in batch]
    x_lens = torch.tensor([len(x) for x in x_seq], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in y_seq], dtype=torch.long)
    x_pad = pad_sequence(x_seq, True, char2id['<pad>'])
    y_pad = pad_sequence(y_seq, True, blank_id)
    return x_pad, x_lens, y_pad, y_lens

dataset = CMUDictDataset(pairs_ids)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [7]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hid_dim=512, num_layers=3, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=char2id['<pad>'])
        self.lstm = nn.LSTM(embed_dim, hid_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=True, dropout=dropout)
        self.out_dim = hid_dim * 2
        self.ctc_fc = nn.Linear(self.out_dim, len(ph2id))
    def forward(self, x, lengths):
        e = self.embed(x)
        packed = pack_padded_sequence(e, lengths.cpu(), True, False)
        out, _ = self.lstm(packed)
        out, _ = pad_packed_sequence(out, True)
        return out
    def ctc_logits(self, enc_out):
        return self.ctc_fc(enc_out)

In [8]:
class Predictor(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hid_dim=512, num_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hid_dim, num_layers=num_layers,
                            batch_first=True, dropout=dropout)
        self.hid_dim = hid_dim
    def forward(self, y):
        e = self.embed(y)
        out, _ = self.lstm(e)
        return out

In [9]:
class Joiner(nn.Module):
    def __init__(self, enc_dim, pred_dim, joint_dim=512):
        super().__init__()
        self.e_proj = nn.Linear(enc_dim, joint_dim)
        self.p_proj = nn.Linear(pred_dim, joint_dim)
        self.act = nn.ReLU()
        self.fc = nn.Linear(joint_dim, len(ph2id))
    def forward(self, e, p):
        return self.fc(self.act(self.e_proj(e) + self.p_proj(p)))

In [10]:
class RNNTransducer(nn.Module):
    def __init__(self, encoder, predictor, joiner, blank_id):
        super().__init__()
        self.encoder = encoder
        self.predictor = predictor
        self.joiner = joiner
        self.blank_id = blank_id
    def forward(self, x, x_lens, y, y_lens):
        e_out = self.encoder(x, x_lens)          # [B,T,2H]
        B,T,_ = e_out.size()
        blank_col = torch.full((B,1), self.blank_id, dtype=y.dtype, device=y.device)
        y_in = torch.cat([blank_col, y],1)
        p_out = self.predictor(y_in)             # [B,U+1,H]
        # expand for joiner
        e_e = e_out.unsqueeze(2)
        p_e = p_out.unsqueeze(1)
        j = self.joiner(e_e, p_e)                # [B,T,U+1,V]
        return j

In [11]:
# instantiate and scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = Encoder(len(char2id)).to(device)
pred = Predictor(len(ph2id)).to(device)
joi = Joiner(enc.out_dim, pred.hid_dim).to(device)
model = RNNTransducer(enc,pred,joi,blank_id).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
ctc_loss = nn.CTCLoss(blank=blank_id, zero_infinity=True)


In [12]:
# Loss and training with auxiliary CTC
ALPHA = 0.3

def total_loss_fn(x, xl, y, yl, logits):
    # RNN-T loss
    lp = F.log_softmax(logits, dim=-1)
    rnnt = torchaudio.functional.rnnt_loss(
        lp,
        y.to(torch.int32),
        xl.to(torch.int32),
        yl.to(torch.int32),
        blank=blank_id
    )
    # CTC loss on encoder
    # enc_out = enc(x, xl)
    # ctc_logits = enc.ctc_logits(enc_out).log_softmax(-1)
    # T = ctc_logits.size(1)
    # input_len = torch.full((ctc_logits.size(0),), T, dtype=torch.long, device=device)
    # ctc = ctc_loss(
    #     ctc_logits.permute(1,0,2),  # T x B x V
    #     y,
    #     input_len,
    #     yl
    # )
    return rnnt

In [13]:
# beam search width
def beam_search(enc_out, width=5):
    # enc_out: [T, D]
    if enc_out.dim() == 3:
        assert enc_out.size(0) == 1, "batch size >1 not supported"
        enc_out = enc_out[0]
    T, D = enc_out.size()
    device = enc_out.device
    # Each beam: (score, hypothesis list)
    beams = [(0.0, [])]
    for t in range(T):  # iterate over time steps
        all_cands = []
        for score, hyp in beams:
            # Prepare predictor input: blank + current hypothesis
            y_input = torch.tensor([[blank_id] + hyp], dtype=torch.long, device=device)
            p_out = pred(y_input)                # [1, U+1, D]
            pred_vec = p_out[:, -1, :]          # [1, D]
            # Encoder output at time t
            enc_vec = enc_out[t].unsqueeze(0)   # [1, D]
            # Joint network
            joint = torch.relu(joi.e_proj(enc_vec) + joi.p_proj(pred_vec))  # [1, J]
            logp = F.log_softmax(joi.fc(joint).squeeze(0), dim=-1)           # [V]
            topk = torch.topk(logp, width)
            for lp, tok in zip(topk.values, topk.indices):
                new_hyp = hyp + ([tok.item()] if tok.item() != blank_id else [])
                all_cands.append((score + lp.item(), new_hyp))
        # Prune to top-k beams
        beams = sorted(all_cands, key=lambda x: x[0], reverse=True)[:width]
    # Return the hypothesis with highest score
    return beams[0][1]
    beams=[(0,[])]
    for t in range(enc_out.size(1)):
        new=[]
        for sc,h in beams:
            y = torch.tensor([[blank_id]+h], device=device)
            p = pred(y)[:, -1, :]
            e = enc_out[0,t].unsqueeze(0)
            j = torch.relu(joi.e_proj(e)+joi.p_proj(p))
            logp = F.log_softmax(joi.fc(j).squeeze(0),-1)
            top = torch.topk(logp, width)
            for lp,tk in zip(top.values,top.indices):
                lst = h+[tk.item()] if tk.item()!=blank_id else h
                new.append((sc+lp.item(),lst))
        beams=sorted(new, key=lambda x: x[0], reverse=True)[:width]
    return beams[0][1]

In [14]:
# Training loop

def train(epochs=15):
    for ep in range(1,epochs+1):
        model.train(); tot, cnt = 0,0
        for b,(x,xl,y,yl) in enumerate(train_loader,1):
            x,xl,y,yl = x.to(device),xl.to(device),y.to(device),yl.to(device)
            logits = model(x,xl,y,yl)
            loss = total_loss_fn(x, xl, y, yl, logits)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            tot += loss.item(); cnt+=1
            if b%100==0:
                print(f"Ep{ep} B{b}/{len(train_loader)} Loss={loss.item():.3f}")
        scheduler.step()
        print(f"Epoch{ep} AvgLoss={tot/cnt:.3f} LR={scheduler.get_last_lr()[0]:.5f}")
        evaluate_cer('Train',train_loader)
        evaluate_cer('Test', test_loader)

In [15]:
# CER and single-word decode

def edit_distance(a,b):
    dp=[[i if j==0 else (j if i==0 else 0) for j in range(len(b)+1)] for i in range(len(a)+1)]
    for i in range(1,len(a)+1):
        for j in range(1,len(b)+1):
            dp[i][j]=dp[i-1][j-1] if a[i-1]==b[j-1] else 1+min(dp[i-1][j],dp[i][j-1],dp[i-1][j-1])
    return dp[-1][-1]
def cer(ref,hyp): return edit_distance(ref,hyp)/len(ref) if ref else 0

def evaluate_cer(name,loader):
    model.eval(); tot,ct=0,0
    with torch.no_grad():
        for x,xl,y,yl in loader:
            x,xl,y,yl=x.to(device),xl.to(device),y.to(device),yl.to(device)
            eo=enc(x,xl)[0]
            pred_ids=beam_search(eo, width=5)
            true=y[0,:yl[0]].tolist();r=[id2ph[i] for i in true]
            h=[id2ph[i] for i in pred_ids]
            tot+=cer(r,h); ct+=1
    print(f"{name} CER={tot/ct:.3f}")


In [None]:
def main():
    train(3)
    for w in ['hello','world','chatgpt']:
        eo=enc(torch.tensor([[char2id.get(c,'<pad>') for c in w]],device=device),
                 torch.tensor([len(w)],device=device))
        phs=[id2ph[i] for i in beam_search(eo,5)]
        print(f"{w} -> {' '.join(phs)}")

if __name__=='__main__': main()

In [23]:
# Save model state
torch.save(model.state_dict(), "rnn_transducer_wo_ctc.pth")
print("Model saved to rnn_transducer_wo_ctc.pth")

Model saved to rnn_transducer.pth


In [16]:
# Load model state
model.load_state_dict(torch.load("/kaggle/input/rnn-t/pytorch/default/1/rnn_transducer.pth", map_location=device))
model.eval()
print("Model loaded from rnn_transducer.pth")

  model.load_state_dict(torch.load("/kaggle/input/rnn-t/pytorch/default/1/rnn_transducer.pth", map_location=device))


Model loaded from rnn_transducer.pth


In [39]:
def predict_word(word):
    # Preprocess word
    chars = list(word.lower())
    ids = [char2id.get(c, char2id['<pad>']) for c in chars]
    x = torch.tensor([ids], dtype=torch.long, device=device)
    x_lens = torch.tensor([len(ids)], dtype=torch.long, device=device)
    # Encode
    enc_out = enc(x, x_lens)[0]  # take [1,T,2H] -> [T,2H]
    # Beam search decode
    pred_ids = beam_search(enc_out, width=5)
    # Map to phonemes
    return [id2ph[i] for i in pred_ids]

In [43]:
predict_word("cement")

['S', 'AH0', 'M', 'EH1', 'N', 'T']

In [23]:
Words=['February','Pronunciation','Salmon','Handkerchief','Honest','Gnome']
for word in Words:
    res=predict_word(word)
    print(res)

['F', 'EH1', 'B', 'R', 'UW0', 'EH2', 'R', 'IY0']
['P', 'R', 'OW0', 'N', 'AH2', 'N', 'S', 'IY0', 'EY1', 'SH', 'AH0', 'N']
['S', 'AE1', 'L', 'M', 'AH0', 'N']
['HH', 'AE1', 'NG', 'K', 'ER0', 'CH', 'IY2', 'F']
['HH', 'OW1', 'N', 'AH0', 'S', 'T']
['N', 'OW1', 'M']
