Skip to content

Commit

Permalink
ASR evaluator (NVIDIA#5728)
Browse files Browse the repository at this point in the history
* backbone

Signed-off-by: fayejf <fayejf07@gmail.com>

* engineer and analyzer

Signed-off-by: fayejf <fayejf07@gmail.com>

* offline_by_chunked

Signed-off-by: fayejf <fayejf07@gmail.com>

* test_ds wip

Signed-off-by: fayejf <fayejf07@gmail.com>

* temp remove inference

Signed-off-by: fayejf <fayejf07@gmail.com>

* mandarin yaml

Signed-off-by: fayejf <fayejf07@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* augmentor and a few updates

Signed-off-by: fayejf <fayejf07@gmail.com>

* address alerts and revert unnecessary changes

Signed-off-by: fayejf <fayejf07@gmail.com>

* Add readme

Signed-off-by: fayejf <fayejf07@gmail.com>

* rename

Signed-off-by: fayejf <fayejf07@gmail.com>

* typo fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* small fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* add missing header

Signed-off-by: fayejf <fayejf07@gmail.com>

* rename augmentor_config to augmentor

Signed-off-by: fayejf <fayejf07@gmail.com>

* raname inference_mode to inference

Signed-off-by: fayejf <fayejf07@gmail.com>

* move utils.py

Signed-off-by: fayejf <fayejf07@gmail.com>

* update temp file

Signed-off-by: fayejf <fayejf07@gmail.com>

* make wer cer clear

Signed-off-by: fayejf <fayejf07@gmail.com>

* seed_everything

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix missing rn augmentor_config in rnnt

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix rnnt transcribe

Signed-off-by: fayejf <fayejf07@gmail.com>

* add more docstring and style fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* address codeQL

Signed-off-by: fayejf <fayejf07@gmail.com>

* reflect comments

Signed-off-by: fayejf <fayejf07@gmail.com>

* update readme

Signed-off-by: fayejf <fayejf07@gmail.com>

* clearer

Signed-off-by: fayejf <fayejf07@gmail.com>

Signed-off-by: fayejf <fayejf07@gmail.com>
Signed-off-by: fayejf <36722593+fayejf@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
  • Loading branch information
3 people authored and titu1994 committed Mar 24, 2023
1 parent 9879c1f commit 600dd46
Show file tree
Hide file tree
Showing 12 changed files with 685 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Optional

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

Expand Down Expand Up @@ -71,6 +72,7 @@ class TranscriptionConfig:
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
Expand All @@ -96,6 +98,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Optional

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf, open_dict

Expand Down Expand Up @@ -95,6 +96,7 @@ class TranscriptionConfig:
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
Expand Down Expand Up @@ -127,6 +129,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
Expand Down
15 changes: 15 additions & 0 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ class TranscriptionConfig:
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
channel_selector: Optional[int] = None # Used to select a single channel from multi-channel files
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation

# General configs
output_filename: Optional[str] = None
batch_size: int = 32
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False
Expand Down Expand Up @@ -152,11 +154,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

# Load augmentor from exteranl yaml file which contains eval info, could be extend to other feature such VAD, P&C
augmentor = None
if cfg.eval_config_yaml:
eval_config = OmegaConf.load(cfg.eval_config_yaml)
augmentor = eval_config.test_ds.get("augmentor")
logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ")

# setup GPU
if cfg.cuda is None:
if torch.cuda.is_available():
Expand Down Expand Up @@ -253,6 +265,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
else:
logging.warning(
Expand All @@ -264,6 +277,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
else:
transcriptions = asr_model.transcribe(
Expand All @@ -272,6 +286,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)

logging.info(f"Finished transcribing {len(filepaths)} files !")
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

Expand Down
11 changes: 10 additions & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ def transcribe(
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
) -> List[str]:
"""
If modify this function, please remember update transcribe_partial_audio() in
nemo/collections/asr/parts/utils/trancribe_utils.py
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Args:
Expand All @@ -131,7 +135,7 @@ def transcribe(
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
Returns:
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
"""
Expand Down Expand Up @@ -182,6 +186,9 @@ def transcribe(
'channel_selector': channel_selector,
}

if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
logits, logits_len, greedy_predictions = self.forward(
Expand Down Expand Up @@ -724,6 +731,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'pin_memory': True,
'channel_selector': config.get('channel_selector', None),
}
if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,5 +579,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer
9 changes: 8 additions & 1 deletion nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def transcribe(
partial_hypothesis: Optional[List['Hypothesis']] = None,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -232,7 +233,7 @@ def transcribe(
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
Returns:
A list of transcriptions in the same order as paths2audio_files. Will also return
"""
Expand Down Expand Up @@ -277,6 +278,9 @@ def transcribe(
'channel_selector': channel_selector,
}

if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
encoded, encoded_len = self.forward(
Expand Down Expand Up @@ -938,6 +942,9 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'pin_memory': True,
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,16 @@ def write_transcription(

def transcribe_partial_audio(
asr_model,
path2manifest: str,
path2manifest: str = None,
batch_size: int = 4,
logprobs: bool = False,
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[int] = None,
augmentor: DictConfig = None,
) -> List[str]:
"""
See description of this function in trancribe() in nemo/collections/asr/models/ctc_models.py """

assert isinstance(asr_model, EncDecCTCModel), "Currently support CTC model only."

Expand Down Expand Up @@ -377,6 +380,8 @@ def transcribe_partial_audio(
'num_workers': num_workers,
'channel_selector': channel_selector,
}
if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = asr_model._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
Expand Down
44 changes: 44 additions & 0 deletions tools/asr_evaluator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
ASR evaluator
--------------------

A tool for thoroughly evaluating the performance of ASR models and other features such as Voice Activity Detection.

Features:
- Simple step to evaluate a model in all three modes currently supported by NeMo: offline, chunked, and offline_by_chunked.
- On-the-fly data augmentation (such as silence, noise, etc.,) for ASR robustness evaluation.
- Investigate the model's performance by detailed insertion, deletion, and substitution error rates for each and all samples.
- Evaluate models' reliability on different target groups such as gender, and audio length if metadata is presented.


ASR evaluator contains two main parts:
- **ENGINE**. To conduct ASR inference.
- **ANALYST**. To evaluate model performance based on predictions.

In Analyst, we can evaluate on metadata (such as duration, emotion, etc.) if it presents in manifest. For example, with the following config, we can calculate WERs for audios in different interval groups, where each group (in seconds) is defined by [[0,2],[2,5],[5,10],[10,20],[20,100000]]. Also, we can calculate the WERs for three groups of emotions, where each group is defined by [['happy','laugh'],['neutral'],['sad']]. Moreover, if we set save_wer_per_class=True, it will calculate WERs for audios in all classes presented in the data (i.e. above 5 classes + 'cry' which presented in data but not in the slot).

```
analyst:
metadata:
duration:
enable: True
slot: [[0,2],[2,5],[5,10],[10,20],[20,100000]]
save_wer_per_class: False # whether to save wer for each presented class.
emotion:
enable: True
slot: [['happy','laugh'],['neutral'],['sad']] # we could have 'cry' in data but not in slot we focus on.
save_wer_per_class: False
```
Check `./conf/eval.yaml` for the supported configuration.

If you plan to evaluate/add new tasks such as Punctuation and Capitalization, add it to the engine.

Run
```
python asr_evaluator.py \
engine.pretrained_name="stt_en_conformer_transducer_large" \
engine.inference_mode.mode="offline" \
engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data>
```
92 changes: 92 additions & 0 deletions tools/asr_evaluator/asr_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

import git
from omegaconf import OmegaConf
from utils import cal_target_metadata_wer, cal_write_wer, run_asr_inference

from nemo.core.config import hydra_runner
from nemo.utils import logging


"""
This script serves as evaluator of ASR models
Usage:
python python asr_evaluator.py \
engine.pretrained_name="stt_en_conformer_transducer_large" \
engine.inference.mode="offline" \
engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data> \
.....
Check out parameters in ./conf/eval.yaml
"""


@hydra_runner(config_path="conf", config_name="eval.yaml")
def main(cfg):
report = {}
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

# Store git hash for reproducibility
if cfg.env.save_git_hash:
repo = git.Repo(search_parent_directories=True)
report['git_hash'] = repo.head.object.hexsha

## Engine
# Could skip next line to use generated manifest

# If need to change more parameters for ASR inference, change it in
# 1) shell script in eval_utils.py in nemo/collections/asr/parts/utils or
# 2) TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr
cfg.engine = run_asr_inference(cfg=cfg.engine)

## Analyst
cfg, total_res, eval_metric = cal_write_wer(cfg)
report.update({"res": total_res})

for target in cfg.analyst.metadata:
if cfg.analyst.metadata[target].enable:
occ_avg_wer = cal_target_metadata_wer(
manifest=cfg.analyst.metric_calculator.output_filename,
target=target,
meta_cfg=cfg.analyst.metadata[target],
eval_metric=eval_metric,
)
report[target] = occ_avg_wer

config_engine = OmegaConf.to_object(cfg.engine)
report.update(config_engine)

config_metric_calculator = OmegaConf.to_object(cfg.analyst.metric_calculator)
report.update(config_metric_calculator)

pretty = json.dumps(report, indent=4)
res = "%.3f" % (report["res"][eval_metric] * 100)
logging.info(pretty)
logging.info(f"Overall {eval_metric} is {res} %")

## Writer
report_file = "report.json"
if "report_filename" in cfg.writer and cfg.writer.report_filename:
report_file = cfg.writer.report_filename

with open(report_file, "a") as fout:
json.dump(report, fout)
fout.write('\n')
fout.flush()


if __name__ == "__main__":
main()

0 comments on commit 600dd46

Please sign in to comment.