In [2]:
from datasets import load_dataset
from typing import Any, Dict, List, Union
from datasets import load_metric
from dataclasses import dataclass
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer,  WhisperProcessor

In [3]:
from pythainlp.tokenize import word_tokenize

In [4]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gc

In [5]:
import librosa
import torch

In [None]:
test = load_dataset("csv", data_files="test.csv")["train"]
gowajee_test = load_dataset("csv", data_files="gowajee_test.csv")["train"]
health_test = load_dataset("csv", data_files="health_test.csv")["train"]
smart_home_test = load_dataset("csv", data_files="smart_home_test.csv")["train"]

In [7]:
model = WhisperForConditionalGeneration.from_pretrained("./whisper-tiny-thai/checkpoint-20000").to("cuda")

In [8]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

In [9]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="Thai", task="transcribe")

In [10]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="Thai", task="transcribe")

In [11]:
def load_audio(batch):
    audio, sr = librosa.load("/mnt/d/data/cv-corpus-13.0-2023-03-09/th/clips/" + batch["path"], sr=16000)
    audio = torch.from_numpy(audio)
    batch["input_features"] = audio
    batch["labels"] = batch["sentence"]
    return batch

def load_audio_gowajee(batch):
    audio, sr = librosa.load("/mnt/d/data/gowajee_v0.9.2/v0.9.2/" + batch["path"], sr=16000)
    audio = torch.from_numpy(audio)
    batch["input_features"] = audio
    batch["labels"] = batch["sentence"]
    return batch

def load_audio_smart_home(batch):
    audio, sr = librosa.load("/mnt/d/data/Dataset/Smarthome/Record/" + batch["path"], sr=16000)
    audio = torch.from_numpy(audio)
    batch["input_features"] = audio
    batch["labels"] = batch["sentence"]
    return batch

def load_audio_health(batch):
    audio, sr = librosa.load("/mnt/d/data/Dataset/Healthcare/Record/" + batch["path"], sr=16000)
    audio = torch.from_numpy(audio)
    batch["input_features"] = audio
    batch["labels"] = batch["sentence"]
    return batch

In [12]:
test = test.map(load_audio, remove_columns=["path", "sentence"], num_proc=8)
gowajee_test = gowajee_test.map(load_audio_gowajee, remove_columns=["path", "sentence"], num_proc=8)
health_test = health_test.map(load_audio_health, remove_columns=["path", "sentence"], num_proc=8)
smart_home_test = smart_home_test.map(load_audio_smart_home, remove_columns=["path", "sentence"], num_proc=8)

Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-bf161c8f204ed3d4/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-fc77c088ad08ab1d_*_of_00008.arrow
Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-a21d5816bce7a1de/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-8987b10fb65b3ae4_*_of_00008.arrow
Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-25577eb20c5736fb/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-819123d436bb3ea4_*_of_00008.arrow
Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-f5a8e28433f5a511/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-760e3589aa3fd906_*_of_00008.arrow


In [13]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        audios = [self.processor.feature_extractor(feature["input_features"], sampling_rate=16000).input_features[0] for feature in features]
        sentences = [feature["labels"] for feature in features]

        input_features = [{"input_features": audio} for audio in audios]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        batch["labels"] = sentences

        return batch

In [14]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [15]:
cer_metric = load_metric("cer")
wer_metric = load_metric("wer")

  cer_metric = load_metric("cer")


In [16]:
eval_dataloader = DataLoader(test, batch_size=32, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.no_grad():
        generated_tokens = (
            model.generate(
                input_features=batch["input_features"].to("cuda"),
                max_new_tokens=255,
                language="Thai"
            )
            .cpu()
            .numpy()
        )
        labels = batch["labels"]
        transcriptions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        cer_metric.add_batch(predictions=[pred_str.replace(" ", "") for pred_str in transcriptions], references=[label_str.replace(" ", "") for label_str in labels])

        pred_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in transcriptions]
        label_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in labels]
        wer_metric.add_batch(predictions=pred_str_newmm, references=label_str_newmm)
    del generated_tokens, labels, batch
    gc.collect()
