# fine-tine whisper large with PEFT-LoRA + int8

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/whisper-large-lora.ipynb)

on colab: mount gdrive using GUI before training

on kaggle: select kaggle free P100 because complicated to run PEFT on multi-GPU

https://huggingface.co/blog/fine-tune-whisper

https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU "datasets[audio]" accelerate transformers bitsandbytes peft
# no compute metrics so no `jiwer`

In [None]:
from dataclasses import dataclass
import datasets as hugDS
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model

In [None]:
SAMPLING_RATE = 16_000
def load_data(mode, **kwargs):
	tmp = hugDS.load_dataset(**kwargs, trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))
	match mode:
		case 0:
			return tmp
		case 1:
			return tmp.select_columns(["audio", "transcription"])
		case 2:
			return tmp.select_columns(["audio", "sentence"]).rename_column("sentence", "transcription")
		case _:
			raise ValueError("oh no!")

MY_DATA = hugDS.IterableDatasetDict()

MY_DATA["train"] = hugDS.concatenate_datasets([  # total: 1.4M samples
	load_data(path="doof-ferb/vlsp2020_vinai_100h",                      split="train", mode=0),  # 56.4k
	load_data(path="doof-ferb/fpt_fosd",                                 split="train", mode=0),  # 25.9k
	load_data(path="doof-ferb/infore1_25hours",                          split="train", mode=0),  # 14.9k
	load_data(path="doof-ferb/infore2_audiobooks",                       split="train", mode=0),  # 315k
	load_data(path="quocanh34/viet_vlsp",                                split="train", mode=0),  # 171k
	load_data(path="linhtran92/final_dataset_500hrs_wer0",               split="train", mode=1),  # 649k
	load_data(path="linhtran92/viet_youtube_asr_corpus_v2",              split="train", mode=1),  # 195k
	load_data(path="google/fleurs",                        name="vi_vn", split="train", mode=1),  # 3k
	load_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="train", mode=2),  # 2.3k
	load_data(path="vivos",                                              split="train", mode=2),  # 11.7k
])

MY_DATA["test"] = hugDS.concatenate_datasets([  # total: 32.6k samples
	load_data(path="quocanh34/viet_vlsp",                                split="validation", mode=0),  # 7.5k
	load_data(path="linhtran92/viet_youtube_asr_corpus_v2",              split="test",       mode=1),  # 21.6k
	load_data(path="google/fleurs",                        name="vi_vn", split="validation", mode=1),  # .3k
	load_data(path="google/fleurs",                        name="vi_vn", split="test",       mode=1),  # .8k
	load_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="validation", mode=2),  # .4k
	load_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="test",       mode=2),  # 1.3k
	load_data(path="vivos",                                              split="test",       mode=2),  # .7k
])

# some samples will be filtered out later (unknown how many)

In [None]:
modelID = "openai/whisper-large-v3"
FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(modelID)
TOKENIZER = WhisperTokenizer.from_pretrained(modelID, language="vi", task="transcribe")
MODEL = WhisperForConditionalGeneration.from_pretrained(modelID, use_cache=False, forced_decoder_ids=None, suppress_tokens=[], load_in_8bit=True, device_map="auto")

DUMMY_TOKEN = -100

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	def __call__(self, features):
		# split inputs and labels since they have to be of different lengths and need different padding methods
		input_features = [{"input_features": feature["input_features"]} for feature in features]
		label_features = [{"input_ids"     : feature["labels"]        } for feature in features]  # get the tokenized label sequences

		batch = FEATURE_EXTRACTOR.pad(input_features, return_tensors="pt")  # treat the audio inputs by simply returning torch tensors
		labels_batch =  TOKENIZER.pad(label_features, return_tensors="pt")  # pad the labels to max length
		labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN)  # replace padding with -100 to ignore loss correctly

		if (labels[:, 0] == TOKENIZER.bos_token_id).all().cpu().item():  # if bos token is appended in previous tokenization step,
			labels = labels[:, 1:]  # cut bos token here as it’s append later anyways

		batch["labels"] = labels
		return batch

