<a href="https://colab.research.google.com/github/usamireko/StableTTS-Training-Colab/blob/main/StableTTS_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown Installation
from google.colab import drive
drive.mount("/content/drive")
!git clone https://github.com/KdaiP/StableTTS.git
%cd /content/StableTTS
!pip install -r requirements.txt

In [None]:
#@markdown A pretrained model will be downloaded in the folder checkpoints created for you inside the StableTTS folder on you GDrive :)
project_name = "" #@param {type:"string"}
!mkdir /content/drive/MyDrive/StableTTS/{project_name}
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/checkpoints
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/wavs
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/logs
!wget "https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/StableTTS/checkpoint_0.pt" -O /content/drive/MyDrive/StableTTS/{project_name}/checkpoints/checkpoint_0.pt

In [None]:
#@markdown Convert your Tacotron2 list to StableTTS
def update_transcript(file_path, project_name):
    # Read the list.txt file
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    # Update each line with the new path format
    updated_lines = []
    for line in lines:
        parts = line.split('|')
        wav_file = parts[0].replace('wavs/', f'/content/drive/MyDrive/StableTTS/{project_name}/wavs/')
        text = parts[1].strip()  # Strip to remove any extra whitespace/newlines
        updated_lines.append(f'{wav_file}|{text}\n')

    # Write the updated lines back to the file (or a new file if preferred)
    with open(file_path, 'w', encoding='utf-8') as file:
        file.writelines(updated_lines)

    print(f"File '{file_path}' updated successfully.")


# Specify the path to your list.txt in Google Drive
file_path = '' #@param {type:"string"}



# Run the update function
update_transcript(file_path, project_name)

In [None]:
#@markdown Download the FireFly-GAN vocoder
!wget "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt" -O /content/StableTTS/vocoders/pretrained/firefly-gan-base-generator.ckpt

Now upload your wav files into the wavs folder created inside the StableTTS/{project_name} folder

Your list should have been adapted to point the full path of the wav instead of just "wavs/xxx.wav" as in Tacotron2

output_filelist_path points to a JSON file created in the same place as where your list.txt is located,

**Configuration**

_resample is not really needed_

In [None]:
%%writefile /content/StableTTS/preprocess.py

import os
import json
from tqdm import tqdm
from dataclasses import dataclass, asdict

import torch
from torch.multiprocessing import Pool, set_start_method
import torchaudio

from config import MelConfig, TrainConfig
from utils.audio import LogMelSpectrogram, load_and_resample_audio

from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

@dataclass
class DataConfig:
    input_filelist_path = '' #@param {type:"string"}
    output_filelist_path = '' #@param {type:"string"}
    output_feature_path = '' #@param {type:"string"}
    language = 'english' #@param ["english", "japanese", "chinese"]
    resample = False #@param{type:"boolean"}

g2p_mapping = {
    'chinese': chinese_to_cnm3,
    'japanese': japanese_to_ipa2,
    'english': english_to_ipa2,
}

data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()

input_filelist_path = data_config.input_filelist_path
output_filelist_path = data_config.output_filelist_path
output_feature_path = data_config.output_feature_path

# Ensure output directories exist
output_mel_dir = os.path.join(output_feature_path, 'mels')
os.makedirs(output_mel_dir, exist_ok=True)
os.makedirs(os.path.dirname(output_filelist_path), exist_ok=True)

if data_config.resample:
    output_wav_dir = os.path.join(output_feature_path, 'waves')
    os.makedirs(output_wav_dir, exist_ok=True)

mel_extractor = LogMelSpectrogram(**asdict(mel_config)).to(device)

g2p = g2p_mapping.get(data_config.language)

def load_filelist(path) -> list:
    file_list = []
    with open(path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            audio_path, text = line.strip().split('|', maxsplit=1)
            file_list.append((str(idx), audio_path, text))
    return file_list

@ torch.inference_mode()
def process_filelist(line) -> str:
    idx, audio_path, text = line
    audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) # shape: [1, time]
    if audio is not None:
        # get output path
        audio_name, _ = os.path.splitext(os.path.basename(audio_path))

        try:
            phone = g2p(text)
            if len(phone) > 0:
                mel = mel_extractor(audio.to(device)).cpu().squeeze(0) # shape: [n_mels, time // hop_length]
                output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt')
                torch.save(mel, output_mel_path)

                if data_config.resample:
                    audio_path = os.path.join(output_wav_dir, f'{idx}_{audio_name}.wav')
                    torchaudio.save(audio_path, audio.cpu(), mel_config.sample_rate)
                return json.dumps({'mel_path': output_mel_path, 'phone': phone, 'audio_path': audio_path, 'text': text, 'mel_length': mel.size(-1)}, ensure_ascii=False, allow_nan=False)
        except Exception as e:
            print(f'Error processing {audio_path}: {str(e)}')


