<a href="https://colab.research.google.com/github/alexjercan/asr-toolkit/blob/master/examples/03_LanguageModels_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect
"""
# If you're using Google Colab and not running locally, run this cell.

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install unidecode
!pip install matplotlib>=3.3.2
!apt-get install libsox-fmt-all libsox-dev sox > /dev/null
!pip install torchaudio
!python -m pip install git+https://github.com/facebookresearch/WavAugment.git > /dev/null
!pip install wandb

## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

# install beam search decoder
!apt-get install -y swig
!git clone https://github.com/NVIDIA/NeMo -b "$BRANCH"
!cd NeMo && bash scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh


"""
Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!
Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case
that you want to use the "Run All Cells" (or similar) option.
"""
# exit()
from IPython.display import clear_output
clear_output()

In [None]:
import os
import re
import wget
import gzip
import shutil

import nemo
import nemo.collections.asr as nemo_asr
import torch
import torch.nn as nn
import numpy as np
import augment
import torchaudio
import torchaudio.datasets

from datetime import datetime as dt
from tqdm import tqdm
import matplotlib.pyplot as plt

from asr.metrics import ASRMetricFunction, CTCLossFunction
from asr.visualisation import play_audio, print_err_html, print_stats, plot_waveform
from asr.general import set_parameter_requires_grad, load_checkpoint, save_checkpoint, tensors_to_device, tensor_to_string
from asr.utils import ChainRunner
from asr.models import GreedyDecoder, BeamSearchDecoderWithLM
from asr.datasets import librispeech_dataloader
from IPython.display import YouTubeVideo

print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME='stt_en_jasper10x5dr'
LM_3GRAM_PATH = '3-gram.arpa'
LM_4GRAM_PATH = '4-gram.arpa'
ROOT = os.path.join(".")

[NeMo W 2021-08-11 13:04:35 optimizers:47] Apex was not found. Using the lamb optimizer will error out.
################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


[NeMo W 2021-08-11 13:04:47 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text_dali._AudioTextDALIDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.


Setup complete. Using torch 1.9.0+cu102 _CudaDeviceProperties(name='Tesla K80', major=3, minor=7, total_memory=11441MB, multi_processor_count=13)


In [None]:
def download_lm(lm_path):
    %rm -v "{lm_path}"*
    !wget "https://www.openslr.org/resources/11/{lm_path}.gz" -O "{lm_path}.gz"
    !gzip -cdv "{lm_path}.gz" > "{lm_path}"

model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=MODEL_NAME, strict=False).to(DEVICE)

VOCABULARY = list(map(lambda x: x.upper(), model.decoder.vocabulary))
vocab = VOCABULARY + ['<pad>']
BLANK = len(vocab) - 1 

DICTIONARY = dict(zip(vocab, range(len(vocab))))
LABELS = {v:k for k, v in DICTIONARY.items()}

_, test_dataloader = librispeech_dataloader(DICTIONARY, root=ROOT, urls=["test-clean"], folder_in_archive="LibriSpeech", batch_size=4, download=True)

download_lm(LM_3GRAM_PATH)
download_lm(LM_4GRAM_PATH)

[NeMo I 2021-08-11 13:05:29 cloud:56] Found existing object /root/.cache/torch/NeMo/NeMo_1.2.0/stt_en_jasper10x5dr/856ae08d5c4bd78b5e27f696e96f7aab/stt_en_jasper10x5dr.nemo.
[NeMo I 2021-08-11 13:05:29 cloud:62] Re-using file from: /root/.cache/torch/NeMo/NeMo_1.2.0/stt_en_jasper10x5dr/856ae08d5c4bd78b5e27f696e96f7aab/stt_en_jasper10x5dr.nemo
[NeMo I 2021-08-11 13:05:29 common:681] Instantiating model from pre-trained checkpoint


[NeMo W 2021-08-11 13:06:04 modelPT:131] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /data2/voices/train_1k.json
    sample_rate: 16000
    labels:
    - ' '
    - a
    - b
    - c
    - d
    - e
    - f
    - g
    - h
    - i
    - j
    - k
    - l
    - m
    - 'n'
    - o
    - p
    - q
    - r
    - s
    - t
    - u
    - v
    - w
    - x
    - 'y'
    - z
    - ''''
    batch_size: 32
    trim_silence: true
    max_duration: 16.7
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    
[NeMo W 2021-08-11 13:06:04 modelPT:138] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: /data2/vo

[NeMo I 2021-08-11 13:06:04 features:252] PADDING: 16
[NeMo I 2021-08-11 13:06:04 features:269] STFT using torch
[NeMo I 2021-08-11 13:06:29 save_restore_connector:143] Model EncDecCTCModel was successfully restored from /root/.cache/torch/NeMo/NeMo_1.2.0/stt_en_jasper10x5dr/856ae08d5c4bd78b5e27f696e96f7aab/stt_en_jasper10x5dr.nemo.
removed '3-gram.arpa'
removed '3-gram.arpa.gz'
--2021-08-11 13:06:30--  https://www.openslr.org/resources/11/3-gram.arpa.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 759636181 (724M) [application/x-gzip]
Saving to: ‘3-gram.arpa.gz’


2021-08-11 13:06:37 (104 MB/s) - ‘3-gram.arpa.gz’ saved [759636181/759636181]

3-gram.arpa.gz:	 68.3%
removed '4-gram.arpa'
removed '4-gram.arpa.gz'
--2021-08-11 13:07:22--  https://www.openslr.org/resources/11/4-gram.arpa.gz
Resolving www.openslr.org (www.openslr.org)... 46.

In [None]:
# Test model
greedy_lm = GreedyDecoder(LABELS, BLANK)

model.eval()
metric_fn = ASRMetricFunction()
loss_fn = CTCLossFunction(blank=BLANK)
loop = tqdm(test_dataloader, position=0, leave=True)

for batch_idx, tensors in enumerate(loop):
    valid_lengths, waveform, target_lengths, utterance = tensors_to_device(tensors, DEVICE)

    with torch.no_grad():
        log_probs, encoded_len, greedy_predictions = model(input_signal=waveform, input_signal_length=valid_lengths)
        loss_fn(log_probs.permute(1, 0, 2), utterance, encoded_len, target_lengths)

        transcriptions = greedy_lm(greedy_predictions, predictions_len=encoded_len)

    metric_fn(tensor_to_string(utterance, target_lengths, LABELS), transcriptions)

    loop.set_postfix(loss=loss_fn.show())
loop.close()
print(metric_fn.show())

100%|██████████| 655/655 [18:51<00:00,  1.73s/it, loss=(ctc:0.0531)]


WER=4.1016	CER=1.2654






In [None]:
# Test model
print("Testing without languange model")
beam_search_lm = BeamSearchDecoderWithLM(
    vocab=VOCABULARY,
    beam_width=16,
    alpha=1.5, beta=1.5,
    lm_path=None,
    num_cpus=max(os.cpu_count(), 1))

def best_transcriptions(transcriptions):
    return list(map(lambda xs: xs[0][1], transcriptions))

model.eval()
metric_fn = ASRMetricFunction()
loss_fn = CTCLossFunction(blank=BLANK)
loop = tqdm(test_dataloader, position=0, leave=True)

for batch_idx, tensors in enumerate(loop):
    valid_lengths, waveform, target_lengths, utterance = tensors_to_device(tensors, DEVICE)

    with torch.no_grad():
        log_probs, encoded_len, greedy_predictions = model(input_signal=waveform, input_signal_length=valid_lengths)
        loss_fn(log_probs.permute(1, 0, 2), utterance, encoded_len, target_lengths)

        transcriptions = beam_search_lm(log_probs=log_probs, log_probs_length=encoded_len)

    metric_fn(tensor_to_string(utterance, target_lengths, LABELS), best_transcriptions(transcriptions))

    loop.set_postfix(loss=loss_fn.show())
loop.close()
print(metric_fn.show())

Testing without languange model


100%|██████████| 655/655 [19:42<00:00,  1.81s/it, loss=(ctc:0.0531)]


WER=4.1086	CER=1.2684






In [None]:
# Test model
print("Testing 3-gram languange model")
beam_search_lm = BeamSearchDecoderWithLM(
    vocab=VOCABULARY,
    beam_width=16,
    alpha=1.5, beta=1.5,
    lm_path=LM_3GRAM_PATH,
    num_cpus=max(os.cpu_count(), 1))

def best_transcriptions(transcriptions):
    return list(map(lambda xs: xs[0][1], transcriptions))

model.eval()
metric_fn = ASRMetricFunction()
loss_fn = CTCLossFunction(blank=BLANK)
loop = tqdm(test_dataloader, position=0, leave=True)

for batch_idx, tensors in enumerate(loop):
    valid_lengths, waveform, target_lengths, utterance = tensors_to_device(tensors, DEVICE)

    with torch.no_grad():
        log_probs, encoded_len, greedy_predictions = model(input_signal=waveform, input_signal_length=valid_lengths)
        loss_fn(log_probs.permute(1, 0, 2), utterance, encoded_len, target_lengths)

        transcriptions = beam_search_lm(log_probs=log_probs, log_probs_length=encoded_len)

    metric_fn(tensor_to_string(utterance, target_lengths, LABELS), best_transcriptions(transcriptions))

    loop.set_postfix(loss=loss_fn.show())
loop.close()
print(metric_fn.show())

Testing 3-gram languange model


  0%|          | 0/655 [00:00<?, ?it/s][NeMo W 2021-08-11 13:11:08 patch_utils:50] torch.stft() signature has been updated for PyTorch 1.7+
    Please update PyTorch to remain compatible with later versions of NeMo.
    To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
      return torch.floor_divide(self, other)
    
100%|██████████| 655/655 [19:11<00:00,  1.76s/it, loss=(ctc:0.0531)]


WER=3.7385	CER=1.3473






In [None]:
# Test model
print("Testing 4-gram languange model")
beam_search_lm = BeamSearchDecoderWithLM(
    vocab=VOCABULARY,
    beam_width=16,
    alpha=1.5, beta=1.5,
    lm_path=LM_4GRAM_PATH,
    num_cpus=max(os.cpu_count(), 1))

def best_transcriptions(transcriptions):
    return list(map(lambda xs: xs[0][1], transcriptions))

model.eval()
metric_fn = ASRMetricFunction()
loss_fn = CTCLossFunction(blank=BLANK)
loop = tqdm(test_dataloader, position=0, leave=True)

for batch_idx, tensors in enumerate(loop):
    valid_lengths, waveform, target_lengths, utterance = tensors_to_device(tensors, DEVICE)

    with torch.no_grad():
        log_probs, encoded_len, greedy_predictions = model(input_signal=waveform, input_signal_length=valid_lengths)
        loss_fn(log_probs.permute(1, 0, 2), utterance, encoded_len, target_lengths)

        transcriptions = beam_search_lm(log_probs=log_probs, log_probs_length=encoded_len)

    metric_fn(tensor_to_string(utterance, target_lengths, LABELS), best_transcriptions(transcriptions))

    loop.set_postfix(loss=loss_fn.show())
loop.close()
print(metric_fn.show())

  0%|          | 0/655 [00:00<?, ?it/s]

Testing 4-gram languange model


100%|██████████| 655/655 [19:12<00:00,  1.76s/it, loss=(ctc:0.0531)]


WER=3.7676	CER=1.3546




