In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
from model import SpeechRecognizer
from helpers.an4_sphere_dataset import build_dataloader

In [4]:
BATCH_SIZE = 16
EPOCHS = 100

In [5]:
# load json unified_vocab.json
import json
with open('datasets/unified_vocab.json', 'r') as f:
    unified_vocab = json.load(f)

In [6]:
# the build_dataloader function contains a collate_fn function that pads the input sequences to the maximum length in the batch

train_loader, val_loader = build_dataloader(
    data_dir="datasets/train",
    batch_size=BATCH_SIZE,
    shuffle=True,
    split=True,
    val_ratio=0.2
)

test_loader = build_dataloader(
    "datasets/test",
    batch_size=BATCH_SIZE,
    split=False,
    shuffle=False,
)

Dataset split: 759 training samples, 189 validation samples


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SpeechRecognizer(input_dim=64, hidden_dim=256, vocab_size=len(unified_vocab['word_to_idx']), dropout_rate=0.5)
model.to(device)

SpeechRecognizer(
  (lstm): LSTM(64, 256, num_layers=2, batch_first=True, bidirectional=True)
  (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout_lstm): Dropout(p=0.0, inplace=False)
  (fc): Linear(in_features=512, out_features=107, bias=True)
)

In [8]:

ctc_loss = nn.CTCLoss(blank=0)  # index 0 is blank
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

tensor(0.)

In [9]:
# 6. Training loop with CTC loss
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for mel_specs, vocab_ids, mel_lengths, vocab_lengths in train_loader:
        mel_specs, vocab_ids = mel_specs.to(device), vocab_ids.to(device)

        B, T = vocab_ids.size()

        # Forward pass
        log_probs = model(mel_specs)

        log_probs_ctc = log_probs.transpose(0, 1)  # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html
        # The CTC loss function needs actual lengths, not masks
        loss = ctc_loss(
            log_probs_ctc,
            vocab_ids,
            input_lengths=mel_lengths,  # Use the actual mel spectrogram lengths
            target_lengths=vocab_lengths  # Use the actual target sequence lengths
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * B
    print(f"Epoch {epoch+1}: training avg CTC loss = {total_loss/len(train_loader.dataset):.3f}")
    # Validate on validation set
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for mel_specs, vocab_ids, mel_lengths, vocab_lengths in val_loader:
            mel_specs, vocab_ids = mel_specs.to(device), vocab_ids.to(device)

            B, T = vocab_ids.size()

            log_probs = model(mel_specs)
            log_probs_ctc = log_probs.transpose(0, 1)  # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html
            # The CTC loss function needs actual lengths, not masks
            loss = ctc_loss(
                log_probs_ctc,
                vocab_ids,
                input_lengths=mel_lengths,  # Use the actual mel spectrogram lengths
                target_lengths=vocab_lengths  # Use the actual target sequence lengths
            )
            total_loss += loss.item() * B
        print(f"Epoch {epoch+1}: validation avg CTC loss = {total_loss/len(val_loader.dataset):.3f}")


Epoch 1: training avg CTC loss = 7.067
Epoch 1: validation avg CTC loss = 3.455
Epoch 2: training avg CTC loss = 3.237
Epoch 2: validation avg CTC loss = 3.178
Epoch 3: training avg CTC loss = 3.072
Epoch 3: validation avg CTC loss = 3.053
Epoch 4: training avg CTC loss = 2.994
Epoch 4: validation avg CTC loss = 2.985
Epoch 5: training avg CTC loss = 2.920
Epoch 5: validation avg CTC loss = 3.014
Epoch 6: training avg CTC loss = 2.921
Epoch 6: validation avg CTC loss = 2.915
Epoch 7: training avg CTC loss = 2.977
Epoch 7: validation avg CTC loss = 3.176
Epoch 8: training avg CTC loss = 3.084
Epoch 8: validation avg CTC loss = 3.058
Epoch 9: training avg CTC loss = 2.932
Epoch 9: validation avg CTC loss = 3.161
Epoch 10: training avg CTC loss = 2.870
Epoch 10: validation avg CTC loss = 2.893
Epoch 11: training avg CTC loss = 2.688
Epoch 11: validation avg CTC loss = 2.757
Epoch 12: training avg CTC loss = 2.659
Epoch 12: validation avg CTC loss = 2.761
Epoch 13: training avg CTC loss = 

KeyboardInterrupt: 

In [None]:
def ctc_greedy_decode(log_probs, idx_to_char):
    """
    Decode CTC output using greedy algorithm
    
    Args:
        log_probs: tensor of shape (batch_size, seq_len, num_classes)
        idx_to_char: dictionary mapping indices to characters
    
    Returns:
        list of decoded strings, one per batch item
    """
    # Get the most likely class at each timestep
    pred_indices = log_probs.argmax(dim=2).cpu().numpy()  # (batch_size, seq_len)
    
    batch_results = []
    for indices in pred_indices:  # Process each sequence in the batch
        # Collapse repeats and remove blanks
        prev = None
        pred_words = []
        for idx in indices:
            if idx != prev and idx != 0:  # 0 is CTC blank
                pred_words.append(idx)
            prev = idx
        batch_results.append(pred_words)
    
    return batch_results

In [None]:
total_words, total_errors = 0, 0
model.eval()

with torch.no_grad():
    for mel_specs, vocab_ids, mel_lengths, vocab_lengths in test_loader:
        mel_specs, vocab_ids = mel_specs.to(device), vocab_ids.to(device)
        log_probs = model(mel_specs)  # (T,1,vocab)
        hyp_words = ctc_greedy_decode(log_probs, unified_vocab['idx_to_word'])
        # Compute word error rate for this utterance
        ref_words = vocab_ids
        # Levenshtein edit distance for words:
        # Initialize DP table
        d = [[0]*(len(hyp_words)+1) for _ in range(len(ref_words)+1)]
        for i in range(len(ref_words)+1): 
            d[i][0] = i
        for j in range(len(hyp_words)+1): 
            d[0][j] = j
        for i, rw in enumerate(ref_words, start=1):
            for j, hw in enumerate(hyp_words, start=1):
                cost = 0 if rw == hw else 1
                d[i][j] = min(d[i-1][j] + 1, d[i][j-1] + 1, d[i-1][j-1] + cost)
        total_errors += d[len(ref_words)][len(hyp_words)]
        total_words += len(ref_words)
wer = total_errors / total_words
print(f"Test WER: {wer*100:.2f}%")


Test WER: 100.00%
