## Tuning facebook:wav2vec2-large-960h

Here, we finetune the facebook:wav2vec2-large-960h model from huggingface using the `cv-valid-train` common_voice dataset. This notebook follows the finetuning framework from this [hugginface blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) with minor adaptations. First, we import the required libraries.

In [1]:
# Imports
import os
import re
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import gc
from multiprocessing import Pool, cpu_count

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from IPython.display import Audio as PlayAudio

from accelerate import Accelerator
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers import TrainingArguments, Trainer
from datasets import load_dataset, Audio, DatasetDict, load_from_disk, Dataset
import evaluate

import torch
from torch.utils.data import DataLoader
import torchaudio
from transformers import get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from tqdm import tqdm

from pydub import AudioSegment
import soundfile as sf
from mutagen import File

from jiwer import wer

HOME_DIR = os.path.expanduser('~')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Helpers
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo

def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

### Pre-processing

We first convert all mp3 files to wav files, which the wav2vec2 model assumes. Additionally converting transcript texts to upper case to match the original model. This may take some time.

In [3]:
# File locations. All files assumed placed in asr_proejct folder
audio_or_dir = os.path.join(HOME_DIR,'asr_project/common_voice/cv-valid-train/')
audio_dir = os.path.join(HOME_DIR,'asr_project/common_voice/cv-valid-train/cv-valid-train/')
audioloc_transcript_or_dir = os.path.join(HOME_DIR,'asr_project/common_voice/cv-valid-train.csv')
audioloc_transcript_dir = os.path.join(HOME_DIR,'asr_project/asr-train/selected_transcript.csv')
temp_dir = os.path.join(HOME_DIR,'asr_project/asr-train/temp.csv')

In [4]:
# # Function to convert mp3 to wav
# def convert_mp3_to_wav(mp3_file):
#     # Generate the output wav file path
#     wav_file = mp3_file.replace('.mp3', '.wav')
    
#     # Convert mp3 to wav if wav file does not exist
#     if not os.path.exists(wav_file):
#         waveform, sample_rate = torchaudio.load(mp3_file)
#         torchaudio.save(wav_file, waveform, sample_rate)
    
#     return wav_file


# df = pd.read_csv(audioloc_transcript_or_dir)

# # Convert mp3 to wav. Change mp3 file extension in df accordingly
# df['filename'] = df['filename'].apply(
#     lambda filename: convert_mp3_to_wav(
#         os.path.join(audio_or_dir, filename)))

# # Put texts to uppercase to match pre-finetuned model
# df['text'] = df['text'].str.upper()
# df['filename'] = df['filename'].map(lambda x: os.path.basename(x))

# df_transcript = df

Checking audio file characteristics ...

In [5]:
# def get_audio_info(file_path):
#     # Extract filename and extension
#     file_name, file_ext = os.path.splitext(os.path.basename(file_path))
#     file_size = os.path.getsize(file_path)  # Size in bytes

#     # Try to get audio length with mutagen
#     try:
#         audio = File(file_path)
#         audio_length = audio.info.length if audio and audio.info else None
#     except Exception as e:
#         print(f"Could not process file {file_name}: {e}")
#         audio_length = None

#     return {
#         'filename': file_name,
#         'extension': file_ext,
#         'size_bytes': file_size,
#         'length_seconds': audio_length
#     }

# def process_directory(directory):
#     # List all audio files in directory
#     audio_files = [
#         os.path.join(directory, f) for f in os.listdir(directory) 
#         if os.path.isfile(os.path.join(directory, f))
#     ]

#     # Use tqdm with multiprocessing
#     with Pool(cpu_count()) as pool:
#         # Wrap audio files list with tqdm for progress bar
#         audio_info = list(tqdm(pool.imap(get_audio_info, audio_files), total=len(audio_files), desc="Processing files"))

#     # Create DataFrame from the list of dictionaries
#     df = pd.DataFrame(audio_info)
#     return df

# # Get audio file information
# audio_df = process_directory(audio_dir)
# audio_df_mp3 = audio_df.loc[audio_df['extension']=='.mp3'].copy()
# audio_df_wav = audio_df.loc[audio_df['extension']=='.wav'].copy().drop(columns=['length_seconds'])
# audio_df_wav = audio_df_wav.merge(audio_df_mp3[['filename','length_seconds']], on='filename', how='left')
# audio_df_wav['filename'] = audio_df_wav['filename'].map(lambda x: x+'.wav')
# audio_df_wav.head()

We see that some of them have very high durations, up to 6 minutes long.

