### Fine-tuning wav2vec 2.0 for English ASR with TIMIT

In [1]:
%%capture
!pip install datasets==1.18.3
!pip install transformers==4.17.0
!pip install jiwer

In [2]:
from datasets import load_dataset, load_metric
timit = load_dataset("timit_asr")

Downloading:   0%|          | 0.00/2.40k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Downloading and preparing dataset timit_asr/clean (download: 828.75 MiB, generated: 7.90 MiB, post-processed: Unknown size, total: 836.65 MiB) to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/b11b576ddcccbcefa7c9f0c4e6c2a43756f3033adffe0fb686aa61043d0450ad...


Downloading:   0%|          | 0.00/869M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset timit_asr downloaded and prepared to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/b11b576ddcccbcefa7c9f0c4e6c2a43756f3033adffe0fb686aa61043d0450ad. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
timit

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 1680
    })
})

In [4]:
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])

In [5]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(timit["train"].remove_columns(["audio", "file"]), num_examples=10)

Unnamed: 0,text
0,Of particular importance is the study of the actions of drugs in this respect.
1,She had your dark suit in greasy wash water all year.
2,There is little doubt that the students benefit from vocational education.
3,Don't ask me to carry an oily rag like that.
4,Cooperation along with understanding alleviate dispute.
5,How much and how many profits could a majority take out of the losses of a few?
6,A moth zig-zagged along the path through Otto's garden.
7,The paper boy bought two apples and three ices.
8,A doctor was in the ambulance with the patient.
9,"By eating yogurt, you may live longer."


In [7]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

In [8]:
timit = timit.map(remove_special_characters)



0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [9]:
show_random_elements(timit["train"].remove_columns(["audio", "file"]))

Unnamed: 0,text
0,she had your dark suit in greasy wash water all year
1,his scalp was blistered from today's hot sun
2,this truth that the moral law is natural has other important corollaries
3,don't ask me to carry an oily rag like that
4,in my place you'd follow such advice as you give me
5,the rose corsage smelled sweet
6,she had your dark suit in greasy wash water all year
7,george seldom watches daytime movies
8,move the garbage nearer to the large window
9,she had your dark suit in greasy wash water all year


In [10]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [11]:
vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [12]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [13]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'h': 0,
 'e': 1,
 't': 2,
 'g': 3,
 'b': 4,
 'i': 5,
 'a': 6,
 'n': 7,
 'j': 8,
 'r': 9,
 "'": 10,
 'y': 11,
 'u': 12,
 'f': 13,
 'p': 14,
 'o': 15,
 'q': 16,
 ' ': 17,
 'v': 18,
 'k': 19,
 's': 20,
 'x': 21,
 'z': 22,
 'd': 23,
 'c': 24,
 'm': 25,
 'w': 26,
 'l': 27}

In [14]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [15]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

30

In [16]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [17]:
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [18]:
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, 
                                             padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [19]:
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

#### Preprocess Data

In [20]:
timit["train"][0]["file"]

'/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV'

In [21]:
timit["train"][0]["audio"]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
 'array': array([-2.1362305e-04,  6.1035156e-05,  3.0517578e-05, ...,
        -3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
 'sampling_rate': 16000}

In [22]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(timit["train"]))