DATA_COLLATOR = DataCollatorSpeechSeq2SeqWithPadding()

In [None]:
def prepare_dataset(batch):
	audio = batch["audio"]
	batch["input_length"] = len(audio)  # compute input length
	batch["input_features"] = FEATURE_EXTRACTOR(audio["array"], sampling_rate=SAMPLING_RATE).input_features[0]  # compute log-Mel input features
	batch["labels"] = TOKENIZER(batch["transcription"]).input_ids  # encode target text to label ids
	batch["labels_length"] = len(batch["labels"])  # compute labels length
	return batch

def filter_inputs(input_length):
	"""Filter inputs with zero input length or longer than 30s"""
	return 0 < input_length < 48e4  # 30s × 16kHz

def filter_labels(labels_length):
	"""Filter label sequences longer than max length 448 tokens"""
	return labels_length < 448  # MODEL.config.max_length

MY_DATA = (MY_DATA
	# .shuffle(seed=42)  # useless coz streaming multiple datasets (cannot set buffer too high coz not enough RAM)
	.map(prepare_dataset)  # no `num_proc` coz streaming
	.filter(filter_inputs, input_columns= ["input_length"])  # no `remove_columns` coz streaming
	.filter(filter_labels, input_columns=["labels_length"])  # no `remove_columns` coz streaming
)  # TODO: enable `batched=True` but don’t know how to write functions

In [None]:
MODEL_BIS = get_peft_model(
	prepare_model_for_int8_training(MODEL),
	LoraConfig(r=32, lora_alpha=64, target_modules=["fc1", "fc2", "q_proj", "v_proj", "k_proj", "out_proj"], lora_dropout=.1, bias="none")
)
MODEL_BIS.print_trainable_parameters()

In [None]:
# mount gdrive using GUI before training
%cd '/content/drive/My Drive/coder'
# %cd /kaggle/working
# !rm -rf ./my-whisper-large

In [None]:
SAVE_PATH = "./my-whisper-large"  # mount gdrive using GUI before training
BATCH_SIZE = 16  # should be a multiple of 8
# kaggle free P100 train faster than colab free T4
# kaggle free T4×2: problem with peft + multi-gpu

# colab free tier can only run for 8-12h max daily
# kaggle free tier can only run for 30h max weekly but max 12h per session

TRAINING_ARGS = Seq2SeqTrainingArguments(
	output_dir=SAVE_PATH,
	per_device_train_batch_size=BATCH_SIZE,
	per_device_eval_batch_size=BATCH_SIZE,
	fp16=True,
	# bf16=True, tf32=True, torch_compile=True,  # GPU Ampere or later
	report_to=["tensorboard"],

	max_steps=500,  # no `num_train_epochs` coz streaming
	warmup_steps=50,
	logging_steps=25,
	save_steps=50,
	eval_steps=50,
	evaluation_strategy="steps",
	save_total_limit=3,

	learning_rate=1e-3,
	warmup_ratio=.05,  # keep between 5-15%
	# gradient_accumulation_steps=1,  # to increase if decrease batch size
	remove_unused_columns=False, label_names=["labels"],  # required by PEFT
	# predict_with_generate=True,  # must disable coz PEFT
)

TRAINER = Seq2SeqTrainer(
	args=TRAINING_ARGS,
	model=MODEL_BIS,
	train_dataset=MY_DATA["train"],
	eval_dataset=MY_DATA["test"],
	data_collator=DATA_COLLATOR,
	# compute_metrics=compute_metrics,  # must disable coz PEFT
	tokenizer=FEATURE_EXTRACTOR,  # not TOKENIZER
)

In [None]:
TRAINER.train()  # resume_from_checkpoint=True  # only if resume

In [None]:
TRAINER.save_model()
!zip -FSr res.zip ./my-whisper-large