# fine-tune whisper tiny with traditional approach + DDP

better use of multi GPU with Distributed Data Parallelism (the other notebook use naive model parallelism)

**attention**: in this case batch size value is total on all GPU and is evenly splited across (instead of batch size × GPU count)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer bitsandbytes
# install then `import evaluate` throw error on kaggle

In [None]:
# everything must be inside this function
def train_ddp(pretrained_model, batch_size, total_steps, save_path, resume_training):
	import torch
	from dataclasses import dataclass
	import datasets as hugDS
	from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
	from accelerate import Accelerator
	import jiwer

	has_bf16 = torch.cuda.is_bf16_supported()  # GPU Ampere or later
	accelerator = Accelerator(project_dir=save_path, log_with="tensorboard", mixed_precision="bf16" if has_bf16 else "fp16")

	FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(pretrained_model)
	TOKENIZER = WhisperTokenizer.from_pretrained(pretrained_model, language="vi", task="transcribe")
	MODEL = WhisperForConditionalGeneration.from_pretrained(pretrained_model, use_cache=False, device_map={"": accelerator.device})
	MODEL.config.forced_decoder_ids = None
	MODEL.config.suppress_tokens = []

	DUMMY_TOKEN = -100

	SAMPLING_RATE = 16_000
	def load_my_data(mode, **kwargs):
		tmp = hugDS.load_dataset(**kwargs, trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))
		match mode:
			case 0:
				return tmp
			case 1:
				return tmp.select_columns(["audio", "transcription"])
			case 2:
				return tmp.select_columns(["audio", "sentence"]).rename_column("sentence", "transcription")
			case _:
				raise ValueError("oh no!")


	with accelerator.main_process_first():
		MY_DATA = hugDS.IterableDatasetDict()
		MY_DATA["train"] = hugDS.concatenate_datasets([  # total: 1.5M samples
			load_my_data(path="google/fleurs",                        name="vi_vn", split="train", mode=1),  # 3k
			load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="train", mode=2),  # 2.3k
			load_my_data(path="vivos",                                              split="train", mode=2),  # 11.7k
			load_my_data(path="doof-ferb/fpt_fosd",                                 split="train", mode=0),  # 25.9k
			load_my_data(path="doof-ferb/infore1_25hours",                          split="train", mode=0),  # 14.9k
			load_my_data(path="doof-ferb/vlsp2020_vinai_100h",                      split="train", mode=0),  # 56.4k
			load_my_data(path="doof-ferb/LSVSC",                                    split="train", mode=1),  # 45k
			load_my_data(path="quocanh34/viet_vlsp",                                split="train", mode=0),  # 171k
			load_my_data(path="linhtran92/viet_youtube_asr_corpus_v2",              split="train", mode=1),  # 195k
			load_my_data(path="doof-ferb/infore2_audiobooks",                       split="train", mode=0),  # 315k
			load_my_data(path="linhtran92/viet_bud500",                             split="train", mode=0),  # 634k
		])
		MY_DATA["test"] = hugDS.concatenate_datasets([  # total: 15k samples
			load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi", split="test", mode=2),  # 1.3k
			# remove FLEURS because error when running in batch
			load_my_data(path="vivos",                                           split="test", mode=2),  # .7k
		])

	# some samples will be filtered out later (unknown how many)

	def prepare_dataset(batch):
		audio = batch["audio"]
		batch["input_length"] = len(audio["array"])  # compute input length
		batch["input_features"] = FEATURE_EXTRACTOR(audio["array"], sampling_rate=SAMPLING_RATE).input_features[0]  # compute log-Mel input features
		batch["labels"] = TOKENIZER(batch["transcription"]).input_ids  # encode target text to label ids
		batch["labels_length"] = len(batch["labels"])  # compute labels length
		return batch

	def filter_inputs(input_length):
		"""Filter inputs with zero input length or longer than 30s"""
		return 0 < input_length < 48e4  # 30s × 16kHz

	def filter_labels(labels_length):
		"""Filter label sequences longer than max length 448 tokens"""
		return labels_length < 448  # MODEL.config.max_length

	MY_DATA = (MY_DATA
		# .shuffle(seed=42)  # useless coz streaming multiple datasets (cannot set buffer too high coz not enough RAM)
		.map(prepare_dataset)  # no `num_proc` coz streaming
		.filter(filter_inputs, input_columns= ["input_length"])  # no `remove_columns` coz streaming
		.filter(filter_labels, input_columns=["labels_length"])  # no `remove_columns` coz streaming
	)  # TODO: enable `batched=True` but don’t know how to write functions

	@dataclass
	class DataCollatorSpeechSeq2SeqWithPadding:
		def __call__(self, features):
			# split inputs and labels since they have to be of different lengths and need different padding methods
			input_features = [{"input_features": feature["input_features"]} for feature in features]
			label_features = [{"input_ids"     : feature["labels"]        } for feature in features]  # get the tokenized label sequences

			batch = FEATURE_EXTRACTOR.pad(input_features, return_tensors="pt")  # treat the audio inputs by simply returning torch tensors
			labels_batch =  TOKENIZER.pad(label_features, return_tensors="pt")  # pad the labels to max length
			labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN)  # replace padding with -100 to ignore loss correctly

			if (labels[:, 0] == TOKENIZER.bos_token_id).all().cpu().item():  # if bos token is appended in previous tokenization step,
				labels = labels[:, 1:]  # cut bos token here as it’s append later anyways

			batch["labels"] = labels
			return batch

	DATA_COLLATOR = DataCollatorSpeechSeq2SeqWithPadding()

	JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
		jiwer.ToLowerCase(),
		jiwer.RemoveKaldiNonWords(),
		jiwer.RemoveMultipleSpaces(),
		jiwer.Strip(),
		jiwer.RemovePunctuation(),
		jiwer.ReduceToListOfListOfWords(),
	])

	def compute_metrics(pred):
		pred_ids = pred.predictions
		label_ids = pred.label_ids
		label_ids[label_ids == DUMMY_TOKEN] = TOKENIZER.pad_token_id  # replace -100 with the pad_token_id

		wer = jiwer.wer(  # we do not want to group tokens when computing the metrics
			reference=TOKENIZER.batch_decode(label_ids, skip_special_tokens=True),
			hypothesis=TOKENIZER.batch_decode(pred_ids, skip_special_tokens=True),
			reference_transform=JIWER_TRANS, hypothesis_transform=JIWER_TRANS
		)
		return {"wer": wer}

	TRAINING_ARGS = Seq2SeqTrainingArguments(
		output_dir=save_path,
		per_device_train_batch_size=batch_size,
		per_device_eval_batch_size=batch_size,
		fp16=not has_bf16,
		bf16=has_bf16, tf32=has_bf16,
		# torch_compile=True,  # SDPA not support whisper yet
		report_to=["tensorboard"],

		max_steps=total_steps,
		logging_steps=25,
		save_steps=50,
		eval_steps=50,
		evaluation_strategy="steps",
		save_total_limit=3,
		accelerator_config={"split_batches": True},

		optim="adamw_bnb_8bit",  # 8-bit AdamW optimizer: lower vram usage than default AdamW
		learning_rate=3.75e-5,
		warmup_ratio=.05,  # keep between 5-15%
		gradient_accumulation_steps=1 if batch_size >= 8 else 8 // batch_size,
		gradient_checkpointing=True,
		gradient_checkpointing_kwargs={"use_reentrant": False},
		predict_with_generate=True,
		# generation_num_beams=5,  # require more VRAM
		load_best_model_at_end=True,
		metric_for_best_model="wer",
		greater_is_better=False,  # WER is better when lower
	)

	TRAINER = Seq2SeqTrainer(
		args=TRAINING_ARGS,
		model=MODEL,
		train_dataset=MY_DATA["train"],
		eval_dataset=MY_DATA["test"],
		data_collator=DATA_COLLATOR,
		compute_metrics=compute_metrics,
		tokenizer=FEATURE_EXTRACTOR,  # not TOKENIZER
	)

	TRAINER.train(resume_from_checkpoint=resume_training)

	accelerator.wait_for_everyone()
	if accelerator.is_main_process:
		TRAINER.save_model()

In [None]:
from accelerate import notebook_launcher

In [None]:
%cd /kaggle/working
!rm -rf ./my-whisper-tiny

In [None]:
notebook_launcher(train_ddp, args=("openai/whisper-tiny", 16, 21000, "./my-whisper-tiny", False), mixed_precision="fp16", num_nodes=1, num_processes=2)

In [None]:
!zip -FSr res.zip ./my-whisper-tiny