In [None]:
import os
import numpy as np
import pandas as pd
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import soundfile as sf
import random
import re
import IPython.display as ipd
from IPython.display import display, HTML
from datasets import load_dataset, load_metric, ClassLabel
import json

from constants import SAMPLE_RATE


device = "cpu"
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

file_phoneme_dict = "../model/phoneme_dict.json"
file_index = "../data/timit/index.txt"

timit = load_dataset("timit_asr")
timit = timit.remove_columns([
    #"phonetic_detail", 
    "word_detail", 
    "dialect_region", 
    "id", 
    "sentence_type", 
    "speaker_id"
])


In [None]:
def show_random_elements(dataset, num_examples=10):
    """
    
    """
    
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

    return None


show_random_elements(timit["train"].remove_columns(["file"]), num_examples=20)


In [None]:
def remove_special_characters(batch):
    """
    
    """
    
    batch["text"] = re.sub(chars_to_ignore_regex, "", batch["text"]).lower()
    
    return batch


timit = timit.map(remove_special_characters)
show_random_elements(timit["train"].remove_columns(["file"]))


In [None]:
def get_target_phn(batch):
    """

    """
    
    batch["target_phoneme"] = batch["phonetic_detail"]["utterance"]
    # batch["target_phoneme"] = list(map(lambda x: dict_phoneme[x], batch["phonetic_detail"]["utterance"]))

    return batch


# with open(file_phoneme_dict, "r") as f_phoneme_dict:
#     dict_phoneme = json.load(f_phoneme_dict)

timit = timit.map(get_target_phn)


In [None]:
def speech_file_to_array_fn(batch):
    """
    
    """
    
    speech_array, sampling_rate = sf.read(batch["file"])
    batch["speech"] = speech_array
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["text"]
    batch["phoneme"] = batch["phonetic_detail"]["utterance"]
    
    return batch


timit = timit.map(speech_file_to_array_fn, remove_columns=timit.column_names["train"], num_proc=4)


In [None]:
rand_int = random.randint(0, len(timit["train"]))
ipd.Audio(data=np.asarray(timit["train"][rand_int]["speech"]), autoplay=True, rate=16000)


In [None]:
rand_int = random.randint(0, len(timit["train"]))

print("Target text:", timit["train"][rand_int]["target_text"])
print("Input array shape:", np.asarray(timit["train"][rand_int]["speech"]).shape)
print("Sampling rate:", timit["train"][rand_int]["sampling_rate"])


In [None]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")


# Generate representations for each sample

In [None]:
model.to(device)

for split in ["train", "test"]:
    path_sample = os.path.join("../data/timit", split)
    
    num_sample = len(timit[split])
    print("number of samples in {} set: {}".format(split, num_sample))
    
    for i in range(num_sample):
        input_values = processor(timit[split][i]["speech"], 
                                 sampling_rate=timit[split][i]["sampling_rate"], 
                                 return_tensors="pt").input_values.to(device)
        
        with torch.no_grad():
            last_hidden_state = model(input_values).last_hidden_state.detach().numpy()
            last_hidden_state = last_hidden_state.reshape([last_hidden_state.shape[-2], last_hidden_state.shape[-1]])
            
            file_sample = os.path.join(path_sample, "{}.npy".format(str(i).zfill(4)))
            np.save(file_sample, last_hidden_state)


In [None]:
list_phn_train = []
num_sample = len(timit["train"])
print("number of samples in {} set: {}".format("train", num_sample))
for i in range(num_sample):
    list_phn_train.append(timit["train"][i]["phoneme"])

list_phn_test = []
num_sample = len(timit["test"])
print("number of samples in {} set: {}".format("test", num_sample))
for i in range(num_sample):
    list_phn_test.append(timit["test"][i]["phoneme"])


In [None]:
df_index_train = pd.DataFrame({
    "split": "train",
    "sample_index": [i for i in range(len(timit["train"]["target_text"]))],
    "target_text": timit["train"]["target_text"],
    "target_phoneme": list_phn_train
})

df_index_test = pd.DataFrame({
    "split": "test",
    "sample_index": [i for i in range(len(timit["test"]["target_text"]))],
    "target_text": timit["test"]["target_text"],
    "target_phoneme": list_phn_test
})

df_index = pd.concat([df_index_train, df_index_test]).reset_index(drop=True)
df_index.to_csv(file_index, sep="\t", index=False)


In [None]:
list_phoneme = [
    'aa',
    'ae',
    'ah',
    'ao',
    'aw',
    'ax',
    'ax-h',
    'axr',
    'ay',
    'b',
    'bcl',
    'ch',
    'd',
    'dcl',
    'dh',
    'dx',
    'eh',
    'el',
    'em',
    'en',
    'eng',
    'epi',
    'er',
    'ey',
    'f',
    'g',
    'gcl',
    'h#',
    'hh',
    'hv',
    'ih',
    'ix',
    'iy',
    'jh',
    'k',
    'kcl',
    'l',
    'm',
    'n',
    'ng',
    'nx',
    'ow',
    'oy',
    'p',
    'pau',
    'pcl',
    'q',
    'r',
    's',
    'sh',
    't',
    'tcl',
    'th',
    'uh',
    'uw',
    'ux',
    'v',
    'w',
    'y',
    'z',
    'zh'
]
dict_phoneme = dict(zip(list_phoneme, [i for i in range(len(list_phoneme))]))

# dict_phoneme["[UNK]"] = len(dict_phoneme)
dict_phoneme["[PAD]"] = len(dict_phoneme)

file_phoneme_dict = "../model/phoneme_dict.json"
with open(file_phoneme_dict, "w") as f_phoneme_dict:
    json.dump(dict_phoneme, f_phoneme_dict, indent=4)
