<a href="https://colab.research.google.com/github/usamireko/StableTTS-Training-Colab/blob/main/StableTTS_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown Installation
from google.colab import drive
drive.mount("/content/drive")
!git clone https://github.com/KdaiP/StableTTS.git
%cd /content/StableTTS
!pip install -r requirements.txt

In [None]:
#@markdown A pretrained model will be downloaded in the folder checkpoints created for you inside the StableTTS folder on you GDrive :)
project_name = "" #@param {type:"string"}
!mkdir /content/drive/MyDrive/StableTTS/{project_name}
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/checkpoints
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/wavs
!mkdir /content/drive/MyDrive/StableTTS/{project_name}/logs
!wget "https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/StableTTS/checkpoint_0.pt" -O /content/drive/MyDrive/StableTTS/{project_name}/checkpoints/checkpoint_0.pt

In [None]:
#@markdown Download the FireFly-GAN vocoder
!wget "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt" -O /content/StableTTS/vocoders/pretrained/firefly-gan-base-generator.ckpt

Now upload your wav files into the wavs folder created inside the StableTTS/{project_name} folder

Your list should have been adapted to point the full path of the wav instead of just "wavs/xxx.wav" as in Tacotron2

Edit preprocess.py to put the paths of the wav folder, the converted txt list and the txt file itself under `DataConfing`

**Variables that needs to be edited**

*Paths inside " " should be the ones used if followed step by step this notebook, so just copy and paste to those variables with {project_name} replaced by the name you chose earlier :3*

```
*   input_filelist_path : "/content/drive/MyDrive/StableTTS/{project_name}/"
*   output_filelist_path: "/content/drive/MyDrive/StableTTS/{project_name}/filelist.json"
*   output_feature_path: "/content/drive/MyDrive/StableTTS/{project_name}"
*   language: english, japanese, chinese, depending of your dataset
```


In [None]:
!python preprocess.py

Edit config.py to adapt to your session

**Variables that need to be modified under `TrainConfig`**

*Paths inside " " should be the ones used if followed step by step this notebook, so just copy and paste to those variables with {project_name} replaced by the name you chose earlier :3*

```
*   train_dataset_path: filelist.json path "/content/drive/MyDrive/StableTTS/{project_name}/filelist.json"
*   test_dataset_path: filelist.json path "/content/drive/MyDrive/StableTTS/{project_name}/filelist.json"(Not used)
*   model_save_path: It really explains itself "/content/drive/MyDrive/StableTTS/{project_name}/checkpoints"
*   log_dir: Path where the logs will be saved "/content/drive/MyDrive/StableTTS/{project_name}/logs"
*   log_interval: Every x epochs logs will be saved
*   save_interval: Same stuff than log but with the checkpoints
```


**Variables that you _might_ want to modify later in the future if needed**


```
*   batch_size
*   learning_rate
*   num_epochs
```








In [None]:
#@markdown Start training!
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/StableTTS/{project_name}/logs
!python train.py

In [None]:
#@markdown Inference Tab!
from IPython.display import Audio, display
import torch

from api import StableTTSAPI

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tts_model_path = '' #@param {type:"string"}
vocoder_model_path = '/content/StableTTS/vocoders/pretrained/firefly-gan-base-generator.ckpt' # path to vocoder checkpoint
vocoder_type = 'ffgan' # ffgan or vocos

# vocoder_model_path = './vocoders/pretrained/vocos.pt'
# vocoder_type = 'vocos'

model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type)
model.to(device)

tts_param, vocoder_param = model.get_params()
print(f'tts_param: {tts_param}, vocoder_param: {vocoder_param}')

In [None]:
text = 'test' #@param {type:"string"}
ref_audio = 'path'#@param {type:"string"}
language = 'english' #@param ["english", "japanese", "chinese"]
solver = 'dopri5' #@param ["dopri5", "euler", "midpoint"]
steps = 30 #@param {type:"slider", min:0, max:50, step:1}
cfg = 3  #@param {type:"slider", min:0, max:10, step:1}

audio_output, mel_output = model.inference(text, ref_audio, language, steps, 1, 1, solver, cfg)

display(Audio(ref_audio))
display(Audio(audio_output, rate=model.mel_config.sample_rate))