In [1]:
!pip install -U torchaudio librosa jiwer datasets transformers huggingface_hub evaluate python-dotenv wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchaudio
  Downloading torchaudio-2.0.1-cp39-cp39-manylinux1_x86_64.whl (4.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
Collecting jiwer
  Downloading jiwer-3.0.1-py3-none-any.whl (21 kB)
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m78.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface_hub
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m25.8 MB/s[0m eta [36m0:00:00[

## Load and prepare data

In [2]:
import os
from dotenv import load_dotenv
load_dotenv()

False

In [3]:
import evaluate
from datasets import load_dataset, load_metric, Audio, concatenate_datasets
from pandas import DataFrame, Series
import pandas as pd
import numpy as np
from functools import partial
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor


In [4]:
def normalize_sentence_ends(batch):
  """Function to preprocess the dataset with the .map method"""
  transcription = batch["transcription"]
  
  if transcription.startswith('"') and transcription.endswith('"'):
    # we can remove trailing quotation marks as they do not affect the transcription
    transcription = transcription[1:-1]
  
  if transcription[-1] not in [".", "?", "!"]:
    # append a full-stop to sentences that do not end in punctuation
    transcription = transcription + "."
  
  batch["transcription"] = transcription
  
  return batch


def prepare_model_inputs(batch, processor):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids

    batch["labels_length"] = len(tokenizer(batch["transcription"], add_special_tokens=False).input_ids)

    return batch



MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

#max_label_length = model.config.max_length
max_label_length = 448

def filter_labels(labels_length):
    """Filter label sequences longer than max length (448)"""
    return labels_length < max_label_length

In [5]:
import torch

from dataclasses import dataclass #, field
# from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [6]:
#dataset_card = "mozilla-foundation/common_voice_11_0"
#HF_TOKEN = os.getenv("HUGGING_FACE_ACCESS_TOKEN")
HF_TOKEN = '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'
SAMPLING_RATE = 16_000

fleurs_train = load_dataset("google/fleurs", "yo_ng", split="train+validation", use_auth_token=HF_TOKEN)
fleurs_test = load_dataset("google/fleurs", "yo_ng", split="test", use_auth_token=HF_TOKEN)

#afrispeech_train = load_dataset("tobiolatunji/afrispeech-200", "yoruba", split="train", use_auth_token=HF_TOKEN).rename_column('transcript', 'transcription')
#afrispeech_test = load_dataset("tobiolatunji/afrispeech-200", "yoruba", split="validation", use_auth_token=HF_TOKEN).rename_column('transcript', 'transcription')

cv_train = load_dataset("mozilla-foundation/common_voice_12_0", "yo", split="train+validation", use_auth_token=HF_TOKEN).rename_column('sentence', 'transcription')
cv_test = load_dataset("mozilla-foundation/common_voice_12_0", "yo", split="test", use_auth_token=HF_TOKEN).rename_column('sentence', 'transcription')


Downloading builder script:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

Downloading and preparing dataset fleurs/yo_ng to /root/.cache/huggingface/datasets/google___fleurs/yo_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.83G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/316M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/692M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.52M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/536k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset fleurs downloaded and prepared to /root/.cache/huggingface/datasets/google___fleurs/yo_ng/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac. Subsequent calls will reuse this data.




Downloading builder script:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/17.9k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/30.4k [00:00<?, ?B/s]

Downloading and preparing dataset afrispeech-200/yoruba to /root/.cache/huggingface/datasets/tobiolatunji___afrispeech-200/yoruba/1.0.0/0994341a78a520144afc15e99c95aacfe056e9833f4becf9efa34969c3f81c5e...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.40G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.49G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/286M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/5.00M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/83.7k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]




Reading metadata...: 0it [00:00, ?it/s][A[A[A


Reading metadata...: 14369it [00:00, 87726.02it/s]


Generating validation split: 0 examples [00:00, ? examples/s]





Reading metadata...: 361it [00:00, 71394.93it/s]


Dataset afrispeech-200 downloaded and prepared to /root/.cache/huggingface/datasets/tobiolatunji___afrispeech-200/yoruba/1.0.0/0994341a78a520144afc15e99c95aacfe056e9833f4becf9efa34969c3f81c5e. Subsequent calls will reuse this data.




Downloading builder script:   0%|          | 0.00/8.25k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.5k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.57k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/63.2k [00:00<?, ?B/s]

Downloading and preparing dataset common_voice_12_0/yo to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice_12_0/yo/12.0.0/dd534e3c6006ee4b577c176df4a8ef23bced8b3150a3b64d2d0a7a5e3f942efb...


Downloading data:   0%|          | 0.00/12.7k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.52M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/809k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.15M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/860k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.60k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.89k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.39k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]





Reading metadata...: 39it [00:00, 24069.73it/s]


Generating validation split: 0 examples [00:00, ? examples/s]






Reading metadata...: 26it [00:00, 49636.73it/s]


Generating test split: 0 examples [00:00, ? examples/s]





Reading metadata...: 27it [00:00, 13051.31it/s]


Generating other split: 0 examples [00:00, ? examples/s]







Reading metadata...: 7it [00:00, 6073.67it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]




Reading metadata...: 20it [00:00, 11473.95it/s]


Dataset common_voice_12_0 downloaded and prepared to /root/.cache/huggingface/datasets/mozilla-foundation___common_voice_12_0/yo/12.0.0/dd534e3c6006ee4b577c176df4a8ef23bced8b3150a3b64d2d0a7a5e3f942efb. Subsequent calls will reuse this data.




In [7]:
#dataset_card = "mozilla-foundation/common_voice_11_0"
#HF_TOKEN = os.getenv("HUGGING_FACE_ACCESS_TOKEN")
#SAMPLING_RATE = 16_000

#common_voice_train = load_dataset(dataset_card, "ha", split="train+validation", use_auth_token=HF_TOKEN)
#common_voice_test = load_dataset(dataset_card, "ha", split="test", use_auth_token=HF_TOKEN)

In [8]:
SAMPLING_RATE = 16000
rem_cols_cv = ["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"]
rem_cols_fleurs = ['id', 'num_samples', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id']
rem_cols_afr = ['speaker_id', 'audio_id', 'age_group', 'gender', 'accent', 'domain', 'country', 'duration']
#r_col = ['id', 'num_samples', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id']
r_col = ['id', 'num_samples', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id',
         'locale', 'age', 'client_id', 'down_votes', 'segment', 'up_votes', 'speaker_id', 'audio_id', 'age_group', 'accent', 'domain', 'country', 'duration']
#common_voice = common_voice.remove_columns(r_col)


fleurs_train, fleurs_test = [
    ds.map(normalize_sentence_ends, desc="preprocess dataset").\
        remove_columns(rem_cols_fleurs).\
        cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
    for ds in [fleurs_train, fleurs_test]
]


#afrispeech_train, afrispeech_test = [
#    ds.map(normalize_sentence_ends, desc="preprocess dataset").\
#        remove_columns(rem_cols_afr).\
#        cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
#    for ds in [afrispeech_train, afrispeech_test]
#]



cv_train, cv_test = [
    ds.map(normalize_sentence_ends, desc="preprocess dataset").\
        remove_columns(rem_cols_cv).\
        cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
    for ds in [cv_train, cv_test]
]



preprocess dataset:   0%|          | 0/2717 [00:00<?, ? examples/s]

preprocess dataset:   0%|          | 0/831 [00:00<?, ? examples/s]

preprocess dataset:   0%|          | 0/65 [00:00<?, ? examples/s]

preprocess dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

In [9]:
common_voice_train = concatenate_datasets([cv_train, fleurs_train])
common_voice_test = concatenate_datasets([cv_test, fleurs_test])

In [10]:
DataFrame(common_voice_train[1:5])

Unnamed: 0,path,audio,transcription
0,/root/.cache/huggingface/datasets/downloads/ex...,{'path': '/root/.cache/huggingface/datasets/do...,Kí ní wọ́n ń pè ní “Alágbe?”.
1,/root/.cache/huggingface/datasets/downloads/ex...,{'path': '/root/.cache/huggingface/datasets/do...,Iléẹjọ́ dá Adélékè sílẹ̀ lórí ẹ̀sùn màgòmág...
2,/root/.cache/huggingface/datasets/downloads/ex...,{'path': '/root/.cache/huggingface/datasets/do...,Iṣẹ́ ni àwọn eléré ẹ̀fẹ̀ máa ń ṣe.
3,/root/.cache/huggingface/datasets/downloads/ex...,{'path': '/root/.cache/huggingface/datasets/do...,Ojisẹ Ọlọ́run fún osere tíátà ni owó láti ṣiṣẹ...


Remove special characters:

In [11]:
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\’\ʻ\”\�\']'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["transcription"]).lower()
    return batch


common_voice_train, common_voice_test = [ds.map(remove_special_characters) for ds in [common_voice_train, common_voice_test]]

Map:   0%|          | 0/2782 [00:00<?, ? examples/s]

Map:   0%|          | 0/858 [00:00<?, ? examples/s]

In [12]:
def extract_all_chars(batch):
  all_text = " ".join(batch["transcription"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [13]:
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

Map:   0%|          | 0/2782 [00:00<?, ? examples/s]

Map:   0%|          | 0/858 [00:00<?, ? examples/s]

In [14]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

In [15]:
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

vocab_dict

{'!': 1,
 '$': 2,
 '%': 3,
 "'": 4,
 '+': 5,
 ',': 6,
 '-': 7,
 '.': 8,
 '/': 9,
 '0': 10,
 '1': 11,
 '2': 12,
 '3': 13,
 '4': 14,
 '5': 15,
 '6': 16,
 '7': 17,
 '8': 18,
 '9': 19,
 ':': 20,
 ';': 21,
 '=': 22,
 '?': 23,
 'A': 24,
 'B': 25,
 'E': 26,
 'F': 27,
 'G': 28,
 'I': 29,
 'K': 30,
 'M': 31,
 'N': 32,
 'O': 33,
 'R': 34,
 'S': 35,
 'T': 36,
 'W': 37,
 'Y': 38,
 '[': 39,
 '\\': 40,
 ']': 41,
 'a': 42,
 'b': 43,
 'c': 44,
 'd': 45,
 'e': 46,
 'f': 47,
 'g': 48,
 'h': 49,
 'i': 50,
 'j': 51,
 'k': 52,
 'l': 53,
 'm': 54,
 'n': 55,
 'o': 56,
 'p': 57,
 'q': 58,
 'r': 59,
 's': 60,
 't': 61,
 'u': 62,
 'v': 63,
 'w': 64,
 'x': 65,
 'y': 66,
 'z': 67,
 '}': 68,
 '£': 69,
 '°': 70,
 '²': 71,
 '´': 72,
 '½': 73,
 '¾': 74,
 'À': 75,
 'È': 76,
 'Ì': 77,
 'Ó': 78,
 'à': 79,
 'á': 80,
 'ç': 81,
 'è': 82,
 'é': 83,
 'ë': 84,
 'ì': 85,
 'í': 86,
 'ï': 87,
 'ò': 88,
 'ó': 89,
 'õ': 90,
 'ù': 91,
 'ú': 92,
 'ü': 93,
 'ę': 94,
 'ń': 95,
 'ǹ': 96,
 '̀': 97,
 '́': 98,
 '̂': 99,
 '̄': 100,
 '̣': 1

In [16]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [17]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True
)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
common_voice_train[0]["audio"]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f58cc55a7ff71370b30d424fa1398dff72b601ee9dc12eddb26d33d2e54d93af/common_voice_yo_36518280.mp3',
 'array': array([-8.18545232e-12,  6.36646291e-12,  6.36646291e-12, ...,
         5.02646435e-06,  1.75111927e-06,  3.75934178e-06]),
 'sampling_rate': 16000}

Play random audio:

In [19]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=SAMPLING_RATE)

ìjàm̀bá ojú òfúrufú ló wọ́pọ̀ ní ilẹ̀ iran èyí tó ní ìtàn pípẹ́ fún àìse àbòjútó tó péye fún ìlò ará ìlú àti ológun


In [20]:
common_voice_train, common_voice_test = [
    ds.map(partial(prepare_model_inputs, processor=processor), remove_columns=ds.column_names) 
    for ds in [common_voice_train, common_voice_test]
]

Map:   0%|          | 0/2782 [00:00<?, ? examples/s]



Map:   0%|          | 0/858 [00:00<?, ? examples/s]

In [21]:
# TODO: here is a good place to split / truncate long sequences

#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

## Training

### Prepare

In [22]:
def compute_metrics(pred, wer_metric):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [23]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [24]:
wer_metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [30]:

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

model.freeze_feature_encoder()

Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['quantizer.codevectors', 'project_hid.weight', 'project_q.weight', 'project_q.bias', 'project_hid.bias', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it 

In [31]:
!mkdir -p output/models

### Run

In [32]:
import wandb

wandb.login() # relies on WANDB_API_KEY env var
run = wandb.init(project="FEM", job_type="training", name="wav2vec2-xls-r-300m")



VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [33]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
  output_dir="output/models",
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=100,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=False,
  report_to="wandb"
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=partial(compute_metrics, wer_metric=wer_metric),
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

In [34]:
trainer.train()

Step,Training Loss,Validation Loss,Wer
100,16.7002,7.534271,1.0
200,4.7004,3.454164,1.0
300,3.3521,3.323382,1.0
400,3.1359,2.115851,1.0
500,1.3154,1.035044,0.794148
600,0.9025,0.910602,0.749449
700,0.745,0.804528,0.670441




Step,Training Loss,Validation Loss,Wer
100,16.7002,7.534271,1.0
200,4.7004,3.454164,1.0
300,3.3521,3.323382,1.0
400,3.1359,2.115851,1.0
500,1.3154,1.035044,0.794148
600,0.9025,0.910602,0.749449
700,0.745,0.804528,0.670441
800,0.572,0.841581,0.656913
900,0.4625,0.843035,0.629809
1000,0.3554,0.901169,0.617703




TrainOutput(global_step=2610, training_loss=1.2959767858758284, metrics={'train_runtime': 17416.3482, 'train_samples_per_second': 4.792, 'train_steps_per_second': 0.15, 'total_flos': 4.036316099401732e+19, 'train_loss': 1.2959767858758284, 'epoch': 30.0})

In [35]:
wandb.finish()

0,1
eval/loss,█▄▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂
eval/runtime,▂▆▄▃▃▅▃▃▃▂▁█▂▂▃▇▂▁▁▂▃▂▂▄▃▃
eval/samples_per_second,▇▃▄▆▆▄▆▆▆▇█▁▇▇▆▂▇██▇▆▇▇▅▅▆
eval/steps_per_second,▇▃▅▆▆▄▆▆▆▇▇▁▇▇▆▂▇██▇▆▇▇▅▅▆
eval/wer,████▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train/learning_rate,▂▄▅▇██▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/loss,1.33663
eval/runtime,129.2102
eval/samples_per_second,6.64
eval/steps_per_second,0.836
eval/wer,0.57864
train/epoch,30.0
train/global_step,2610.0
train/learning_rate,0.0
train/loss,0.0359
train/total_flos,4.036316099401732e+19
