# Import all required packages

In [None]:
!pip install librosa
!pip install pandas
!pip install transformers
!pip install jiwer
!pip install scikit-learn
!pip install torch
!pip install datasets
!pip install dataclasses
!pip install typing
!pip install numpy

In [2]:
import os
import re
from tqdm import tqdm
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Trainer, TrainingArguments
import pandas as pd
from jiwer import wer, cer
from sklearn.model_selection import train_test_split
from datasets import Dataset, ClassLabel, load_metric
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import random
import json
import numpy as np
import warnings
from pyaspeller import YandexSpeller
from tqdm import tqdm

# Processing source files

Get all files, define new sort function to sort as \[1, 2, 3 ... 100\], because built-in function sorts strings as \[1, 100, 101 ...\].

In [6]:
def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]

In [7]:
def prepare_files(directory, file_with_text, inf):
    with open(file_with_text, encoding='utf-16') as f:
        text = f.readlines()
    files = os.listdir(directory)
    files_full = []
    for filename in files:
        if '.DS_Store' not in filename:
            f = os.path.join(directory, filename)
            files_full.append(f)
    files_full.sort(key=natural_keys)
    j = 0
    dict_for_inf = []
    for filename in tqdm(files_full):
        if not '=' in text[j] and not 'Ð½Ñ€Ð·Ð±' in text[j] and not '[' in text[j] and not '<' in text[j]:
            x = text[j].replace('\n', '').lower()
            x = x.replace('.', ' ')
            x = x.replace(',', ' ')
            x = x.replace(':', ' ')
            x = x.replace('?', ' ')
            x = x.replace('!', ' ')
            x = x.replace('â€“', ' ')
            x = x.replace('-', ' ')
            x = x.replace('Ñ‘', 'Ðµ')
            x = re.sub('(\s){2,}', ' ', x)
            x = re.sub('\(.*\)', '', x)
            x = x.rstrip()
            dict_for_inf.append({'respondent':inf, 'path': filename, 'sentence': x})
        j += 1  
    return dict_for_inf

In [8]:
enm = prepare_files('/content/input_opochka/new_mono_enm20180618', 
                    '/content/input_opochka/20180618_enm1930_1to487.txt', 'ENM1930')
ive = prepare_files('/content/input_opochka/new_mono_ive20190702', 
                    '/content/input_opochka/20190702_ive1949_1to234.txt', 'IVE1949')
onv = prepare_files('/content/input_opochka/new_mono_onv20180622', 
                    '/content/input_opochka/20180622_onv1972_1to529.txt', 'ONV1972')
saf = prepare_files('/content/input_opochka/new_mono_saf20190701', 
                    '/content/input_opochka/20190701_saf1973_1to434.txt', 'IVE1949')
tai = prepare_files('/content/input_opochka/new_mono_tai20190706', 
                    '/content/input_opochka/20190706_tai1955_1to167.txt', 'TAI1955')
tve = prepare_files('/content/input_opochka/new_mono_tve20190702', 
                    '/content/input_opochka/20190702_tve1955_1to709.txt', 'TVE1955')
vav = prepare_files('/content/input_opochka/new_mono_vav20180619', 
                    '/content/input_opochka/20180619_vav1949_1to277.txt', 'VAV1949')

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 487/487 [00:00<00:00, 243488.62it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 234/234 [00:00<00:00, 233905.42it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 529/529 [00:00<00:00, 264456.12it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–

In [9]:
all_data = enm + ive + onv + saf + tai + tve + vav
len(all_data)

2256

# Read audio

In [11]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch["path"], sr=16000)
    batch["speech"] = speech_array
    batch["sentence"] = batch["sentence"]
    return batch

test_dataset = []
for l in tqdm(all_data):
    test_dataset.append(speech_file_to_array_fn(l))
data = [d['speech'] for d in test_dataset]

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2256/2256 [02:11<00:00, 17.18it/s]


# Fine-tune

In [12]:
df = pd.DataFrame(test_dataset, columns=['respondent', 'path', 'sentence', 'speech'])
ds = Dataset.from_pandas(df[['sentence', 'speech']])
ds = ds.train_test_split(test_size=0.3, seed=22)

In [13]:
def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

vocabs = ds.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=ds.column_names["train"])


Map:   0%|          | 0/1579 [00:00<?, ? examples/s]

Map:   0%|          | 0/677 [00:00<?, ? examples/s]

In [14]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

In [15]:
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [16]:
LANG_ID = "ru"
MODEL_ID = "bond005/wav2vec2-large-ru-golos-with-lm"

processor = Wav2Vec2Processor.from_pretrained(MODEL_ID, padding=True)
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [17]:
def prepare_dataset(batch, processor):
    audio = batch["speech"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [18]:
ds = ds.map(lambda examples: prepare_dataset(examples, processor))

Map:   0%|          | 0/1579 [00:00<?, ? examples/s]



Map:   0%|          | 0/677 [00:00<?, ? examples/s]

In [19]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",).to('cpu')
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",).to('cpu')

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100).to('cpu')
        batch["labels"] = labels.to('cpu')
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [21]:
wer_metric = load_metric("wer")
cer_metric = load_metric("cer", revision="master")

  wer_metric = load_metric("wer")


In [22]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    
    pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
    label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer, 'cer': cer}

