# fine-tine whisper large with PEFT-LoRA + int4 + DDP

better use of multi GPU with Distributed Data Parallelism (the other notebook use naive model parallelism)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU "datasets[audio]" accelerate transformers bitsandbytes peft
# no compute metrics so no `jiwer`

In [None]:
# everything must be inside this function
def train_ddp(pretrained_model, batch_size, total_steps, save_path, resume_training):
	from dataclasses import dataclass
	import torch
	import datasets as hugDS
	from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, BitsAndBytesConfig, Seq2SeqTrainingArguments, Seq2SeqTrainer
	from accelerate import Accelerator
	import peft

	has_bf16 = torch.cuda.is_bf16_supported()  # GPU Ampere or later
	accelerator = Accelerator(project_dir=save_path, log_with="tensorboard", mixed_precision="bf16" if has_bf16 else "fp16")

	FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(pretrained_model)
	TOKENIZER = WhisperTokenizer.from_pretrained(pretrained_model, language="vi", task="transcribe")
	MODEL = WhisperForConditionalGeneration.from_pretrained(
		pretrained_model, use_cache=False, device_map={"": accelerator.device},
		quantization_config=BitsAndBytesConfig(
			load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
			bnb_4bit_compute_dtype=torch.bfloat16 if has_bf16 else torch.float16
		)
	)
	MODEL.config.forced_decoder_ids = None
	MODEL.config.suppress_tokens = []

	DUMMY_TOKEN = -100

	MODEL_BIS = peft.get_peft_model(
		peft.prepare_model_for_kbit_training(MODEL, use_gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}),
		peft.LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=.05, bias="none")
	)
	if accelerator.is_main_process:
		MODEL_BIS.print_trainable_parameters()  # 16 millions = 1% of 1.6 billions params of whisper large

	###############################################################################
	# prepare data

	SAMPLING_RATE = 16_000
	def load_my_data(**kwargs):
		return hugDS.load_dataset(**kwargs, split="train", trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))


	MY_DATA = hugDS.concatenate_datasets([  # total: 86k samples
		load_my_data(path="doof-ferb/fpt_fosd"),  # 25.9k
		load_my_data(path="doof-ferb/infore1_25hours"),  # 14.9k
		load_my_data(path="doof-ferb/LSVSC").select_columns(["audio", "transcription"]),  # 45k
	])


	def prepare_dataset(batch):
		audio = batch["audio"]
		batch["input_length"] = len(audio["array"])  # compute input length
		batch["input_features"] = FEATURE_EXTRACTOR(audio["array"], sampling_rate=SAMPLING_RATE).input_features[0]  # compute log-Mel input features
		batch["labels"] = TOKENIZER(batch["transcription"]).input_ids  # encode target text to label ids
		batch["labels_length"] = len(batch["labels"])  # compute labels length
		return batch

	def filter_inputs(input_length):
		"""Filter inputs with zero input length or longer than 30s"""
		return 0 < input_length < 48e4  # 30s × 16kHz

	def filter_labels(labels_length):
		"""Filter label sequences longer than max length 448 tokens"""
		return labels_length < 448  # MODEL.config.max_length

	MY_DATA = (MY_DATA
		.map(prepare_dataset)  # no `num_proc` coz streaming
		.filter(filter_inputs, input_columns= ["input_length"])
		.filter(filter_labels, input_columns=["labels_length"])
	)

	@dataclass
	class DataCollatorSpeechSeq2SeqWithPadding:
		def __call__(self, features):
			# split inputs and labels since they have to be of different lengths and need different padding methods
			input_features = [{"input_features": feature["input_features"]} for feature in features]
			label_features = [{"input_ids"     : feature["labels"]        } for feature in features]  # get the tokenized label sequences

			batch = FEATURE_EXTRACTOR.pad(input_features, return_tensors="pt")  # treat the audio inputs by simply returning torch tensors
			labels_batch =  TOKENIZER.pad(label_features, return_tensors="pt")  # pad the labels to max length
			labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN)  # replace padding with -100 to ignore loss correctly

			if (labels[:, 0] == TOKENIZER.bos_token_id).all().cpu().item():  # if bos token is appended in previous tokenization step,
				labels = labels[:, 1:]  # cut bos token here as it’s append later anyways

			batch["labels"] = labels
			return batch

	DATA_COLLATOR = DataCollatorSpeechSeq2SeqWithPadding()

	###############################################################################
	# training setup

	# a practical learning rate while fine-tuning is a value 40× smaller than original used for pre-training
	if "tiny" in pretrained_model:
		LEARNING_RATE = 3.75e-5
	elif "base" in pretrained_model:
		LEARNING_RATE = 2.5e-5
	elif "small" in pretrained_model:
		LEARNING_RATE = 1.25e-5
	elif "medium" in pretrained_model:
		LEARNING_RATE = 6.25e-6
	elif "large" in pretrained_model:
		LEARNING_RATE = 5e-6
	else:
		LEARNING_RATE = 5e-5


	TRAINING_ARGS = Seq2SeqTrainingArguments(
		output_dir=save_path,
		per_device_train_batch_size=batch_size,
		# per_device_eval_batch_size=batch_size,
		fp16=not has_bf16,
		bf16=has_bf16, tf32=has_bf16,
		# torch_compile=True,  # SDPA not support whisper yet
		report_to=["tensorboard"],

		max_steps=total_steps,
		logging_steps=25,
		save_steps=50,
		evaluation_strategy="no",
		save_total_limit=5,
		accelerator_config={"split_batches": True},  # mandatory for streaming datasets

		optim="adamw_bnb_8bit",  # 8-bit AdamW optimizer: lower vram usage than default AdamW
		learning_rate=LEARNING_RATE,
		warmup_steps=.05,  # keep between 5-15%
		gradient_accumulation_steps=1 if batch_size >= 8 else 8 // batch_size,  # keep effective batch size as min 8 per device
		remove_unused_columns=False, label_names=["labels"],  # required by PEFT
		# predict_with_generate=True,  # must disable coz PEFT
	)

	TRAINER = Seq2SeqTrainer(
		args=TRAINING_ARGS,
		model=MODEL_BIS,
		train_dataset=MY_DATA,
		data_collator=DATA_COLLATOR,
		# compute_metrics=compute_metrics,  # must disable coz PEFT
		tokenizer=FEATURE_EXTRACTOR,  # not TOKENIZER
	)

	TRAINER.train(resume_from_checkpoint=resume_training)

	accelerator.wait_for_everyone()
	if accelerator.is_main_process:
		TRAINER.save_model()

In [None]:
from accelerate import notebook_launcher

In [None]:
%cd /kaggle/working
!rm -rf ./my-whisper-lora

In [None]:
notebook_launcher(train_ddp, args=("vinai/PhoWhisper-large", 12, 100, "./my-whisper-lora", False), mixed_precision="fp16", num_nodes=1, num_processes=2)

In [None]:
!zip -FSr res.zip ./my-whisper-lora