In [38]:
import librosa
import librosa.display
import numpy as np
import torch
import os
from os import listdir
from os.path import isfile, join
import ntpath
from torch.utils.data import DataLoader
import warnings
from g2p_en import G2p
#import pytorch_lightning as pl

def logmelfilterbank(audio,
                     sampling_rate,
                     fft_size=1024,
                     hop_size=256,
                     win_length=None,
                     window="hann",
                     num_mels=80,
                     fmin=None,
                     fmax=None,
                     eps=1e-10,
                     ):
    """Compute log-Mel filterbank feature.

    Args:
        audio (ndarray): Audio signal (T,).
        sampling_rate (int): Sampling rate.
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length. If set to None, it will be the same as fft_size.
        window (str): Window function type.
        num_mels (int): Number of mel basis.
        fmin (int): Minimum frequency in mel basis calculation.
        fmax (int): Maximum frequency in mel basis calculation.
        eps (float): Epsilon value to avoid inf in log calculation.

    Returns:
        ndarray: Log Mel filterbank feature (#frames, num_mels).

    """
    # get amplitude spectrogram
    x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
                          win_length=win_length, window=window, pad_mode="reflect")
    spc = np.abs(x_stft).T  # (#frames, #bins)

    # get mel basis
    fmin = 0 if fmin is None else fmin
    fmax = sampling_rate / 2 if fmax is None else fmax
    mel_basis = librosa.filters.mel(sampling_rate, fft_size, num_mels, fmin, fmax)
    mel = np.dot(spc, mel_basis.T)
    return np.log10(np.maximum(1e-5, mel)).T

class LJDataset(torch.utils.data.Dataset):
    def __init__(self, hp, split='train'):
        self.hp = hp
        self.split = split
        self.data_files = self._get_data_files(hp.dataset, hp.data_dir, hp.data_file)
        self.mel_matrix = librosa.filters.mel(sr=22050, n_fft=1024, n_mels=80)
        self.g2p_en = G2p()
        
    def _get_data_files(self, dataset, root_dir, file):
        if dataset == 'lj':
            metadata = root_dir + 'metadata.csv'

            data_files = []
            with open(metadata, 'r') as f:
                l = f.readline().strip()
                while l:
                    l = l.split('|')
                    if self.split == 'test':
                        if 'LJ003' in l[0]:
                            pass
                        else:
                            l = f.readline().strip()
                            continue
                            
                    elif self.split == 'valid':
                        if 'LJ001' in l[0] or 'LJ002' in l[0]:
                            pass
                        else:
                            l = f.readline().strip()
                            continue
                    
                    else: # train
                        if 'LJ001' in l[0] or 'LJ002' in l[0] or 'LJ003' in l[0]:
                            l = f.readline().strip()
                            continue
                        else:
                            pass
                    
                    print(l[0])
                    wav_file = root_dir + 'wavs/' + l[0] + '.wav'
                    text = l[2]
                    data_files.append((wav_file, text))
                    l = f.readline().strip()

            return data_files    
        
        elif dataset == 'kss':
            metadata = root_dir + file

            data_files = []
            with open(metadata, 'r') as f:
                l = f.readline().strip()
                while l:
                    l = l.split('|')
                    wav_file = root_dir + l[0]
                    text = l[2]
                    data_files.append((wav_file, text))
                    l = f.readline().strip()

            return data_files
    
    def _get_mel(self, data_file):
        wav, _ = librosa.core.load(data_file, sr=22050)
        #wav, _ = librosa.effects.trim(wav, top_db=40)
        
        with warnings.catch_warnings():
            mel = logmelfilterbank(wav, sampling_rate=22050, fft_size=1024, hop_size=256, fmin=80, fmax=7600)
    
        if self.hp.mel_norm:
            mel = (mel + 5) / 5
            
        return mel
    
    def _get_utf8_values(self, text):
        if self.hp.g2p:
            if self.hp.dataset == 'lj':
                text_array = self.g2p_en(text)
                text = ""
                for t in text_array:
                    text += t
                
        #text = g2p(text)
        text_utf = text.encode()
        ts = [0]
        for t in text_utf:
            ts.append(t)
        ts.append(0)    
        utf8_values = np.array(ts)
        
        return utf8_values
        
        
    def __getitem__(self, index):
        mel = self._get_mel(self.data_files[index][0])
        string = self.data_files[index][1]
        text = self._get_utf8_values(string)
        
        return torch.LongTensor(text), torch.FloatTensor(mel), string, self.data_files[index][0]
        
    def __len__(self):
        return len(self.data_files)
    