wer = 100 * wer_metric.compute()
cer = 100 * cer_metric.compute()
print(f"wer: {wer}")
print(f"cer: {cer}")

  0%|                                                                                            | 0/90 [00:00<?, ?it/s]2023-05-27 13:28:04.178649: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-27 13:28:05.270520: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
100%|███████████████████████████████████████████████████████████████████████████████████| 90/90 [06:11<00:00,  4.13s/it]

wer: 23.144494550920708
cer: 6.740680318230452





In [17]:
eval_dataloader = DataLoader(gowajee_test, batch_size=32, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.no_grad():
        generated_tokens = (
            model.generate(
                input_features=batch["input_features"].to("cuda"),
                max_new_tokens=255,
                language="Thai"
            )
            .cpu()
            .numpy()
        )
        labels = batch["labels"]
        transcriptions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        cer_metric.add_batch(predictions=[pred_str.replace(" ", "") for pred_str in transcriptions], references=[label_str.replace(" ", "") for label_str in labels])

        pred_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in transcriptions]
        label_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in labels]
        wer_metric.add_batch(predictions=pred_str_newmm, references=label_str_newmm)
    del generated_tokens, labels, batch
    gc.collect()
wer = 100 * wer_metric.compute()
cer = 100 * cer_metric.compute()
print(f"wer: {wer}")
print(f"cer: {cer}")

100%|███████████████████████████████████████████████████████████████████████████████████| 32/32 [01:25<00:00,  2.67s/it]

wer: 24.792643346556076
cer: 11.394521138912856





In [18]:
eval_dataloader = DataLoader(health_test, batch_size=32, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.no_grad():
        generated_tokens = (
            model.generate(
                input_features=batch["input_features"].to("cuda"),
                max_new_tokens=255,
                language="Thai"
            )
            .cpu()
            .numpy()
        )
        labels = batch["labels"]
        transcriptions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        cer_metric.add_batch(predictions=[pred_str.replace(" ", "") for pred_str in transcriptions], references=[label_str.replace(" ", "") for label_str in labels])

        pred_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in transcriptions]
        label_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in labels]
        wer_metric.add_batch(predictions=pred_str_newmm, references=label_str_newmm)
    del generated_tokens, labels, batch
    gc.collect()
wer = 100 * wer_metric.compute()
cer = 100 * cer_metric.compute()
print(f"wer: {wer}")
print(f"cer: {cer}")

100%|███████████████████████████████████████████████████████████████████████████████████| 30/30 [01:36<00:00,  3.21s/it]

wer: 13.28364752301622
cer: 4.143479718404291





In [19]:
eval_dataloader = DataLoader(smart_home_test, batch_size=32, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.no_grad():
        generated_tokens = (
            model.generate(
                input_features=batch["input_features"].to("cuda"),
                max_new_tokens=255,
                language="Thai"
            )
            .cpu()
            .numpy()
        )
        labels = batch["labels"]
        transcriptions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        cer_metric.add_batch(predictions=[pred_str.replace(" ", "") for pred_str in transcriptions], references=[label_str.replace(" ", "") for label_str in labels])

        pred_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in transcriptions]
        label_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in labels]
        wer_metric.add_batch(predictions=pred_str_newmm, references=label_str_newmm)
    del generated_tokens, labels, batch
    gc.collect()
wer = 100 * wer_metric.compute()
cer = 100 * cer_metric.compute()
print(f"wer: {wer}")
print(f"cer: {cer}")

100%|███████████████████████████████████████████████████████████████████████████████████| 30/30 [01:31<00:00,  3.06s/it]

wer: 12.992943129929433
cer: 3.4138499936545537





In [None]:
model.push_to_hub(repo_id="juierror/whisper-tiny-thai")

In [None]:
tokenizer.push_to_hub(repo_id="juierror/whisper-tiny-thai")

In [None]:
feature_extractor.push_to_hub(repo_id="juierror/whisper-tiny-thai")

In [None]:
processor.push_to_hub(repo_id="juierror/whisper-tiny-thai")