Automatic Speech Recognition (ASR) has leapt from research labs onto every phone and smart-speaker we own, yet the recipe that turns a WAV file into words is still a bit of a black box for many data-scientists. This notebook lifts the lid by guiding you, code cell by code cell, through the entire pipeline needed to build a character-level speech recogniser using nothing more than open-source PyTorch, torchaudio and a GPU.

By the end you’ll have a working speech recogniser trained from scratch, a solid grasp of each design decision, and ready-to-reuse code snippets for your own audio projects.

In [17]:
import torch, torchaudio, random, re
from torch import nn
from itertools import groupby
from pathlib import Path
from jiwer import wer, cer

SAMPLE_RATE = 16_000       
BLANK_TOKEN = "<BLANK>"

## Dataset

* train-clean-100: small enough that you can train a character-level model in a couple of hours on a single GPU, yet large enough to reach ~20 % WER with a modest network.

* test-clean: the standard benchmark used in almost every ASR paper; you can directly compare your WER/CER to published numbers.

In [3]:
train_ds = torchaudio.datasets.LIBRISPEECH(".", url="train-clean-100", download=True)
test_ds  = torchaudio.datasets.LIBRISPEECH(".", url="test-clean",  download=True)

100%|██████████| 5.95G/5.95G [03:42<00:00, 28.7MB/s] 
100%|██████████| 331M/331M [00:13<00:00, 25.3MB/s] 


## Tokenizer

Before the acoustic model can learn anything, we must convert text transcripts into integer sequences that a loss function (CTC in our case) can compare to the model’s output. This cell builds the character-level vocabulary ("tokenizer") from the training set and supplies two helper functions:

* encode() – turns a raw transcript string into a 1-D torch.Tensor of token IDs.

* decode() – performs the inverse: given a list/array of IDs it reconstructs the text.

Because CTC needs a special blank token to indicate “no character emitted”, we place BLANK_TOKEN at index 0. Characters are then assigned indices 1 … N in deterministic (sorted) order so that training and inference share exactly the same mapping.

In [5]:
def yield_text(data_set):
    """
    Generator that iterates over the dataset and yields each transcript
    as a *list of individual characters* (lower-cased).
    
    We ignore all the other items in the dataset tuple (waveform, rate, etc.)
    by unpacking them into underscores.
    """
    for _, _, txt, _, _, _ in data_set: # (wave, sr, text, speaker_id, chapter_id, utt_id)
        yield list(txt.lower()) # lower-case for case-insensitive training

# Build character vocabulary
#  • BLANK_TOKEN goes at index 0 for CTC
#  • sorted() gives deterministic order (important when you resume training)
vocab = [BLANK_TOKEN] + sorted(
            set(ch                           # unique characters …
                for txt in yield_text(train_ds)   # … across every transcript
                for ch in txt)
        )

# stoi = “string-to-index" lookup: character -> integer ID
stoi  = {ch: i for i, ch in enumerate(vocab)}

def encode(text: str) -> torch.Tensor:
    """
    Map a string to a 1-D tensor of integer IDs (dtype long).
    """
    return torch.tensor([stoi[c] for c in text.lower()], dtype=torch.long)

def decode(ids) -> str:
    """
    Map a sequence (list / tensor) of IDs back to a string.
    """
    return "".join(vocab[i] for i in ids)

## Audio Transform

Neural ASR models don’t work directly on raw 16 kHz audio; they expect 2-D time × frequency representations that capture the perceptually important parts of speech.
The industry-standard front-end is a log-Mel spectrogram:

1. Short-time Fourier Transform (STFT) chops the waveform into overlapping 25 ms windows (400 samples at 16 kHz) and converts them to frequency bins.
   
2. A Mel filterbank (80 filters here) compresses those bins onto a scale that mimics human pitch perception.

3. We take the logarithm (dB) so differences in loudness become additive.

During training we additionally apply SpecAugment, masking random time and frequency stripes, to make the model robust to speaker and noise variation 

