# Finetune an Automatic Speech Recognition (ASR) AI model

In [3]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import pandas as pd
import torchaudio
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch

## Load Pre-Trained Model

In [2]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Import cv-valid-train.csv

In [3]:
ds = pd.read_csv('../data/cv-valid-train.csv')
ds.head()

Unnamed: 0,filename,text,up_votes,down_votes,age,gender,accent,duration
0,cv-valid-train/sample-000000.mp3,learn to recognize omens and follow them the o...,1,0,,,,
1,cv-valid-train/sample-000001.mp3,everything in the universe evolved he said,1,0,,,,
2,cv-valid-train/sample-000002.mp3,you came so that you could learn about your dr...,1,0,,,,
3,cv-valid-train/sample-000003.mp3,so now i fear nothing because it was those ome...,1,0,,,,
4,cv-valid-train/sample-000004.mp3,if you start your emails with greetings let me...,3,2,,,,


In [4]:
ds.shape

(195776, 8)

## Split Data into Training and Validation Set

Due to time and resource constraint, I will be using a small subset of the data provided to fine-tune the model.
Consequently, the usage of fewer data will lead to lower model performance.

The resulting data will then be split into training and validation data using a 70-30 ratio.

In [None]:
size = 51

In [None]:
generator1 = torch.Generator().manual_seed(1)
train_set, val_set = torch.utils.data.random_split(ds[:size], [0.7, 0.3], generator = generator1)

## Data Preprocessing

During data preprocessing, audio that are not 16kHz will be resampled. It will then pass through the processor from the pre-trained model to obtain the input values required for fine-tuning. To ensure that the input are of the same shape/ dimension for fine-tuning, the input values will be padded to standardise its length to 180,000. Any longer will be truncated.

In [15]:
required_rate = 16000
max_len = 0
for i in tqdm(train_set.dataset['filename']):
    waveform, sample_rate = torchaudio.load(f'../data/cv-valid-train/{i}')
    if sample_rate != required_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, required_rate)
        resampled_waveform = resampler(waveform)
    else:
        resampled_waveform = waveform

    input_values = processor(resampled_waveform[0], return_tensors="pt", sampling_rate = required_rate, padding = 'do_not_pad').input_values
    if len(input_values.squeeze(0)) > max_len:
        max_len = len(input_values.squeeze(0))

print(f'Maximum length: {max_len}')

100%|██████████| 51/51 [00:01<00:00, 50.33it/s]

Maximum length: 174336





In [4]:
def preprocess_input(audio_filename):
    required_rate = 16000
    waveform, sample_rate = torchaudio.load(f'../data/{audio_filename}')
    if sample_rate != required_rate:
        resampler = torchaudio.transforms.Resample(sample_rate, required_rate)
        resampled_waveform = resampler(waveform)
    else:
        resampled_waveform = waveform

    input_values = processor(resampled_waveform[0], return_tensors="pt", padding="max_length"
                             , sampling_rate = required_rate, max_length=180000, truncation = True).input_values
    return input_values

In [8]:
training_data = []
validation_data = []

for i in tqdm(train_set.indices):
    input = preprocess_input(f'cv-valid-train/{ds.loc[i, 'filename']}').squeeze(0)
    output = processor.tokenizer(ds.loc[i, 'text'], return_tensors="pt", padding="max_length", max_length = 200, truncation = True)
    training_data.append({'input': input, 'output': output['input_ids'][0]})

for i in tqdm(val_set.indices):
    input = preprocess_input(f'cv-valid-train/{ds.loc[i, 'filename']}').squeeze(0)
    output = processor.tokenizer(ds.loc[i, 'text'], return_tensors="pt", padding="max_length", max_length = 200, truncation = True)
    validation_data.append({'input': input, 'output': output['input_ids'][0]})

100%|██████████| 36/36 [00:00<00:00, 63.80it/s]
100%|██████████| 15/15 [00:00<00:00, 60.47it/s]


In [9]:
print(f'Size of training data: {round(len(train_set.indices)/train_set.dataset.shape[0],3)}')
print(f'Size of validation data: {round(len(val_set.indices)/train_set.dataset.shape[0],3)}')

Size of training data: 0.706
Size of validation data: 0.294


## Model Training