In [23]:
model = Wav2Vec2ForCTC.from_pretrained(
    "bond005/wav2vec2-large-ru-golos-with-lm", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
)
model.freeze_feature_extractor()



In [24]:
training_args = TrainingArguments(
  output_dir='./wav2vec2-large-ru-golos-with-lm-opochka',
  per_device_train_batch_size=8,
  learning_rate=1e-4,
  evaluation_strategy="epoch",
  num_train_epochs=15,
  report_to="none",
  save_strategy="epoch"
)


In [25]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=processor.feature_extractor,
)

In [26]:
device = torch.device("cpu")

In [27]:
warnings.filterwarnings('ignore')

In [28]:
trainer.train()

Epoch,Training Loss,Validation Loss,Wer,Cer
1,No log,3.443566,0.554121,0.260587
2,No log,2.855818,0.528615,0.249001
3,1.278200,3.339507,0.522398,0.24516
4,1.278200,3.669631,0.515383,0.238513
5,1.278200,3.860431,0.51076,0.236776
6,0.951400,3.772517,0.509645,0.234245
7,0.951400,4.202432,0.508529,0.235891
8,0.774500,3.944754,0.508529,0.233696
9,0.774500,4.184476,0.509166,0.233971
10,0.774500,4.385952,0.500717,0.229245


TrainOutput(global_step=2970, training_loss=0.8033801557239058, metrics={'train_runtime': 124430.5653, 'train_samples_per_second': 0.19, 'train_steps_per_second': 0.024, 'total_flos': 6.411286297934812e+18, 'train_loss': 0.8033801557239058, 'epoch': 15.0})

# Testing

In [39]:
MODEL_ID = "bond005/wav2vec2-large-ru-golos-with-lm"
model = Wav2Vec2ForCTC.from_pretrained('/content/wav2vec2-large-ru-golos-with-lm-opochka/checkpoint-2574/', local_files_only=True)
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID, padding=True)

In [40]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"]).unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)

    return batch

results = ds["test"].map(map_to_result, remove_columns=ds["test"].column_names)

Map:   0%|          | 0/677 [00:00<?, ? examples/s]

In [41]:
wers = []
cers = []


for item in results:
    if item['text'] != '' and item['text'] != ' ':
        w = wer(item['text'], item['pred_str'])
        wers.append(w)
        c = cer(item['text'], item['pred_str'])
        cers.append(c)

print('Mean WER: ', sum(wers)/len(wers))
print('Mean CER: ', sum(cers)/len(cers))

Mean WER:  0.5665289617662056
Mean CER:  0.3175943809726962


In [42]:
test_results = results.to_pandas()
path = "/content/wav2vec_opochka_without_spellcheck.xlsx"
writer = pd.ExcelWriter(path, engine = 'xlsxwriter')

test_results.to_excel(writer) 

writer.save()
writer.close()

# Use a spellchecker for the received transcriptions

In [43]:
speller = YandexSpeller()
transcrtiptions_spelled = []
for t in tqdm(results['pred_str']):
    transcrtiptions_spelled.append(speller.spelled(t))

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 677/677 [01:48<00:00,  6.23it/s]


In [44]:
wers = []
cers = []

for i, transcrtiption_spelled in enumerate(transcrtiptions_spelled):
    if results['text'][i] != '' and results['text'][i] != ' ':
        w = wer(results['text'][i], transcrtiption_spelled)
        wers.append(w)
        c = cer(results['text'][i], transcrtiption_spelled)
        cers.append(c)
        results['pred_str'][i] = transcrtiption_spelled

print('Mean WER: ', sum(wers)/len(wers))
print('Mean CER: ', sum(cers)/len(cers))

Mean WER:  0.526308835090948
Mean CER:  0.3189697004219541


In [35]:
test_results = results.to_pandas()
path = "/content/wav2vec_opochka_with_spellcheck.xlsx"
writer = pd.ExcelWriter(path, engine = 'xlsxwriter')

test_results.to_excel(writer) 

writer.save()
writer.close()