In [None]:
from datasets import load_dataset, DatasetDict

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration

from transformers import Seq2SeqTrainingArguments
from IPython.display import Audio

from dataclasses import dataclass
from typing import Any, Dict, List, Union

In [2]:
import librosa
import torch
import augment
import numpy as np
import evaluate

In [3]:
from datasets import load_metric

In [None]:
train = load_dataset("csv", data_files="train_merge.csv")["train"]
dev = load_dataset("csv", data_files="dev_merge.csv")["train"]

In [5]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

In [6]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="Thai", task="transcribe")

In [7]:
input_str = train[0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

Input:                 ใครเป็นผู้รับ
Decoded w/ special:    <|startoftranscript|><|th|><|transcribe|><|notimestamps|>ใครเป็นผู้รับ<|endoftext|>
Decoded w/out special: ใครเป็นผู้รับ
Are equal:             True


In [8]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="Thai", task="transcribe")

In [9]:
def load_audio_batch(batch):
    audios = []
    for path in batch["path"]:
        audio, sr = librosa.load(path, sr=16000)
        audios.append(audio)
    batch["input_features"] = audios
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [10]:
train = train.map(load_audio_batch, remove_columns=["path", "sentence"], batched=True, batch_size=8, num_proc=4)
dev = dev.map(load_audio_batch, remove_columns=["path", "sentence"], batched=True, batch_size=8, num_proc=4)

Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-2f1816aee5d0c184/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-1145b0691893d1e3_*_of_00004.arrow
Loading cached processed dataset at /home/jui/.cache/huggingface/datasets/csv/default-149a15838ff7908b/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-e0b7068f2cec181a_*_of_00004.arrow


In [11]:
import random

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        audios = [feature["input_features"] for feature in features]
        sentences = [feature["labels"] for feature in features]
        
        for i in range(len(audios)):
            audio = audios[i]
            audio = self.processor.feature_extractor(audio, sampling_rate=16000).input_features[0]
            audios[i] = audio

        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": audio} for audio in audios]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [12]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [15]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").cuda()

In [16]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [17]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-tiny-thai",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    warmup_steps=500,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=32,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=2500,
    eval_steps=2500,
    logging_steps=2500,
    num_train_epochs=5,
    report_to=["tensorboard"],
    greater_is_better=False,
    push_to_hub=False,
    save_total_limit=3,
)

In [18]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train,
    eval_dataset=dev,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()