In [None]:
import re
import json
import pandas as pd

# libs para preparar customs datasets
from data_processor.dataset import MLS, CommonVoice
from data_processor.cleaner import CreateTidyDataset

# libs para conectar o custom dataset com a pipeline
import torch
from torch.utils.data import DataLoader

# Especifico de wav2vec
from datasets import load_metric
from transformers import (Wav2Vec2CTCTokenizer, 
                          Wav2Vec2FeatureExtractor, 
                          Wav2Vec2Processor, 
                          Wav2Vec2ForCTC,
                          TrainingArguments,
                          Trainer)

from core.utils import DataCollatorCTCWithPadding

In [None]:
mls = MLS(data_train_dir = "data/mls_portuguese/train", 
          data_test_dir  = "data/mls_portuguese/test",
          data_dev_dir   = "data/mls_portuguese/dev")

cov = CommonVoice(main_path = "data/common_voice/cv-corpus-7.0-2021-07-21/pt")

databases = [(cov, True), (mls,True)]
tidy_dataset = CreateTidyDataset(databases)

In [None]:
tidy_dataset.converter_audio()

In [None]:
train_df, test_df = tidy_dataset.parse_datasets()

In [None]:
# parse our dataset
regex = '[\,\?\.\!\-\;\:\"\'\“\&\«\´\»\”\ü]'
vocab = set(re.sub(regex, ' ', train_df["text"].str.cat(sep='').lower(), count=0, flags=0))
vocab.update({"[UNK]","[PAD]"})
vocab_dict = {v: k for k, v in enumerate(vocab)}
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [None]:
class Wav2vecDataset(torch.utils.data.Dataset):
    
    def __init__(self, df:pd.DataFrame):
        self.df = df
        self.max_size = len(self.df)
    def __getitem__(self, idx):
        return self.df.loc[idx,["file", "text"]].to_dict()

    def __len__(self):
        return self.max_size

    
train_dataset = Wav2vecDataset(train_df)
test_dataset = Wav2vecDataset(test_df)

train_loader = DataLoader(train_df, batch_size=16, shuffle=True)
test_loader = DataLoader(test_df, batch_size=16, shuffle=True)

In [None]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", 
                                 unk_token="[UNK]", 
                                 pad_token="[PAD]", 
                                 word_delimiter_token="|")

"""
É importante saber o sampling_rate do embedding onde os embeddings foram pré treinados.
"""
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, 
                                             sampling_rate=16000, 
                                             padding_value=0.0, 
                                             do_normalize=True, 
                                             return_attention_mask=False)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
wer_metric = load_metric("wer")

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base", 
                                       gradient_checkpointing=True, 
                                       ctc_loss_reduction="mean", 
                                       pad_token_id=processor.tokenizer.pad_token_id,)

In [None]:
training_args = TrainingArguments(
  # output_dir="/content/gdrive/MyDrive/wav2vec2-base-timit-demo",
  output_dir="./wav2vec2-base-timit-demo",
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,)

In [None]:
trainer = Trainer(model=model,
                  data_collator=data_collator,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=train_dataset,
                  eval_dataset=test_dataset,
                  tokenizer=processor.feature_extractor,)

In [None]:
trainer.train()