# fine-tune whisper tiny with traditional approach

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/train/whisper-tiny-traditional.ipynb)

on colab: mount gdrive using GUI before training

on kaggle: select kaggle free T4×2 for auto double batch size

disable evaluation on test sets because it gets stuck indefinitely

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer
# install then `import evaluate` throw error on kaggle

In [None]:
from dataclasses import dataclass
import datasets as hugDS
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
import jiwer

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(mode, **kwargs):
	tmp = hugDS.load_dataset(**kwargs, trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))
	match mode:
		case 0:
			return tmp
		case 1:
			return tmp.select_columns(["audio", "transcription"])
		case 2:
			return tmp.select_columns(["audio", "sentence"]).rename_column("sentence", "transcription")
		case _:
			raise ValueError("oh no!")

MY_DATA = hugDS.IterableDatasetDict()

MY_DATA["train"] = hugDS.concatenate_datasets([  # total: 1.5M samples
	load_my_data(path="google/fleurs",                        name="vi_vn", split="train", mode=1),  # 3k
	load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="train", mode=2),  # 2.3k
	load_my_data(path="vivos",                                              split="train", mode=2),  # 11.7k
	load_my_data(path="doof-ferb/fpt_fosd",                                 split="train", mode=0),  # 25.9k
	load_my_data(path="doof-ferb/infore1_25hours",                          split="train", mode=0),  # 14.9k
	load_my_data(path="doof-ferb/vlsp2020_vinai_100h",                      split="train", mode=0),  # 56.4k
	load_my_data(path="doof-ferb/LSVSC",                                    split="train", mode=1),  # 45k
	load_my_data(path="quocanh34/viet_vlsp",                                split="train", mode=0),  # 171k
	load_my_data(path="linhtran92/viet_youtube_asr_corpus_v2",              split="train", mode=1),  # 195k
	load_my_data(path="doof-ferb/infore2_audiobooks",                       split="train", mode=0),  # 315k
	load_my_data(path="linhtran92/viet_bud500",                             split="train", mode=0),  # 634k
])

MY_DATA["test"] = hugDS.concatenate_datasets([  # total: 59k samples
	load_my_data(path="google/fleurs",                        name="vi_vn", split="validation", mode=1),  # .3k
	load_my_data(path="google/fleurs",                        name="vi_vn", split="test",       mode=1),  # .8k
	load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="validation", mode=2),  # .4k
	load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",    split="test",       mode=2),  # 1.3k
	load_my_data(path="vivos",                                              split="test",       mode=2),  # .7k
	load_my_data(path="doof-ferb/LSVSC",                                    split="validation", mode=1),  # 5.7k
	load_my_data(path="doof-ferb/LSVSC",                                    split="test",       mode=1),  # 5.7k
	load_my_data(path="quocanh34/viet_vlsp",                                split="validation", mode=0),  # 7.5k
	load_my_data(path="linhtran92/viet_youtube_asr_corpus_v2",              split="test",       mode=1),  # 21.6k
	load_my_data(path="linhtran92/viet_bud500",                             split="validation", mode=0),  # 7.5k
	load_my_data(path="linhtran92/viet_bud500",                             split="test",       mode=0),  # 7.5k
])

# some samples will be filtered out later (unknown how many)

In [None]:
modelID = "openai/whisper-tiny"
FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(modelID)
TOKENIZER = WhisperTokenizer.from_pretrained(modelID, language="vi", task="transcribe")
MODEL = WhisperForConditionalGeneration.from_pretrained(modelID, use_cache=False)
MODEL.config.forced_decoder_ids = None
MODEL.config.suppress_tokens = []

DUMMY_TOKEN = -100

In [None]:
def prepare_dataset(batch):
	audio = batch["audio"]
	batch["input_length"] = len(audio["array"])  # compute input length
	batch["input_features"] = FEATURE_EXTRACTOR(audio["array"], sampling_rate=SAMPLING_RATE).input_features[0]  # compute log-Mel input features
	batch["labels"] = TOKENIZER(batch["transcription"]).input_ids  # encode target text to label ids
	batch["labels_length"] = len(batch["labels"])  # compute labels length
	return batch

def filter_inputs(input_length):
	"""Filter inputs with zero input length or longer than 30s"""
	return 0 < input_length < 48e4  # 30s × 16kHz

