In [15]:

from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel
import kenlm
import copy
import os
import numpy as np
import torch
from tqdm.auto import tqdm
import json
from joblib.parallel import Parallel
import joblib
import soundfile as sf

import soundfile as sf
from jiwer import wer

In [16]:
manifests_dir = "/home/khoatlv/manifests"
test_manifest_processed_processed = os.path.join(manifests_dir, "test_manifest_processed.json")
train_manifest_processed_processed = os.path.join(manifests_dir, "train_manifest_processed.json")

manifests_processed_all = "/home/khoatlv/Conformer_ASR/scripts/evaluation/manifests/data_collected_eval_manifests.json"

if os.path.exists(manifests_processed_all): os.remove(manifests_processed_all)
os.system("cat {test_manifest_processed_processed} {train_manifest_processed_processed} > {manifests_processed_all}".format(
    test_manifest_processed_processed=test_manifest_processed_processed,
    train_manifest_processed_processed=train_manifest_processed_processed,
    manifests_processed_all=manifests_processed_all
))

os.system("wc -l {manifests_processed_all}".format(manifests_processed_all=manifests_processed_all))

9141 /home/khoatlv/Conformer_ASR/scripts/evaluation/manifests/data_collected_eval_manifests.json


0

In [17]:
def get_decoder_ngram_model(tokenizer, ngram_lm_path):
    vocab_dict = tokenizer.get_vocab()
    sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items())
    vocab = [x[1] for x in sort_vocab][:-2]
    vocab_list = vocab
    vocab_list[tokenizer.pad_token_id] = ""
    vocab_list[tokenizer.unk_token_id] = ""
    vocab_list[tokenizer.word_delimiter_token_id] = " "
    alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=tokenizer.pad_token_id)
    lm_model = kenlm.Model(ngram_lm_path)
    decoder = BeamSearchDecoderCTC(alphabet, language_model=LanguageModel(lm_model))
    return decoder

wav2vec2_processor = "/home/khoatlv/Conformer_ASR/scripts/evaluation/wav2vec_models/preprocessor"
wav2vec2_model = "/home/khoatlv/Conformer_ASR/scripts/evaluation/wav2vec_models/CTCModel"
lm_file = "/home/khoatlv/Conformer_ASR/scripts/evaluation/wav2vec_models/4-gram-lm_large.bin"

processor = Wav2Vec2Processor.from_pretrained(wav2vec2_processor)
model = Wav2Vec2ForCTC.from_pretrained(wav2vec2_model).to(torch.device('cuda'))
ngram_lm_model = get_decoder_ngram_model(processor.tokenizer, lm_file)

In [18]:
def read_manifest(path):
    manifest = []
    with open(path, 'r') as f:
        for line in tqdm(f, desc="Reading manifest data"):
            line = line.replace("\n", "")
            data = json.loads(line)
            manifest.append(data)
    return manifest


def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = sf.read(batch["path"])
    batch["speech"] = speech_array
    batch["sampling_rate"] = sampling_rate
    return batch

def read_wav(data_manifest):
    try:
        data = dict()
        y, rate = sf.read(data_manifest["audio_filepath"])
        data["speed"] = y
        data["text"] = data_manifest["text"]
        data["audio_filepath"] = data_manifest["audio_filepath"]
        return data
    except:
        return None

def read_batch(paths):
    with Parallel(n_jobs=16, verbose=10) as parallel:
        result = parallel(joblib.delayed(read_wav)(manifests_data) for manifests_data in paths)
    
    return result

def transcribe_ASR(raw_signal):
    signal = np.asarray(raw_signal, dtype=np.float32).flatten()
    input_values = processor(
        signal,
        sampling_rate=16000,
        return_tensors="pt"
    ).input_values.to("cuda")
    logits = model(input_values).logits[0]
    # pred_ids = torch.argmax(logits, dim=-1)
    # greedy_search_output = processor.decode(pred_ids)
    beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=200)
    return beam_search_output

