# Import all required packages

In [None]:
!pip install librosa
!pip install torch
!pip install transformers
!pip install pandas
!pip install datasets
!pip install pyaspeller

In [2]:
import os
import re
from tqdm import tqdm
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import pandas as pd
from datasets import Dataset
from pyaspeller import YandexSpeller
from tqdm import tqdm
from jiwer import wer, cer

# Processing source files

Get all files, define new sort function to sort as \[1, 2, 3 ... 100\], because built-in function sorts strings as \[1, 100, 101 ...\].

In [4]:
def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]

In [5]:
def prepare_files(directory, file_with_text, inf):
    with open(file_with_text, encoding='utf-16') as f:
        text = f.readlines()
    files = os.listdir(directory)
    files_full = []
    for filename in files:
        if '.DS_Store' not in filename:
            f = os.path.join(directory, filename)
            files_full.append(f)
    files_full.sort(key=natural_keys)
    j = 0
    dict_for_inf = []
    for filename in tqdm(files_full):
        if not '=' in text[j] and not 'нрзб' in text[j] and not '[' in text[j] and not '<' in text[j]:
            x = text[j].replace('\n', '').lower()
            x = x.replace('.', ' ')
            x = x.replace(',', ' ')
            x = x.replace(':', ' ')
            x = x.replace('?', ' ')
            x = x.replace('!', ' ')
            x = x.replace('–', ' ')
            x = x.replace('-', ' ')
            x = x.replace('ё', 'е')
            x = re.sub('(\s){2,}', ' ', x)
            x = re.sub('\(.*\)', '', x)
            dict_for_inf.append({'respondent':inf, 'path': filename, 'sentence': x})
        j += 1  
    return dict_for_inf

In [6]:
lnt = prepare_files('/content/input/new_mono_lnt20210706', 
                    '/content/input/20210706_lnt1950_1to831.txt', 'LNT1950')
mga_1307 = prepare_files('/content/input/new_mono_mga20210713', 
                         '/content/input/20210713mga1932_1to1159.txt', 'MGA1932')
mga_1607 = prepare_files('/content/input/new_mono_mga20210716', 
                         '/content/input/20210716mga1932_1to856.txt', 'MGA1932')

mga_1007 = prepare_files('/content/input/new_mono_mga20220710', 
                         '/content/input/20220710mga1932_1to304.txt', 'MGA1932')
gip_0707 = prepare_files('/content/input/new_mono_gip20210707', 
                         '/content/input/20210707gip1953_1to1607.txt', 'GIP1953')
gip_1507 = prepare_files('/content/input/new_mono_gip20220715', 
                         '/content/input/20220715gip1953_1to332.txt', 'GIP1953')
gip_2704 = prepare_files('/content/input/new_mono_gip20230427', 
                         '/content/input/20230427gip1953_1to873.txt', 'GIP1953')

apb_0707 = prepare_files('/content/input/new_mono_apb20220707', 
                         '/content/input/20220707apb1940_1to674.txt', 'AB1940')
apb_1007 = prepare_files('/content/input/new_mono_apb20220710', 
                         '/content/input/20220710apb1940EZ_1to659.txt', 'AB1940')
apb_2704 = prepare_files('/content/input/new_mono_apb20230427', 
                         '/content/input/20230427apb1940_1to557.txt', 'AB1940')
zns_1007 = prepare_files('/content/input/new_mono_zns20220710', 
                         '/content/input/20220710zns1939_1to677.txt', 'ZNS1939')
zns_1107 = prepare_files('/content/input/new_mono_zns20220711', 
                         '/content/input/20220711zns1939_1to379.txt', 'ZNS1939')

100%|████████████████████████████████████████████████████████████████████████████| 831/831 [00:00<00:00, 207703.15it/s]
100%|██████████████████████████████████████████████████████████████████████████| 1159/1159 [00:00<00:00, 289701.93it/s]
100%|████████████████████████████████████████████████████████████████████████████| 856/856 [00:00<00:00, 285331.34it/s]
100%|████████████████████████████████████████████████████████████████████████████| 304/304 [00:00<00:00, 303804.72it/s]
100%|██████████████████████████████████████████████████████████████████████████| 1607/1607 [00:00<00:00, 229533.34it/s]
100%|████████████████████████████████████████████████████████████████████████████| 332/332 [00:00<00:00, 332103.25it/s]
100%|████████████████████████████████████████████████████████████████████████████| 873/873 [00:00<00:00, 291856.16it/s]
100%|████████████████████████████████████████████████████████████████████████████| 674/674 [00:00<00:00, 337064.61it/s]
100%|███████████████████████████████████