In [6]:
# audio_df_wav.describe()

Checking the transcript, we find that the longest line read is only 33 words long, which should not take that long to read.

In [7]:
# df_transcript['len'] = df_transcript['text'].str.len()
# df_transcript = df_transcript[['filename','len','text']]

# filename_longest = df_transcript.loc[df_transcript['len']==df_transcript['len'].max(), 'filename'].item()
# text_longest = df_transcript.loc[df_transcript['len']==df_transcript['len'].max(), 'text'].item()

# print(f'Filename with longest transcript: {filename_longest}')
# print(f'Longest transcript text: {text_longest}')

# longest_clip_duration = audio_df_wav.loc[audio_df_wav['filename']==filename_longest,'length_seconds'].item()
# print(f'Longest transcript duration: {longest_clip_duration}s')

The clip with the longest transcript is 11s long. Considering differences in reading speeds, we assume the longest legitimate script reading to be 15s long. __We discard all samples with durations above 15s__. This will help prevent memory issues during model finetuning. We drop a total of 397 samples, keeping ~195k samples, saving a copy as csv file for later reference.

In [8]:
# df_transcript = df_transcript.merge(audio_df_wav[['filename', 'length_seconds']], on='filename', how='left')
# (df_transcript['length_seconds'] > 15).sum().item(),  (df_transcript['length_seconds'] < 15).sum().item()

In [9]:
# df_transcript = df_transcript.loc[df_transcript['length_seconds']<15].drop(columns=['len','length_seconds'])
# df_transcript.to_csv(audioloc_transcript_dir, index=False)

We create a `DatasetDict` for easy access to train-val splits.

In [10]:
# # Load csv file with wav filenames, complete path and create dataset
# df = pd.read_csv(audioloc_transcript_dir)
# df['filename'] = df['filename'].map(lambda x: os.path.join(audio_dir,x))
# df.to_csv(temp_dir,index=False)

In [11]:
# dataset = load_dataset('csv', data_files=temp_dir, split='train')
# dataset = dataset.cast_column("filename",
#                               Audio(sampling_rate=16000))         # Cast audio files with 16kHz sampling rate

# # train-val 70-30 split
# dataset = dataset.train_test_split(test_size=0.3, seed=42)        # Split to train-val

# # Final, combined dataset
# dataset = DatasetDict({
#     'train': dataset['train'],
#     'val': dataset['test']})

# dataset

We will make use of the tokenizer and processor from `facebook/wav2vec2-large-960h` in the model finetuning below. First, the transcripts are converted to the format expected by the model. The transcript have already been converted into uppercase earlier for this purpose. We insert start, end, and delimited tokens below.

In [12]:
# # Following the style of facebook/wav2ec2-large-960h model
# start_token = "<s>"
# end_token = "</s>"
# word_delimiter_token = "|"

# # Define the preprocessing function
# def preprocess_transcript(example):
#     transcript = example['text']  # Assuming the column with text is named 'text'
    
#     # Step 1: Replace multiple spaces with a single space
#     transcript = re.sub(r'\s+', ' ', transcript)  # Remove extra spaces
    
#     # Step 2: Add start and end tokens, and replace spaces with '|'
#     processed_transcript = start_token + transcript.replace(" ", f"{word_delimiter_token}") + end_token
    
#     return {"processed_text": processed_transcript}  # Return the processed text in a dictionary

# # Apply the preprocessing to both train and validation splits
# dataset = dataset.map(preprocess_transcript, remove_columns=["text"],num_proc=4)

Converting to column names expected by model.

In [13]:
# dataset = dataset.rename_column("filename", "input_values")
# dataset = dataset.rename_column("processed_text", "labels")

Then, we tokenize the transcripts and use the `input_values` and `labels` column names in the datasets.

In [14]:
# # Load processor
# processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")

# def prepare_dataset(batch):
#     # Process 'input_values' column for 1D waveform values
#     batch["input_values"] = processor(batch["input_values"]["array"],
#                                       sampling_rate=16000).input_values[0]
    
#     # Process the 'labels' column to create 'labels' (text data)
#     batch["labels"] = processor(text=batch["labels"]).input_ids
    
#     return batch

# # Map the dataset transformation to both 'train' and 'val' splits
# dataset = dataset.map(prepare_dataset, num_proc=2)


In [15]:
# # Save the dataset to a directory
# dataset.save_to_disk("temp_dataset")

For a quick check, play a random audio file below...

In [16]:
# rand_int = random.randint(0, len(dataset["train"]))
# print(dataset["train"]["labels"][rand_int])