class TextMelCollate():
    
    """ Zero-pads model inputs and targets based on number of frames per setep
    """
    def __init__(self, hp):
        self.hp = hp
        
    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized]
        """
        
        outputs = {}
        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        strings = []
        files = []
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text
            strings.append(batch[ids_sorted_decreasing[i]][2])
            files.append(batch[ids_sorted_decreasing[i]][3])
        outputs['text'] = text_padded
        outputs['text_lengths'] = input_lengths
        outputs['strings'] = strings
        outputs['files'] = files
            
        # include mel padded and gate padded
        num_mels = batch[0][1].size(0)    
        max_target_len = max([x[1].shape[1] for x in batch])
        #max_target_len = 1024
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        if self.hp.mel_norm:
            mel_padded.fill_(0)
        else:
            mel_padded.fill_(-5)
            
        output_lengths = torch.LongTensor(len(batch))
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :, :mel.size(1)] = mel
            output_lengths[i] = mel.size(1)
            
        outputs['mels'] = mel_padded
        outputs['mel_lengths'] = output_lengths

        return outputs


In [39]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.light57_hparams import create_hparams
from model import Model
from utils import sizeof_fmt, Logger

In [40]:
stt_hparams, tts_hparams = create_hparams()

In [44]:
trainset = LJDataset(tts_hparams, split='train')
print(trainset)

LJ004-0001
LJ004-0002
LJ004-0003
LJ004-0004
LJ004-0005
LJ004-0006
LJ004-0007
LJ004-0008
LJ004-0009
LJ004-0010
LJ004-0011
LJ004-0012
LJ004-0013
LJ004-0014
LJ004-0015
LJ004-0016
LJ004-0017
LJ004-0018
LJ004-0019
LJ004-0020
LJ004-0021
LJ004-0022
LJ004-0023
LJ004-0024
LJ004-0025
LJ004-0026
LJ004-0027
LJ004-0028
LJ004-0029
LJ004-0030
LJ004-0031
LJ004-0032
LJ004-0033
LJ004-0034
LJ004-0035
LJ004-0036
LJ004-0037
LJ004-0038
LJ004-0039
LJ004-0040
LJ004-0041
LJ004-0042
LJ004-0043
LJ004-0044
LJ004-0045
LJ004-0046
LJ004-0047
LJ004-0048
LJ004-0049
LJ004-0050
LJ004-0051
LJ004-0052
LJ004-0054
LJ004-0055
LJ004-0056
LJ004-0057
LJ004-0058
LJ004-0059
LJ004-0060
LJ004-0061
LJ004-0062
LJ004-0063
LJ004-0064
LJ004-0065
LJ004-0066
LJ004-0067
LJ004-0068
LJ004-0069
LJ004-0070
LJ004-0071
LJ004-0072
LJ004-0073
LJ004-0074
LJ004-0075
LJ004-0076
LJ004-0077
LJ004-0078
LJ004-0079
LJ004-0080
LJ004-0081
LJ004-0082
LJ004-0083
LJ004-0084
LJ004-0085
LJ004-0086
LJ004-0087
LJ004-0088
LJ004-0089
LJ004-0090
LJ004-0091
LJ004-0092

LJ014-0293
LJ014-0294
LJ014-0295
LJ014-0296
LJ014-0297
LJ014-0298
LJ014-0299
LJ014-0300
LJ014-0301
LJ014-0302
LJ014-0303
LJ014-0304
LJ014-0305
LJ014-0306
LJ014-0307
LJ014-0308
LJ014-0309
LJ014-0310
LJ014-0311
LJ014-0312
LJ014-0313
LJ014-0314
LJ014-0315
LJ014-0316
LJ014-0317
LJ014-0318
LJ014-0320
LJ014-0321
LJ014-0322
LJ014-0323
LJ014-0324
LJ014-0325
LJ014-0326
LJ014-0327
LJ014-0328
LJ014-0329
LJ014-0330
LJ014-0331
LJ014-0332
LJ014-0333
LJ014-0334
LJ014-0335
LJ014-0336
LJ014-0337
LJ014-0338
LJ014-0339
LJ014-0340
LJ015-0001
LJ015-0002
LJ015-0003
LJ015-0004
LJ015-0005
LJ015-0006
LJ015-0007
LJ015-0008
LJ015-0009
LJ015-0010
LJ015-0011
LJ015-0012
LJ015-0013
LJ015-0014
LJ015-0015
LJ015-0016
LJ015-0017
LJ015-0018
LJ015-0019
LJ015-0020
LJ015-0021
LJ015-0022
LJ015-0023
LJ015-0024
LJ015-0025
LJ015-0026
LJ015-0027
LJ015-0028
LJ015-0029
LJ015-0030
LJ015-0031
LJ015-0032
LJ015-0033
LJ015-0034
LJ015-0035
LJ015-0036
LJ015-0037
LJ015-0038
LJ015-0039
LJ015-0040
LJ015-0041
LJ015-0042
LJ015-0043
LJ015-0044

LJ029-0120
LJ029-0121
LJ029-0122
LJ029-0123
LJ029-0124
LJ029-0125
LJ029-0126
LJ029-0127
LJ029-0128
LJ029-0129
LJ029-0130
LJ029-0131
LJ029-0132
LJ029-0133
LJ029-0134
LJ029-0135
LJ029-0136
LJ029-0137
LJ029-0138
LJ029-0139
LJ029-0140
LJ029-0141
LJ029-0142
LJ029-0143
LJ029-0144
LJ029-0145
LJ029-0146
LJ029-0147
LJ029-0148
LJ029-0149
LJ029-0150
LJ029-0151
LJ029-0152
LJ029-0153
LJ029-0154
LJ029-0155
LJ029-0156
LJ029-0157
LJ029-0158
LJ029-0159
LJ029-0160
LJ029-0161
LJ029-0162
LJ029-0163
LJ029-0164
LJ029-0165
LJ029-0166
LJ029-0167
LJ029-0168
LJ029-0169
LJ029-0170
LJ029-0171
LJ029-0172
LJ029-0173
LJ029-0174
LJ029-0175
LJ029-0176
LJ029-0177
LJ029-0178
LJ029-0179
LJ029-0180
LJ029-0181
LJ029-0182
LJ029-0183
LJ029-0184
LJ029-0185
LJ029-0186
LJ029-0187
LJ029-0188
LJ029-0189
LJ029-0190
LJ029-0191
LJ029-0192
LJ029-0193
LJ029-0194
LJ029-0195
LJ029-0196
LJ029-0197
LJ029-0198
LJ029-0199
LJ029-0200
LJ029-0201
LJ029-0202
LJ029-0203
LJ029-0204
LJ029-0205
LJ029-0206
LJ029-0207
LJ029-0208
LJ029-0209
LJ029-0210

LJ045-0126
LJ045-0127
LJ045-0128
LJ045-0129
LJ045-0130
LJ045-0131
LJ045-0132
LJ045-0133
LJ045-0134
LJ045-0135
LJ045-0136
LJ045-0137
LJ045-0138
LJ045-0139
LJ045-0140
LJ045-0141
LJ045-0142
LJ045-0143
LJ045-0144
LJ045-0145
LJ045-0146
LJ045-0147
LJ045-0148
LJ045-0149
LJ045-0150
LJ045-0151
LJ045-0152
LJ045-0153
LJ045-0154
LJ045-0155
LJ045-0156
LJ045-0157
LJ045-0158
LJ045-0159
LJ045-0160
LJ045-0161
LJ045-0162
LJ045-0163
LJ045-0164
LJ045-0165
LJ045-0166
LJ045-0167
LJ045-0168
LJ045-0169
LJ045-0170
LJ045-0171
LJ045-0172
LJ045-0173
LJ045-0174
LJ045-0175
LJ045-0176
LJ045-0177
LJ045-0178
LJ045-0179
LJ045-0180
LJ045-0181
LJ045-0182
LJ045-0183
LJ045-0184
LJ045-0185
LJ045-0186
LJ045-0187
LJ045-0188
LJ045-0189
LJ045-0190
LJ045-0191
LJ045-0192
LJ045-0193
LJ045-0194
LJ045-0195
LJ045-0196
LJ045-0197
LJ045-0198
LJ045-0199
LJ045-0200
LJ045-0201
LJ045-0202
LJ045-0203
LJ045-0204
LJ045-0205
LJ045-0206
LJ045-0207
LJ045-0208
LJ045-0209
LJ045-0210
LJ045-0211
LJ045-0212
LJ045-0213
LJ045-0214
LJ045-0215
LJ045-0216

<__main__.LJDataset object at 0x7fe88ffb00d0>


In [45]:
batch = trainset[0]

In [46]:
batch

(tensor([ 0, 68, 72, 65, 72, 48, 32, 75, 82, 65, 65, 49, 78, 73, 72, 48, 75, 65,
         72, 48, 76, 90, 32, 65, 72, 49, 86, 32, 78, 85, 87, 49, 71, 69, 89, 48,
         84, 32, 44, 32, 86, 65, 65, 49, 76, 89, 85, 87, 48, 77, 32, 84, 85, 87,
         49, 32, 46, 32, 66, 65, 89, 49, 32, 65, 65, 49, 82, 84, 72, 69, 82, 48,
         32, 71, 82, 73, 72, 49, 70, 73, 72, 48, 84, 72, 83, 32, 46, 32, 83, 69,
         72, 49, 75, 83, 72, 65, 72, 48, 78, 32, 83, 69, 72, 49, 86, 65, 72, 48,
         78, 32, 68, 72, 65, 72, 48, 32, 66, 73, 72, 48, 71, 73, 72, 49, 78, 73,
         72, 48, 78, 71, 90, 32, 65, 72, 49, 86, 32, 80, 82, 73, 72, 49, 90, 65,
         72, 48, 78, 32, 82, 65, 72, 48, 70, 65, 79, 49, 82, 77, 32, 46,  0]),
 tensor([[-2.3109, -2.1379, -2.0853,  ..., -2.0462, -2.0138, -2.0290],
         [-1.9713, -1.9684, -1.7868,  ..., -1.8993, -2.1514, -2.1673],
         [-2.4588, -1.5695, -1.4098,  ..., -2.4075, -2.4679, -2.2240],
         ...,
         [-3.9692, -3.9631, -3.4529,  ..., -4.