In [1]:
import os
import glob
import pandas as pd
import numpy as np
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset

In [2]:
class TIMITDataset(Dataset):
    def __init__(self, main_dir, mode: str):
        super(TIMITDataset, self).__init__()
        paths = glob.glob(f"{main_dir}\*\*\*.wav")
        self.sample_paths = [os.path.splitext(path)[0] 
                for path in paths if not path.endswith(".WAV.wav")]
        self.dictionary = self.phoneme_dict()
        self.mode = mode

    def dataframe(self):
        def sentence_type(sample_name):
            if sample_name.startswith("SA"):
                return "dialect"
            elif sample_name.startswith("SX"):
                return "compact"
            else:
                return "diverse"   

        sample_dict = {}
        for i, sample in tqdm(
                enumerate(self.sample_paths), total=len(self.sample_paths),
                desc="Generating"):

            _, _, dialect, speaker_id, sample_name = sample.split("\\")
            S, mfcc, frames, phonemes = self.process_file(sample)
            
            if self.mode == "full":
                sample_type = sentence_type(sample_name)
                sample_rate = librosa.get_samplerate(sample + ".wav")
                with open(sample + ".txt", 'r') as f:
                    context = f.read().split()
                    n_frames, text = context[1], " ".join(context[2::])
                sample_dict[i] = (sample_name, dialect, speaker_id,
                    sample_type, sample_rate, n_frames, text, sample, S, 
                    mfcc, frames, phonemes)
                columns = ["sample_name", "dialect", "speaker_id", 
                    "sample_type", "sample_rate", "n_frames", "text", 
                    "sample_path", "spec_array", "mfcc_array", "frame_array", 
                    "phoneme_array"]
            elif self.mode == "partial":
                sample_dict[i] = (sample_name, speaker_id, S, mfcc, frames, phonemes)
                columns = ["sample_name", "speaker_id", "spec_array", "mfcc_array",
                    "frame_array", "phoneme_array"]
            else:
                logging.error("Invalid mode, only full or partial allowed.")
                break

        return pd.DataFrame.from_dict(
            sample_dict, orient="index", columns=columns)


    def phoneme_dict(self):
        phonemes = set()
        for sample in self.sample_paths:
            with open(sample + ".phn", "r") as f:
                for line in f.readlines():
                    phonemes.add(line.split()[-1])
        return {phoneme: i for i, phoneme in enumerate(phonemes)}


    def spectral_features(self, path, type):
        y, sr = librosa.load(path, librosa.get_samplerate(path))
        S = librosa.feature.melspectrogram(
                y=y, sr=sr, n_mels=128, fmax=8000)
        if type=="mel":
            S = librosa.power_to_db(S, ref=np.max)
        elif type=="mfcc":
            S = librosa.feature.mfcc(
                S=S, n_mfcc=128).T
        else:
            logging.error("Invalid type input, only mel or mfcc allowed.")
        return torch.from_numpy(S)


    def process_file(self, path):
        wav_path = path + ".wav"
        phn_path = path + ".phn"
        with open(phn_path, "r") as f:
            frames, phonemes = [0], []
            for line in f.readlines():
                _, time, phoneme = line.split(" ")
                frames.append(int(time))
                phonemes.append(phoneme.strip('\n'))
        frames = torch.Tensor(frames)
        phoneme = np.array(phonemes)   

        mfcc = self.spectral_features(wav_path, type="mel")
        S = self.spectral_features(wav_path, type="mfcc")
        return S, mfcc, frames, phonemes


    def split_dataset(self):
        data = self.dataframe()
        train, test = train_test_split(
            data, test_size=0.25, shuffle=True,
            stratify=data['speaker_id'], random_state=42)
        return train, test

dataset = TIMITDataset(main_dir="TIMIT-dataset\data", mode="partial")
train, test = dataset.split_dataset()
train

Generating:   0%|          | 0/6300 [00:00<?, ?it/s]

Unnamed: 0,sample_name,speaker_id,spec_array,mfcc_array,frame_array,phoneme_array
3687,SX319,FJCS0,"[[tensor(1.2060e-05), tensor(8.8126e-06), tens...","[[tensor(-65.7270), tensor(-65.1304), tensor(-...","[tensor(0.), tensor(4360.), tensor(5320.), ten...","[h#, ah, bcl, b, ih, gcl, g, ow, tcl, q, ay, d..."
6171,SA2,MKDD0,"[[tensor(7.4580e-06), tensor(6.2724e-06), tens...","[[tensor(-50.0056), tensor(-52.0300), tensor(-...","[tensor(0.), tensor(2180.), tensor(2450.), ten...","[h#, d, ow, n, q, ae, s, kcl, m, iy, tcl, t, i..."
905,SX158,MDEM0,"[[tensor(0.0003), tensor(0.0002), tensor(-0.00...","[[tensor(-74.3413), tensor(-76.5413), tensor(-...","[tensor(0.), tensor(2600.), tensor(2799.), ten...","[h#, dh, ix, dcl, d, r, ah, n, kcl, k, er, dcl..."
334,SI2327,MPSW0,"[[tensor(1.0861e-05), tensor(1.1324e-05), tens...","[[tensor(-55.5831), tensor(-57.2442), tensor(-...","[tensor(0.), tensor(1949.), tensor(3270.), ten...","[h#, em, s, aa, r, ix, v, ix, dx, ae, sh, epi,..."
397,SX309,MRSO0,"[[tensor(1.7078e-05), tensor(1.3041e-05), tens...","[[tensor(-52.4337), tensor(-51.8094), tensor(-...","[tensor(0.), tensor(1960.), tensor(2597.), ten...","[h#, dh, ax, pcl, p, r, uw, f, epi, dh, eh, q,..."
...,...,...,...,...,...,...
3991,SA2,MDHL0,"[[tensor(0.0003), tensor(9.2377e-05), tensor(-...","[[tensor(-64.3449), tensor(-64.8335), tensor(-...","[tensor(0.), tensor(2170.), tensor(2750.), ten...","[h#, d, ow, n, q, ae, s, kcl, m, iy, tcl, t, i..."
1352,SI1298,MRJM1,"[[tensor(3.3208e-06), tensor(1.6774e-06), tens...","[[tensor(-58.9032), tensor(-58.2838), tensor(-...","[tensor(0.), tensor(1086.), tensor(1437.), ten...","[h#, b, ax-h, tcl, t, w, iy, n, m, iy, dx, iy,..."
5689,SX43,MNLS0,"[[tensor(3.2796e-05), tensor(2.5904e-05), tens...","[[tensor(-57.8896), tensor(-59.7843), tensor(-...","[tensor(0.), tensor(2321.), tensor(3103.), ten...","[h#, q, eh, l, dcl, d, axr, l, iy, pcl, p, iy,..."
1779,SX393,MAPV0,"[[tensor(1.0439e-05), tensor(7.2835e-06), tens...","[[tensor(-66.2206), tensor(-67.2553), tensor(-...","[tensor(0.), tensor(2310.), tensor(4840.), ten...","[h#, sh, iy, y, ux, z, ix, z, bcl, b, ow, th, ..."
