In [1]:
import json
import os
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
import torchaudio
from tqdm import tqdm

from features_makers import get_pitch, get_energy
from waveglow.text import text_to_sequence

%load_ext autoreload
%autoreload 2

In [2]:
dataset_dir = 'data/datasets/ljspeech/'

In [3]:
@dataclass
class TrainConfig:
    checkpoint_path = "./model_new"
    logger_path = "./logger"
    mel_ground_truth = "data/datasets/ljspeech/mels"
    alignment_path = "data/datasets/ljspeech/alignments"
    data_path = 'data/datasets/ljspeech/train.txt'
    energy_path = 'data/datasets/ljspeech/energies'
    pitch_path = 'data/datasets/ljspeech/pitches'
    wav_path = 'data/datasets/ljspeech/LJSpeech-1.1/wavs'
    
    wandb_project = 'fastspeech_example'
    
    text_cleaners = ['english_cleaners']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = 'cuda:0'

    batch_size = 16
    epochs = 2000
    n_warm_up_step = 4000

    learning_rate = 1e-3
    weight_decay = 1e-6
    grad_clip_thresh = 1.0
    decay_step = [500000, 1000000, 2000000]

    save_step = 3000
    log_step = 5
    clear_Time = 20

    batch_expand_size = 32

In [4]:
def process_text(train_text_path):
    with open(train_text_path, "r", encoding="utf-8") as f:
        txt = []
        for line in f.readlines():
            txt.append(line)

        return txt

In [5]:
def make_index(train_config):
    buffer = list()
    text = process_text(train_config.data_path)

    os.makedirs(os.path.join(
        os.path.dirname(train_config.data_path),
        'pitches'
    ), exist_ok=True)
    os.makedirs(os.path.join(
        os.path.dirname(train_config.data_path),
        'energies'
    ), exist_ok=True)

    filenames = []
    for filename in Path(train_config.wav_path).iterdir():
        filenames.append(filename.name)

    i_to_filename = {i: filename for i, filename in enumerate(sorted(filenames))}

    for i in tqdm(range(len(text))):
        mel_path = os.path.join(
            train_config.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i + 1)
        )
        duration_path = os.path.join(
            train_config.alignment_path, str(i) + ".npy"
        )

        character = text[i][0:len(text[i]) - 1]
        og_text = character
        character = text_to_sequence(character, train_config.text_cleaners)

        audio_name = i_to_filename[i]
        audio_path = os.path.join(train_config.wav_path, audio_name)
        audio, sr = torchaudio.load(audio_path)
        audio = audio.to(torch.float64).numpy().sum(axis=0)

        mel = np.load(mel_path)
        pitch = get_pitch(mel, audio, sr)
        pitch_path = os.path.join(
            train_config.pitch_path, str(i) + ".npy"
        )
        np.save(pitch_path, pitch)

        energy = get_energy(mel)
        energy_path = os.path.join(
            train_config.energy_path, str(i) + ".npy"
        )
        np.save(energy_path, energy)
        
        buffer.append({
            "text": og_text,
            "tokens": character, 
            "duration_path": duration_path,
            "mel_path": mel_path,
            "energy_path": energy_path,
            "pitch_path": pitch_path,
            "audio_len": audio.shape[0],
            "audio_path": audio_path
        })

    return buffer

In [6]:
index = make_index(TrainConfig())

100%|██████████| 13100/13100 [42:16<00:00,  5.16it/s]


In [7]:
index_path = 'data/datasets/ljspeech/train_index.json'
with open(index_path, 'w') as fp:
    json.dump(index, fp, indent=2)