In [2]:
import nemo.collections.asr as nemo_asr
import os
from tqdm.auto import tqdm
import numpy as np
import json
import pickle
from nemo.utils import logging
import torch
import contextlib
import nemo
import editdistance
import csv

[NeMo W 2022-04-04 09:39:20 optimizers:55] Apex was not found. Using the lamb or fused_adam optimizer will error out.
################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



## Prepare Manifest

In [3]:
test_manifest = "/home/khoatlv/manifests/test_manifest_processed.json"
train_manifest = "/home/khoatlv/manifests/train_manifest_processed.json"
all_data_manifest = "/home/khoatlv/Conformer_ASR/scripts/evaluation/all_data_manifest.json"

# pickle data
probs_cache_file = "/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_asr_model/probs_cache_file"
conformer_transcribe_log = "/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_asr_model/conformer_log.json"

os.system(f"cat {test_manifest} {train_manifest} > {all_data_manifest}")
os.system(f"wc -l {all_data_manifest}")

use_amp = True
acoustic_batch_size = 16
beam_width = 200
alpha=2
beta=2.5

55188 /home/khoatlv/Conformer_ASR/scripts/evaluation/all_data_manifest.json


In [4]:
def softmax(x):
    e = np.exp(x - np.max(x))
    return e / e.sum(axis=-1).reshape([x.shape[0], 1])

## Evaluate Conformer Model

In [5]:
lm_path = "/home/khoatlv/Conformer_ASR/n_gram_lm/n_gram_lm_model/6-conformer-small-gram_trained.bin"
asr_model_path = "/home/khoatlv/Conformer_ASR/models/conformer/Conformer_small_epoch=98.nemo"

# Restore ASR Model
asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from(
    restore_path=asr_model_path,
    map_location='cuda'    
)

# Restore Beam Search N-LM
TOKEN_OFFSET = 100
vocab = asr_model.decoder.vocabulary
vocab = [chr(idx + TOKEN_OFFSET) for idx in range(len(vocab))]
ids_to_text_func = asr_model.tokenizer.ids_to_text

beam_search_lm = nemo_asr.modules.BeamSearchDecoderWithLM(
    vocab=list(vocab),
    beam_width=beam_width,
    alpha=alpha, 
    beta=beta,
    lm_path=lm_path,
    num_cpus=max(os.cpu_count(), 1),
    input_tensor=False
)

[NeMo I 2022-04-04 09:39:23 mixins:146] Tokenizer SentencePieceTokenizer initialized with 256 tokens


[NeMo W 2022-04-04 09:39:23 modelPT:148] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /home/nhan/NovaIntechs/data/ASR_Data/manifests/train_manifest_processed.json
    sample_rate: 16000
    max_duration: 16.7
    min_duration: 0.1
    is_tarred: false
    tarred_audio_filepaths: null
    shuffle_n: 2048
    bucketing_strategy: synced_randomized
    bucketing_batch_size: null
    shuffle: true
    batch_size: 32
    pin_memory: true
    trim_silence: true
    use_start_end_token: true
    normalize_transcripts: false
    num_workers: 16
    
[NeMo W 2022-04-04 09:39:23 modelPT:155] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manif

[NeMo I 2022-04-04 09:39:23 features:255] PADDING: 0
[NeMo I 2022-04-04 09:39:23 features:272] STFT using torch
[NeMo I 2022-04-04 09:39:29 save_restore_connector:157] Model EncDecCTCModelBPE was successfully restored from /home/khoatlv/Conformer_ASR/models/conformer/Conformer_small_epoch=98.nemo.


