In [5]:
import torch, torchaudio, evaluate
from transformers import TrainingArguments, Trainer, Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_from_disk
import pandas as pd
import numpy as np

processor = Wav2Vec2Processor.from_pretrained("wav2vec2-xlsr53-TH-cmv-processor")
model = Wav2Vec2ForCTC.from_pretrained("/project/lt200007-tspai2/thepeach/wav2vec2-xlsr53-TH-cmv-ckp3/checkpoint-48000").to("cuda")
wer_metric = evaluate.load("metric/wer.py")
cer_metric = evaluate.load("metric/cer.py")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def evaluate(batch):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = batch['input_values']
    inputs = inputs.to(device)

    with torch.no_grad():
        logits = model(inputs)

    pred_ids = torch.argmax(logits.logits, dim=-1)
    batch["pred_sentence"] = processor.batch_decode(pred_ids)
    return batch

In [7]:
model.eval()

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elemen

In [8]:
dataset = load_from_disk('dataset_wav2vec2/dataset').with_format("torch")

In [9]:
from tqdm import tqdm

In [10]:
sentences = []
pred_sentences= []

for i in tqdm(range(len(dataset['test'])), desc="Processing test data"):
    data = dataset['test'][i]['input_values']
    with torch.no_grad():
        logits = model(data.reshape(1, -1).to("cuda")).logits
    predicted_ids = torch.argmax(logits, dim=-1)

    sentences.append(" ".join(processor.batch_decode(dataset['test'][i]['labels'])).replace("  ", "x").replace(" ", "").replace("x", " "))
    pred_sentences.append(processor.batch_decode(predicted_ids.cpu().numpy().tolist())[0])

Processing test data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21030/21030 [07:40<00:00, 45.64it/s]


In [11]:
result_df = pd.DataFrame({
    'sentence': sentences,
    'pred_sentence': pred_sentences,
})

In [12]:
wer_metric.compute(predictions=result_df.pred_sentence,references=result_df.sentence)

0.11236078815322482

In [13]:
result_df['sentence_join']= result_df['sentence'].apply(lambda x : x.replace(' ',''))
result_df['pred_sentence_join']= result_df['pred_sentence'].apply(lambda x : x.replace(' ',''))

In [14]:
cer_metric.compute(predictions=result_df.pred_sentence_join,references=result_df.sentence_join)

0.034740422415786634