In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Fri Jul 26 16:21:46 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 53%   43C    P8             27W /  350W |     173MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
dataset_name = "JRHuy/vivos-fleurs"

In [5]:
from datasets import load_dataset, DatasetDict

vi_asr = DatasetDict()

vi_asr["train"] = load_dataset(dataset_name, split="train+validation")
vi_asr["test"] = load_dataset(dataset_name, split="test")

print(vi_asr)

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 15015
    })
    test: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 1617
    })
})


In [6]:
vi_asr["train"][0]

{'audio': {'path': 'VIVOSSPK27_066.wav',
  'array': array([0.        , 0.        , 0.        , ..., 0.01083374, 0.0128479 ,
         0.01464844]),
  'sampling_rate': 16000},
 'transcription': 'TÌNH YÊU THƯƠNG THẬT SỰ SỰ KIÊN TRÌ VÀ LÍ TƯỞNG TỐT ĐẸP NHẤT ĐỊNH SẼ CHIẾN THẮNG TẤT CẢ TRONG ĐÓ CÓ CẢ ĐÓI NGHÈO VÀ LẠC HẬU'}

In [7]:
sample = vi_asr["train"][0]

In [8]:
import IPython.display as ipd

ipd.Audio(sample['audio']['array'], rate=sample['audio']['sampling_rate'])

In [9]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small", language="vi", task="transcribe"
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=example["transcription"],
    )

    # compute input length of audio sample in seconds
    example["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    return example

In [11]:
vi_asr = vi_asr.map(
    prepare_dataset, remove_columns=vi_asr.column_names["train"], num_proc=1
)

In [12]:
vi_asr

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels', 'input_length'],
        num_rows: 15015
    })
    test: Dataset({
        features: ['input_features', 'labels', 'input_length'],
        num_rows: 1617
    })
})

In [13]:
max_input_length = 30.0


def is_audio_in_length_range(length):
    return length < max_input_length

In [14]:
vi_asr["train"] = vi_asr["train"].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

In [15]:
vi_asr["train"]

Dataset({
    features: ['input_features', 'labels', 'input_length'],
    num_rows: 15009
})

In [16]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"][0]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [17]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [18]:
import evaluate

metric = evaluate.load("wer")

In [19]:
from transformers import WhisperForConditionalGeneration, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True, # or using load_in_4bit 
    llm_int8_threshold=6.0, # for real-time apps with limited hardware resources, a lower threshold might be more beneficial but it may be potential accuracy loss
    # Skip modules should not undergo 8-bit quantization
    llm_int8_skip_modules=None, # for example, ["LayerNorm"]
)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small", device_map="auto", quantization_config=bnb_config)



In [20]:
from functools import partial

# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# set language and task for generation and re-enable cache
model.generate = partial(
    model.generate, language="vi", task="transcribe", use_cache=True
)

In [21]:
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model,use_gradient_checkpointing=True)

In [22]:
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

<torch.utils.hooks.RemovableHandle at 0x7f9742eb0d10>

In [23]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.1, bias="none", peft_type="SEQ_2_SEQ_LM")

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 884,736 || all params: 242,619,648 || trainable%: 0.3647


In [24]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="stevehoang9/whisper-small-vi",  # name on the HF Hub
    overwrite_output_dir=True,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    lr_scheduler_type="constant_with_warmup",
#     lr_scheduler_type="linear", # if max_steps over 4000
    warmup_steps=50,
    max_steps=700,  # increase to 4000 if you have your own GPU or a Colab paid plan
    gradient_checkpointing=True,
    fp16=True,
    fp16_full_eval=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    # optim="adamw_torch",
#     save_steps=500,
#     eval_steps=500,
    save_steps=100,
    eval_steps=100,
    logging_steps=50,
    # weigth_decay=0.01,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
#     metric_for_best_model="wer",
    greater_is_better=False,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    push_to_hub=True,
)

In [25]:
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

# This callback helps to save only the adapter weights and remove the base model weights.
class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vi_asr["train"],
    eval_dataset=vi_asr["test"],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [26]:
import warnings

warnings.filterwarnings('ignore', category=FutureWarning, message='`resume_download` is deprecated')

In [27]:
import torch.utils.checkpoint as checkpoint 

checkpoint.use_reentrant = False 

In [28]:
trainer.train()



Step,Training Loss,Validation Loss
100,1.4586,0.873886
200,0.5456,0.634125
300,0.4047,0.563164
400,0.2732,0.41197
500,0.2448,0.402061
600,0.2123,0.397244
700,0.2137,0.389582


Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-100 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-200 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-300 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-400 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory stevehoang9/whisper-small-vi/checkpoint-600 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination

TrainOutput(global_step=700, training_loss=0.5242434910365513, metrics={'train_runtime': 3548.4088, 'train_samples_per_second': 6.313, 'train_steps_per_second': 0.197, 'total_flos': 6.48386536955904e+18, 'train_loss': 0.5242434910365513, 'epoch': 1.49})

In [29]:
peft_model_id = "stevehoang9/whisper-small-vi-300steps"
model.push_to_hub(peft_model_id)

adapter_model.safetensors:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/stevehoang9/whisper-small-vi-300steps/commit/5682f061928d15c75a4a3791b00ce9b7f0f10ea2', commit_message='Upload model', commit_description='', oid='5682f061928d15c75a4a3791b00ce9b7f0f10ea2', pr_url=None, pr_revision=None, pr_num=None)

In [30]:
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainer

peft_model_id = "stevehoang9/whisper-small-vi-300steps" # Use the same model ID as before.
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
model.config.use_cache = True

adapter_model.safetensors:   0%|          | 0.00/7.10M [00:00<?, ?B/s]

In [31]:
import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer


model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
eval_dataloader = DataLoader(vi_asr["test"], batch_size=8, collate_fn=data_collator)
# forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
normalizer = BasicTextNormalizer()

predictions = []
references = []
normalized_predictions = []
normalized_references = []

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    # forced_decoder_ids=forced_decoder_ids,
                    max_new_tokens=255,
                )
                .cpu()
                .numpy()
            )
            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
            decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
            predictions.extend(decoded_preds)
            references.extend(decoded_labels)
            normalized_predictions.extend([normalizer(pred).strip() for pred in decoded_preds])
            normalized_references.extend([normalizer(label).strip() for label in decoded_labels])
        del generated_tokens, labels, batch
    gc.collect()
wer = 100 * metric.compute(predictions=predictions, references=references)
normalized_wer = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
eval_metrics = {"eval/wer": wer, "eval/normalized_wer": normalized_wer}

print(f"{wer=} and {normalized_wer=}")
print(eval_metrics)

  0%|          | 0/203 [00:00<?, ?it/s]

100%|██████████| 203/203 [16:04<00:00,  4.75s/it]

wer=44.62192872865727 and normalized_wer=41.01120265544425
{'eval/wer': 44.62192872865727, 'eval/normalized_wer': 41.01120265544425}



