# Fine-tune Whisper for Enenlhet on Google Colab

There are several challenges at the outset:

- Whisper is a sprawling, poorly maintained, and brittle ecosystem that has blockers practically built into it
- Enenlhet is a low-resource, endangered language, so it doesn't have any pre-built tokenizers or any other items that Whisper models require
- Google Colab is unstable and liable to crash or disconnect in the middle of fine-tuning

But, I have tried fine-tuning on a CPU, and that was a long nightmare of monkey code and hardware snafus.

I'm going to try here to make a notebook that will successfully fine-tune a Whisper model. To do that, I need to:

1. Set up the environment properly
    a. Install packages
    b. Log into the Hugging Face Hub
    c. Make sure we're using the GPU
    d. Create directories
    e. Set a random seed for reproducibility
    f. Set the name of the model we'll be fine-tuning.
2. Prepare the dataset
3. Download and prepare the model
4. Set up and initialize a trainer with arguments optimized for GPU
5. Set up a `compute_metrics` function to use Word Error Rate to measure the model's performance
6. Save the output
7. Upload the model to Hugging Face Hub.

## Set up the environment

### Install and import packages

Several packages are not installed by default on Google Colab, so they must be added.

In [18]:
# Install necessary packages
!pip install --quiet transformers datasets accelerate evaluate huggingface_hub codecarbon jiwer --quiet

In [19]:

# Import necessary libraries
from codecarbon import EmissionsTracker
from glob import glob
import json
from dataclasses import dataclass
from datasets import (
    Audio,
    Dataset,
    DatasetDict,
    load_dataset
)
import evaluate
from huggingface_hub import snapshot_download, notebook_login
import numpy as np
import os
import random
import torch
from transformers import (
    EarlyStoppingCallback,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperFeatureExtractor,
    WhisperTokenizer
)
from transformers.trainer_seq2seq import Seq2SeqTrainer
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
from typing import Any, Dict, List, Union

### Log into the Hugging Face Hub

In [20]:
# Authenticate with Hugging Face Hub
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Make sure that we're using a GPU

In [21]:
# Check if GPU is available
if not torch.cuda.is_available():
    raise RuntimeError("GPU is not available. Please enable GPU in 'Runtime > Change runtime type'.")
else:
    print("GPU is available:", torch.cuda.get_device_name(0))

GPU is available: NVIDIA A100-SXM4-40GB


### Create directories

In [22]:
# Create necessary directories
output_dir = "./enenlhet-whisper-model"
log_dir = "./logs"
dataset_dir = "./enenlhet-dataset"

