WhisperAI Fine-tuning for location names in Singapore

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import librosa
import evaluate
import pandas as pd
import soundfile as sf

from torch.utils.data import DataLoader
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer
from datasets import Dataset, DatasetDict, load_dataset, ClassLabel, Features, Value




In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Device used:', device)
print(f"GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

Device used: cuda
GPU name: NVIDIA GeForce RTX 4070 SUPER


Example Use Case that the model does not recognize Singapore locations

In [37]:
# Read the .wav file and downsample the audio to 16000Hz as required by WhisperAI
file_path = "Location_Trial.wav"
audio_data, sampling_rate = librosa.load(file_path, sr=16000)  # sr=16000 to convert the sampling rate to 16000Hz

# Instantiante the processor and model from HuggingFace Hub
processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base.en")

# Use the model and processor to transcribe the audio:
input_features = processor(
    audio_data, sampling_rate=sampling_rate, return_tensors="pt"
).input_features

# Generate token ids
predicted_ids = model.generate(input_features)

# Decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

In [38]:
transcription[0]

' I have a friend staying at Songye Kadut. We love to go to Gailan to eat and we have family staying in Telot Blanga.'

Creating the Dataset

In [6]:
# Function to load audio and metadata
def load_audio(data):
    audio_path = data['audio_path']
    audio, sr = librosa.load(audio_path, sr=16000)
    return {'audio': audio, 'text': data['transcription']}

# Load metadata into a DataFrame
metadata_df = pd.read_csv('metadata.csv')

# Define features for the dataset
features = Features({
    'audio': Value('float32'),
    'transcription': Value('string'),
})

# Load audio files into dataset
train_set = Dataset.from_pandas(metadata_df)

# Preprocess audio data
train_set = train_set.map(load_audio, remove_columns=['audio_path', 'transcription'])

Map: 100%|██████████| 648/648 [00:06<00:00, 97.45 examples/s] 


In [7]:
train_set

Dataset({
    features: ['audio', 'text'],
    num_rows: 648
})

Now we repeat the process to create the test dataset

In [8]:
# Load metadata into a DataFrame
test_metadata_df = pd.read_csv('metadata_test.csv')

# Define features for the dataset
features = Features({
    'audio': Value('float32'),
    'transcription': Value('string'),
})

# Load audio files into dataset
test_set = Dataset.from_pandas(test_metadata_df)

# Preprocess audio data
test_set = test_set.map(load_audio, remove_columns=['audio_path', 'transcription'])

Map: 100%|██████████| 162/162 [00:00<00:00, 482.95 examples/s]


In [9]:
test_set

Dataset({
    features: ['audio', 'text'],
    num_rows: 162
})

Create a DatasetDict with both train and test sets

In [10]:
whisper_dataset = DatasetDict({
        'train': train_set,
        'test': test_set,
})

whisper_dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 648
    })
    test: Dataset({
        features: ['audio', 'text'],
        num_rows: 162
    })
})

Fine-tuning WhisperAI

In [11]:
# Defining the model and processor

model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-base.en')
processor = WhisperProcessor.from_pretrained('openai/whisper-base.en')
tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-base.en')

In [12]:
def preprocess_function(data):
    # Process audio
    audio = data["audio"]

    # Get the processed audio input values
    processed_audio = processor(audio, sampling_rate=16000)

    # Get the input features
    input_features = processed_audio["input_features"]

    # Convert input_features to a tensor
    data["input_features"] = torch.tensor(input_features[0], dtype=torch.float32)

    # Process text (tokenize)
    labels = tokenizer(
        data["text"], return_tensors="pt", padding="max_length", max_length=128, truncation=True
    ).input_ids

    # Convert labels to a tensor (1D)
    data["labels"] = labels[0]

    return data

# Apply the preprocessing function to the datasets
whisper_dataset = whisper_dataset.map(preprocess_function, remove_columns=['audio', 'text'])
whisper_dataset.set_format(type="torch", columns=["input_features", "labels"])