In [6]:
mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,   # target 16 kHz
        n_fft        = 400,        # 25 ms analysis window (400 / 16 kHz)
        hop_length   = 160,        # 10 ms stride  (overlap = 15 ms)
        n_mels       = 80)         # # of Mel bands -> feature dim F = 80

db  = torchaudio.transforms.AmplitudeToDB()   # converts power -> log-dB

def spec_augment(spec):
    """
    Light variant of SpecAugment:
    • FrequencyMasking    hides  up to 15 Mel bands   (simulates channel noise)
    • TimeMasking         hides  up to 35 frames      (simulates drop-outs)
    """
    spec = torchaudio.transforms.FrequencyMasking(freq_mask_param=15)(spec)
    return torchaudio.transforms.TimeMasking(time_mask_param=35)(spec)

def preprocess(wav, sr, augment=False):
    """
    wav      : Tensor (1, N)   – mono waveform
    sr       : int             – original sample-rate
    augment  : bool            – enable SpecAugment (use True for training)
    
    Returns: Tensor (T, F)     – log-Mel spectrogram
    """
    # Resample if the incoming clip is not already 16 kHz
    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    
    # STFT -> Mel filterbank -> log-dB  (shape: 1 × F × T)
    spec = db(mel(wav))
    
    # Optional data-augmentation
    if augment:
        spec = spec_augment(spec)
    
    # Tidy up tensor shape:   (1, F, T) -> (T, F)
    return spec.squeeze(0).transpose(0, 1)

## Model

Automatic speech-recognition models need to:

1. Condense local time-frequency patterns (formants, plosives, ­s-bursts…) that appear in the log-Mel spectrogram.

2. Model long-range context so the network “remembers” what letters have already been emitted.

3. Project the encoded sequence onto the vocabulary that the CTC loss expects (blank + characters).

A Convolution + Recurrent + feed-forward (CRNN) stack is a light-weight way to do all three:

* Small 2D CNN front-end learns local spectral filters and reduces frequency variance.

* Bidirectional LSTM encoder captures left-hand and right-hand context over hundreds of frames.

* Linear head maps the hidden representation to per-time-step log-probabilities over the vocabulary.

In [7]:
class CRNN(nn.Module):

    def __init__(self, vocab_size,
                 feat_dim=80,          # F  – Mel bands coming from preprocess()
                 hidden=384,           # H  – LSTM hidden size
                 layers=5,             # #  – stacked BiLSTM layers
                 dropout=0.1):
        super().__init__()

        # ---------- Convolutional front-end --------------------------
        # Two 3×3 conv layers (stride 1) act on the spectrogram as if it
        # were a single-channel image:  (B, 1, T, F) -> (B, 32, T, F)
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1), nn.ReLU()
        )
        
        # ---------- Recurrent encoder (bidirectional) ---------------
        # Input size to the LSTM is 32 * feat_dim because the conv keeps
        # 32 channels and we later flatten the frequency axis.
        self.rnn = nn.LSTM(
            input_size   = feat_dim * 32,  # 32×80 = 2560
            hidden_size  = hidden,         # 384
            num_layers   = layers,         # 5 stacked layers
            batch_first  = True,           # (B, T, ·)
            bidirectional=True,
            dropout      = dropout
        )

        # ---------- Classification head -----------------------------
        # Two-layer MLP with GELU + LayerNorm stabilises training before
        # projecting to vocab_size (includes the CTC blank token).
        self.fc = nn.Sequential(
            nn.Linear(hidden * 2, hidden),   # ×2 for bidirection
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Linear(hidden, vocab_size)    # logits per character
        )

    def forward(self, x):
        """
        x: Tensor (B, T, F) – batch of log-Mel spectrograms
        Returns: log-softmax probabilities  (B, T, vocab)
        """
        # 1. CNN expects channel dim -> add unsqueeze(1) to get (B, 1, T, F)
        x = self.cnn(x.unsqueeze(1))                 # (B, 32, T, F)

        # 2. Flatten CNN output so each time-step is a 1-D vector
        B, C, T, F = x.shape                         # C = 32
        x = x.permute(0, 2, 1, 3).reshape(B, T, C * F)  # (B, T, 32·F)

        # 3. Bidirectional LSTM encoder
        x, _ = self.rnn(x)                           # (B, T, 2H)

        # 4. MLP head -> character logits, then log-softmax for CTC loss
        return self.fc(x).log_softmax(dim=-1)        # (B, T, vocab)

