From ca567b632813b183450d3dac1fcdf401e0608646 Mon Sep 17 00:00:00 2001 From: hwangjeff Date: Wed, 15 Jun 2022 18:21:13 +0000 Subject: [PATCH] [release 0.12] Remove examples code using prototype features from release --- examples/asr/emformer_rnnt/README.md | 4 +- examples/asr/emformer_rnnt/pipeline_demo.py | 15 +- .../emformer_rnnt/tedlium3/eval_pipeline.py | 90 -------- .../asr/librispeech_conformer_rnnt/README.md | 49 ----- .../librispeech_conformer_rnnt/data_module.py | 194 ------------------ .../asr/librispeech_conformer_rnnt/eval.py | 75 ------- .../global_stats.json | 166 --------------- .../librispeech_conformer_rnnt/lightning.py | 183 ----------------- .../asr/librispeech_conformer_rnnt/train.py | 106 ---------- .../librispeech_conformer_rnnt/train_spm.py | 80 -------- .../librispeech_conformer_rnnt/transforms.py | 109 ---------- 11 files changed, 2 insertions(+), 1069 deletions(-) delete mode 100644 examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/README.md delete mode 100644 examples/asr/librispeech_conformer_rnnt/data_module.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/eval.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/global_stats.json delete mode 100644 examples/asr/librispeech_conformer_rnnt/lightning.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/train.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/train_spm.py delete mode 100644 examples/asr/librispeech_conformer_rnnt/transforms.py diff --git a/examples/asr/emformer_rnnt/README.md b/examples/asr/emformer_rnnt/README.md index 5925bf742f..1dce040f10 100644 --- a/examples/asr/emformer_rnnt/README.md +++ b/examples/asr/emformer_rnnt/README.md @@ -15,7 +15,7 @@ This directory contains sample implementations of training and evaluation pipeli ### Pipeline Demo [`pipeline_demo.py`](./pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` -or `EMFORMER_RNNT_BASE_TEDLIUM3` bundle that wraps a pre-trained Emformer RNN-T produced by the corresponding recipe below to perform streaming and full-context ASR on several audio samples. +bundle that wraps a pre-trained Emformer RNN-T produced by the LibriSpeech recipe below to perform streaming and full-context ASR on several audio samples. ## Model Types @@ -67,8 +67,6 @@ The table below contains WER results for dev and test subsets of TED-LIUM releas | dev | 0.108 | | test | 0.098 | -[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table. - ### MuST-C release v2.0 The MuST-C model is configured with a vocabulary size of 500. Consequently, the MuST-C model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol). In contrast to those of the datasets for the above two models, MuST-C's transcripts are cased and punctuated; we preserve the casing and punctuation when training the SentencePiece model. diff --git a/examples/asr/emformer_rnnt/pipeline_demo.py b/examples/asr/emformer_rnnt/pipeline_demo.py index 4821c1b8ca..de935f0167 100644 --- a/examples/asr/emformer_rnnt/pipeline_demo.py +++ b/examples/asr/emformer_rnnt/pipeline_demo.py @@ -13,13 +13,8 @@ import torch import torchaudio -from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3 -from mustc.dataset import MUSTC +from common import MODEL_TYPE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle -from torchaudio.prototype.pipelines import ( - EMFORMER_RNNT_BASE_MUSTC, - EMFORMER_RNNT_BASE_TEDLIUM3, -) logger = logging.getLogger(__name__) @@ -35,14 +30,6 @@ class Config: partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"), EMFORMER_RNNT_BASE_LIBRISPEECH, ), - MODEL_TYPE_MUSTC: Config( - partial(MUSTC, subset="tst-COMMON"), - EMFORMER_RNNT_BASE_MUSTC, - ), - MODEL_TYPE_TEDLIUM3: Config( - partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"), - EMFORMER_RNNT_BASE_TEDLIUM3, - ), } diff --git a/examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py b/examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py deleted file mode 100644 index fa34098baa..0000000000 --- a/examples/asr/emformer_rnnt/tedlium3/eval_pipeline.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -import logging -import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter - -import torch -import torchaudio -from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3 - - -logger = logging.getLogger(__name__) - - -def compute_word_level_distance(seq1, seq2): - return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) - - -def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_processor, use_cuda): - total_edit_distance = 0 - total_length = 0 - if subset == "dev": - dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="dev") - elif subset == "test": - dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="test") - with torch.no_grad(): - for idx in range(len(dataset)): - sample = dataset[idx] - waveform = sample[0].squeeze() - if use_cuda: - waveform = waveform.to(device="cuda") - actual = sample[2].replace("\n", "") - if actual == "ignore_time_segment_in_scoring": - continue - features, length = feature_extractor(waveform) - hypos = decoder(features, length, 20) - hypothesis = hypos[0] - hypothesis = token_processor(hypothesis[0]) - total_edit_distance += compute_word_level_distance(actual, hypothesis) - total_length += len(actual.split()) - if idx % 100 == 0: - print(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") - print(f"Final WER for {subset} set: {total_edit_distance / total_length}") - - -def run_eval_pipeline(args): - decoder = EMFORMER_RNNT_BASE_TEDLIUM3.get_decoder() - token_processor = EMFORMER_RNNT_BASE_TEDLIUM3.get_token_processor() - feature_extractor = EMFORMER_RNNT_BASE_TEDLIUM3.get_feature_extractor() - - if args.use_cuda: - feature_extractor = feature_extractor.to(device="cuda").eval() - decoder = decoder.to(device="cuda") - _eval_subset(args.tedlium_path, "dev", feature_extractor, decoder, token_processor, args.use_cuda) - _eval_subset(args.tedlium_path, "test", feature_extractor, decoder, token_processor, args.use_cuda) - - -def _parse_args(): - parser = ArgumentParser( - description=__doc__, - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--tedlium-path", - type=pathlib.Path, - help="Path to TED-LIUM release 3 dataset.", - ) - parser.add_argument( - "--use-cuda", - action="store_true", - default=False, - help="Run using CUDA.", - ) - parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") - return parser.parse_args() - - -def _init_logger(debug): - fmt = "%(asctime)s %(message)s" if debug else "%(message)s" - level = logging.DEBUG if debug else logging.INFO - logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") - - -def cli_main(): - args = _parse_args() - _init_logger(args.debug) - run_eval_pipeline(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/librispeech_conformer_rnnt/README.md b/examples/asr/librispeech_conformer_rnnt/README.md deleted file mode 100644 index 4ad68c3165..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/README.md +++ /dev/null @@ -1,49 +0,0 @@ -# Conformer RNN-T ASR Example - -This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model. - -## Setup -### Install PyTorch and TorchAudio nightly or from source -Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training. - -To install the nightly, follow the directions at . - -To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md). - -### Install additional dependencies -```bash -pip install pytorch-lightning sentencepiece -``` - -## Usage - -### Training - -[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following: -- Access to GPU nodes for training. -- Full LibriSpeech dataset. -- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py). -- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py). - -Sample SLURM command: -``` -srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_unigram_1023.model --epochs 160 -``` - -### Evaluation - -[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model on LibriSpeech test-clean. - -Sample SLURM command: -``` -srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=159.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_unigram_1023.model --use_cuda -``` - -The table below contains WER results for various splits. - -| | WER | -|:-------------------:|-------------:| -| test-clean | 0.0310 | -| test-other | 0.0805 | -| dev-clean | 0.0314 | -| dev-other | 0.0827 | diff --git a/examples/asr/librispeech_conformer_rnnt/data_module.py b/examples/asr/librispeech_conformer_rnnt/data_module.py deleted file mode 100644 index b256a8c902..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/data_module.py +++ /dev/null @@ -1,194 +0,0 @@ -import os -import random - -import torch -import torchaudio -from pytorch_lightning import LightningDataModule - - -def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None): - batches = [] - current_batch = [] - current_token_count = 0 - for idx, target_length in idx_target_lengths: - if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size): - batches.append(current_batch) - current_batch = [idx] - current_token_count = target_length - else: - current_batch.append(idx) - current_token_count += target_length - - if current_batch: - batches.append(current_batch) - - return batches - - -def get_sample_lengths(librispeech_dataset): - fileid_to_target_length = {} - - def _target_length(fileid): - if fileid not in fileid_to_target_length: - speaker_id, chapter_id, _ = fileid.split("-") - - file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt - file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text) - - with open(file_text) as ft: - for line in ft: - fileid_text, transcript = line.strip().split(" ", 1) - fileid_to_target_length[fileid_text] = len(transcript) - - return fileid_to_target_length[fileid] - - return [_target_length(fileid) for fileid in librispeech_dataset._walker] - - -class CustomBucketDataset(torch.utils.data.Dataset): - def __init__( - self, - dataset, - lengths, - max_tokens, - num_buckets, - shuffle=False, - batch_size=None, - ): - super().__init__() - - assert len(dataset) == len(lengths) - - self.dataset = dataset - - max_length = max(lengths) - min_length = min(lengths) - - assert max_tokens >= max_length - - buckets = torch.linspace(min_length, max_length, num_buckets) - lengths = torch.tensor(lengths) - bucket_assignments = torch.bucketize(lengths, buckets) - - idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)] - if shuffle: - idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets)) - else: - idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True) - - sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) - self.batches = _batch_by_token_count( - [(idx, length) for idx, length, _ in sorted_idx_length_buckets], - max_tokens, - batch_size=batch_size, - ) - - def __getitem__(self, idx): - return [self.dataset[subidx] for subidx in self.batches[idx]] - - def __len__(self): - return len(self.batches) - - -class TransformDataset(torch.utils.data.Dataset): - def __init__(self, dataset, transform_fn): - self.dataset = dataset - self.transform_fn = transform_fn - - def __getitem__(self, idx): - return self.transform_fn(self.dataset[idx]) - - def __len__(self): - return len(self.dataset) - - -class LibriSpeechDataModule(LightningDataModule): - librispeech_cls = torchaudio.datasets.LIBRISPEECH - - def __init__( - self, - *, - librispeech_path, - train_transform, - val_transform, - test_transform, - max_tokens=700, - batch_size=2, - train_num_buckets=50, - train_shuffle=True, - num_workers=10, - ): - super().__init__() - self.librispeech_path = librispeech_path - self.train_dataset_lengths = None - self.val_dataset_lengths = None - self.train_transform = train_transform - self.val_transform = val_transform - self.test_transform = test_transform - self.max_tokens = max_tokens - self.batch_size = batch_size - self.train_num_buckets = train_num_buckets - self.train_shuffle = train_shuffle - self.num_workers = num_workers - - def train_dataloader(self): - datasets = [ - self.librispeech_cls(self.librispeech_path, url="train-clean-360"), - self.librispeech_cls(self.librispeech_path, url="train-clean-100"), - self.librispeech_cls(self.librispeech_path, url="train-other-500"), - ] - - if not self.train_dataset_lengths: - self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] - - dataset = torch.utils.data.ConcatDataset( - [ - CustomBucketDataset( - dataset, - lengths, - self.max_tokens, - self.train_num_buckets, - batch_size=self.batch_size, - ) - for dataset, lengths in zip(datasets, self.train_dataset_lengths) - ] - ) - dataset = TransformDataset(dataset, self.train_transform) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=self.num_workers, - batch_size=None, - shuffle=self.train_shuffle, - ) - return dataloader - - def val_dataloader(self): - datasets = [ - self.librispeech_cls(self.librispeech_path, url="dev-clean"), - self.librispeech_cls(self.librispeech_path, url="dev-other"), - ] - - if not self.val_dataset_lengths: - self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] - - dataset = torch.utils.data.ConcatDataset( - [ - CustomBucketDataset( - dataset, - lengths, - self.max_tokens, - 1, - batch_size=self.batch_size, - ) - for dataset, lengths in zip(datasets, self.val_dataset_lengths) - ] - ) - dataset = TransformDataset(dataset, self.val_transform) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers) - return dataloader - - def test_dataloader(self): - dataset = self.librispeech_cls(self.librispeech_path, url="test-clean") - dataset = TransformDataset(dataset, self.test_transform) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=None) - return dataloader diff --git a/examples/asr/librispeech_conformer_rnnt/eval.py b/examples/asr/librispeech_conformer_rnnt/eval.py deleted file mode 100644 index 15c4cb6646..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/eval.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -import pathlib -from argparse import ArgumentParser - -import torch -import torchaudio -from lightning import ConformerRNNTModule, get_data_module - - -logger = logging.getLogger() - - -def compute_word_level_distance(seq1, seq2): - return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) - - -def run_eval(args): - model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model_path=str(args.sp_model_path)).eval() - data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path)) - - if args.use_cuda: - model = model.to(device="cuda") - - total_edit_distance = 0 - total_length = 0 - dataloader = data_module.test_dataloader() - with torch.no_grad(): - for idx, (batch, sample) in enumerate(dataloader): - actual = sample[0][2] - predicted = model(batch) - total_edit_distance += compute_word_level_distance(actual, predicted) - total_length += len(actual.split()) - if idx % 100 == 0: - logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") - logger.warning(f"Final WER: {total_edit_distance / total_length}") - - -def cli_main(): - parser = ArgumentParser() - parser.add_argument( - "--checkpoint-path", - type=pathlib.Path, - help="Path to checkpoint to use for evaluation.", - required=True, - ) - parser.add_argument( - "--global-stats-path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="Path to JSON file containing feature means and stddevs.", - ) - parser.add_argument( - "--librispeech-path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", - required=True, - ) - parser.add_argument( - "--sp-model-path", - type=pathlib.Path, - help="Path to SentencePiece model.", - required=True, - ) - parser.add_argument( - "--use-cuda", - action="store_true", - default=False, - help="Run using CUDA.", - ) - args = parser.parse_args() - run_eval(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/librispeech_conformer_rnnt/global_stats.json b/examples/asr/librispeech_conformer_rnnt/global_stats.json deleted file mode 100644 index c182ce7c65..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/global_stats.json +++ /dev/null @@ -1,166 +0,0 @@ -{ - "mean": [ - 15.058613777160645, - 16.34557342529297, - 16.34653663635254, - 16.240671157836914, - 17.45355224609375, - 17.445302963256836, - 17.52323341369629, - 18.076807022094727, - 17.699262619018555, - 17.706790924072266, - 17.24724578857422, - 17.153791427612305, - 17.213361740112305, - 17.347240447998047, - 17.331117630004883, - 17.21516227722168, - 17.030071258544922, - 16.818960189819336, - 16.573062896728516, - 16.29717254638672, - 16.00996971130371, - 15.794167518615723, - 15.616395950317383, - 15.459056854248047, - 15.306838989257812, - 15.199165344238281, - 15.208144187927246, - 14.883454322814941, - 14.787869453430176, - 14.947835922241211, - 14.5912504196167, - 14.76955509185791, - 14.617781639099121, - 14.840407371520996, - 14.83073616027832, - 14.909119606018066, - 14.89070987701416, - 14.918207168579102, - 14.939517974853516, - 14.913643836975098, - 14.863334655761719, - 14.803299903869629, - 14.751264572143555, - 14.688116073608398, - 14.63498306274414, - 14.615056037902832, - 14.680213928222656, - 14.616259574890137, - 14.707776069641113, - 14.630264282226562, - 14.644737243652344, - 14.547430038452148, - 14.529033660888672, - 14.49357795715332, - 14.411538124084473, - 14.33312702178955, - 14.260393142700195, - 14.204919815063477, - 14.130182266235352, - 14.06987476348877, - 14.010197639465332, - 13.938552856445312, - 13.750232696533203, - 13.607213973999023, - 13.457777976989746, - 13.31512451171875, - 13.167718887329102, - 13.019341468811035, - 12.8869047164917, - 12.795098304748535, - 12.685126304626465, - 12.620392799377441, - 12.58949089050293, - 12.537697792053223, - 12.496938705444336, - 12.410022735595703, - 12.346826553344727, - 12.221966743469238, - 12.122841835021973, - 12.005624771118164 - ], - "invstddev": [ - 0.25952333211898804, - 0.2590482831001282, - 0.24866817891597748, - 0.24776232242584229, - 0.22200720012187958, - 0.21363843977451324, - 0.20652402937412262, - 0.19909949600696564, - 0.2021811604499817, - 0.20355898141860962, - 0.20546883344650269, - 0.2061648815870285, - 0.20569036900997162, - 0.20412985980510712, - 0.20357738435268402, - 0.2041499763727188, - 0.2055872678756714, - 0.20807604491710663, - 0.21054454147815704, - 0.21341396868228912, - 0.21418628096580505, - 0.22065168619155884, - 0.2248840034008026, - 0.22723940014839172, - 0.230172261595726, - 0.23371541500091553, - 0.23734734952449799, - 0.23960146307945251, - 0.24088498950004578, - 0.241532102227211, - 0.24218633770942688, - 0.24371792376041412, - 0.2447739839553833, - 0.25564682483673096, - 0.2632736265659332, - 0.2549223005771637, - 0.24608071148395538, - 0.2464841604232788, - 0.2470586597919464, - 0.24785254895687103, - 0.24904784560203552, - 0.2503036856651306, - 0.25226327776908875, - 0.2532329559326172, - 0.2527913451194763, - 0.2518651783466339, - 0.2504975199699402, - 0.24836081266403198, - 0.24765831232070923, - 0.24767662584781647, - 0.24965286254882812, - 0.2501370906829834, - 0.2508895993232727, - 0.2512582540512085, - 0.25150999426841736, - 0.2525503635406494, - 0.25313329696655273, - 0.2534785270690918, - 0.25330957770347595, - 0.25366073846817017, - 0.25502219796180725, - 0.2608155608177185, - 0.25662899017333984, - 0.2558451294898987, - 0.25671014189720154, - 0.2577403485774994, - 0.25914356112480164, - 0.2596718966960907, - 0.25953933596611023, - 0.2610883116722107, - 0.26132410764694214, - 0.26272818446159363, - 0.26397505402565, - 0.26440608501434326, - 0.26543495059013367, - 0.26753780245780945, - 0.26935192942619324, - 0.26732245087623596, - 0.26666897535324097, - 0.2663257420063019 - ] -} diff --git a/examples/asr/librispeech_conformer_rnnt/lightning.py b/examples/asr/librispeech_conformer_rnnt/lightning.py deleted file mode 100644 index 2f8c09beb5..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/lightning.py +++ /dev/null @@ -1,183 +0,0 @@ -import logging -import math -from typing import List, Tuple - -import sentencepiece as spm -import torch -import torchaudio -from data_module import LibriSpeechDataModule -from pytorch_lightning import LightningModule -from torchaudio.models import Hypothesis, RNNTBeamSearch -from torchaudio.prototype.models import conformer_rnnt_base -from transforms import Batch, TestTransform, TrainTransform, ValTransform - -logger = logging.getLogger() - -_expected_spm_vocab_size = 1023 - - -class WarmupLR(torch.optim.lr_scheduler._LRScheduler): - r"""Learning rate scheduler that performs linear warmup and exponential annealing. - - Args: - optimizer (torch.optim.Optimizer): optimizer to use. - warmup_steps (int): number of scheduler steps for which to warm up learning rate. - force_anneal_step (int): scheduler step at which annealing of learning rate begins. - anneal_factor (float): factor to scale base learning rate by at each annealing step. - last_epoch (int, optional): The index of last epoch. (Default: -1) - verbose (bool, optional): If ``True``, prints a message to stdout for - each update. (Default: ``False``) - """ - - def __init__( - self, - optimizer: torch.optim.Optimizer, - warmup_steps: int, - force_anneal_step: int, - anneal_factor: float, - last_epoch=-1, - verbose=False, - ): - self.warmup_steps = warmup_steps - self.force_anneal_step = force_anneal_step - self.anneal_factor = anneal_factor - super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) - - def get_lr(self): - if self._step_count < self.force_anneal_step: - return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs] - else: - scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step) - return [scaling_factor * base_lr for base_lr in self.base_lrs] - - -def post_process_hypos( - hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor -) -> List[Tuple[str, float, List[int], List[int]]]: - tokens_idx = 0 - score_idx = 3 - post_process_remove_list = [ - sp_model.unk_id(), - sp_model.eos_id(), - sp_model.pad_id(), - ] - filtered_hypo_tokens = [ - [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos - ] - hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] - hypos_ids = [h[tokens_idx][1:] for h in hypos] - hypos_score = [[math.exp(h[score_idx])] for h in hypos] - - nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids)) - - return nbest_batch - - -class ConformerRNNTModule(LightningModule): - def __init__(self, sp_model_path): - super().__init__() - - self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) - spm_vocab_size = self.sp_model.get_piece_size() - assert spm_vocab_size == _expected_spm_vocab_size, ( - "The model returned by conformer_rnnt_base expects a SentencePiece model of " - f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size " - f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model." - ) - self.blank_idx = spm_vocab_size - - # ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration. - # For greater customizability, please refer to ``conformer_rnnt_model``. - self.model = conformer_rnnt_base() - self.loss = torchaudio.transforms.RNNTLoss(reduction="sum") - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9) - self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96) - - self.automatic_optimization = False - - def _step(self, batch, _, step_type): - if batch is None: - return None - - prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1]) - prepended_targets[:, 1:] = batch.targets - prepended_targets[:, 0] = self.blank_idx - prepended_target_lengths = batch.target_lengths + 1 - output, src_lengths, _, _ = self.model( - batch.features, - batch.feature_lengths, - prepended_targets, - prepended_target_lengths, - ) - loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths) - self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) - - return loss - - def configure_optimizers(self): - return ( - [self.optimizer], - [{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}], - ) - - def forward(self, batch: Batch): - decoder = RNNTBeamSearch(self.model, self.blank_idx) - hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20) - return post_process_hypos(hypotheses, self.sp_model)[0][0] - - def training_step(self, batch: Batch, batch_idx): - """Custom training step. - - By default, DDP does the following on each train step: - - For each GPU, compute loss and gradient on shard of training data. - - Sync and average gradients across all GPUs. The final gradient - is (sum of gradients across all GPUs) / N, where N is the world - size (total number of GPUs). - - Update parameters on each GPU. - - Here, we do the following: - - For k-th GPU, compute loss and scale it by (N / B_total), where B_total is - the sum of batch sizes across all GPUs. Compute gradient from scaled loss. - - Sync and average gradients across all GPUs. The final gradient - is (sum of gradients across all GPUs) / B_total. - - Update parameters on each GPU. - - Doing so allows us to account for the variability in batch sizes that - variable-length sequential data commonly yields. - """ - - opt = self.optimizers() - opt.zero_grad() - loss = self._step(batch, batch_idx, "train") - batch_size = batch.features.size(0) - batch_sizes = self.all_gather(batch_size) - self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True) - loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0) - opt.step() - - # step every epoch - sch = self.lr_schedulers() - if self.trainer.is_last_batch: - sch.step() - - return loss - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "val") - - def test_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "test") - - -def get_data_module(librispeech_path, global_stats_path, sp_model_path): - train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) - val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) - test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path) - return LibriSpeechDataModule( - librispeech_path=librispeech_path, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - ) diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py deleted file mode 100644 index 0d86b1db96..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ /dev/null @@ -1,106 +0,0 @@ -import pathlib -from argparse import ArgumentParser - -from lightning import ConformerRNNTModule, get_data_module -from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.plugins import DDPPlugin - - -def run_train(args): - seed_everything(1) - checkpoint_dir = args.exp_dir / "checkpoints" - checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/val_loss", - mode="min", - save_top_k=5, - save_weights_only=False, - verbose=True, - ) - train_checkpoint = ModelCheckpoint( - checkpoint_dir, - monitor="Losses/train_loss", - mode="min", - save_top_k=5, - save_weights_only=False, - verbose=True, - ) - lr_monitor = LearningRateMonitor(logging_interval="step") - callbacks = [ - checkpoint, - train_checkpoint, - lr_monitor, - ] - trainer = Trainer( - default_root_dir=args.exp_dir, - max_epochs=args.epochs, - num_nodes=args.nodes, - gpus=args.gpus, - accelerator="gpu", - strategy=DDPPlugin(find_unused_parameters=False), - callbacks=callbacks, - reload_dataloaders_every_n_epochs=1, - ) - - model = ConformerRNNTModule(str(args.sp_model_path)) - data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path)) - trainer.fit(model, data_module, ckpt_path=args.checkpoint_path) - - -def cli_main(): - parser = ArgumentParser() - parser.add_argument( - "--checkpoint-path", - default=None, - type=pathlib.Path, - help="Path to checkpoint to use for evaluation.", - ) - parser.add_argument( - "--exp-dir", - default=pathlib.Path("./exp"), - type=pathlib.Path, - help="Directory to save checkpoints and logs to. (Default: './exp')", - ) - parser.add_argument( - "--global-stats-path", - default=pathlib.Path("global_stats.json"), - type=pathlib.Path, - help="Path to JSON file containing feature means and stddevs.", - ) - parser.add_argument( - "--librispeech-path", - type=pathlib.Path, - help="Path to LibriSpeech datasets.", - required=True, - ) - parser.add_argument( - "--sp-model-path", - type=pathlib.Path, - help="Path to SentencePiece model.", - required=True, - ) - parser.add_argument( - "--nodes", - default=4, - type=int, - help="Number of nodes to use for training. (Default: 4)", - ) - parser.add_argument( - "--gpus", - default=8, - type=int, - help="Number of GPUs per node to use for training. (Default: 8)", - ) - parser.add_argument( - "--epochs", - default=120, - type=int, - help="Number of epochs to train for. (Default: 120)", - ) - args = parser.parse_args() - run_train(args) - - -if __name__ == "__main__": - cli_main() diff --git a/examples/asr/librispeech_conformer_rnnt/train_spm.py b/examples/asr/librispeech_conformer_rnnt/train_spm.py deleted file mode 100644 index 75dba161c4..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/train_spm.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500. - -Example: -python train_spm.py --librispeech-path ./datasets -""" - -import io -import pathlib -from argparse import ArgumentParser, RawTextHelpFormatter - -import sentencepiece as spm - - -def get_transcript_text(transcript_path): - with open(transcript_path) as f: - return [line.strip().split(" ", 1)[1].lower() for line in f] - - -def get_transcripts(dataset_path): - transcript_paths = dataset_path.glob("*/*/*.trans.txt") - merged_transcripts = [] - for path in transcript_paths: - merged_transcripts += get_transcript_text(path) - return merged_transcripts - - -def train_spm(input): - model_writer = io.BytesIO() - spm.SentencePieceTrainer.train( - sentence_iterator=iter(input), - model_writer=model_writer, - vocab_size=1023, - model_type="unigram", - input_sentence_size=-1, - character_coverage=1.0, - bos_id=0, - pad_id=1, - eos_id=2, - unk_id=3, - ) - return model_writer.getvalue() - - -def parse_args(): - default_output_path = "./spm_unigram_1023.model" - parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) - parser.add_argument( - "--librispeech-path", - required=True, - type=pathlib.Path, - help="Path to LibriSpeech dataset.", - ) - parser.add_argument( - "--output-file", - default=pathlib.Path(default_output_path), - type=pathlib.Path, - help=f"File to save model to. (Default: '{default_output_path}')", - ) - return parser.parse_args() - - -def run_cli(): - args = parse_args() - - root = args.librispeech_path / "LibriSpeech" - splits = ["train-clean-100", "train-clean-360", "train-other-500"] - merged_transcripts = [] - for split in splits: - path = pathlib.Path(root) / split - merged_transcripts += get_transcripts(path) - - model = train_spm(merged_transcripts) - - with open(args.output_file, "wb") as f: - f.write(model) - - -if __name__ == "__main__": - run_cli() diff --git a/examples/asr/librispeech_conformer_rnnt/transforms.py b/examples/asr/librispeech_conformer_rnnt/transforms.py deleted file mode 100644 index 80531ea373..0000000000 --- a/examples/asr/librispeech_conformer_rnnt/transforms.py +++ /dev/null @@ -1,109 +0,0 @@ -import json -import math -from collections import namedtuple -from functools import partial -from typing import List - -import sentencepiece as spm -import torch -import torchaudio - - -Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) - - -_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) -_gain = pow(10, 0.05 * _decibel) - -_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) - - -def _piecewise_linear_log(x): - x = x * _gain - x[x > math.e] = torch.log(x[x > math.e]) - x[x <= math.e] = x[x <= math.e] / math.e - return x - - -class FunctionalModule(torch.nn.Module): - def __init__(self, functional): - super().__init__() - self.functional = functional - - def forward(self, input): - return self.functional(input) - - -class GlobalStatsNormalization(torch.nn.Module): - def __init__(self, global_stats_path): - super().__init__() - - with open(global_stats_path) as f: - blob = json.loads(f.read()) - - self.mean = torch.tensor(blob["mean"]) - self.invstddev = torch.tensor(blob["invstddev"]) - - def forward(self, input): - return (input - self.mean) * self.invstddev - - -def _extract_labels(sp_model, samples: List): - targets = [sp_model.encode(sample[2].lower()) for sample in samples] - lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32) - targets = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(elem) for elem in targets], - batch_first=True, - padding_value=1.0, - ).to(dtype=torch.int32) - return targets, lengths - - -def _extract_features(data_pipeline, samples: List): - mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] - features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) - features = data_pipeline(features) - lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) - return features, lengths - - -class TrainTransform: - def __init__(self, global_stats_path: str, sp_model_path: str): - self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) - self.train_data_pipeline = torch.nn.Sequential( - FunctionalModule(_piecewise_linear_log), - GlobalStatsNormalization(global_stats_path), - FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)), - torchaudio.transforms.FrequencyMasking(27), - torchaudio.transforms.FrequencyMasking(27), - torchaudio.transforms.TimeMasking(100, p=0.2), - torchaudio.transforms.TimeMasking(100, p=0.2), - FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)), - ) - - def __call__(self, samples: List): - features, feature_lengths = _extract_features(self.train_data_pipeline, samples) - targets, target_lengths = _extract_labels(self.sp_model, samples) - return Batch(features, feature_lengths, targets, target_lengths) - - -class ValTransform: - def __init__(self, global_stats_path: str, sp_model_path: str): - self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) - self.valid_data_pipeline = torch.nn.Sequential( - FunctionalModule(_piecewise_linear_log), - GlobalStatsNormalization(global_stats_path), - ) - - def __call__(self, samples: List): - features, feature_lengths = _extract_features(self.valid_data_pipeline, samples) - targets, target_lengths = _extract_labels(self.sp_model, samples) - return Batch(features, feature_lengths, targets, target_lengths) - - -class TestTransform: - def __init__(self, global_stats_path: str, sp_model_path: str): - self.val_transforms = ValTransform(global_stats_path, sp_model_path) - - def __call__(self, sample): - return self.val_transforms([sample]), [sample]