# Fine-tuning Whisper

This notebook fine-tunes Whisper on French. The default Whisper multilingual model initially seems to have rather poor performance on French.

This notebook roughly follows [this blog post](https://huggingface.co/blog/fine-tune-whisper).

**Goal**: Fine-tune `whisper-tiny` to have medium to high performance on French-language input *without* timestamps.

In [1]:
!pip install --upgrade pip
# jiwer is used for the word error rate (WER) metric
!pip install --upgrade datasets[audio] transformers evaluate jiwer



In [2]:
!pip install pyspellchecker==0.8.1



In [3]:
import wandb
# See https://discuss.huggingface.co/t/how-to-turn-wandb-off-in-trainer/6237/10
wandb.init(mode='disabled')

In [4]:
from pathlib import Path

checkpoint_remote_path = Path('./final-checkpoints').resolve()
def connect_to_google_drive():
    """ Connects to Google Drive and configures the notebook to upload final
        checkpoints. """
    from google.colab import drive

    drive.mount('/content/drive')
    return Path('/content/drive/My Drive') / 'whisper' / 'checkpoints'

# Optional:
#checkpoint_remote_path = connect_to_google_drive()

In [5]:
if not checkpoint_remote_path.parent.exists():
    checkpoint_remote_path.parent.mkdir(parents=True)

In [6]:
checkpoint_path = Path('./whisper/checkpoints').resolve()

In [7]:
import shutil


## Load data

The [voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli/viewer/fr/train?f%5Braw_text%5D%5Bmin%5D=236&f%5Braw_text%5D%5Bmax%5D=354&f%5Braw_text%5D%5Btransform%5D=length&row=45) and [CommonVoice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) datasets will be used to fine-tune Whisper. We'll also modify the CommonVoice data so that the training data is multi-sentence.

To speed up processing later on, we download the full dataset at once (`streaming=False`). The initial download may take some time.

In [8]:
from datasets import load_dataset, IterableDatasetDict, interleave_datasets

def load_dataset_from_id(dataset_id: str):
    data_raw = IterableDatasetDict()

    data_raw['train'] = load_dataset(dataset_id, 'fr', split='train', streaming=False).to_iterable_dataset()
    print('Loaded training data. Loading test data:')
    data_raw['test'] = load_dataset(dataset_id, 'fr', split='test', streaming=True)
    return data_raw


In [11]:
print('Loading Voxpopuli')
voxpopuli_data_raw = load_dataset_from_id('facebook/voxpopuli').rename_column('raw_text', 'text')


Loading Voxpopuli


train_part_1.tar.gz:  24%|##4       | 409M/1.70G [00:00<?, ?B/s]

train_part_2.tar.gz:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

train_part_3.tar.gz:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

train_part_4.tar.gz:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

train_part_5.tar.gz:   0%|          | 0.00/1.75G [00:00<?, ?B/s]

train_part_6.tar.gz:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

train_part_7.tar.gz:   0%|          | 0.00/1.70G [00:00<?, ?B/s]

train_part_8.tar.gz:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

train_part_9.tar.gz:   0%|          | 0.00/1.76G [00:00<?, ?B/s]

train_part_10.tar.gz:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

train_part_11.tar.gz:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

train_part_12.tar.gz:   0%|          | 0.00/1.75G [00:00<?, ?B/s]

train_part_13.tar.gz:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

train_part_14.tar.gz:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

dev_part_0.tar.gz:   0%|          | 0.00/605M [00:00<?, ?B/s]

test_part_0.tar.gz:   0%|          | 0.00/605M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Loaded training data. Loading test data:


In [12]:
print('Loading CommonVoice...')
common_voice_data_raw = load_dataset_from_id('mozilla-foundation/common_voice_11_0')\
    .rename_column('sentence', 'text')

Loading CommonVoice...


fr_train_1.tar:  19%|#8        | 294M/1.59G [00:00<?, ?B/s]

fr_train_2.tar:   0%|          | 0.00/1.54G [00:00<?, ?B/s]

fr_train_3.tar:   0%|          | 0.00/1.53G [00:00<?, ?B/s]

fr_train_4.tar:   0%|          | 0.00/1.48G [00:00<?, ?B/s]

fr_train_5.tar:   0%|          | 0.00/1.49G [00:00<?, ?B/s]

fr_train_6.tar:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

fr_train_7.tar:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

fr_train_8.tar:   0%|          | 0.00/1.44G [00:00<?, ?B/s]

fr_train_9.tar:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

fr_train_10.tar:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

fr_train_11.tar:   0%|          | 0.00/1.80G [00:00<?, ?B/s]

fr_train_12.tar:   0%|          | 0.00/168M [00:00<?, ?B/s]

fr_dev_0.tar:   0%|          | 0.00/702M [00:00<?, ?B/s]

fr_test_0.tar:   0%|          | 0.00/714M [00:00<?, ?B/s]

fr_other_0.tar:   0%|          | 0.00/478M [00:00<?, ?B/s]

fr_invalidated_0.tar:   0%|          | 0.00/1.80G [00:00<?, ?B/s]

fr_invalidated_1.tar:   0%|          | 0.00/652M [00:00<?, ?B/s]

train.tsv:   0%|          | 0.00/125M [00:00<?, ?B/s]

dev.tsv:   0%|          | 0.00/3.83M [00:00<?, ?B/s]

test.tsv:   0%|          | 0.00/3.81M [00:00<?, ?B/s]

other.tsv:   0%|          | 0.00/3.68M [00:00<?, ?B/s]

invalidated.tsv:   0%|          | 0.00/14.4M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 12267it [00:00, 122663.98it/s][A
Reading metadata...: 25939it [00:00, 130925.89it/s][A
Reading metadata...: 39508it [00:00, 133095.75it/s][A
Reading metadata...: 52818it [00:00, 132081.32it/s][A
Reading metadata...: 66028it [00:00, 129016.73it/s][A
Reading metadata...: 78942it [00:00, 119990.84it/s][A
Reading metadata...: 92630it [00:00, 125158.24it/s][A
Reading metadata...: 106024it [00:00, 127829.18it/s][A
Reading metadata...: 119905it [00:00, 131158.80it/s][A
Reading metadata...: 134189it [00:01, 134686.91it/s][A
Reading metadata...: 148073it [00:01, 135936.08it/s][A
Reading metadata...: 161706it [00:01, 133364.46it/s][A
Reading metadata...: 175079it [00:01, 130682.94it/s][A
Reading metadata...: 188180it [00:01, 128458.03it/s][A
Reading metadata...: 201850it [00:01, 130858.72it/s][A
Reading metadata...: 215510it [00:01, 132541.71it/s][A
Reading metadata...: 229772it [00:01, 135517.83it/s][A
Reading met

Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 16089it [00:00, 138484.88it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 16089it [00:00, 131083.71it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 14359it [00:00, 120453.47it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 0it [00:00, ?it/s][A
Reading metadata...: 14304it [00:00, 142925.98it/s][A
Reading metadata...: 28597it [00:00, 132925.11it/s][A
Reading metadata...: 43531it [00:00, 140121.54it/s][A
Reading metadata...: 57607it [00:00, 138394.41it/s]


Loaded training data. Loading test data:


In [13]:
# Normalize text
import unicodedata, re
from datasets import Audio

audioFeature = Audio(sampling_rate=16_000)

def normalize_text(text: str):
    replacements = [
        ['’', '\''],
        ['‘', '\''],
        ['́a', 'á'], # Convert from two-character á to one-character á
        ['́u', 'ú'],
        ['́e', 'é'],
        ['̀e', 'è'],
        ['̀a', 'à'],
        # Some characters don't work with the GGML conversion script:
        ['œ', '[oe]'],
        ['́', '\''],
        ['̂', '\''],
        ['̀', '\''],
        ['—', '--'],
        ['…', '...'],
        ['の', ''],
    ]
    for [orig, replace] in replacements:
        text = text.replace(orig, replace)

    if len(text) > 1:
        text = text[0].upper() + text[1:]

    if len(text) == 0:
        return '[BLANK_AUDIO]'

    return text
def normalize_texts(batch):
    batch['text'] = [normalize_text(text) for text in batch['text']]
    # Re-encode the audio-data: This seems necessary because after mapping,
    # `datasets` attempts to **re-decode audio-data**. Since double-decoding
    # breaks things, we encode the audio here.
    batch['audio'] = [audioFeature.encode_example(audio) for audio in batch['audio']]
    return batch

def normalize_dataset_text(dataset):
    def normalize_dataset_part(dataset):
        return dataset.map(
            normalize_texts,
            batched=True,
            features=dataset.features
        )
    dataset['train'] = normalize_dataset_part(dataset['train'])
    dataset['test'] = normalize_dataset_part(dataset['test'])

normalize_dataset_text(voxpopuli_data_raw)

In [14]:

def cast_audio(data):
    return data.cast_column('audio', audioFeature)

voxpopuli_data = cast_audio(voxpopuli_data_raw)

In [15]:
next(iter(voxpopuli_data['train']))

{'audio_id': '20200212-0900-PLENARY-fr_20200212-18:11:25_1',
 'language': 2,
 'audio': {'path': None,
  'array': array([ 3.66210938e-04, -9.15527344e-05, -2.74658203e-04, ...,
         -6.40869141e-04, -7.93457031e-04, -9.15527344e-04]),
  'sampling_rate': 16000},
 'text': 'Notre délégation défendra la lutte contre les écarts salariaux à travail égal, contre les cyberviolences, les mariages forcés et les mutilations génitales.',
 'normalized_text': 'notre délégation défendra la lutte contre les écarts salariaux à travail égal contre les cyberviolences les mariages forcés et les mutilations génitales.',
 'gender': 'female',
 'speaker_id': '182995',
 'is_gold_transcript': True,
 'accent': 'None'}

In [16]:
from random import randint
import numpy as np

def combine_sentences(batch):
    # See https://github.com/huggingface/datasets/issues/5361
    if len(batch['audio']) > 0:
        joinedAudio = audioFeature.encode_example({
            'array': np.concatenate([ audio['array'] for audio in batch['audio'] ]),
            'sampling_rate': batch['audio'][0]['sampling_rate']
        })
        batch['audio'] = [joinedAudio]
        batch['text'] = [ normalize_text(' '.join(batch['text'])) ]
        return batch
    else:
        return batch
common_voice_data = common_voice_data_raw
common_voice_data = common_voice_data\
    .remove_columns(['accent', 'age', 'client_id', 'locale', 'segment', 'gender', 'up_votes', 'down_votes', 'path'])


def map_subdataset(key: str):
    common_voice_data[key] = common_voice_data[key].map(
        combine_sentences,
        batched=True,
        batch_size=3,
        # Pass features to allow casting audio later. See https://github.com/huggingface/datasets/issues/5828
        features=common_voice_data[key].features
    )
common_voice_data = cast_audio(common_voice_data)
map_subdataset('train')
map_subdataset('test')
common_voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 1
    })
    test: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 1
    })
})

In [17]:
test_data = next(iter(common_voice_data['train']))
print(test_data)

{'audio': {'path': None, 'array': array([-3.05175781e-05, -3.05175781e-05,  0.00000000e+00, ...,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00]), 'sampling_rate': 16000}, 'text': 'Il est dissous à Trèves. Les candidats sont présentés par le Fili. Il y rencontre plusieurs éditeurs dont il rejette pourtant les propositions de traductions.'}


In [18]:
from IPython.display import Audio as AudioDisplay
AudioDisplay(test_data['audio']['array'], rate=test_data['audio']['sampling_rate'])

In [19]:

voice_data = IterableDatasetDict()
voice_data['train'] = interleave_datasets([
    voxpopuli_data_raw['train'], common_voice_data['train'],
])
voice_data['test'] = interleave_datasets([
    voxpopuli_data_raw['test'], common_voice_data['test'],
])

In [20]:
voice_data = voice_data\
    .remove_columns(['gender', 'normalized_text', 'accent', 'is_gold_transcript', 'audio_id', 'language'])

voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: ['audio', 'text', 'speaker_id'],
        num_shards: 1
    })
    test: IterableDataset({
        features: ['audio', 'text', 'speaker_id'],
        num_shards: 1
    })
})

In [21]:
print(next(iter(voice_data['test'])))

{'audio': {'path': None, 'array': array([-0.01843262, -0.00570679,  0.01486206, ...,  0.00610352,
        0.00418091,  0.00317383]), 'sampling_rate': 16000}, 'text': 'Je souhaite juste rappeler que ces droits de plantation étaient intégrés dans la réforme de 2008, qui a été adoptée par le Conseil des ministres. Mais malgré cela et au delà de cela, dans le cadre de cette réforme 2013, nous sommes revenus sur cette question, compte tenu aussi du rapport', 'speaker_id': 'None'}


The GGML conversion script has trouble with some characters (e.g. the `\u0301` accute accent character). For now, replace these characters early so they won't appear in the updated vocabulary:



In [22]:


print(next(iter(voice_data['train'])))

{'audio': {'path': None, 'array': array([ 3.66210938e-04, -9.15527344e-05, -2.74658203e-04, ...,
       -6.40869141e-04, -7.93457031e-04, -9.15527344e-04]), 'sampling_rate': 16000}, 'text': 'Notre délégation défendra la lutte contre les écarts salariaux à travail égal, contre les cyberviolences, les mariages forcés et les mutilations génitales.', 'speaker_id': '182995'}


## Inspecting a sample

Let's check that the expected columns are still present in the training data:

In [23]:
sample = next(iter(voice_data['train']))

In [24]:
sample

{'audio': {'path': None,
  'array': array([ 3.66210938e-04, -9.15527344e-05, -2.74658203e-04, ...,
         -6.40869141e-04, -7.93457031e-04, -9.15527344e-04]),
  'sampling_rate': 16000},
 'text': 'Notre délégation défendra la lutte contre les écarts salariaux à travail égal, contre les cyberviolences, les mariages forcés et les mutilations génitales.',
 'speaker_id': '182995'}

## Create the feature extractor and tokenizer

We'll be fine-tuning the `openai/whisper-tiny` model. Here, the feature extractor and tokenizer for this model are fetched from Huggingface:

In [25]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer

finetune_from_id = 'openai/whisper-tiny'
feature_extractor = WhisperFeatureExtractor.from_pretrained(finetune_from_id, language='french', task='transcribe')
tokenizer_original = WhisperTokenizer.from_pretrained(finetune_from_id, language='french', task='transcribe')

We'll create a customized tokenizer based on `tokenizer_original` in the next section.

## Vocabulary adjustements

**Note**: Adjusting the vocabulary makes training Whisper a bit more difficult. Consider skipping this section.

At present, this notebook only supports fine-tuning languages supported by the upstream Whisper project.

It may be possible to get better accuracy by customizing the vocabulary. One way to do this might be with the (very slow) `tokenizer.train_new_from_iterator` function. For example, with something similar to the following:
```python
def sentence_data_generator():
    """ Outputs *just* the batched string data from voice_data """
    texts = voice_data['train'].select_columns(['text'])
    for samples in texts.iter(batch_size=500):
        # Yields a list of all sentences in the batch
        yield samples['text']

text_data = data_generator()
print(next(text_data))

# 5027 is the size of whisper-tiny's default vocabulary
tokenizer = tokenizer_original.train_new_from_iterator(text_data, 50257)
```

Changing the vocabulary like this may also increase the time needed to train the model.

For now, we demonstrate replacing unused/unwanted tokens with ones that might be more useful and reloading the tokenizer:

In [26]:
# Step 1: Save the vocabulary to a file
tokenizer_directory = Path('whisper-default-tokenizer')
tokenizer_original.save_pretrained(tokenizer_directory)


('whisper-default-tokenizer/tokenizer_config.json',
 'whisper-default-tokenizer/special_tokens_map.json',
 'whisper-default-tokenizer/vocab.json',
 'whisper-default-tokenizer/merges.txt',
 'whisper-default-tokenizer/normalizer.json',
 'whisper-default-tokenizer/added_tokens.json')

Now that the tokenizer is saved in `tokenizer_directory`, we can load `tokenizer_directory/vocab.json` and modify it:

In [27]:
# Step 2: Get vocab.json
import json

def json_from_path(path: Path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.loads(f.read())

vocab = json_from_path(tokenizer_directory / 'vocab.json')

In [28]:
# Step 3: Find some words we can definitely remove
from spellchecker import SpellChecker

english_checker = SpellChecker(language='en')
french_checker = SpellChecker(language='fr')
def is_known_word(spell_checker, word: str):
    """ Returns true if the `spell_checker` thinks `word` is spelled correctly.
        Changing the `spell_checker` changes which words are considered correct.
    """
    return len(spell_checker.unknown([word.lower()])) == 0

def is_english_only_word(word: str):
    """ Returns true if `word` is an English word, but not a French word """
    is_english = is_known_word(english_checker, word)
    is_french = is_known_word(french_checker, word)
    return is_english and not is_french

print('The is_english_only_word function should return True is a word is spelled correctly in English, but not in French:', is_english_only_word('testing'))

# This character marks the beginning of a word in vocab.json
word_start_char = 'Ġ'
replacable_keys = []

def mark_english_only_words():
    """ Marks all English-only words are replacable """
    for key in vocab:
        if not key.startswith(word_start_char):
            continue

        # Skip short words, as they're more likely to be prefixes of French words, too.
        if len(key) <= 5:
            continue
        word = key[1:]
        if is_english_only_word(word):
            replacable_keys.append(key)

mark_english_only_words()
replacable_keys[0:10]

The is_english_only_word function should return True is a word is spelled correctly in English, but not in French: True


['ĠABOUT',
 'ĠANNOUNCER',
 'ĠAPPLAUSE',
 'ĠAbigail',
 'ĠAboriginal',
 'ĠAbout',
 'ĠAbove',
 'ĠAbsolutely',
 'ĠAcademic',
 'ĠAcademy']

In [29]:
# Step 4: Collect information about French words
from collections import defaultdict
import re

NONWORD_REGEX = re.compile(r'[ \t?.,;!()/\-«»]+')
def split_by_word(text: str):
    """ Splits the given `text` into words. Returns a list of those words. """
    return NONWORD_REGEX.split(text)

def build_word_counts():
    """ Builds a map from certain words to the number of times they appear.
        This map will not include all words in the training set.
    """
    # Constants: Ignore short words
    min_word_length = 4
    max_sentences_to_process = 7_000 # Don't process more than roughly this number of sentences

    # Output
    word_counts = defaultdict(lambda: 0)

    sentences = voice_data['train'].select_columns(['text'])
    sentences_processed = 0
    for column in sentences.iter(batch_size=100):
        sentences = column['text']
        for sentence in sentences:
            for word in split_by_word(sentence):
                if len(word) >= min_word_length:
                    word_counts[word.lower()] += 1
            sentences_processed += 1

        if sentences_processed > max_sentences_to_process:
            break
    return word_counts

word_counts = build_word_counts()
# Sort by occurrences
def get_val(pair):
    (key, val) = pair
    return val
most_common_words = sorted(word_counts.items(), key=get_val, reverse=True)

In [30]:
most_common_words[0:10]

[('pour', 1674),
 ('dans', 1642),
 ('nous', 1320),
 ('plus', 945),
 ('sont', 913),
 ('elle', 873),
 ('cette', 832),
 ("c'est", 789),
 ('vous', 705),
 ('avec', 677)]

In [31]:
# Step 5: Replace!
next_replacement_idx = 0
new_vocab = dict(vocab)
replaced_keys = set()

for key in replacable_keys:
    if next_replacement_idx >= len(most_common_words):
        # Out of words to replace with
        break
    (replacement,count) = most_common_words[next_replacement_idx]
    next_replacement_idx += 1
    new_key = word_start_char + replacement
    # Don't map multiple keys to the same token value
    if new_key in new_vocab:
        continue
    # Don't add uncommon words
    if count <= 4:
        continue

    # Replace [key] with [new_key]
    token_value = new_vocab[key]
    del new_vocab[key]
    new_vocab[new_key] = token_value
    replaced_keys.add(key)

print("Made {} replacements".format(len(replaced_keys)))

new_merges = []
with open(tokenizer_directory / 'merges.txt', 'r', encoding='utf-8') as merges:
    for line in merges.readlines():
        if len(line) == 0:
            continue
        words = split_by_word(line)
        if not (words[0] in replaced_keys):
            new_merges.append(line.strip())

Made 2967 replacements


Great! We now have a vocabulary file optimized for French. Let's load it:

In [32]:
# Write to a file
tokenizer_fr_directory = Path('updated-tokenizer')
if tokenizer_fr_directory.exists():
    shutil.rmtree(tokenizer_fr_directory)
shutil.copytree(tokenizer_directory, tokenizer_fr_directory)
with open(tokenizer_fr_directory / 'vocab.json', 'w', encoding='utf-8') as f:
    json.dump(new_vocab, f, ensure_ascii=False)


with open(tokenizer_fr_directory / 'merges.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(new_merges))

In [33]:
from transformers import WhisperTokenizer

# Use a normal WhisperTokenizer -- WhisperTokenizerFast has trouble with the updated
# vocabulary.
tokenizer = WhisperTokenizer(
    tokenizer_fr_directory / 'vocab.json',
    tokenizer_fr_directory / 'merges.txt',
    tokenizer_fr_directory / 'normalizer.json',
    bos_token='<|startoftranscript|>',
    unk_token='',
    pad_token='<|endoftext|>',
    language='french',
    task='transcribe',
)

# See https://discuss.huggingface.co/t/fine-tuning-whisper-on-my-own-dataset-with-a-customized-tokenizer/25903
tokenizer.add_special_tokens(tokenizer_original.special_tokens_map)

105

In [34]:
# For debugging, update the output directory
shutil.rmtree(tokenizer_fr_directory)
tokenizer.save_pretrained(tokenizer_fr_directory)

('updated-tokenizer/tokenizer_config.json',
 'updated-tokenizer/special_tokens_map.json',
 'updated-tokenizer/vocab.json',
 'updated-tokenizer/merges.txt',
 'updated-tokenizer/normalizer.json',
 'updated-tokenizer/added_tokens.json')

In [35]:
# Uncomment to use the default tokenizer
#tokenizer=tokenizer_original

## Create the processor

Next, load the `WhisperProcessor`, which combines a feature extractor and tokenizer.

In [36]:
from transformers import WhisperProcessor

processor = WhisperProcessor(feature_extractor, tokenizer)

Use the feature extractor to convert the data into a format suitable for the model:

In [37]:
def map_sample(batch):
    audio_data = batch['audio']['array']
    audio_sample_rate = batch['audio']['sampling_rate']
    features = processor.feature_extractor(audio_data, sampling_rate=audio_sample_rate)

    batch['input_features'] = features.input_features[0]
    batch['labels'] = processor.tokenizer(batch['text']).input_ids
    return batch

# Remove columns no longer used
voice_data_original = voice_data # For debugging
voice_data = voice_data.map(map_sample, remove_columns=['audio', 'text'])
voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: Unknown,
        num_shards: 1
    })
    test: IterableDataset({
        features: Unknown,
        num_shards: 1
    })
})

In [38]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(finetune_from_id)
model.generation_config.language = 'french'
model.generation_config.task = 'transcribe'
model.generation_config.forced_decoder_ids = None


In [39]:
from dataclasses import dataclass
from typing import Any
import torch
# See the linked blog post and https://huggingface.co/docs/transformers/main_classes/data_collator

@dataclass
class DataCollatorWithPadding:
    ''' Converts raw data into a batch ready for the model '''
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: list) -> dict[str, torch.Tensor]:
        input_features = [{'input_features': f['input_features']} for f in features]
        label_features = [{'input_ids': f['labels']} for f in features]

        # According to the linked blog post, the input and label features need
        # to be padded separately (due to different final lengths), then
        # recombined:
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        # transformers uses -100 for masking
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Don't double-prepend the beginning of sequence token:
        if (labels[:,0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels
        return batch

data_collator = DataCollatorWithPadding(processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

# Viewing sample data

Let's look at some of the training data:

In [40]:
sample_data = next(iter(voice_data['test']))
sample_labels = sample_data['labels']

In [41]:
processor.decode(sample_labels)

'<|startoftranscript|><|fr|><|transcribe|><|notimestamps|>Je souhaite juste rappeler que ces droits de plantation étaient intégrés dans la réforme de 2008, qui a été<|endoftext|>ée par le Conseil des ministres. Mais malgré cela et au delà de cela, dans le cadre de cette réforme 2013, nous sommes revenus sur cette question, compte tenu aussi du rapport<|endoftext|>'

In [42]:
def run_on_sample_audio():
    """ Returns the (text) result of running the model on a single audio sample. """
    sample_audio = next(iter(voice_data_original['test']))['audio']
    inputs = processor(sample_audio['array'], return_tensors='pt')
    try:
        generated_ids = model.generate(inputs=inputs.input_features)
    except:
        generated_ids = model.generate(inputs=inputs.input_features.to('cuda'))
    return processor.batch_decode(generated_ids)

In [43]:
print(run_on_sample_audio())

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[' juste rappeler que ces droits de plantation étaient intégrés dans la réforme de Miluit qui a été voté par le Conseil de ministre. Mais malgré cela et au-delà de cela dans le cadre de cette réforme 2013, nous sommes révenus sur cette question, et continue aussi de rapports']


## Preparing an evaluation function


In [44]:
import evaluate

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')

def compute_metrics(data):
    true_labels = data.label_ids
    predictions = data.predictions

    # Convert padding from HF
    true_labels[true_labels == -100] = processor.tokenizer.pad_token_id

    predicted_text = processor.batch_decode(predictions, skip_special_tokens=True)
    label_text = processor.batch_decode(true_labels, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=predicted_text, references=label_text)
    cer = cer_metric.compute(predictions=predicted_text, references=label_text)
    return { 'wer': wer, 'cer': cer }


## Preparing training arguments

In [45]:
from transformers import Seq2SeqTrainingArguments

# TODO: Update this if you're planning to push the custom model to
# huggingface (ignore otherwise):
hub_model_id = 'personalizedrefrigerator/whisper-tiny-fr'

def make_training_args(max_steps: int):
    return Seq2SeqTrainingArguments(
        output_dir = checkpoint_path,
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 1,
        hub_model_id=hub_model_id,
        learning_rate=1e-5,
        max_steps=max_steps,
        gradient_checkpointing=True,
        logging_first_step=True,
        fp16=True,
        eval_strategy='steps',
        per_device_eval_batch_size=8,
        generation_max_length=256,
        predict_with_generate=True,
        auto_find_batch_size = True,
        save_steps=3000,
        eval_steps=1000,
        logging_steps=25,
        save_total_limit=1,
    )

In [46]:
small_eval_dataset = voice_data['test'].shuffle(seed=11).take(128)
large_eval_dataset = voice_data['test'].shuffle(seed=12).take(512)

In [47]:
from transformers import Seq2SeqTrainer

def make_trainer(max_steps: int = 16_000):
    return Seq2SeqTrainer(
        args=make_training_args(max_steps),
        model=model,
        train_dataset=voice_data['train'],
        eval_dataset=small_eval_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        processing_class=processor.feature_extractor,
    )

trainer = make_trainer()

## Training and evaluation

In [48]:
trainer.evaluate(large_eval_dataset)

Reading metadata...: 16089it [00:00, 37576.47it/s]
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


{'eval_loss': 1.4306994676589966,
 'eval_model_preparation_time': 0.004,
 'eval_wer': 0.6166231830040999,
 'eval_cer': 0.3612077328979437,
 'eval_runtime': 208.996,
 'eval_samples_per_second': 2.45,
 'eval_steps_per_second': 0.306}

In [49]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Model Preparation Time,Wer,Cer
1000,0.6588,1.051852,0.004,0.508676,0.297035
2000,0.6227,0.985802,0.004,0.543683,0.309958
3000,0.6213,0.959001,0.004,0.542466,0.332634
4000,0.5999,0.925832,0.004,0.538508,0.338633
5000,0.5183,0.909204,0.004,0.47519,0.287672
6000,0.5623,0.899834,0.004,0.453881,0.275237
7000,0.5388,0.888408,0.004,0.482192,0.299668


Reading metadata...: 16089it [00:00, 47750.77it/s]
Reading metadata...: 16089it [00:00, 44094.23it/s]
Reading metadata...: 16089it [00:00, 46570.80it/s]
Reading metadata...: 16089it [00:00, 48335.98it/s]
Reading metadata...: 16089it [00:00, 47057.51it/s]
Reading metadata...: 16089it [00:00, 38402.64it/s]
Reading metadata...: 16089it [00:00, 46280.57it/s]


KeyboardInterrupt: 

In [50]:
if checkpoint_remote_path.exists():
    shutil.rmtree(checkpoint_remote_path)
shutil.copytree(checkpoint_path, checkpoint_remote_path)

PosixPath('/content/final-checkpoints')

In [51]:
trainer.evaluate(large_eval_dataset)

Reading metadata...: 16089it [00:00, 47191.57it/s]


Step,Training Loss,Validation Loss,Model Preparation Time,Wer,Cer
1000,0.6588,1.051852,0.004,0.508676,0.297035
2000,0.6227,0.985802,0.004,0.543683,0.309958
3000,0.6213,0.959001,0.004,0.542466,0.332634
4000,0.5999,0.925832,0.004,0.538508,0.338633
5000,0.5183,0.909204,0.004,0.47519,0.287672
6000,0.5623,0.899834,0.004,0.453881,0.275237
7000,0.5388,0.888408,0.004,0.482192,0.299668
7050,0.5847,0.752852,0.004,0.454342,0.272734


{'eval_loss': 0.7528520822525024,
 'eval_model_preparation_time': 0.004,
 'eval_wer': 0.4543421543048826,
 'eval_cer': 0.2727337261304749}

In [52]:
model_output_dir = Path('./final-model').resolve()
trainer.save_model(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('/content/final-model/tokenizer_config.json',
 '/content/final-model/special_tokens_map.json',
 '/content/final-model/vocab.json',
 '/content/final-model/merges.txt',
 '/content/final-model/normalizer.json',
 '/content/final-model/added_tokens.json')

In [53]:
print(run_on_sample_audio())

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


['Je souhaite juste rappeler que ces droits de plantation étaient intégrés dans la réforme 2008 qui a été voté par le Conseil de ministre, mais Mald Grécela et au delà de cellars, dans le cadre de cette réforme 2013, nous sommes révénés sur cette question et continue aussi de rapports']


# Model conversion

Next, we need to convert the model into a format usable by Joplin. This next step converts the model from PyTorch to GGML.

In [54]:
!git clone https://github.com/openai/whisper whisper-github
!git clone https://github.com/ggerganov/whisper.cpp
!cd whisper.cpp && git checkout v1.7.4

Cloning into 'whisper-github'...
remote: Enumerating objects: 828, done.[K
remote: Counting objects: 100% (370/370), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 828 (delta 333), reused 301 (delta 301), pack-reused 458 (from 2)[K
Receiving objects: 100% (828/828), 8.26 MiB | 8.99 MiB/s, done.
Resolving deltas: 100% (496/496), done.
Cloning into 'whisper.cpp'...
remote: Enumerating objects: 15734, done.[K
remote: Counting objects: 100% (2936/2936), done.[K
remote: Compressing objects: 100% (586/586), done.[K
remote: Total 15734 (delta 2472), reused 2453 (delta 2345), pack-reused 12798 (from 3)[K
Receiving objects: 100% (15734/15734), 19.06 MiB | 14.20 MiB/s, done.
Resolving deltas: 100% (10842/10842), done.
Note: switching to 'v1.7.4'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

If

In [55]:
# Patch convert-h5-to-ggml to work with more recent model versions
conversion_script_path = Path('whisper.cpp/models/convert-h5-to-ggml.py')
conversion_script_content = conversion_script_path.read_text()
with open(conversion_script_path, 'w') as conversion_script:
    bad_if_statement = 'if "max_length" not in hparams:'
    replaced_if_statement = 'if "max_length" not in hparams or hparams["max_length"] == None:'
    conversion_script.write(conversion_script_content.replace(bad_if_statement, replaced_if_statement))

In [56]:
!mkdir ./ggml
!python whisper.cpp/models/convert-h5-to-ggml.py ./final-model ./whisper-github ./ggml

2025-02-26 22:36:44.290598: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740609404.490191   61986 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740609404.554158   61986 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
model.encoder.conv1.weight  ->  encoder.conv1.weight
encoder.conv1.weight 3 (384, 80, 3)
model.encoder.conv1.bias  ->  encoder.conv1.bias
  Reshaped variable:  encoder.conv1.bias  to shape:  (384, 1)
encoder.conv1.bias 2 (384, 1)
  Converting to float32
model.encoder.conv2.weight  ->  encoder.conv2.weight
encoder.conv2.weight 3 (384, 384, 3)
model.encoder.conv2.bias  ->  encoder.conv2.bias
  Reshaped variable:  encoder.conv2.bias  to

For smaller size and better performance, we can also quantize the GGML model:

In [57]:
!cd whisper.cpp && cmake -B build && cmake --build build --config Release
!./whisper.cpp/build/bin/quantize ./ggml/ggml-model.bin ./ggml/ggml-model-q8_0.bin q8_0

  Compatibility with CMake < 3.10 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value.  Or, use the <min>...<max> syntax
  to tell CMake that the project requires at least <min> but has been updated
  to work with policies introduced by <max> or earlier.

[0m
-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /usr/bin/git (found version "2.34.1")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- CMAKE_SYSTEM_PROCES

Now, let's make sure that the `.ggml` model works. Start by downloading some test audio:

In [58]:
!mkdir ./test-audio
# Download the first chapter of Alice in Wonderland (in French)
!wget -P ./test-audio/ https://www.archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_01_carroll_128kb.mp3
# Convert it to a format that's understandable by whisper.cpp:
# -t 30                 Take the first 30s
# -i ...                Input path
# -ar 16000             Sample rate of 16000 HZ
# -ac 1                 1 audio channel
# -codec:a pcm_s16le    Audio codec
!ffmpeg -t 30 -i ./test-audio/aliceaupays_01_carroll_128kb.mp3 -ar 16000 -ac 1 -codec:a pcm_s16le ./test-audio/recording-fr.wav

--2025-02-26 22:38:25--  https://www.archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_01_carroll_128kb.mp3
Resolving www.archive.org (www.archive.org)... 207.241.224.2
Connecting to www.archive.org (www.archive.org)|207.241.224.2|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_01_carroll_128kb.mp3 [following]
--2025-02-26 22:38:26--  https://archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_01_carroll_128kb.mp3
Resolving archive.org (archive.org)... 207.241.224.2
Connecting to archive.org (archive.org)|207.241.224.2|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://ia803201.us.archive.org/25/items/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_01_carroll_128kb.mp3 [following]
--2025-02-26 22:38:27--  https://ia803201.us.archive.org/25/items/alice_au_pays_des_mervei

Next, use the `whisper-cli` command to transcribe the audio using our GGML model:

In [59]:
# Test converting the WAV file to text using the GGML file that we built
!./whisper.cpp/build/bin/whisper-cli --language fr --no-timestamps -m ./ggml/ggml-model.bin ./test-audio/recording-fr.wav

whisper_init_from_file_with_params_no_state: loading model from './ggml/ggml-model.bin'
whisper_init_with_params_no_state: use gpu    = 1
whisper_init_with_params_no_state: flash attn = 0
whisper_init_with_params_no_state: gpu_device = 0
whisper_init_with_params_no_state: dtw        = 0
whisper_init_with_params_no_state: devices    = 1
whisper_init_with_params_no_state: backends   = 1
whisper_model_load: loading model
whisper_model_load: n_vocab       = 51865
whisper_model_load: n_audio_ctx   = 1500
whisper_model_load: n_audio_state = 384
whisper_model_load: n_audio_head  = 6
whisper_model_load: n_audio_layer = 4
whisper_model_load: n_text_ctx    = 448
whisper_model_load: n_text_state  = 384
whisper_model_load: n_text_head   = 6
whisper_model_load: n_text_layer  = 4
whisper_model_load: n_mels        = 80
whisper_model_load: ftype         = 1
whisper_model_load: qntvr         = 0
whisper_model_load: type          = 1 (tiny)
whisper_model_load: adding 1607 extra tokens
whisper_model_load

In [60]:
# Compare with the upstream model
!mkdir ./ggml-upstream/
!sh ./whisper.cpp/models/download-ggml-model.sh tiny ./ggml-upstream/
!./whisper.cpp/build/bin/whisper-cli --language fr --no-timestamps -m ./ggml-upstream/ggml-tiny.bin ./test-audio/recording-fr.wav

Downloading ggml model tiny from 'https://huggingface.co/ggerganov/whisper.cpp' ...
Done! Model 'tiny' saved in './ggml-upstream//ggml-tiny.bin'
You can now use it like this:

  $ ./main -m ./ggml-upstream//ggml-tiny.bin -f samples/jfk.wav

whisper_init_from_file_with_params_no_state: loading model from './ggml-upstream/ggml-tiny.bin'
whisper_init_with_params_no_state: use gpu    = 1
whisper_init_with_params_no_state: flash attn = 0
whisper_init_with_params_no_state: gpu_device = 0
whisper_init_with_params_no_state: dtw        = 0
whisper_init_with_params_no_state: devices    = 1
whisper_init_with_params_no_state: backends   = 1
whisper_model_load: loading model
whisper_model_load: n_vocab       = 51865
whisper_model_load: n_audio_ctx   = 1500
whisper_model_load: n_audio_state = 384
whisper_model_load: n_audio_head  = 6
whisper_model_load: n_audio_layer = 4
whisper_model_load: n_text_ctx    = 448
whisper_model_load: n_text_state  = 384
whisper_model_load: n_text_head   = 6
whisper_mode

In [67]:
from huggingface_hub import notebook_login, HfApi

# (Optional) Publish to Huggingface (does not currently include the ggml model)
def push_to_hub():
    notebook_login()

    revision = 'train-on-voxpopuli-and-commonvoice'
    # Publish the GGML files
    api = HfApi()
    # Commit to base the new branch on (replace this):
    #base_on = '9dc99c95056795aaa8fbed87c976965c7ff0a129'
    #api.create_branch(repo_id = hub_model_id, branch=revision, revision=base_on)
    api.upload_folder(
        folder_path='./ggml',
        repo_id=hub_model_id,
        path_in_repo='ggml/',
        revision=revision
    )

    # Publish the model, processor
    trainer.push_to_hub(
        dataset_tags=['facebook/voxpopuli', 'mozilla-foundation/common_voice_11_0'],
        language='fr',
        model_name='Whisper Tiny (Finetuned on French)',
        finetuned_from=finetune_from_id,
        tasks='automatic-speech-recognition',
        revision=revision
    )
    # Note: If this creates a new repo, it will be public
    tokenizer.push_to_hub(hub_model_id, revision=revision)

In [None]:
# Uncomment to publish
push_to_hub()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]