## Data Loader

PyTorch’s DataLoader batches together whatever a Dataset yields. ASR datasets, however, contain variable-length utterances:

* Spectrograms differ in time steps T.

* Transcripts differ in number of characters.

The model still wants tensors of equal shape in a batch, and CTC needs the true lengths of every sequence to ignore the right-hand padding. Therefore we write a small collate function that:

1. Converts each (waveform, ..., text) sample -> (spec, label).

2. Records the real input and target lengths.

3. Pads spectrograms so they stack into a single tensor.

During training we also enable SpecAugment inside the same function so that augmentation is applied after the audio is loaded but before the model sees it.

In [13]:
from torch.utils.data import DataLoader

def collate(batch, augment=False):
    """
    batch  : list[ tuple ]  – items coming from the Dataset
    augment: bool           – if True, apply SpecAugment
    
    Returns
    -------
    specs        : Tensor (B, T_max, F)      – zero-padded log-Mel batches
    labels       : 1-D Tensor (sum targets)  – all character IDs concatenated
    input_lens   : list[int]                 – true T for each utterance
    target_lens  : list[int]                 – number of chars per transcript
    """
    specs, labels, input_lens, target_lens = [], [], [], []

    # unpack the dataset tuples
    for wav, sr, txt, *_ in batch:          # we ignore the speaker/chapter IDs
        spec  = preprocess(wav, sr, augment)   # (T, F)
        specs.append(spec)
        
        lab   = encode(txt)                   # 1-D tensor of char IDs
        labels.append(lab)
        
        input_lens.append(spec.shape[0])      # T  (time frames)
        target_lens.append(len(lab))          # #chars

    # ---------- pad the variable-length spectrograms -------------------
    # pad_sequence -> (B, T_max, F), filled with zeros on the right
    specs  = nn.utils.rnn.pad_sequence(specs, batch_first=True).float()
    
    # ---------- flatten the list of 1-D label tensors ------------------
    labels = torch.cat(labels)               # CTC expects 1-D concat
    
    return specs, labels, input_lens, target_lens

train_loader = DataLoader(train_ds, 32, shuffle=True, collate_fn=lambda b: collate(b, True))
test_loader  = DataLoader(test_ds,  8, shuffle=False, collate_fn=collate)

## Training

Now we are ready to Instantiate the network with the right vocabulary size, detect whether the notebook is running on a GPU (or several). If multiple GPUs are available, wrap the model in nn.DataParallel so the same forward/back-prop step is executed on every GPU and the gradients are automatically averaged. Move the (possibly wrapped) model onto the chosen device so all its parameters and buffers live in GPU memory (or CPU if no GPU).

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
n_gpus  = torch.cuda.device_count()
model  = CRNN(len(vocab))
if n_gpus > 1:
    print(f"Using {n_gpus} GPUs via DataParallel")
    model = nn.DataParallel(model)  # replicates model each forward
model = model.to(device)

In [52]:
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

DataParallel(
  (module): CRNN(
    (cnn): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
    )
    (rnn): LSTM(2560, 384, num_layers=5, batch_first=True, dropout=0.1, bidirectional=True)
    (fc): Sequential(
      (0): Linear(in_features=768, out_features=384, bias=True)
      (1): GELU(approximate='none')
      (2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=384, out_features=29, bias=True)
    )
  )
)
Total parameters: 23547261


In [16]:
from torch.amp import GradScaler, autocast