# audio_data = dataset["train"][rand_int]["input_values"]
# PlayAudio(data=audio_data, rate=16000)

... and check the data formats, e.g. 1-D waveform.

In [17]:
# rand_int = random.randint(0, len(dataset["train"]))

# print("Target (encoded) text:", dataset["train"][rand_int]["labels"])
# print("Input array shape:", np.asarray(dataset["train"][rand_int]["input_values"]).shape)

### Training

As elaborated [here](https://huggingface.co/blog/fine-tune-wav2vec2-english), a data collator with dynamic padding is more efficient for ASR applications, considering the lengths of the input sequences.

In [4]:
# Load dataset
dataset = load_from_disk("temp_dataset")

In [5]:
# Load processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")

In [6]:
@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Using the WER metric.

In [7]:
wer_metric = evaluate.load("wer")

def remove_start_end_tags(texts):
    # Remove the <s> and </s> tags from both ends of each string
    return [re.sub(r"^<s>|</s>$", "", text) for text in texts]

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace padding token id with -100
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and references
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    # Remove the <s> and </s> tags from the decoded strings
    pred_str = remove_start_end_tags(pred_str)
    label_str = remove_start_end_tags(label_str)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Finally, we load the pre-trained model.

In [None]:
# Load model
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-960h", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

# # Freeze feature extractor layers
# model.freeze_feature_encoder()

# Freeze all layers except the head
for param in model.parameters():
    param.requires_grad = False  # Freeze all parameters

# Assuming the head is the `classifier` in Wav2Vec2ForCTC
for param in model.lm_head.parameters():  # For the head (classifier) layer
    param.requires_grad = True  # Unfreeze the head


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We first get a baseline WER for a quick comparison with the finetuned model's performance later.

In [23]:
# def map_to_result(batch):
#     with torch.no_grad():
#         input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
#         logits = model(input_values).logits

#     pred_ids = torch.argmax(logits, dim=-1)
#     batch["pred_str"] = processor.batch_decode(pred_ids)[0]
#     batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
#     return batch

# model.to('cuda')
# results = dataset["val"].map(map_to_result, remove_columns=dataset["val"].column_names)

In [24]:
# def remove_start_end_tags(batch):
#     # Remove the <s> and </s> tags from both ends of each string in 'pred_str' and 'text'
#     batch["pred_str"] = re.sub(r"^<s>|</s>$", "", batch["pred_str"])
#     batch["text"] = re.sub(r"^<s>|</s>$", "", batch["text"])
#     return batch

# # Apply the function to the entire dataset
# results = results.map(remove_start_end_tags)


# print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

We see that the model shows a WER of about 10.5% before finetuning. We complete the setup for the huggingface trainer and begin training below.

In [None]:
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments


In [26]:
# Start training
result = trainer.train()
print_summary(result)

 12%|█▏        | 250/2137 [05:12<38:47,  1.23s/it]

{'loss': 9.9285, 'grad_norm': 4.3315043449401855, 'learning_rate': 1.2450000000000001e-05, 'epoch': 0.12}


 23%|██▎       | 500/2137 [10:20<29:53,  1.10s/it]

{'loss': 9.3718, 'grad_norm': 3.421187162399292, 'learning_rate': 2.495e-05, 'epoch': 0.23}


                                                  
 23%|██▎       | 500/2137 [19:23<29:53,  1.10s/it] 

{'eval_loss': 2.665811061859131, 'eval_wer': 0.1104161030247281, 'eval_runtime': 542.9331, 'eval_samples_per_second': 107.958, 'eval_steps_per_second': 6.749, 'epoch': 0.23}


 35%|███▌      | 750/2137 [24:13<28:57,  1.25s/it]    

{'loss': 8.5016, 'grad_norm': 3.1834044456481934, 'learning_rate': 3.745e-05, 'epoch': 0.35}


 47%|████▋     | 1000/2137 [29:04<22:49,  1.20s/it]

{'loss': 7.5337, 'grad_norm': 2.6247568130493164, 'learning_rate': 4.995e-05, 'epoch': 0.47}


                                                   
 47%|████▋     | 1000/2137 [35:58<22:49,  1.20s/it]

{'eval_loss': 1.6904981136322021, 'eval_wer': 0.11616434897101528, 'eval_runtime': 413.9829, 'eval_samples_per_second': 141.586, 'eval_steps_per_second': 8.851, 'epoch': 0.47}


 58%|█████▊    | 1250/2137 [41:07<16:26,  1.11s/it]    

{'loss': 6.5242, 'grad_norm': 2.016247510910034, 'learning_rate': 3.9050131926121375e-05, 'epoch': 0.58}


 70%|███████   | 1500/2137 [46:22<13:34,  1.28s/it]

{'loss': 5.7639, 'grad_norm': 1.742484450340271, 'learning_rate': 2.805628847845207e-05, 'epoch': 0.7}


                                                   
 70%|███████   | 1500/2137 [53:04<13:34,  1.28s/it]

{'eval_loss': 1.2211542129516602, 'eval_wer': 0.1202514293959562, 'eval_runtime': 401.3354, 'eval_samples_per_second': 146.047, 'eval_steps_per_second': 9.13, 'epoch': 0.7}


 82%|████████▏ | 1750/2137 [58:12<08:27,  1.31s/it]    

{'loss': 5.2705, 'grad_norm': 2.2663869857788086, 'learning_rate': 1.706244503078276e-05, 'epoch': 0.82}


 94%|█████████▎| 2000/2137 [1:03:24<02:56,  1.29s/it]

{'loss': 4.9872, 'grad_norm': 2.0656380653381348, 'learning_rate': 6.068601583113457e-06, 'epoch': 0.94}


                                                     
 94%|█████████▎| 2000/2137 [1:10:00<02:56,  1.29s/it]

{'eval_loss': 1.0519750118255615, 'eval_wer': 0.1210215897408149, 'eval_runtime': 396.6229, 'eval_samples_per_second': 147.783, 'eval_steps_per_second': 9.238, 'epoch': 0.94}


100%|██████████| 2137/2137 [1:12:38<00:00,  2.04s/it]   

{'train_runtime': 4358.677, 'train_samples_per_second': 31.377, 'train_steps_per_second': 0.49, 'train_loss': 7.080849418586585, 'epoch': 1.0}
Time: 4358.68
Samples/second: 31.38
GPU memory occupied: 14071 MB.





In [27]:
# Load audio file
audio_file = "/home/tfc/asr_project/common_voice/cv-valid-train/cv-valid-train/sample-000000.wav"  # Replace with your audio file path
waveform, sample_rate = torchaudio.load(audio_file)


# If the sample rate is not 16kHz, resample it
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

# Convert to the right format for the model
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values

# Get logits from the model
model.to('cpu')
with torch.no_grad():
    logits = model(input_values).logits

# Get predicted ids
predicted_ids = logits.argmax(dim=-1)

# Decode the predicted ids to text
transcription = processor.batch_decode(predicted_ids)

print(transcription)  # Print the transcription result



['LEARNED T RECOGNIE OMEN AND  FOLLOW THEM THEOLD  KING HAD SAID']


In [8]:

# Specify the path where the fine-tuned model is saved
model_dir = os.path.expanduser('~/asr_project/asr-train/model_outputs/checkpoint-2137')

model = Wav2Vec2ForCTC.from_pretrained(
    model_dir
)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")


In [None]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
    return batch

model.to('cuda')
results = dataset["val"].map(map_to_result, remove_columns=dataset["val"].column_names)

In [None]:
def remove_start_end_tags(batch):
    # Remove the <s> and </s> tags from both ends of each string in 'pred_str' and 'text'
    batch["pred_str"] = re.sub(r"^<s>|</s>$", "", batch["pred_str"])
    batch["text"] = re.sub(r"^<s>|</s>$", "", batch["text"])
    return batch

# Apply the function to the entire dataset
results = results.map(remove_start_end_tags)


print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))
# WER of 0.12 for eval set

Map: 100%|██████████| 58614/58614 [00:00<00:00, 73247.36 examples/s]


Test WER: 0.121


In [None]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
    return batch

model.to('cuda')
results = dataset["train"].map(map_to_result, remove_columns=dataset["train"].column_names)

Map: 100%|██████████| 136764/136764 [55:19<00:00, 41.20 examples/s] 


In [None]:
def remove_start_end_tags(batch):
    # Remove the <s> and </s> tags from both ends of each string in 'pred_str' and 'text'
    batch["pred_str"] = re.sub(r"^<s>|</s>$", "", batch["pred_str"])
    batch["text"] = re.sub(r"^<s>|</s>$", "", batch["text"])
    return batch

# Apply the function to the entire dataset
results = results.map(remove_start_end_tags)


print("train WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))
# WER of 0.12 for train set.

Map: 100%|██████████| 136764/136764 [00:02<00:00, 66451.57 examples/s]


train WER: 0.120
