In [None]:
!pip install accelerate -U
!pip install transformers soundfile datasets jiwer gdown torchmetrics
!mkdir ./dataset

In [None]:
!mkdir ./dataset
import gdown
def drive_download(idx, output):
    url = 'https://drive.google.com/uc?id=' + idx
    gdown.download(url, output, quiet=False)
drive_download("1ZBL3h6bHMmd8MIUNXqg72PucUkC9ZSWJ", "./dataset/train_data.zip")
drive_download("1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64", "./dataset/public_test.zip")
# drive_download("1K_07kix1OgBGO2FNPh-Lxqr1yLbtqFYt", "./dataset/train.jsonl")

In [None]:
!unzip ./dataset/public_test.zip -d ./dataset/test
!unzip ./dataset/train_data.zip -d ./dataset/train

In [None]:
import utils
utils.download_data()

In [None]:
!unzip /kaggle/working/dataset/train_data.zip -d /kaggle/working/dataset/train

In [None]:
import os, glob, re, torch, json, utils, numpy as np, soundfile as sf
from functools import partial
from datasets import load_metric
from torch.utils.data import DataLoader, Dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
class WhisperDataset(Dataset):
    def __init__(self, processor, root_path, files_id, labels=None):
        self.processor = processor
        self.root_path = root_path
        self.files_id = files_id
        self.labels = labels

    def _process_sound_file(self, idx):
        speech, samplerate  = sf.read(os.path.join(self.root_path, self.files_id[idx]))
        chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
        clean_txt = lambda txt: re.sub(chars_to_ignore_regex, '', txt.lower()).lower()
        label = clean_txt(self.labels[idx]["sentence"]) if self.labels is not None else None
        input_feature = self.processor(speech, text=label, sampling_rate=samplerate)
        # input_feature["input_length"] = len(speech) / samplerate
        return input_feature
#         return {"input_features": speech, "sampling_rate": samplerate, "labels": label}

    def __len__(self):
        return len(self.files_id)

    def __getitem__(self, idx):
        data = self._process_sound_file(idx)
        return {"input_features": data.input_features, "labels": data.labels if "labels" in data else None, "file_id": self.files_id[idx]}

In [None]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, processor=None):
        self.processor = processor
    def __call__(self, features):
        input_features = [{"input_features": feature["input_features"][0]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        if features[0]["labels"] is None:
            batch["file_id"] = [i["file_id"] for i in features]
            return batch
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

In [None]:
def train_test_split(processor, root_path, notation_file, test_size=0.3)   :
    notations = utils.load_annotation(notation_file)
    dataset = WhisperDataset(processor, root_path, [i["file"] for i in notations], notations)
    N = len(dataset)
    print(f"Len dataset: {N}")
    train_size = int(N * (1-test_size))
    train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, N-train_size])
    return train_set, valid_set

In [None]:
processor = WhisperProcessor.from_pretrained("geninhu/whisper-medium-vi", language="Vietnamese", task="transcribe")

In [None]:
train_ds, valid_ds = train_test_split(processor, "/kaggle/working/dataset/train/Train/", "/kaggle/working/dataset/train.jsonl", test_size=0.3)
len(train_ds), len(valid_ds)

In [None]:
# processor = WhisperProcessor.from_pretrained("GeoffVdr/whisper-medium-nlcv11", language="Vietnamese", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("GeoffVdr/whisper-medium-nlcv11")

In [None]:
for name, p in model.model.named_parameters():
    p.requires_grad = False
for name, p in model.proj_out.named_parameters():
    p.requires_grad = True

In [None]:
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper_v1.0",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=1000,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    fp16_full_eval=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=256,
    save_steps=300,
    eval_steps=300,
    logging_steps=300,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    save_total_limit=1,
)

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    data_collator=data_collator,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
trainer.train()

In [None]:
processor = WhisperProcessor.from_pretrained("GeoffVdr/whisper-medium-nlcv11", language="Vietnamese", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("./whisper_v1.0/checkpoint-5000")

In [None]:
test_set = WhisperDataset(processor, "./dataset/test/public_test", [i for i in os.listdir("./dataset/test/public_test")])
len(test_set)

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, collate_fn=data_collator)

In [None]:
def whisper_inference(model, test_loader, processor, device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.eval()
    model = model.to(device)
    model.half()
    pred_sentences = {}
    for idx, batch in enumerate(test_loader, 1):
        X_test =  batch["input_features"].to(device)
        file_test = batch["file_id"]
        with torch.set_grad_enabled(False):
            with torch.autocast("cuda", dtype=torch.float16, enabled=True):
                logits = model.generate(inputs=X_test)
        # logits = torch.argmax(logits, dim=-1).cpu().detach().numpy()
        transcriptions = processor.batch_decode(logits, skip_special_tokens=True)
        # transcriptions = [
            # decoder.decode_beams(i)[0][0] for i in logits.cpu().detach().numpy()
        # ]
        for file_id, trans in zip(file_test, transcriptions):
            pred_sentences[file_id] = trans
            # print(trans)
        print("\r", end="")
        print(f"\r {idx} / {len(test_loader)}", end = "" if idx != len(test_loader) else "\n")
    return pred_sentences

In [None]:
pred_sentences = whisper_inference(model, test_loader, processor, torch.device("cuda:0"))

In [None]:
pred_sentences

In [None]:
with open("./whisper_test_sentences.json", "w", encoding="utf-8") as f:
    json.dump(pred_sentences, f, ensure_ascii=False)
    f.close()