ctc    = nn.CTCLoss(blank=0, zero_infinity=True) # alignment-free loss
opt    = torch.optim.AdamW(model.parameters(), 5e-4, weight_decay=1e-4) # Adam + decoupled weight-decay
# One-cycle policy: large LR ▸ small LR over 1 training “cycle” (here 20 epochs)
sched  = torch.optim.lr_scheduler.OneCycleLR(opt, 5e-4, steps_per_epoch=len(train_loader), epochs=20)
scaler = GradScaler() # dynamic loss-scaling for FP16

best_wer, patience, PATIENCE = 1.0, 0, 4   # early stop if WER doesn’t improve 4 evals

for epoch in range(20):
    
    model.train() # enable dropout, etc.
    
    for step, (x, y, in_len, tar_len) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()

        # Automatic Mixed Precision: FP16 forward & backward,
        # FP32 master weights held by optimiser.
        with autocast(device_type=device):
            logprobs = model(x)                     # (B, T, V)
            loss = ctc(logprobs.permute(1,0,2), y, in_len, tar_len) # CTC expects (T, B, V)

        # scale -> backward (prevent underflow), then unscale for safe clipping
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        nn.utils.clip_grad_norm_(model.parameters(), 5.0) # exploding grads
        
        scaler.step(opt) # FP32 weight update
        scaler.update() # adjust scaling factor for next iter
        sched.step() # advance LR schedule
        
        if step % 100 == 0:
            print(f"[{epoch}] step {step}   loss {loss.item():.3f}")

    # --- validation & early-stopping
    model.eval(); hyp, ref = [], []
    with torch.no_grad():
        for x, y, in_len, tar_len in test_loader:
            x = x.to(device)
            lp = model(x)                    # (B, T, V)
            ids = lp.argmax(-1)              # greedy decode
            for i, l in enumerate(in_len):
                hyp.append(decode([k for k,_ in groupby(ids[i][:l].cpu().tolist()) if k!=0]))
            split = torch.split(y, tar_len)
            ref.extend([decode(t.tolist()) for t in split])
    val_wer = wer(ref, hyp)
    print(f"Epoch {epoch}  valid WER: {val_wer:.3%}")
    if val_wer + 0.002 < best_wer:
        best_wer, patience = val_wer, 0
        torch.save(model.state_dict(), "best_asr.pt")
    else:
        patience += 1
        if patience >= PATIENCE:
            print("Early stopping – no improvement"); break