In [7]:
all_records = lnt + mga_1307 + mga_1607 + mga_1007 + gip_0707 + gip_1507 + gip_2704 + apb_0707 + apb_1007 + apb_2704 + zns_1007 + zns_1107
len(all_records)

7922

# Define the model and read audio

In [8]:
LANG_ID = "ru"
MODEL_ID = "bond005/wav2vec2-large-ru-golos-with-lm"

processor = Wav2Vec2Processor.from_pretrained(MODEL_ID, padding=True)
model = Wav2Vec2ForCTC.from_pretrained('/content/wav2vec2-large-ru-golos-with-lm-opochka/checkpoint-2574/', local_files_only=True)

In [10]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch["path"], sr=16000)
    batch["speech"] = speech_array
    batch["sentence"] = batch["sentence"]
    return batch

test_dataset = []
for l in tqdm(all_records):
    test_dataset.append(speech_file_to_array_fn(l))
data = [d['speech'] for d in test_dataset]

100%|██████████████████████████████████████████████████████████████████████████████| 7922/7922 [05:48<00:00, 22.75it/s]


In [14]:
df = pd.DataFrame(test_dataset, columns=['respondent', 'path', 'sentence', 'speech'])
ds = Dataset.from_pandas(df[['sentence', 'speech']])
ds = ds.train_test_split(test_size=0.3, seed=22)

# Prepare dataset to test

In [15]:
def prepare_dataset(batch, processor):
    audio = batch["speech"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [16]:
ds = ds.map(lambda examples: prepare_dataset(examples, processor))

Map:   0%|          | 0/5545 [00:00<?, ? examples/s]



Map:   0%|          | 0/2377 [00:00<?, ? examples/s]

# Test

In [17]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"]).unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)

    return batch

results = ds["test"].map(map_to_result, remove_columns=ds["test"].column_names)

Map:   0%|          | 0/2377 [00:00<?, ? examples/s]

# Evaluate

In [18]:
wers = []
cers = []


for item in results:
    if item['text'] != '' and item['text'] != ' ':
        w = wer(item['text'], item['pred_str'])
        wers.append(w)
        c = cer(item['text'], item['pred_str'])
        cers.append(c)

print('Mean WER: ', sum(wers)/len(wers))
print('Mean CER: ', sum(cers)/len(cers))

Mean WER:  0.5828604111874978
Mean CER:  0.3063438050043265


In [19]:
test_results = results.to_pandas()
path = "/content/wav2vec_opochka_on_shetnevo_without_spellcheck.xlsx"
writer = pd.ExcelWriter(path, engine = 'xlsxwriter')

test_results.to_excel(writer) 

writer.save()
writer.close()

  warn("Calling close() on already closed file.")


# Use a spellchecker for the received transcriptions

In [20]:
speller = YandexSpeller()
transcrtiptions_spelled = []
for t in tqdm(results['pred_str']):
    transcrtiptions_spelled.append(speller.spelled(t))

100%|██████████████████████████████████████████████████████████████████████████████| 2377/2377 [05:58<00:00,  6.64it/s]


In [21]:
wers = []
cers = []

for i, transcrtiption_spelled in enumerate(transcrtiptions_spelled):
    if results['text'][i] != '' and results['text'][i] != ' ':
        w = wer(results['text'][i], transcrtiption_spelled)
        wers.append(w)
        c = cer(results['text'][i], transcrtiption_spelled)
        cers.append(c)
        results['pred_str'][i] = transcrtiption_spelled

print('Mean WER: ', sum(wers)/len(wers))
print('Mean CER: ', sum(cers)/len(cers))

Mean WER:  0.5378986179137468
Mean CER:  0.30670844485045673


In [22]:
test_results = results.to_pandas()
path = "/content/wav2vec_opochka_on_shetnevo_with_spellcheck.xlsx"
writer = pd.ExcelWriter(path, engine = 'xlsxwriter')

test_results.to_excel(writer) 

writer.save()
writer.close()

  warn("Calling close() on already closed file.")