Map: 100%|██████████| 648/648 [01:02<00:00, 10.32 examples/s]
Map: 100%|██████████| 162/162 [00:05<00:00, 32.08 examples/s]


In [13]:
# Sanity check to ensure the data has been preprocessed to the right format
whisper_dataset['train'][0]

{'input_features': tensor([[ 0.7029,  0.6257,  0.4525,  ..., -0.7932, -0.7932, -0.7932],
         [ 0.4301,  0.3983,  0.2610,  ..., -0.7932, -0.7932, -0.7932],
         [ 0.2831,  0.1739,  0.1047,  ..., -0.7932, -0.7932, -0.7932],
         ...,
         [-0.4738, -0.7932, -0.7357,  ..., -0.7932, -0.7932, -0.7932],
         [-0.4476, -0.7013, -0.6744,  ..., -0.7932, -0.7932, -0.7932],
         [-0.3572, -0.7842, -0.7932,  ..., -0.7932, -0.7932, -0.7932]]),
 'labels': tensor([50257, 50362,    40,   423,   257,  1256,   286,  2460, 10589,   287,
          2895,  4270,   509,   952,    13,  3574,   616,  2156,    11,  2895,
          4270,   509,   952,   318,   257,  1310,  1290,    13,   317,  4451,
          1545,   286,  6164,   468,   257, 16394,   290, 14768,   287,  2895,
          4270,   509,   952,    13,  6674,   314,   815,  1445,   284,  2895,
          4270,   509,   952,    13,  3914,   338,   467,   284,  2895,  4270,
           509,   952,     0, 50256, 50256, 50256, 50256

Now we create an evaluation metric called "Word Error Rate" or WER in short. This is a commonly used metric to evaluate the performance of automatic speech recognition systems.

In [14]:
eval_metric = evaluate.load('wer')

def compute_metrics(eval_pred):
    
    # Extract logits and labels from the eval_pred tuple
    logits, labels = eval_pred

    # Get the predicted indices
    predictions = logits[0].argmax(axis=-1)

    # Decode predictions and references
    decoded_predictions = processor.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)

    # Update WER metric
    wer_score = eval_metric.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer": wer_score}
    

'\n    predictions = pred.predictions.argmax(-1)\n    decoded_predictions = processor.batch_decode(predictions, skip_special_tokens=True)\n    decoded_labels = processor.batch_decode(pred.label_ids, skip_special_tokens=True)\n\n    # Update WER metric\n    wer_score = eval_metric.compute(predictions=decoded_predictions, references=decoded_labels)\n    return {"wer": wer_score["wer"]}\n    \n    pred_ids = eval_pred.predictions\n    label_ids = eval_pred.label_ids\n\n    pred_decoded = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n    label_decoded = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n    wer = 100 * eval_metric.compute(predictions=pred_decoded, references=label_decoded)\n\n    return {\'WER\': wer}    '

Pushing the model to GPU

In [15]:
model.to(device)

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 512)
      (layers): ModuleList(
        (0-5): 6 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=False)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          

In [16]:
# Create Data Collator
def data_collator(batch):
    input_values = torch.stack([f["input_features"] for f in batch])
    labels = torch.stack([f["labels"] for f in batch])
    return {"input_features": input_values, "labels": labels}

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',              # output directory
    evaluation_strategy="epoch",         # evaluation strategy
    save_strategy="epoch",              # save checkpoint after every epoch
    learning_rate=2e-5,                  # learning rate
    per_device_train_batch_size=8,      # batch size for training
    per_device_eval_batch_size=2,       # batch size for evaluation
    num_train_epochs=5,                  # number of training epochs
    weight_decay=0.01,                   # strength of weight decay
    logging_dir='./logs',                # directory for storing logs
    fp16=True,                           # enable mixed precision to reduce memory usage
)

