# Fine-tuning Multilingual Whisper

In this notebook, we fine-tune the multilingual Whisper small model on the Khanty dataset. We achieve 44.57% WER with our fine-tuned ASR model. This notebook is adapted from https://huggingface.co/blog/fine-tune-whisper (see for more detailed code explanation)

In [None]:
#Installing libraries
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0
Collecting accelerate
  Downloading accelerate-0.29.3-py3-none-any.whl.metadata (18 kB)
Collecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Collecting jiwer
  Downloading jiwer-3.0.3-py3-none-any.whl.metadata (2.6 kB)
Collecting tensorboard
  Downloading tensorboard-2.16.2-py3-none-any.whl.metadata (1.6 kB)
Collecting gradio
  Downloading gradio-4.28.3-py3-none-any.whl.metadata (15 kB)
Collecting datasets[audio]
  Downloading datasets-2.19.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets[audio])
  Downloading dill-0.3.8-py3-none-an

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import pandas as pd
df = pd.read_csv('/content/gdrive/MyDrive/folder_for_export/metadata.csv')
df['transcription'] = df['transcription'].fillna('silence')


In [None]:
df.to_csv('metadata.csv')

In [None]:
#Loading dataset made in creating_asr_dataset.ipynb
from datasets import load_dataset
dataset = load_dataset("audiofolder", data_dir="/content/gdrive/MyDrive/folder_for_export")

Resolving data files:   0%|          | 0/3564 [00:00<?, ?it/s]

In [None]:
#Splitting the data into train and test
dataset = dataset['train'].train_test_split(test_size=0.2)
dataset = dataset.remove_columns(['Unnamed: 0'])

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
#Loading Whisper Feature Extractor
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
#Loading Whisper Tokenizer for hungarian (the closes high-resource language to khanty available)
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="hungarian", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
#Checking the tokenizer
input_str = dataset["train"][0]["transcription"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")


Input:                 śăta rəpitti mănԑma iśi əmăś pităs śit mit oλ wəs
Decoded w/ special:    <|startoftranscript|><|hu|><|transcribe|><|notimestamps|>śăta rəpitti mănԑma iśi əmăś pităs śit mit oλ wəs<|endoftext|>
Decoded w/out special: śăta rəpitti mănԑma iśi əmăś pităs śit mit oλ wəs
Are equal:             True


In [None]:
#Loading Whisper Processor
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="hungarian", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
from datasets import Audio

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))


In [None]:
print(dataset["train"][0])

{'audio': {'path': '/content/gdrive/MyDrive/folder_for_export/data/ua_punshum_yasa_samn_tayae_27.wav', 'array': array([-0.0042136 , -0.00557795, -0.00451304, ..., -0.00035517,
        0.00120888,  0.        ]), 'sampling_rate': 16000}, 'transcription': 'śăta rəpitti mănԑma iśi əmăś pităs śit mit oλ wəs'}


In [None]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids

    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch


In [None]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"], num_proc=2)


Map (num_proc=2):   0%|          | 0/2850 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/713 [00:00<?, ? examples/s]

In [None]:
#Loading Whisper for Conditional Generation
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

In [None]:
model.generation_config.language = "hungarian"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None


In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


In [None]:
import evaluate

metric = evaluate.load("wer")


In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="khanty_whisper_asr",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
    hub_token = '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',
)


In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


In [None]:
#Training the model
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
1000,0.0936,0.657697,47.867536
2000,0.0069,0.779062,45.576183
3000,0.0006,0.815544,44.204717


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

Step,Training Loss,Validation Loss,Wer
1000,0.0936,0.657697,47.867536
2000,0.0069,0.779062,45.576183
3000,0.0006,0.815544,44.204717
4000,0.0003,0.842916,44.50577
5000,0.0003,0.854894,44.572671


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

TrainOutput(global_step=5000, training_loss=0.1792287979803048, metrics={'train_runtime': 23675.4405, 'train_samples_per_second': 3.379, 'train_steps_per_second': 0.211, 'total_flos': 2.297774674427904e+19, 'train_loss': 0.1792287979803048, 'epoch': 27.932960893854748})