# Usage

In [34]:
import os
from tqdm import tqdm
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import pandas as pd
from datasets import load_dataset, load_metric

In [2]:
PATH = '../'
processor = Wav2Vec2Processor.from_pretrained("anton-l/wav2vec2-large-xlsr-53-tatar")
model = Wav2Vec2ForCTC.from_pretrained("anton-l/wav2vec2-large-xlsr-53-tatar").to("cuda")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["audio"]['path'])
    batch["speech"] = speech_array.squeeze().numpy()
    return batch

In [5]:
dataset_1 = load_dataset("audiofolder", data_dir=PATH + 'tatar_asr_1')
# dataset_2 = load_dataset("audiofolder", data_dir=PATH + 'tatar_asr_2')

Resolving data files:   0%|          | 0/146372 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28508 [00:00<?, ?it/s]

In [6]:
dataset_1

DatasetDict({
    train: Dataset({
        features: ['audio', 'Unnamed: 0', 'transcription'],
        num_rows: 73185
    })
    test: Dataset({
        features: ['audio', 'Unnamed: 0', 'transcription'],
        num_rows: 14253
    })
})

In [7]:
dataset_1['test']['audio'][0]

{'path': '/home/asr/projects/speach/tatar_asr_1/test/20.1.wav',
 'array': array([ 2.13623047e-04,  6.10351562e-05, -3.96728516e-04, ...,
         3.35693359e-04,  3.96728516e-04,  2.13623047e-04]),
 'sampling_rate': 16000}

In [8]:
dataset_1 = dataset_1.map(speech_file_to_array_fn)

In [9]:
inputs = processor(dataset_1['test']["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)

In [10]:
with torch.no_grad():
    logits = model(inputs.input_values.to('cuda'), attention_mask=inputs.attention_mask.to('cuda')).logits

In [11]:
predicted_ids = torch.argmax(logits, dim=-1)

In [15]:
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", dataset_1['test']["transcription"][:2])

Prediction: ['шедөният зехнән килеханын сорады', 'сүзләр вәкдәләшеләр төммуена барды']
Reference: ['фидания цехның телефонын сорады', 'сүзләр вәгъдәләшүләр төн буена барды']


# Evaluation

In [21]:
! pip install jiwer -qq

In [24]:
wer = load_metric("wer")
model.to("cuda")
None

In [25]:
def clean_sentence(sent):
    sent = sent.lower()
    # 'ё' is equivalent to 'е'
    sent = sent.replace('ё', 'е')
    # replace non-alpha characters with space
    sent = "".join(ch if ch.isalpha() else " " for ch in sent)
    # remove repeated spaces
    sent = " ".join(sent.split())
    return sent

In [26]:
targets = []
preds = []

In [35]:
for row in tqdm(dataset_1['test'], total=len(dataset_1['test'])):
    row["transcription"] = clean_sentence(row["transcription"])

    inputs = processor(row["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = model(inputs.input_values.to('cuda'), attention_mask=inputs.attention_mask.to('cuda')).logits

    pred_ids = torch.argmax(logits, dim=-1)

    targets.append(row["transcription"])
    preds.append(processor.batch_decode(pred_ids)[0])

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14253/14253 [11:16<00:00, 21.06it/s]


In [36]:
print("WER: {:2f}".format(100 * wer.compute(predictions=preds, references=targets)))

WER: 47.310808