In [10]:
def model_training(batch_size, num_epochs, learning_rate):

    training_dataloader = DataLoader(training_data, batch_size = batch_size, shuffle = True)
    validation_dataloader = DataLoader(validation_data, batch_size = batch_size, shuffle = True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    loss_fn = torch.nn.CTCLoss()
    min_loss = 10**10

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        
        for batch in tqdm(training_dataloader):
            inputs = batch['input']            
            labels = batch['output']
            optimizer.zero_grad()
            outputs = model(inputs).logits
            input_lengths = torch.full((outputs.shape[0],), outputs.shape[1], dtype=torch.long)
            target_lengths = torch.full((labels.shape[0],), labels.shape[1], dtype=torch.long)
            loss = loss_fn(outputs.transpose(0, 1), labels, input_lengths, target_lengths)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        avg_train_loss = train_loss/len(training_dataloader)
        
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(validation_dataloader):
                inputs = batch['input']
                labels = batch['output']
                outputs = model(inputs).logits
                input_lengths = torch.full((outputs.shape[0],), outputs.shape[1], dtype=torch.long)
                target_lengths = torch.full((labels.shape[0],), labels.shape[1], dtype=torch.long)
                loss = loss_fn(outputs.transpose(0, 1), labels, input_lengths, target_lengths)
                val_loss += loss.item()
        avg_val_loss = val_loss/len(validation_dataloader)
        
        print(epoch)
        print(avg_train_loss)
        print(avg_val_loss)

        if avg_val_loss < min_loss:
            min_loss = avg_val_loss
            best_model = model.state_dict()
    
    return min_loss, best_model
    

In [None]:
learning_rates = [0.00001, 0.0001]
batch_sizes = [2, 5]
num_epochs = 2
results = []

for i in learning_rates:
    for j in batch_sizes:
        print(f"learning_rate={i}, batch_size={j}")
        min_loss, best_model = model_training(learning_rate=i, batch_size=j, num_epochs=num_epochs)
        results.append((i, j, min_loss))

learning_rate=1e-05, batch_size=2


100%|██████████| 18/18 [19:24<00:00, 64.71s/it]
100%|██████████| 8/8 [01:29<00:00, 11.13s/it]


0
-13.44194910261366
-14.360321521759033


100%|██████████| 18/18 [12:41<00:00, 42.29s/it]
100%|██████████| 8/8 [01:24<00:00, 10.57s/it]


1
-11.269831551445854
-11.981998085975647
learning_rate=1e-05, batch_size=5


100%|██████████| 8/8 [41:04<00:00, 308.09s/it]
100%|██████████| 3/3 [03:42<00:00, 74.02s/it]


0
-8.014303758740425
-5.567410945892334


100%|██████████| 8/8 [29:16<00:00, 219.58s/it]
100%|██████████| 3/3 [01:35<00:00, 31.86s/it]


1
-4.39245143532753
-2.1101353963216147
learning_rate=0.0001, batch_size=2


100%|██████████| 18/18 [22:58<00:00, 76.59s/it]
100%|██████████| 8/8 [03:34<00:00, 26.81s/it]


0
0.4001913805388742
-0.23253468051552773


100%|██████████| 18/18 [20:27<00:00, 68.17s/it]
100%|██████████| 8/8 [00:53<00:00,  6.63s/it]


1
1.0771342772576544
0.9891796633601189
learning_rate=0.0001, batch_size=5


100%|██████████| 8/8 [17:50<00:00, 133.81s/it]
100%|██████████| 3/3 [00:51<00:00, 17.05s/it]


0
3.307461053133011
4.7875870068868


100%|██████████| 8/8 [29:48<00:00, 223.56s/it]
100%|██████████| 3/3 [03:06<00:00, 62.16s/it]

1
1.875274233520031
0.8801562984784445





Learning rate of 1e-05 with batch size of 2 performed best out of the different combinations of hyperparameters tested since it has the lowest loss value.

In [17]:
results

[(1e-05, 2, -14.360321521759033),
 (1e-05, 5, -5.567410945892334),
 (0.0001, 2, -0.23253468051552773),
 (0.0001, 5, 0.8801562984784445)]

In [30]:
final_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
final_model.load_state_dict(best_model)
processor.save_pretrained('./wav2vec2-large-960h-cv')
processor.tokenizer.save_pretrained('./wav2vec2-large-960h-cv')
final_model.save_pretrained('./wav2vec2-large-960h-cv')

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Utilize Fine-Tuned Model on cv-valid-test

In [13]:
ds_test = pd.read_csv('../data/cv-valid-test.csv')

In [None]:
ds_test = ds_test[:10]

In [8]:
processor = Wav2Vec2Processor.from_pretrained('./wav2vec2-large-960h-cv')
model = Wav2Vec2ForCTC.from_pretrained('./wav2vec2-large-960h-cv')

In [None]:
test_data = []

for i in tqdm(range(ds_test.shape[0])):
    input = preprocess_input(f'cv-valid-test/{ds_test.loc[i, 'filename']}').squeeze(0)
    output = processor.tokenizer(ds_test.loc[i, 'text'], return_tensors="pt", padding="max_length", max_length = 200, truncation = True)
    test_data.append({'input': input, 'output': output['input_ids'][0]})

  0%|          | 0/3995 [00:00<?, ?it/s]

 13%|█▎        | 504/3995 [00:03<00:28, 121.57it/s]

In [None]:
test_dataloader = DataLoader(test_data, batch_size = 20, shuffle = True)
loss_fn = torch.nn.CTCLoss()
test_loss = 0

for batch in tqdm(test_dataloader):
    inputs = batch['input']
    labels = batch['output']
    outputs = model(inputs).logits
    input_lengths = torch.full((outputs.shape[0],), outputs.shape[1], dtype=torch.long)
    target_lengths = torch.full((labels.shape[0],), labels.shape[1], dtype=torch.long)
    loss = loss_fn(outputs.transpose(0, 1), labels, input_lengths, target_lengths)
    test_loss += loss.item()

100%|██████████| 1/1 [01:11<00:00, 71.72s/it]


In [None]:
test_loss

0.19517859816551208

## Compare Fine-Tuned wav2vec2-large-960h-cv Model with Pre-Trained Model wav2vec2-large-960h

In [5]:
ds_dev = pd.read_csv('../asr/cv-valid-dev.csv')

In [6]:
ds_dev.loc[0]

filename                             cv-valid-dev/sample-000000.mp3
text              be careful with your prognostications said the...
up_votes                                                          1
down_votes                                                        0
age                                                             NaN
gender                                                          NaN
accent                                                          NaN
duration                                                        NaN
generated_text    BE CAREFUL WITH YOUR PROGNOSTICATIONS SAID THE...
Name: 0, dtype: object

In [11]:
input = preprocess_input(f'cv-valid-dev/{ds_dev.loc[0, 'filename']}')
logits = model(input).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)

In [12]:
transcription

['']