def save_log(log_data):
    fieldnames = ['audio_filepath', 'text', 'pred', 'wer']
    datas = [
        [
            data["audio_filepath"], 
            data["text"], 
            data["pred"], 
            data["wer"], 
        ] for data in log_data
    ]

    if os.path.exists(log_data_path):
        print("old")
        with open(log_data_path, 'a', encoding='UTF8', newline='') as f:
            writer = csv.writer(f)
            writer.writerows(datas)
            f.close()
    else:
        print("new")
        with open(log_data_path, 'w', encoding='UTF8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(fieldnames)
            writer.writerows(datas)
            f.close()

In [19]:
manifests_processed_all_data = read_manifest(manifests_processed_all)
print(len(manifests_processed_all_data))
print(manifests_processed_all_data[0])

Reading manifest data: 0it [00:00, ?it/s]

9141
{'audio_filepath': '/home/khoatlv/data/data_collected/viettel/resample/mở-điều-hòa-ga-ra_2_0.wav', 'duration': 1.6138125, 'text': 'mở điều hòa ga ra'}


In [20]:
import csv

dataset_src = "data_collected"
log_data_path = os.path.join(f"/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_data/{dataset_src}", "data_log.csv")
log_error_files_path = os.path.join(f"/home/khoatlv/Conformer_ASR/scripts/evaluation/eval_data/{dataset_src}", "log_error_files.csv")
log_data = []
error_files = []

len_test = len(manifests_processed_all_data)
# len_test = 5
step = 1000
create_file = True
if os.path.exists(log_data_path): os.remove(log_data_path)
if os.path.exists(log_error_files_path): os.remove(log_error_files_path)

for i in range(0, len_test, step):
    print("Process {} files".format(i))
    batch_data = []
    if i + step >= len_test:
        batch_data = read_batch(manifests_processed_all_data[i: len_test])
    else:
        batch_data = read_batch(manifests_processed_all_data[i: i + step])

    for data in batch_data:
        try:
            log = dict()
            log["audio_filepath"] = data["audio_filepath"]
            log["text"] = data["text"]
            log["wer"] = None
            log["pred"] = None

            if data is not None:
                try:
                    pred = transcribe_ASR(data["speed"])
                    wer_score = wer([pred], [data["text"]])
                    log["wer"] = wer_score
                    log["pred"] = pred
                    log_data.append(log)
                except:
                    error_files.append(data["audio_filepath"])
                    print("Error in file {}".format(data))

            if len(log_data) >= 1000:
                save_log(log_data)
                log_data = []
        except Exception as e:
            print(f"{e}: {data}")

if len(log_data) > 0:
    save_log(log_data)

with open(log_error_files_path, 'w', encoding='UTF8', newline='') as f:
    datas = "\n".join(error_files)
    f.writelines(datas)
    f.close()

Process 0 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Done  40 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Batch computation too fast (0.1735s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done  53 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Done  66 tasks      | elapsed:    4.2s
[Parallel(n_jobs=16)]: Done  82 tasks      | elapsed:    4.3s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0327s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done 112 tasks      | elapsed:    4.3s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0312s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 152 tasks      | elapsed:    4.3s
[Parallel(n_jobs=16)]: Done 216 tasks      | elapsed:    4.3s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0

new
Process 1000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0179s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0170s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  30 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0237s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 110 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0471s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 2000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0136s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0327s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0223s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 104 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0353s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 228 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 3000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0196s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0344s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0436s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 104 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0703s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 4000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0152s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0319s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0236s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 106 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0568s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 5000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0114s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0171s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  30 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  49 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0345s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 104 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0498s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 6000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0121s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0191s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0269s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 104 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0595s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 7000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0138s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0171s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0223s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 106 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0535s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 8000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0146s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0356s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  74 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0418s.) Setting batch_size=8.
[Parallel(n_jobs=16)]: Done 104 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 164 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 224 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0536s.) Setting batch_size=16.
[Parallel(n_jobs=16)]: Done 360 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Batch com

old
Process 9000 files


[Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Batch computation too fast (0.0085s.) Setting batch_size=2.
[Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Batch computation too fast (0.0376s.) Setting batch_size=4.
[Parallel(n_jobs=16)]: Done  48 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  73 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done  98 tasks      | elapsed:    0.1s
[Parallel(n_jobs=16)]: Done 125 out of 141 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=16)]: Done 141 out of 141 | elapsed:    0.1s finished


old
