# Reference training notebook for Stylish TTS

## Download dataset and unzip

In [None]:
from IPython.utils import io
print("Downloading and extracting dataset...")
with io.capture_output():
    !mkdir /content
    %cd /content
    !rm -r dataset
    !mkdir dataset
    !wget https://huggingface.co/datasets/hr16/ViVoicePP/resolve/main/CDMedia_processed.zip -nc
    %cd dataset
    !unzip ../CDMedia_processed.zip
print("Finished")

## Install stuffs

In [None]:

import requests
from bs4 import BeautifulSoup
import torch
import sys
from IPython.utils import io

print("Installing k2-fsa...")
with io.capture_output():
    py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    torch_ver, device = torch.__version__.split('+')
    device = device[:len('cu12')] + '.' + device[len('cu12'):]
    device = device.replace('cu', 'cuda')
    base_url = "https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/"
    index_page = BeautifulSoup(requests.get(base_url).text)
    whl_url = None
    for a_el in BeautifulSoup(requests.get(base_url + torch_ver).text).select('a.external'):
        _whl_url = a_el.get('href')
        if device in _whl_url and py_version in _whl_url:
            whl_url = _whl_url
            break
    !pip install {whl_url}


print("Installing Stylish...")
with io.capture_output():
    %cd /content
    !wget https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt -nc
    !rm -r /content/stylish-dataset
    !git clone https://github.com/Fannovel16/stylish-dataset/
    
    %cd /content
    !rm -r stylish-tts
    !git clone https://github.com/Fannovel16/stylish-tts/
    %cd stylish-tts
    try:
        import munch
    except:
        !pip install matplotlib resampy -U
        !pip install -r requirements.txt
    %cd /content
print("Finished")

## Extract pitches using RMVPE

In [None]:
%cd /content/stylish-dataset
print("Warmup may take a bit of time")
!python stylish-dataset/all-pitch.py \
    --method rmvpe \
    --rmvpe_checkpoint /content/rmvpe.pt \
    --wavdir /content/dataset \
    --trainpath /content/dataset/train_list.txt \
    --valpath /content/dataset/val_list.txt \
    --outpath /content/dataset/pitch.safetensors

## Write config files

In [None]:
%cd /content/stylish-tts/

In [None]:
%%writefile config/model.yml
multispeaker: false
n_mels: 80
sample_rate: 24000
n_fft: 2048
win_length: 1200
hop_length: 300
style_dim: 128
inter_dim: 512

text_aligner:
  hidden_dim: 256
  token_embedding_dim: 512

decoder:
  hidden_dim: 1024
  residual_dim: 64

generator:
  type: 'ringformer'
  resblock_kernel_sizes: [ 3, 7, 11 ]
  upsample_rates: [ 4, 5 ]
  upsample_initial_channel: 512
  resblock_dilation_sizes: [ [ 1, 3, 5 ], [ 1, 3, 5 ], [ 1, 3, 5 ] ]
  upsample_kernel_sizes: [ 8, 10 ]
  gen_istft_n_fft: 60
  gen_istft_hop_size: 15
  depth: 2

text_encoder:
  tokens: 178 # number of phoneme tokens
  hidden_dim: 192
  filter_channels: 768
  heads: 2
  layers: 6
  kernel_size: 3
  dropout: 0.1

style_encoder:
  layers: 4

duration_predictor:
  n_layer: 3
  max_dur: 50 # maximum duration of a single phoneme
  dropout: 0.2

pitch_energy_predictor:
  dropout: 0.2

# speech language model config
slm:
  model: 'microsoft/wavlm-base-plus'
  sr: 16000 # sampling rate of SLM

symbol:
  pad: "$"
  punctuation: ";:,.!?¡¿—…\"()“” "
  letters: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
  letters_ipa: "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢăǁᵊǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
  voiced: "ɑɐɒæɓʙβɔɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɥɨɪʝɭɫɮʟɱɯɰŋɳɲɴøɵœɶɹɺɾʀʁɽʉʊʋⱱʌɣɤʎʏʑʐʒʔʕʢᵊᵻˈˌ"
  unvoiced: "ɕçɧħʜɬɸθʘɻʂʃʈʧʍχʡǀǁǃːˑʼʴʰʱʲʷˠˤ˞pftsckxqh↓↑→↗↘'̩';:,.!?¡¿—…\"()“” $"


