In [4]:
import torch
import torchaudio
from torch.utils.data import DataLoader

from transformers import (
    Speech2TextProcessor, 
    Speech2TextTokenizer, 
    Speech2TextForConditionalGeneration
)

from typing import Tuple, Dict

In [5]:
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
tokenizer = Speech2TextTokenizer.from_pretrained("facebook/s2t-small-librispeech-asr")
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")

In [19]:
def preprocessing(sample: Tuple[torch.Tensor, int, str, int, int, int]) -> Dict[str, torch.Tensor]:
    batch = dict()
    batch["input_values"] = processor(
        sample[0][0], sampling_rate=16_000, padding=True, return_tensors="pt"
    ).input_features[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(sample[2], padding=True, return_tensors="pt").input_ids
    
    return batch

In [20]:
dataset = torchaudio.datasets.LIBRISPEECH(root="", url='dev-clean')
dataset = dataset.map(preprocessing)

In [21]:
dataset[0]

{'input_values': tensor([[-1.7439, -1.5824, -1.3852,  ..., -1.4764, -1.3188, -1.5373],
         [-1.7372, -1.5597, -1.5623,  ..., -1.3479, -1.4273, -1.6317],
         [-1.4334, -1.3532, -1.2806,  ..., -1.5108, -1.3358, -1.4236],
         ...,
         [-1.6064, -1.5294, -1.5514,  ..., -1.0002, -1.0575, -1.0322],
         [-1.7444, -1.6128, -1.4223,  ..., -1.0696, -0.9122, -1.0744],
         [-1.3830, -1.3087, -1.3786,  ..., -1.0823, -1.0664, -1.1232]]),
 'labels': tensor([[ 129, 8053,   66,   30,    4, 5878,    8,    4, 1080, 3353,    5,    6,
            52,   60,  534,    9, 1524,   20, 5517,    2]])}

In [29]:
class DataCollator(object):
    def __init__(self, processor: Speech2TextProcessor, padding=True, 
                 max_length=None, max_length_labels=None, 
                 pad_to_multiple_of=None, pad_to_multiple_of_labels=None):
        self.processor = processor
        self.padding = padding
        self.max_length = max_length
        self.max_length_labels = max_length_labels
        self.pad_to_multiple_of = pad_to_multiple_of
        self.pad_to_multiple_of_labels = pad_to_multiple_of_labels
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        
        batch = self.processor.pad(
            input_features, 
            padding=self.padding, 
            max_length=self.max_length, 
            pad_to_multiple_of=self.pad_to_multiple_of, 
            return_tensors="pt"
        )
        
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features, 
                padding=self.padding, 
                max_length=self.max_length, 
                pad_to_multiple_of=self.pad_to_multiple_of_labels, 
                return_tensors="pt"
            )
        
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        batch["labels"] = labels
        
        return batch

## Wav2Vec

In [1]:
from huggingface_hub import notebook_login

In [2]:
notebook_login()

VBox(children=(HTML(value='<center>\n<img src=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from datasets import load_dataset, load_metric

In [9]:
timit = load_dataset("timit_asr")

Reusing dataset timit_asr (/home/tsimur/.cache/huggingface/datasets/timit_asr/clean/2.0.1/b11b576ddcccbcefa7c9f0c4e6c2a43756f3033adffe0fb686aa61043d0450ad)


  0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])
timit

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'audio', 'text'],
        num_rows: 1680
    })
})

In [11]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

In [12]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = list()
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(timit["train"].remove_columns(["file", "audio"]))

Unnamed: 0,text
0,Baseball's no cinch.
1,"For roast, insert meat thermometer diagonally so it does not rest on bone."
2,Where're you takin' me?
3,Chocolate and roses never fail as a romantic gift.
4,"Although always alone, we survive."
5,Those answers will be straightforward if you think them through carefully first.
6,Challenge each general's intelligence.
7,We saw eight tiny icicles below our roof.
8,Thomas thinks a larger clamp solves the problem.
9,"The larvae, kept warm by the queen, are full grown in about ten days."


In [13]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch

timit = timit.map(remove_special_characters)

0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [14]:
show_random_elements(timit["train"].remove_columns(["file", "audio"]))

Unnamed: 0,text
0,an adult male baboon's teeth are not suitable for eating shellfish
1,vietnamese cuisine is exquisite
2,she had your dark suit in greasy wash water all year
3,cut a small corner off each edge
4,keep your seats boys i just want to put some finishing touches on this thing
5,rockandroll music has a great rhythm
6,she had your dark suit in greasy wash water all year
7,we experience distress and frustration obtaining our degrees
8,don't ask me to carry an oily rag like that
9,please sing just the club theme


In [15]:
def extract_all_chars(batch):
    all_text = " ".join(batch['text'])
    vocab = list(set(all_text))
    
    return {
        'vocab': [vocab], 
        'all_text': [all_text]
    }

In [16]:
vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, 
                   keep_in_memory=True, remove_columns=timit.column_names['train'])

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [19]:
vocabs

DatasetDict({
    train: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
})

In [20]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'j': 0,
 'y': 1,
 'r': 2,
 't': 3,
 "'": 4,
 'k': 5,
 'z': 6,
 'i': 7,
 'e': 8,
 'p': 9,
 'q': 10,
 's': 11,
 'o': 12,
 'l': 13,
 'w': 14,
 'u': 15,
 ' ': 16,
 'c': 17,
 'x': 18,
 'b': 19,
 'n': 20,
 'f': 21,
 'd': 22,
 'm': 23,
 'v': 24,
 'h': 25,
 'g': 26,
 'a': 27}