cuda:0
Using 2 GPUs via DataParallel
[0] step 0   loss 18.155
[0] step 100   loss 6.043
[0] step 200   loss 5.074
[0] step 300   loss 3.448
[0] step 400   loss 2.872
[0] step 500   loss 2.840
[0] step 600   loss 2.749
[0] step 700   loss 2.740
[0] step 800   loss 2.672
Epoch 0  valid WER: 98.741%
[1] step 0   loss 2.681
[1] step 100   loss 2.661
[1] step 200   loss 2.536
[1] step 300   loss 2.259
[1] step 400   loss 1.996
[1] step 500   loss 1.850
[1] step 600   loss 1.693
[1] step 700   loss 1.500
[1] step 800   loss 1.299
Epoch 1  valid WER: 87.403%
[2] step 0   loss 1.323
[2] step 100   loss 1.219
[2] step 200   loss 1.192
[2] step 300   loss 1.075
[2] step 400   loss 1.030
[2] step 500   loss 1.005
[2] step 600   loss 0.913
[2] step 700   loss 0.923
[2] step 800   loss 0.852
Epoch 2  valid WER: 67.565%
[3] step 0   loss 0.797
[3] step 100   loss 0.752
[3] step 200   loss 0.790
[3] step 300   loss 0.752
[3] step 400   loss 0.681
[3] step 500   loss 0.778
[3] step 600   loss 0.760
[3

## Post-training Evaluation

After the network has finished learning we need an objective scoreboard.
This block switches the model to inference mode, runs it once over the
test_loader, and prints the two ASR staples — Word-Error-Rate (WER)
and Character-Error-Rate (CER). We use greedy decoding (arg-max per
frame, then CTC collapse) because it is fast and gives a lower-bound on the
model’s true performance.

In [19]:
model.eval(); hyp, ref = [], []
with torch.no_grad():
    for x, y, in_len, tar_len in test_loader:
        x = x.to(device)
        lp = model(x)                    # (B, T, V)
        ids = lp.argmax(-1)              # greedy
        for i, l in enumerate(in_len):
            hyp.append(decode([k for k,_ in groupby(ids[i][:l].cpu().tolist()) if k!=0]))
        split = torch.split(y, tar_len)
        ref.extend([decode(t.tolist()) for t in split])
    val_wer = wer(ref, hyp)
    val_cer = cer(ref, hyp)
    print(f"WER: {val_wer:.3%}")
    print(f"CER: {val_cer:.3%}")

WER: 19.975%
CER: 6.395%


* The network gets most letters right, so character accuracy is good.

* But a single letter slip can break an entire word, inflating WER.

* Word-boundary mistakes (extra or missing spaces) also hurt WER but not CER.

How good is ~20 % WER on train-clean-100 -> test-clean?

| Setup (literature)                                            | WER on *test-clean* |
| ------------------------------------------------------------- | ------------------- |
| Shallow CRNN (≈23 M params) trained **from scratch** on 100 h | **18–22 %**         |
| Same network + **2-gram KenLM beam search**                   | 14–17 %             |
| Pre-trained *wav2vec 2.0* base → fine-tuned on 100 h          | 4–6 %               |
| OpenAI *Whisper-tiny.en* fine-tune (≈40 M params)             | ≈ 4 %               |


In [41]:
print("Target example :", ref[0])
print("Prediction :", hyp[0])

Target example : he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce
Prediction : he hoped there would be stool for dinner turnips and carots and bruised potatos and fat muttin pieces to be ladled out and thick peppered flower fattens sauce


* The core content is intact; every noun but one is recognisable, which aligns with the low CER (6 %).

* Most mistakes are homophones or single-letter drops—classic symptoms of a character CTC model decoded without linguistic context.

* A small language model or even the simple spell-checker post-filter already implemented would likely fix flour/flower, stew/stool, carrots, potatoes, and fattened, reducing WER by several points.

This single sentence illustrates how the quantitative numbers (≈ 20 % WER, 6 % CER) translate into qualitative perception: the transcript is easily understandable, yet word-level polish is needed for production-grade accuracy.

## Post-correction

After CTC decoding the acoustic model still produces small, but WER-inflating, surface errors—single-letter drops (potatos), homophones (flour -> flower), or wrong inflections (fattened -> fattens). Rather than retrain the network we pass the raw transcript through a lightweight dictionary-based spell-checker (pyspellchecker). 

For every token that is not in the vocabulary the module picks the most probable alternative within one edit distance, exploiting corpus word-frequency statistics. This zero-cost CPU step typically removes 1–3% of WER on LibriSpeech-scale data, turning an already intelligible output into a clean, correctly spelled sentence.

In [28]:
from spellchecker import SpellChecker

spell = SpellChecker(distance=2)          # edit-distance ≤1  (fast; good when CER≈6 %)

def post_correct(sentence: str) -> str:
    """
    Replace each OOV token with SpellChecker's best suggestion.
    Keeps punctuation and numbers untouched.
    """
    corrected = []
    for token in sentence.split():
        # skip tokens that are in dictionary OR contain non-letters
        if token.lower() in spell or not token.isalpha():
            corrected.append(token)
        else:
            corrected.append(spell.correction(token) or token)  # fallback: keep as is
    return " ".join(corrected)

In [29]:
hyps_raw, refs = [], []

with torch.no_grad():
    for feats, labels, in_len, tar_len in test_loader:
        feats = feats.to(device, non_blocking=True)
        logp  = model(feats)
        pred  = logp.argmax(-1).cpu()

        for i, T in enumerate(in_len):
            ids = [k for k, _ in groupby(pred[i][:T].tolist()) if k != 0]
            txt = decode(ids)
            hyps_raw.append(txt)

        split = torch.split(labels, tar_len)
        refs.extend(decode(t.tolist()) for t in split)

In [30]:
# Apply post-correction
hyps_spell = [post_correct(h) for h in hyps_raw]

# Compute metrics – before / after spell-check
wer_raw  = wer(refs, hyps_raw)
cer_raw  = cer(refs, hyps_raw)
wer_post = wer(refs, hyps_spell)
cer_post = cer(refs, hyps_spell)

print(f"Raw  →  WER {wer_raw:.2%}   |  CER {cer_raw:.2%}")
print(f"Spell-corrected  →  WER {wer_post:.2%}   |  CER {cer_post:.2%}")

Raw  →  WER 19.97%   |  CER 6.40%
Spell-corrected  →  WER 16.91%   |  CER 6.49%


In [40]:
print("Target example :", refs[0])
print("Prediction :", hyps_spell[0])

Target example : he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce
Prediction : he hoped there would be stool for dinner turnips and carrots and bruised potato and fat mutton pieces to be ladled out and thick peppered flower fattens sauce


## Transcription

After training we need a convenient, reusable function that turns an arbitrary .wav clip into plain text. The transcribe routine below wraps the whole inference pipeline in five concise steps:

1. Load and preprocess The waveform is read with torchaudio.load, resampled (if needed) and converted to a log-Mel spectrogram by re-using the preprocess function.

2. Batchify & move to GPU/CPU A dummy batch dimension is added so the tensor shape matches what the network expects, and the tensor is copied to the chosen device.

3. Forward pass (no gradients) Inside torch.no_grad() the model produces log-probabilities over the vocabulary for every time frame.

4. Greedy CTC decoding argmax selects the highest-probability token per frame; consecutive duplicates and blank tokens are collapsed, then mapped back to characters with decode.

5. Optional spell-checking If spellcheck=True, the raw string is run through the post_correct spell-checker to fix common single-letter or homophone errors, improving readability and WER by a few points.

The function returns the final transcript string and can be called on any short (≈10 s) WAV file in a single line

In [42]:
# Single-file transcription routine
def transcribe(wav_path: str, spellcheck: bool = True) -> str:
    """Return the (optionally spell-corrected) transcription of wav_path."""
    wav, sr = torchaudio.load(wav_path)
    feat = preprocess(wav, sr, augment=False)        # (T, F)
    feat = feat.unsqueeze(0).to(device)              # (1, T, F)

    with torch.no_grad():
        logits = model(feat)                         # (1, T, V)

    ids = logits.argmax(-1)[0].cpu().tolist()   # (T,)
    ids = [k for k, _ in groupby(ids) if k != 0]
    text = decode(ids)
        
    if spellcheck:
        text = post_correct(text)

    return text

To test how well our purely English-trained model copes with foreign accents, we fed it a 10-second extract from a speech by François Hollande, former President of France, a clear non-native English speaker. Even though LibriSpeech contains only native (North-American) readers, the model still captures the overall structure of the sentence.

In [48]:
wav_file = "/kaggle/input/hollande/francois_hollande.wav"
print("Transcription:\n", transcribe(wav_file))
print("Transcription (no post-correction):\n", transcribe(wav_file, False))

Transcription:
 repro all you because you gon't be do what we wale to do
Transcription (no post-correction):
 reprod  all you becauese you gon't be do whatk we wale to do


The lightweight post-correction layer fixes many of these letter-level slips, yielding a far more legible transcription with no extra acoustic training.

This small experiment highlights two points:

* The CRNN-CTC network generalises beyond native accents surprisingly well, thanks to its character-level target space and SpecAugment robustness.

* A cheap text-only spell-checker can further clean up accent-induced spelling noise, closing much of the gap to native-speaker performance without re-training the model on accented data.