In [None]:
%%writefile config/config.yml
training:
  # log training data every this number of steps
  log_interval: 100
  # validate and save model every this number of steps
  save_interval: 2500
  # validate model every this number of steps
  val_interval: 2500
  device: "cuda"
  # Keep this as 'no' if you have the VRAM.
  # Lower precision slows training.
  # "bf16", "fp16", or "no" for no mixed precision
  mixed_precision: "no"

# Number of epochs, max batch sizes and general learning rate of each stage.
training_plan:
  alignment:
    # alignment pretraining
    epochs: 8
    # Maximum number of segments per batch.
    probe_batch_max: 128
    # Learing Rate for this stage
    lr: 3e-5
  acoustic:
    # training of acoustic models and vocoder
    epochs: 10
    probe_batch_max: 8
    lr: 1e-4
  textual:
    # training for duration/pitch/energy from text
    epochs: 20
    probe_batch_max: 16
    lr: 1e-4

dataset:
  train_data: "/content/dataset/train_list.txt"
  val_data: "/content/dataset/val_list.txt"
  wav_path: "/content/dataset"
  pitch_path: "/content/dataset/pitch.safetensors"
  alignment_path: "/content/dataset/alignment.safetensors"

validation:
  # Number of samples to generate per validation step
  sample_count: 10
  # Specific segments to use for validation
  # force_samples:
  # - "filename.from.val_data.txt"
  # - "other.filename.from.val_data.txt"
  # - "other.other.filename.from.val_data.txt"


# Weights are pre-tuned. Do not change these unless you
# know what you are doing.
loss_weight:
  # mel reconstruction loss
  mel: 5
  # generator loss
  generator: 2
  # speech-language model feature matching loss
  slm: 0.2
  # pitch F0 reconstruction loss
  pitch: 3
  # energy reconstruction loss
  energy: 1
  # duration loss
  duration: 1
  # duration predictor probability output cross entropy loss
  duration_ce: 1
  # style reconstruction loss
  style: 1
  # magnitude/phase loss
  magphase: 1
  # confidence for alignment (placeholder)
  confidence: 1
  # alignment loss
  align_loss: 1
  # discriminator loss (placeholder)
  discriminator: 1

## Train alignment model TDNN_BLSTM-CTC and cache the alignment

In [None]:
!PYTHONPATH="$PYTHONPATH:/content/stylish-tts/train:/content/stylish-tts/lib" python \
    train/stylish_train/train.py \
    --model_config_path config/model.yml \
    --config_path config/config.yml \
    --stage alignment \
    --out_dir /content/checkpoints/

In [None]:
!PYTHONPATH="$PYTHONPATH:/content/stylish-tts/train:/content/stylish-tts/lib" python \
    train/stylish_train/dataprep/align_text.py \
    --model_config_path config/model.yml \
    --config_path config/config.yml \
    --model /content/checkpoints/alignment_model.safetensors \
    --out /content/dataset/alignment.safetensors

## Filter out samples having misaligned phonemes by confidence score

In [None]:
FILTER_PERCENTAGE = 0.2

from pathlib import Path
lines = Path("/content/dataset/scores_train.txt").read_text().splitlines()
data = []
for line in lines:
    score, path = line.split(' ', 1)
    score = float(score)
    data.append((path, score))
data.sort(key=lambda x: x[1])
threshold = data[:int(len(data) * FILTER_PERCENTAGE)][-1][1]
threshold = round(threshold, 2)
removed = set([path for path, score in data if score < threshold])

filtered_lines = []
for line in Path("/content/dataset/train_list.txt").read_text(encoding="utf-8").splitlines():
    path, ipa, *others = line.split('|')
    if path not in removed: filtered_lines.append(line)

Path("/content/dataset/filtered_train_list.txt").write_text('\n'.join(filtered_lines), encoding="utf-8")
!sed -i 's/\/train_list.txt/\/filtered_train_list.txt/' config/config.yml
print(f"Filtered out {len(removed)} files")

## Actual training

In [None]:
!PYTHONPATH="$PYTHONPATH:/content/stylish-tts/train:/content/stylish-tts/lib" python \
    train/stylish_train/train.py \
    --model_config_path config/model.yml \
    --config_path config/config.yml \
    --stage acoustic \
    --out_dir /content/checkpoints/

## Upload checkpoints to HuggingFace

In [None]:
!huggingface-cli login --token="YOUR_HF_MODEL_TOKEN"
!huggingface-cli upload stylish-tts-cd-media /content/checkpoints/textual/ .