# If directories do not exist, create them
os.makedirs(output_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(dataset_dir, exist_ok=True)

### Set a random seed for reproducibility

In [23]:
# Set random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7b1a80abfcb0>

### Set the model_name variable

### Download the dataset

First I'll use `snapshot_download()` to download the content of the dataset repository.

In [24]:
whisper_dataset = load_dataset("enenlhet-asr/enenlhet-whisper-dataset")
# Ensure PyTorch tensors are returned
# whisper_dataset["train"].set_format(type="torch", columns=["input_features", "attention_mask", "labels"])
# whisper_dataset["validation"].set_format(type="torch", columns=["input_features", "attention_mask", "labels"])
# whisper_dataset["test"].set_format(type="torch", columns=["input_features", "attention_mask", "labels"])
# Use only the first 100 samples from the training set
# whisper_dataset["train"] = whisper_dataset["train"].select(range(100))
# whisper_dataset["validation"] = whisper_dataset["validation"].select(range(50))
# whisper_dataset["test"] = whisper_dataset["test"].select(range(50))

  table = cls._concat_blocks(blocks, axis=0)


In [25]:
# # Collect all label lengths from the training set
# label_lengths = [len(example["labels"]) for example in whisper_dataset["train"]]

# # Calculate the 95th percentile
# percentile_95 = int(np.percentile(label_lengths, 95))
# print(f"95th percentile label length: {percentile_95}")

In [26]:
# max_label = max(label_lengths)
# print(f"Max label length: {max_label}")

## Set up the model

In [27]:
# Define the name of the model once so it can be changed easily, if necessary
model_name = "openai/whisper-small"
# Load the model
model = WhisperForConditionalGeneration.from_pretrained(model_name)
# Get the processor
processor = WhisperProcessor.from_pretrained(model_name)
# Load the tokenizer
tokenizer = processor.tokenizer
tokenizer.pad_token = "<|pad|>"
tokenizer.add_tokens(["<|pad|>"])
model.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
model.resize_token_embeddings(len(tokenizer))
# Load the feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
# Create the processor with feature extractor and tokenizer
processor = WhisperProcessor(
    feature_extractor=feature_extractor,
    tokenizer=tokenizer
)

model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [28]:
# Set the device to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Move the model to the GPU
model.to(device)

Using device: cuda


WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

### Create the data collator

A data collator takes elements from the prepared datasets and creates batches for passing to the model. It also applies extra processing steps, like padding and masking here, to ensure that all the inputs are the same length.

Note that `input_features` and `label_features` correspond to "audio" and text, respectively, in the original dataset.

In [29]:
# Define the custom DataCollator class
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If the beginning of sentence (bos) token is appended in previous tokenization step,
        # cut it, since it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Initialize the data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

# Figure out the number of training steps
num_train_samples = len(whisper_dataset["train"])
num_eval_samples = len(whisper_dataset["validation"])
recommended_max_steps = (num_train_samples // 16) * 3

print(f"Maximum training steps: {recommended_max_steps}")

Maximum training steps: 942


In [30]:
from torch.utils.data import DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch

# Sample dummy processor + model for context (skip if you already have these)
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Your custom collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

# Take a small subset of your dataset to test
sample_batch = [whisper_dataset["train"][i] for i in range(4)]

# Run the collator
batch = data_collator(sample_batch)

# Inspect shapes and tensor values
for k, v in batch.items():
    print(f"{k}: shape={v.shape}, dtype={v.dtype}")

# Optional: check if padding mask and labels match up
print("\nLabels (first example):")
print(batch["labels"][0])
print("\nInput features (first example):")
print(batch["input_features"][0].shape)

input_features: shape=torch.Size([4, 80, 3000]), dtype=torch.float32
labels: shape=torch.Size([4, 12]), dtype=torch.int64

Labels (first example):
tensor([ 2330,   297, 18275,   514,   463,     6,    64,  -100,  -100,  -100,
         -100,  -100])

Input features (first example):
torch.Size([80, 3000])


### Define a custom `compute_metrics() function

The best metric for an ASR model is Word Error Rate (WER), so the `compute_metrics()` function must focus on that.

In [31]:
# Evaluation metric
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Print shapes and a few values
    print("pred_ids shape:", pred_ids.shape)
    print("label_ids shape:", label_ids.shape)
    print("First 10 pred_ids[0]:", pred_ids[0][:10])
    print("First 10 label_ids[0]:", label_ids[0][:10])

    # If pred_ids are logits (3D), take argmax
    if len(pred_ids.shape) == 3:
        print("Predictions appear to be logits. Taking argmax...")
        pred_ids = np.argmax(pred_ids, axis=-1)

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Print decoded example
    print("Decoded pred:", tokenizer.decode(pred_ids[0], skip_special_tokens=True))
    print("Decoded label:", tokenizer.decode(label_ids[0], skip_special_tokens=True))

    # Decode all predictions and references
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    print("First pred_str:", pred_str[0])
    print("First label_str:", label_str[0])
    print("Lengths:", len(pred_str), len(label_str))

    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Define the model's settings

The `training_args` variable holds many important settings that affect the outcome of the fine-tuning. The settings here are optimized for use when fine-tuning on a GPU.

I'm going to explain the settings, even though they are [documented on Hugging Face](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments), since I want to be sure that I understand them. 🤓

- `output_dir`: This is where the model's files will be stored. It was defined earlier in the notebook.
- `per_device_train_batch_size`: The number of input and label pairs included per batch sent to the device (GPU) for training.
- `per_device_eval_batch_size`: The number of input and label pairs included per batch sent to the device (GPU) for evaluation.
- `gradient_accumulation_steps`: The trainer will perform a backward pass after two steps. The backward pass is part of the learning process, where the model makes adjustments based on what it has learned to that point.
- `learning_rate`: This is the, well, learning rate for the optimizer. I have selected `1.25e-5` (i.e., 1.25 x 10 − 5) as suggested at <https://github.com/vasistalodagala/whisper-finetune>
- `warmup_steps`: The warmup process helps to avoid any big changes in the model's settings at the start of training.
- `max_steps`: This is the number of steps that the training will last. I calculated it above by dividing the length of the training set by the number I'd select for the `per_device_train_batch_size`.
- `eval_strategy`: The evaluation (WER) will be performed after the number of steps assigned in `eval_steps`.
- `eval_steps`: This is calculated above by dividing the length of the validation set by the number assigned to `per_device_eval_batch_size`.
- `save_steps`: How often the model will be saved. I'm saving as often as I evaluate.
- `save_total_limit`: I don't want to fill up my space with checkpoints, so I'm setting it to 2.
- `logging_steps`: How often information will be logged
- `predict_with_generate`: This is set to `False` because I'm not using the model to generate text.
- `report_to`: This will send the log data to Tensorboard, which is a nice way of visualizing the information.
- `load_best_model_at_end`: This ensures that only the best model is loaded for saving.
- `metric_for_best_model`: Defines Word Error Rate as the metric for determining the best model.
- `greater_is_better`: This is set to false because a lower WER is better.
- `fp16`: This is a performance boost. It tells the trainer to use 16-bit floating point numbers instead of the default 32-bit, which take longer to calculate.
- `gradient_checkpointing`: Another performance boost by storing only a small number of checkpoints and recomputing during the backward pass.
- `hub_model_id`: Identifies the model repo on Hugging Face Hub.
- ` hub_strategy`: Set to `end` to push to the hub when the trainer has finished.
- `push_to_hub`: Pushes the model to the hub at the end of the training.

I have used steps instead of epochs because Whisper typically isn't trained for more than two or three epochs for a small dataset.

In [32]:
# # Define the training arguments
# training_args = Seq2SeqTrainingArguments(
#     output_dir=output_dir,                      # Output directory for model predictions and checkpoints
#     per_device_train_batch_size=4,             # Batch size for training
#     per_device_eval_batch_size=2,               # Batch size for evaluation
#     gradient_accumulation_steps=2,              # Number of updates steps to accumulate before performing a backward/update pass
#     learning_rate=1.25e-5,                      # Initial learning rate for the optimizer
#     #warmup_steps=200,                           # Number of warmup steps for learning rate scheduler
#     max_steps=max_steps,                        # Total number of training steps to perform
#     eval_strategy="no",                      # Evaluation strategy to adopt during training
#     #eval_steps=500,                             # Number of steps between evaluations
#     save_steps=500,                             # Number of steps between model saves
#     save_total_limit=2,                         # Limit the total amount of checkpoints. Deletes the older checkpoints.
#     logging_steps=25,                           # Number of steps between logging
#     #predict_with_generate=False,                # Whether to use generate for predictions
#     report_to="none",                           # List of integrations to report the results
#     #generation_max_length=64,                   # Maximum length of generated sequences
#     #generation_num_beams=1,                     # Number of beams to use for generation
#     #load_best_model_at_end=True,                # Load the best model when finished training
#     #metric_for_best_model="wer",                # Use word error rate (WER) to evaluate the best model
#     #greater_is_better=False,                    # WER is lower when better, so we set this to False
#     fp16=True,                                  # GPU: use FP16 for speed
#     gradient_checkpointing=True,                # Helps with memory on GPU
#     hub_model_id="sjhuskey/enenlhet-whisper",   # Hugging Face Hub model ID
#     hub_strategy="end",                         # Hub strategy to use when pushing model
#     push_to_hub=True,                           # Set to True if pushing model
# )

In [33]:
training_args = Seq2SeqTrainingArguments(
    run_name="enenlhet-whisper",
    output_dir=output_dir,  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=100,
    max_steps=4710,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=65,
    save_steps=200,
    save_total_limit=2,
    eval_steps=200,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    hub_model_id="enenlhet-asr/enenlhet-whisper",   # Hugging Face Hub model ID
    hub_strategy="end",                         # Hub strategy to use when pushing model
    push_to_hub=True,                           # Set to True if pushing model
)

### Initialize the trainer

The `trainer` gets some additional settings here, including the splits to use for training and testing. Note that most of the settings just point back to previously defined variables.

In [34]:
# model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
model.config.forced_decoder_ids = None

In [35]:
# Initialize the trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=whisper_dataset["train"],
    eval_dataset=whisper_dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

### Initialize an emissions tracker

In [36]:
# tracker = EmissionsTracker(
#     project_name="whisper-enenlhet-gpu",
#     output_dir=log_dir,
#     output_file="whisper-emissions-gpu.csv"
# )

## Train the model

I have wrapped the training step in `try` and `except` to handle the seemingly inevitable problems that crop up at this stage. If all went well with the previous steps, the model should start training. If not, then debugging is the name of the game.

In [37]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
try:
    # Start emissions tracking
    print("Starting emissions tracking...")
    # tracker.start()
    # Train the model
    print("Starting training...")
    train_result = trainer.train()
    # Stop emissions tracking
    # tracker.stop()
    print("\n Training completed successfully!")
    print(f" Final training loss: {train_result.training_loss:.4f}")
    print(f" Total training steps: {train_result.global_step}")

    # Final evaluation
    print("\n Running final evaluation...")
    eval_result = trainer.evaluate()
    print(f"Final WER: {eval_result['eval_wer']:.4f}")

    # Save final model
    print(f"Saving final model to {training_args.output_dir}")
    trainer.save_model()
    processor.save_pretrained(training_args.output_dir)

    print("Training and saving completed successfully!")

# If nothing went as planned …
except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()

    # Still try to save current model state
    print("💾 Attempting to save current model state...")
    try:
        trainer.save_model()
        processor.save_pretrained(training_args.output_dir)
        print("Model saved despite training error")
    except:
        print("Could not save model")

Starting emissions tracking...
Starting training...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
200,4.8362,4.636086,433.616384
400,3.2841,3.34328,701.998002
600,2.9488,3.11349,489.010989
800,2.6678,3.076094,401.848152
1000,2.2033,3.02431,516.783217
1200,2.2217,2.991132,562.187812
1400,1.9348,3.058273,410.689311


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


pred_ids shape: (630, 448)
label_ids shape: (630, 448)
First 10 pred_ids[0]: [  297   304    71  1501   330   287    71  1857 31494   287]
First 10 label_ids[0]: [   77  1684 23255 20332   364  2918     6    64  -100  -100]
Decoded pred:  nalhengke lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga
Decoded label: nietnek nak aniam'a
First pred_str:  nalhengke lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga lhennenga
First label_str: nietnek nak aniam'a
Lengths: 630 630
pred_ids shape: (630, 448)
label_ids shape: (630, 448)
First 10 pred_ids[0]: [ 2012  1301     6   514  1641 17342  1301     6   514 31332]
First 10 label_ids[0]: [   77  1684 23255 20332   364  2918     6    64  -100  -100]
Decoded pred:  Amai'akha svai'ak kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha kelha ke

There were missing keys in the checkpoint model loaded: ['proj_out.weight'].



 Training completed successfully!
 Final training loss: 3.3636
 Total training steps: 1400

 Running final evaluation...


pred_ids shape: (630, 448)
label_ids shape: (630, 448)
First 10 pred_ids[0]: [  287 39903   220   220   220   220   220   220   220   220]
First 10 label_ids[0]: [   77  1684 23255 20332   364  2918     6    64  -100  -100]
Decoded pred:  lhta                                                           
Decoded label: nietnek nak aniam'a
First pred_str:  lhta                                                           
First label_str: nietnek nak aniam'a
Lengths: 630 630
Final WER: 401.8482
Saving final model to ./enenlhet-whisper-model


Uploading...:   0%|          | 0.00/967M [00:00<?, ?B/s]

Training and saving completed successfully!
