In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Sun Aug  4 18:11:58 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   47C    P5             91W /  350W |     324MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
train_path = '/media/sanslab/Data/stevehoang/funnyproject/speech-to-command-with-whisper/data/vivos/train'
test_path = '/media/sanslab/Data/stevehoang/funnyproject/speech-to-command-with-whisper/data/vivos/test'

In [5]:
from data_preparation.load_dataset import load_dataset

vivos = load_dataset(train_path=train_path, test_path=test_path)

In [6]:
vivos

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription', 'gender'],
        num_rows: 11660
    })
    test: Dataset({
        features: ['audio', 'transcription', 'gender'],
        num_rows: 760
    })
})

In [7]:
vivos = vivos.remove_columns(["gender"])

In [8]:
# from datasets import load_dataset, DatasetDict

# vi_asr = DatasetDict()

# vi_asr["train"] = load_dataset(dataset_name, split="train+validation")
# vi_asr["test"] = load_dataset(dataset_name, split="test")

# print(vi_asr)

In [9]:
from datasets import Audio 

vivos = vivos.cast_column("audio", Audio(sampling_rate=16000))

In [10]:
vivos["train"].features

{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'transcription': Value(dtype='string', id=None)}

In [11]:
vivos["train"][0]

{'audio': {'path': '/media/sanslab/Data/stevehoang/funnyproject/speech-to-command-with-whisper/data/vivos/train/waves/VIVOSSPK01/VIVOSSPK01_T018.wav',
  'array': array([ 0.00000000e+00,  0.00000000e+00, -3.05175781e-05, ...,
          3.05175781e-05, -1.83105469e-04, -3.35693359e-04]),
  'sampling_rate': 16000},
 'transcription': 'NHỮNG SAI LẦM KHI ĐI CHỌN KHÁCH SẠN'}

In [12]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small", language="vi", task="transcribe"
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [13]:
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=example["transcription"],
    )

    # compute input length of audio sample in seconds
    example["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    return example

In [14]:
vivos = vivos.map(
    prepare_dataset, remove_columns=vivos.column_names["train"], num_proc=1
)

Map:   0%|          | 0/11660 [00:00<?, ? examples/s]

Map:   0%|          | 0/760 [00:00<?, ? examples/s]

In [15]:
vivos 

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels', 'input_length'],
        num_rows: 11660
    })
    test: Dataset({
        features: ['input_features', 'labels', 'input_length'],
        num_rows: 760
    })
})

In [16]:
max_input_length = 30.0


def is_audio_in_length_range(length):
    return length < max_input_length

In [17]:
vivos["train"] = vivos["train"].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

Filter:   0%|          | 0/11660 [00:00<?, ? examples/s]

In [18]:
vivos["train"]

Dataset({
    features: ['input_features', 'labels', 'input_length'],
    num_rows: 11660
})

In [19]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"][0]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [20]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [21]:
import evaluate

metric = evaluate.load("wer")

In [22]:
from transformers import WhisperForConditionalGeneration, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True, # or using load_in_4bit 
    llm_int8_threshold=3.0, # for real-time apps with limited hardware resources, a lower threshold might be more beneficial but it may be potential accuracy loss
    # Skip modules should not undergo 8-bit quantization
    llm_int8_skip_modules=None, # for example, ["LayerNorm"]
)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small", device_map="auto", quantization_config=bnb_config)



In [23]:
from functools import partial

# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# set language and task for generation and re-enable cache
model.generate = partial(
    model.generate, language="vi", task="transcribe", use_cache=True
)

In [24]:
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model,use_gradient_checkpointing=True)

In [25]:
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

<torch.utils.hooks.RemovableHandle at 0x7f0edf856610>

In [26]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(r=32,
                    lora_alpha=8, 
                    target_modules=["q_proj", "v_proj"], 
                    lora_dropout=0.05, 
                    bias="none", 
                    task_type="CAUSAL_LM",
                    target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"],
                    modules_to_save=["embed_tokens", "input_layernorm", "post_attention_layernorm", "norm"],
                    )
model = get_peft_model(model, config)
model.print_trainable_parameters()

SyntaxError: keyword argument repeated: target_modules (1258718003.py, line 9)

: 

In [27]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="stevehoang9/whisper-small-vi",  # name on the HF Hub
    overwrite_output_dir=True,
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    lr_scheduler_type="constant_with_warmup",
    # lr_scheduler_type="linear", # if max_steps over 4000
    # warmup_steps=250,
    warmup_steps=50,
    max_steps=500,  # increase to 4000 if you have your own GPU or a Colab paid plan
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs = {"use_reentrant": True}, #must be false for DDP
    fp16=True,
    fp16_full_eval=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=32,
    predict_with_generate=True,
    generation_max_length=225,
    # optim="adamw_torch",
#     save_steps=500,
#     eval_steps=500,
    save_steps=100,
    eval_steps=100,
    # logging_steps=100,
    logging_steps=25,
    # weight_decay=0.01,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
#     metric_for_best_model="wer",
    greater_is_better=False,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    push_to_hub=True,
)

In [28]:
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

# This callback helps to save only the adapter weights and remove the base model weights.
class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vivos["train"],
    eval_dataset=vivos["test"],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [29]:
trainer.train()

  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# torch.save(trainer.state_dict(), 'model.pth')

In [None]:
# model = trainer.load_state_dict(torch.load('model.pth'))

In [None]:
# filepath = 'model.pth'
# input_sample = torch.randn(1, 16000)
# model.to_onnx(filepath, input_sample, export_params=True)

In [29]:
peft_model_id = "stevehoang9/whisper-small-vi-500steps"
model.push_to_hub(peft_model_id)

adapter_model.safetensors:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/stevehoang9/whisper-small-vi-500steps/commit/48e42c171cec6cf974b5df07e56376e18d791d7a', commit_message='Upload model', commit_description='', oid='48e42c171cec6cf974b5df07e56376e18d791d7a', pr_url=None, pr_revision=None, pr_num=None)

In [30]:
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainer

peft_model_id = "stevehoang9/whisper-small-vi-500steps" # Use the same model ID as before.
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)
model.config.use_cache = True

adapter_config.json:   0%|          | 0.00/767 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

In [31]:
import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer


model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
eval_dataloader = DataLoader(vi_asr["test"], batch_size=8, collate_fn=data_collator)
# forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
normalizer = BasicTextNormalizer()

predictions = []
references = []
normalized_predictions = []
normalized_references = []

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    # forced_decoder_ids=forced_decoder_ids,
                    max_new_tokens=255,
                )
                .cpu()
                .numpy()
            )
            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
            decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
            predictions.extend(decoded_preds)
            references.extend(decoded_labels)
            normalized_predictions.extend([normalizer(pred).strip() for pred in decoded_preds])
            normalized_references.extend([normalizer(label).strip() for label in decoded_labels])
        del generated_tokens, labels, batch
    gc.collect()
wer = 100 * metric.compute(predictions=predictions, references=references)
normalized_wer = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
eval_metrics = {"eval/wer": wer, "eval/normalized_wer": normalized_wer}

print(f"{wer=} and {normalized_wer=}")
print(eval_metrics)

100%|██████████| 203/203 [15:55<00:00,  4.71s/it]


wer=21.464691534297103 and normalized_wer=20.962005808784305
{'eval/wer': 21.464691534297103, 'eval/normalized_wer': 20.962005808784305}