def filter_labels(labels_length):
	"""Filter label sequences longer than max length 448 tokens"""
	return labels_length < 448  # MODEL.config.max_length

MY_DATA = (MY_DATA
	# .shuffle(seed=42)  # useless coz streaming multiple datasets (cannot set buffer too high coz not enough RAM)
	.map(prepare_dataset)  # no `num_proc` coz streaming
	.filter(filter_inputs, input_columns= ["input_length"])  # no `remove_columns` coz streaming
	.filter(filter_labels, input_columns=["labels_length"])  # no `remove_columns` coz streaming
)  # TODO: enable `batched=True` but don’t know how to write functions

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
	def __call__(self, features):
		# split inputs and labels since they have to be of different lengths and need different padding methods
		input_features = [{"input_features": feature["input_features"]} for feature in features]
		label_features = [{"input_ids"     : feature["labels"]        } for feature in features]  # get the tokenized label sequences

		batch = FEATURE_EXTRACTOR.pad(input_features, return_tensors="pt")  # treat the audio inputs by simply returning torch tensors
		labels_batch =  TOKENIZER.pad(label_features, return_tensors="pt")  # pad the labels to max length
		labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN)  # replace padding with -100 to ignore loss correctly

		if (labels[:, 0] == TOKENIZER.bos_token_id).all().cpu().item():  # if bos token is appended in previous tokenization step,
			labels = labels[:, 1:]  # cut bos token here as it’s append later anyways

		batch["labels"] = labels
		return batch

DATA_COLLATOR = DataCollatorSpeechSeq2SeqWithPadding()

In [None]:
JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

def compute_metrics(pred):
	pred_ids = pred.predictions
	label_ids = pred.label_ids
	label_ids[label_ids == DUMMY_TOKEN] = TOKENIZER.pad_token_id  # replace -100 with the pad_token_id

	wer = jiwer.wer(  # we do not want to group tokens when computing the metrics
		reference=TOKENIZER.batch_decode(label_ids, skip_special_tokens=True),
		hypothesis=TOKENIZER.batch_decode(pred_ids, skip_special_tokens=True),
		reference_transform=JIWER_TRANS, hypothesis_transform=JIWER_TRANS
	)
	return {"wer": wer}

In [None]:
# mount gdrive using GUI before training
%cd '/content/drive/My Drive/coder'
# %cd /kaggle/working
# !rm -rf ./my-whisper-tiny

In [None]:
SAVE_PATH = "./my-whisper-tiny"
BATCH_SIZE = 16  # should be a multiple of 8
# kaggle free P100 train faster than colab free T4
# kaggle free T4×2: no speed up but auto double batch size

# colab free tier can only run for 8-12h max daily
# kaggle free tier can only run for 30h max weekly but max 12h per session

TRAINING_ARGS = Seq2SeqTrainingArguments(
	output_dir=SAVE_PATH,
	per_device_train_batch_size=BATCH_SIZE,
	per_device_eval_batch_size=BATCH_SIZE,
	fp16=True,
	# bf16=True, tf32=True, torch_compile=True,  # GPU Ampere or later
	report_to=["tensorboard"],

	max_steps=21000,  # no `num_train_epochs` coz streaming
	logging_steps=25,
	save_steps=50,
	# eval_steps=50,
	evaluation_strategy="no",  # "steps"
	save_total_limit=3,

	learning_rate=3.75e-5,
	warmup_ratio=.05,  # keep between 5-15%
	# gradient_accumulation_steps=1,  # to increase if decrease batch size
	gradient_checkpointing=True,
	gradient_checkpointing_kwargs={"use_reentrant": False},
	predict_with_generate=True,
	# generation_num_beams=5,  # require more VRAM
	# load_best_model_at_end=True,
	# metric_for_best_model="wer",
	# greater_is_better=False,  # WER is better when lower
)

TRAINER = Seq2SeqTrainer(
	args=TRAINING_ARGS,
	model=MODEL,
	train_dataset=MY_DATA["train"],
	eval_dataset=MY_DATA["test"],
	data_collator=DATA_COLLATOR,
	compute_metrics=compute_metrics,
	tokenizer=FEATURE_EXTRACTOR,  # not TOKENIZER
)

In [None]:
TRAINER.train()  # resume_from_checkpoint=True  # only if resume

In [None]:
TRAINER.save_model()
!zip -FSr res.zip ./my-whisper-tiny