# Finetuning Hugging Face Wav2Vec2 model on LibriSpeech dataset 

## Init HuggingFace hub

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [None]:
repo_name = "wav2vec2-finetuning-model"

## LibriSpeech dataset

In [None]:
import torch
import torchaudio
from datasets import Dataset, DatasetDict

In [None]:
dataset = {
    "train": torchaudio.datasets.LIBRISPEECH(root="", url="train-clean-100"), 
    "val": torchaudio.datasets.LIBRISPEECH(root="", url="dev-clean"), 
    "test": torchaudio.datasets.LIBRISPEECH(root="", url="test-clean")
}

In [None]:
def transform_dataset(sample):
    audio = sample[0][0]
    sampling_rate = sample[1]
    text = sample[2].lower()
    
    return {
        "audio": sample[0][0], 
        "sampling_rate": sample[1], 
        "text": text
    }

In [None]:
for split in dataset.keys():
    dataset[split] = dataset[split].map(transform_dataset)

## Vocab

In [None]:
import string
import json

In [None]:
vocab = {w: idx for idx, w in enumerate(string.ascii_lowercase)}

In [None]:
vocab.update({
    "|": len(vocab), 
    "'": len(vocab) + 1, 
    "<UNK>": len(vocab) + 2, 
    "<PAD>": len(vocab) + 3
})

In [None]:
vocab

In [None]:
with open(r'vocab.json', 'w') as vocab_file:
    json.dump(vocab, vocab_file)

## Text tokenizer

In [None]:
from transformers import Wav2Vec2CTCTokenizer

In [None]:
tokenizer = Wav2Vec2CTCTokenizer(
    "./vocab.json", unk_token="<UNK>", pad_token="<PAD>", word_delimiter_token="|"
)

In [None]:
tokenizer.push_to_hub(repo_name)

## Audio Wav2Vec2 processing

In [None]:
from transformers import (
    Wav2Vec2FeatureExtractor, 
    Wav2Vec2Processor
)

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0, 
    do_normalize=True, return_attntion_mask=False
)

In [None]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

## Preprocessing dataset

In [None]:
def data_preprocessing(sample):
    sample['audio'] = processor(sample['audio'], sampling_rate=sample["sampling_rate"]).input_values[0]
    
    with processor.as_target_processor():
        sample['label'] = processor(sample['text']).input_ids
    
    return sample

In [None]:
for split in dataset.keys():
    dataset[split] = dataset[split].map(data_preprocessing)

## Data Collator

In [None]:
import torch

from dataclasses import dataclass, field
from typing import Optional, Union

In [None]:
@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, samples):
        input_features = [{'input_values': s['audio']} for s in samples]
        label_features = [{'input_ids': s['label']} for s in samples]

        batch = self.processor.pad(
            input_features, 
            padding=self.padding, 
            max_length=self.max_length, 
            pad_to_multiple_of=self.pad_to_multiple_of, 
            return_tensors="pt"
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features, 
                padding=self.padding, 
                max_length=self.max_length_labels, 
                pad_to_multiple_of=self.pad_to_multiple_of_labels, 
                return_tensors="pt"
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch['labels'] = labels

        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

## Metrics

In [None]:
from datasets import load_metric

In [None]:
wer_metric = load_metric('wer')

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    
    return {'wer': wer}

## Model

In [None]:
from transformers import Wav2Vec2ForCTC

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    'facebook/wav2vec2-base', 
    ctc_loss_reduction='mean', 
    pad_token_id=processor.tokenizer.pad_token_id
)

In [None]:
model.freeze_feature_encoder()

## Init arguments

In [None]:
from transformers import TrainingArguments

In [None]:
training_args = TrainingArguments(
    output_dir=repo_name, 
    group_by_length=True, 
    per_device_train_batch_size=32, 
    evaluation_strategy='steps', 
    num_train_epochs=30, 
    fp16=True, 
    gradient_checkpointing=True, 
    save_steps=500, 
    eval_steps=500, 
    logging_steps=500, 
    learning_rate=1e-4, 
    weight_decay=5e-3, 
    warmup_steps=1000, 
    save_total_limit=2
)

## Training...

In [None]:
from transformers import Trainer

In [None]:
trainer = Trainer(
    model=model, 
    data_collator=data_collator, 
    args=training_args, 
    compute_metrics=compute_metrics, 
    train_dataset=timit_prepared['train'], 
    eval_dataset=timit_prepared['test'], 
    tokenizer=processor.feature_extractor
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()