print(timit["train"][rand_int]["text"])
ipd.Audio(data=np.asarray(timit["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)

we do not arrive at spatial images by means of the sense of touch by itself 


In [23]:
rand_int = random.randint(0, len(timit["train"]))

print("Target text:", timit["train"][rand_int]["text"])
print("Input array shape:", np.asarray(timit["train"][rand_int]["audio"]["array"]).shape)
print("Sampling rate:", timit["train"][rand_int]["audio"]["sampling_rate"])

Target text: we got drenched from the uninterrupted rain 
Input array shape: (44954,)
Sampling rate: 16000


In [24]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

In [25]:
timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)

In [26]:
max_input_length_in_sec = 4.0
timit["train"] = timit["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/5 [00:00<?, ?ba/s]

#### Training & Evaluation

In [27]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [28]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [29]:
wer_metric = load_metric("wer")

Downloading:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [30]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [31]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

Downloading:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


Downloading:   0%|          | 0.00/363M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['project_hid.weight', 'quantizer.weight_proj.bias', 'quantizer.weight_proj.weight', 'project_q.bias', 'quantizer.codevectors', 'project_q.weight', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

In [32]:
model.freeze_feature_encoder()

In [36]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="wav2vec_sov",
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

In [37]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit["train"],
    eval_dataset=timit["test"],
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


### Training

In [38]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 3978
  Num Epochs = 30
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 14940


Step,Training Loss,Validation Loss,Wer
500,3.5565,1.706126,1.007512
1000,0.87,0.560123,0.535042
1500,0.4519,0.450682,0.452829
2000,0.3106,0.440572,0.430294
2500,0.2292,0.453765,0.410723
3000,0.1906,0.432837,0.395149
3500,0.1591,0.455827,0.38557
4000,0.1358,0.425028,0.383916
4500,0.1247,0.463069,0.379436
5000,0.1078,0.446104,0.370271


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec_sov/checkpoint-500
Configuration saved in wav2vec_sov/checkpoint-500/config.json
Model weights saved in wav2vec_sov/checkpoint-500/pytorch_model.bin
Feature extractor saved in wav2vec_sov/checkpoint-500/preprocessor_config.json
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec_sov/checkpoint-1000
Configuration saved in wav2vec_sov/checkpoint-1000/c

TrainOutput(global_step=14940, training_loss=0.23804458703701117, metrics={'train_runtime': 6684.2754, 'train_samples_per_second': 17.854, 'train_steps_per_second': 2.235, 'total_flos': 3.0766610595969556e+18, 'train_loss': 0.23804458703701117, 'epoch': 30.0})

### Evaluate

In [39]:
model = trainer.model

In [40]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
  return batch

In [41]:
results = timit["test"].map(map_to_result, remove_columns=timit["test"].column_names)

0ex [00:00, ?ex/s]



Let's compute the overall WER now.

In [42]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

Test WER: 0.287


28.7% WER

In [43]:
show_random_elements(results)

Unnamed: 0,pred_str,text
0,young children should avoide exposiure to contages diseses,young children should avoid exposure to contagious diseases
1,artofficial intelligence is for real,artificial intelligence is for real
2,their propes were two steplauters a chaire and a pomb fan,their props were two stepladders a chair and a palm fan
3,if people were more generous there would be no need for wealfare,if people were more generous there would be no need for welfare
4,the fish began to leep frantically on the surfacse of the small eac,the fish began to leap frantically on the surface of the small lake
5,her rite hand akes when ever is the warametric pressur changes,her right hand aches whenever the barometric pressure changes
6,only laowyers love milions,only lawyers love millionaires
7,the neareus sentagaud my not be with then walck in distance,the nearest synagogue may not be within walking distance
8,basket ball can be an entertaining sport,basketball can be an entertaining sport
9,she had your dark suit in greasy wash water all year,she had your dark suit in greasy wash water all year


In [44]:
model.to("cuda")

with torch.no_grad():
  logits = model(torch.tensor(timit["test"][:1]["input_values"], device="cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)

# convert ids to tokens
" ".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))



'[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] t h e e | b b [PAD] [PAD] [PAD] u n n g g g [PAD] l l l l [PAD] [PAD] [PAD] o o o | w w a a s s | | [PAD] [PAD] p p l l [PAD] e s s s s [PAD] e n n t [PAD] l l l y | | | s s [PAD] [PAD] i i i t t [PAD] [PAD] [PAD] u u u [PAD] [PAD] [PAD] a a a t t [PAD] e e d d | | [PAD] n e e e a a r r | t h e e | | s s s h h [PAD] o [PAD] [PAD] [PAD] o r r r r | | [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'