# Instantiate the Trainer
trainer = Seq2SeqTrainer(
    model=model,                                            # the model to be trained
    args=training_args,                                     # training arguments
    data_collator=data_collator,                            # data collator
    train_dataset=whisper_dataset['train'],                 # training dataset
    eval_dataset=whisper_dataset['test'],                   # evaluation dataset (optional)
    compute_metrics=compute_metrics
)

In [17]:
torch.cuda.empty_cache()

In [18]:
# Start fine-tuning the model
trainer.train()

                                                
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}


{'eval_loss': 0.016931457445025444, 'eval_wer': 0.06804478897502153, 'eval_runtime': 38.9152, 'eval_samples_per_second': 4.163, 'eval_steps_per_second': 2.081, 'epoch': 1.0}


                                                  
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}


{'eval_loss': 0.015643645077943802, 'eval_wer': 0.06632213608957795, 'eval_runtime': 20.1312, 'eval_samples_per_second': 8.047, 'eval_steps_per_second': 4.024, 'epoch': 2.0}


                                                 
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}


{'eval_loss': 0.014437884092330933, 'eval_wer': 0.06287683031869079, 'eval_runtime': 32.6822, 'eval_samples_per_second': 4.957, 'eval_steps_per_second': 2.478, 'epoch': 3.0}


                                                 
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}


{'eval_loss': 0.014445878565311432, 'eval_wer': 0.06287683031869079, 'eval_runtime': 23.7332, 'eval_samples_per_second': 6.826, 'eval_steps_per_second': 3.413, 'epoch': 4.0}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}
                                                 
100%|██████████| 405/405 [06:17<00:00,  1.46it/s]

{'eval_loss': 0.014320522546768188, 'eval_wer': 0.06201550387596899, 'eval_runtime': 30.6139, 'eval_samples_per_second': 5.292, 'eval_steps_per_second': 2.646, 'epoch': 5.0}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}
100%|██████████| 405/405 [06:18<00:00,  1.07it/s]

{'train_runtime': 378.4499, 'train_samples_per_second': 8.561, 'train_steps_per_second': 1.07, 'train_loss': 0.022598238344545718, 'epoch': 5.0}





TrainOutput(global_step=405, training_loss=0.022598238344545718, metrics={'train_runtime': 378.4499, 'train_samples_per_second': 8.561, 'train_steps_per_second': 1.07, 'total_flos': 2.101463875584e+17, 'train_loss': 0.022598238344545718, 'epoch': 5.0})

Manually saving the processor after the fine-tuning has been completed.

In [26]:
processor.save_pretrained(training_args.output_dir)

[]

Loading a checkpoint from the fine-tuning to be used as the model for prediction. From the above fine-tuning results, the last checkpoint (checkpoint-405) gave the best wer.

In [39]:
checkpoint_path = "results/checkpoint-405"

# Load the model and processor from the checkpoint
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path)
processor = WhisperProcessor.from_pretrained("results")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Comparison of the first file we trialed on the original Whisper-base.en model, after fine-tuning.

In [40]:
# Moving model to CUDA
model.to(device)

# Read the .wav file and downsample the audio to 16000Hz as required by WhisperAI
file_path = "Location_Trial.wav"
audio_data, sampling_rate = librosa.load(file_path, sr=16000)  # sr=16000 to convert the sampling rate to 16000Hz

# Use the model and processor to transcribe the audio:
input_features = processor(
    audio_data, sampling_rate=sampling_rate, return_tensors="pt"
).input_features

# Move input features to the same device as the model
input_features = input_features.to(device)

# Generate token ids using the model
with torch.no_grad():  # Disable gradient tracking for inference
    predicted_ids = model.generate(input_features)

# Decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

# Print the transcription
print(transcription[0])

I have a friend staying at Sungei Kadut. We love to go to Geylang to eat. And we have family staying in Telok Blangah.