def main():
    set_start_method('spawn') # CUDA must use spawn method
    input_filelist = load_filelist(input_filelist_path)
    results = []

    with Pool(processes=2) as pool:
        for result in tqdm(pool.imap(process_filelist, input_filelist), total=len(input_filelist)):
            if result is not None:
                results.append(f'{result}\n')

    # save filelist
    with open(output_filelist_path, 'w', encoding='utf-8') as f:
        f.writelines(results)
    print(f"filelist file has been saved to {output_filelist_path}")

# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if __name__ == '__main__':
    main()

In [None]:
#@markdown Preprocess dataset
!python preprocess.py

**Modifying config.py**



*   num_epochs isnt really needed to be modified for a pretrained session
*   test_dataset_path isnt needed, just copy and paste the same path of train_dataset_path






```
Variables that you might want to modify later in the future if needed
*   batch_size
*   learning_rate (found in the config.py file)
```








In [None]:
%%writefile /content/StableTTS/config.py

from dataclasses import dataclass

@dataclass
class MelConfig:
    sample_rate: int = 44100
    n_fft: int = 2048
    win_length: int = 2048
    hop_length: int = 512
    f_min: float = 0.0
    f_max: float = None
    pad: int = 0
    n_mels: int = 128
    center: bool = False
    pad_mode: str = "reflect"
    mel_scale: str = "slaney"

    def __post_init__(self):
        if self.pad == 0:
            self.pad = (self.n_fft - self.hop_length) // 2

@dataclass
class ModelConfig:
    hidden_channels: int = 256
    filter_channels: int = 1024
    n_heads: int = 4
    n_enc_layers: int = 3
    n_dec_layers: int = 6
    kernel_size: int = 3
    p_dropout: int = 0.1
    gin_channels: int = 256

@dataclass
class TrainConfig:
    train_dataset_path: str = '' #@param {type:"string"}
    test_dataset_path: str = '' #@param {type:"string"}
    batch_size: int = 32 #@param {type:"integer"}
    learning_rate: float = 1e-4
    num_epochs: int = 10000 #@param {type:"integer"}
    model_save_path: str = '' #@param {type:"string"}
    log_dir: str = '' #@param {type:"string"}
    log_interval: int = 25 #@param {type:"integer"}
    save_interval: int = 50 #@param {type:"integer"}
    warmup_steps: int = 200

@dataclass
class VocosConfig:
    input_channels: int = 128
    dim: int = 512
    intermediate_dim: int = 1536
    num_layers: int = 8

In [None]:
#@markdown Start training!
log_dir = '' #@param {type:"string"}
%load_ext tensorboard
%tensorboard --logdir {log_dir}
!python train.py

In [None]:
#@markdown Inference Tab!
from IPython.display import Audio, display
import torch

from api import StableTTSAPI

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tts_model_path = '' #@param {type:"string"}
vocoder_model_path = '/content/StableTTS/vocoders/pretrained/firefly-gan-base-generator.ckpt' # path to vocoder checkpoint
vocoder_type = 'ffgan' # ffgan or vocos

# vocoder_model_path = './vocoders/pretrained/vocos.pt'
# vocoder_type = 'vocos'

model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type)
model.to(device)

tts_param, vocoder_param = model.get_params()
print(f'tts_param: {tts_param}, vocoder_param: {vocoder_param}')

In [None]:
text = 'test' #@param {type:"string"}
ref_audio = 'path'#@param {type:"string"}
language = 'english' #@param ["english", "japanese", "chinese"]
solver = 'dopri5' #@param ["dopri5", "euler", "midpoint"]
steps = 30 #@param {type:"slider", min:0, max:50, step:1}
cfg = 3  #@param {type:"slider", min:0, max:10, step:1}

audio_output, mel_output = model.inference(text, ref_audio, language, steps, 1, 1, solver, cfg)

display(Audio(ref_audio))
display(Audio(audio_output, rate=model.mel_config.sample_rate))