In [7]:
def eval_comformer():
    # Load manifest data and extract audio_path, target text
    target_transcripts = []
    with open(all_data_manifest, 'r') as manifest_file:
        audio_file_paths = []
        for line in tqdm(manifest_file, desc=f"Reading Manifest {all_data_manifest} ...", ncols=120):
            data = json.loads(line)
            target_transcripts.append(data['text'])
            audio_file_paths.append(data['audio_filepath'])
    
    # Load audio wav and transribe
    if probs_cache_file and os.path.exists(probs_cache_file):
        logging.info(f"Found a pickle file of probabilities at '{probs_cache_file}'.")
        logging.info(f"Loading the cached pickle file of probabilities from '{probs_cache_file}' ...")
        with open(probs_cache_file, 'rb') as probs_file:
            all_probs = pickle.load(probs_file)

        if len(all_probs) != len(audio_file_paths):
            raise ValueError(
                f"The number of samples in the probabilities file '{probs_cache_file}' does not "
                f"match the manifest file. You may need to delete the probabilities cached file."
            )
    else:
        if use_amp:
            if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
                logging.info("AMP is enabled!\n")
                autocast = torch.cuda.amp.autocast
        else:

            @contextlib.contextmanager
            def autocast():
                yield

        with autocast():
            with torch.no_grad():
                all_logits = asr_model.transcribe(audio_file_paths, batch_size=acoustic_batch_size, logprobs=True)
        all_probs = [softmax(logits) for logits in all_logits]
        if probs_cache_file:
            logging.info(f"Writing pickle files of probabilities at '{probs_cache_file}'...")
            with open(probs_cache_file, 'wb') as f_dump:
                pickle.dump(all_probs, f_dump)
                
    logging.info(f"==============================Starting the beam search decoding===============================")
    # logging.info(f"Grid search size: {len([]])}")
    logging.info(f"It may take some time...")
    logging.info(f"==============================================================================================")
    
    wer_dist_count = 0
    words_count = 0
    sample_idx = 0
    
    if conformer_transcribe_log:
        out_file = open(conformer_transcribe_log, 'w', encoding='UTF8', newline='')
        writer = csv.writer(out_file)
        headers = ["audio_filepath", "pred_text", "reference", "wer"]
        writer.writerow(headers)
    
    it = tqdm(
        range(int(np.ceil(len(all_probs) / acoustic_batch_size))),
        desc=f"Beam search decoding with width={beam_width}, alpha={alpha}, beta={beta}",
        ncols=120,
    )
    for batch_idx in it:
        # disabling type checking
        with nemo.core.typecheck.disable_checks():
            probs_batch = all_probs[batch_idx * acoustic_batch_size : (batch_idx + 1) * acoustic_batch_size]
            beams_batch = beam_search_lm.forward(log_probs=probs_batch, log_probs_length=None,)
        
        for beams_idx, beams in enumerate(beams_batch):
            target = target_transcripts[sample_idx + beams_idx]
            target_split_w = target.split()
            words_count += len(target_split_w)
            
            # For BPE encodings, need to shift by TOKEN_OFFSET to retrieve the original sub-word ids
            pred_text = ids_to_text_func([ord(c) - TOKEN_OFFSET for c in beams[0][1]])
            pred_split_w = pred_text.split()
            wer_dist = editdistance.eval(target_split_w, pred_split_w)
            wer_dist_count += wer_dist
            
            wer = round(float("{:.2}".format(wer_dist / len(target_split_w))), 2)
            print(f"target: {target}. pred_text: {pred_text}. wer: {wer}%")
            
            audio_path = audio_file_paths[sample_idx + beams_idx]
            if round(float(wer), 2) > 0.5: writer.writerow([audio_path, pred_text, target, wer])
        # break
        sample_idx += len(probs_batch)
    
    logging.info(
        'WER with beam search decoding and N-gram model = {:.2}'.format(wer_dist_count / words_count))
    
    if conformer_transcribe_log:
        out_file.close()
eval_comformer()

Reading Manifest /home/khoatlv/Conformer_ASR/scripts/evaluation/all_data_manifest.json ...: 0it [00:00, ?it/s]

[NeMo I 2022-04-04 09:40:15 987078218:13] Found a pickle file of probabilities at '/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_asr_model/probs_cache_file'.
[NeMo I 2022-04-04 09:40:15 987078218:14] Loading the cached pickle file of probabilities from '/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_asr_model/probs_cache_file' ...
[NeMo I 2022-04-04 09:40:21 987078218:45] It may take some time...


Beam search decoding with width=200, alpha=2, beta=2.5:   0%|                                  | 0/3450 [00:00…

target: em liền gọi to. pred_text: em liền gọi to. wer: 0.0%
target: ở cái nơi rừng thiêng nước độc này. pred_text: ở nơi dừng vui huy động này. wer: 0.62%
target: anh nấy cũng khoe rằng ảnh của mình hơn. pred_text: anh ấy cũng khoe rằng ảnh của mình hơn. wer: 0.11%
target: nhuộm ánh nắng tà qua mái tóc. pred_text: ánh nắng tang bên ngoài tóc. wer: 0.57%
target: không không không sao đâu. pred_text: không song không sao đâu. wer: 0.2%
target: nó cứ nhìn chằm chằm vào đó. pred_text: nó cứ nhìn chằm chằm vào đó. wer: 0.0%
target: nên mỗi lần đi đâu về muộn một chút. pred_text: nên mỗi lần đi đâu về muộn một chút. wer: 0.0%
target: những câu chuyện hay những bức ảnh được chúng tôi khoe lại. pred_text: những câu chuyện hay những bức ảnh được chúng tôi khoe lại. wer: 0.0%
target: cũng vừa mới ngồi thôi nước còn chưa uống hết một nửa đây này. pred_text: cũng vừa mới ngồi thôi nước còn chưa uống hết một nửa đây này. wer: 0.0%
target: vậy thì tại sao không giúp mình chiếm được tình cảm của quâ