Skip to content

Commit

Permalink
Support for multi-speaker (r9y9#10)
Browse files Browse the repository at this point in the history
* Trying this

* Support for speaker adaptation

* eval more

* New preprocess

* Try this

* this

* remove debug

* trying this

* a

* a

* Does this work?

* Fix order

* hua

* dropout

* Add comment

* This should be safe

* Be explicit to avoid future bug

* Add README for VCTK preprocessing

* rename and tweak

* Change names: use preset="deepvoice3_ljspeech" instead of
use_preset=True

* Revert "This should be safe"

This reverts commit 54fcbec.

* remove latest and set default to deepvoice3

* minor fix

* Add brief guide for speaker adaptation and building multi-speaker model

* key/value projections switch

* Add more doc

* fix up

* never mind

* this

* Revert "this"

This reverts commit ab2fc5e.

* Cleanup hparams and add comment

* maybe

* yeah

* Fix

* try this again

* Revert "try this again"

This reverts commit dc5d38e.
  • Loading branch information
r9y9 committed Dec 21, 2017
1 parent a934acd commit 0421749
Show file tree
Hide file tree
Showing 14 changed files with 1,875 additions and 132 deletions.
91 changes: 76 additions & 15 deletions README.md
@@ -1,4 +1,4 @@
# deepvoice3_pytorch
# Deepvoice3_pytorch

[![Build Status](https://travis-ci.org/r9y9/deepvoice3_pytorch.svg?branch=master)](https://travis-ci.org/r9y9/deepvoice3_pytorch)

Expand All @@ -12,11 +12,11 @@ Current progress and planned TO-DOs can be found at [#1](https://github.com/r9y9
## Highlights

- Convolutional sequence-to-sequence model with attention for text-to-speech synthesis
- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/) and [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) datasets
- Multi-speaker and single speaker versions of DeepVoice3
- Audio samples and pre-trained models
- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets
- Language-dependent frontend text processor for English and Japanese

Support for multi-speaker models is planned but not completed yet.

## Audio samples

- [DeepVoice3] Samples from the model trained on LJ Speech Dataset: https://www.dropbox.com/sh/uq4tsfptxt0y17l/AADBL4LsPJRP2PjAAJRSH5eta?dl=0
Expand All @@ -26,29 +26,33 @@ Support for multi-speaker models is planned but not completed yet.

| URL | Model | Data | Hyper paramters | Git commit | Steps |
|-----|------------|----------|--------------------------------------------------|----------------------|--------|
| [link](https://www.dropbox.com/s/4r207fq6s8gt2sl/20171213_deepvoice3_checkpoint_step00021000.pth?dl=0) | DeepVoice3 | LJSpeech | `--hparams="builder=deepvoice3,use_preset=True"` | [4357976](https://github.com/r9y9/deepvoice3_pytorch/tree/43579764f35de6b8bac2b18b52a06e4e11b705b2)| 210000 |
| [link](https://www.dropbox.com/s/j8ywsvs3kny0s0x/20171129_nyanko_checkpoint_step000585000.pth?dl=0) | Nyanko | LJSpeech | `--hparams="builder=nyanko,use_preset=True"` | [ba59dc7](https://github.com/r9y9/deepvoice3_pytorch/tree/ba59dc75374ca3189281f6028201c15066830116) | 585000 |
| [link](https://www.dropbox.com/s/4r207fq6s8gt2sl/20171213_deepvoice3_checkpoint_step00021000.pth?dl=0) | DeepVoice3 | LJSpeech | `builder=deepvoice3,preset=deepvoice3_ljspeech` | [4357976](https://github.com/r9y9/deepvoice3_pytorch/tree/43579764f35de6b8bac2b18b52a06e4e11b705b2)| 210000 |
| [link](https://www.dropbox.com/s/j8ywsvs3kny0s0x/20171129_nyanko_checkpoint_step000585000.pth?dl=0) | Nyanko | LJSpeech | `builder=nyanko,preset=nyanko_ljspeech` | [ba59dc7](https://github.com/r9y9/deepvoice3_pytorch/tree/ba59dc75374ca3189281f6028201c15066830116) | 585000 |
| [TODO](https://www.dropbox.com/s/j8ywsvs3kny0s0x/20171129_nyanko_checkpoint_step000585000.pth?dl=0) | Multi-speaker DeepVoice3 | VCTK | `builder=deepvoice3_vctk,preset=deepvoice3_vctk` | [TODO](https://github.com/r9y9/deepvoice3_pytorch/tree/ba59dc75374ca3189281f6028201c15066830116) | 300000 |

See the `Synthesize from a checkpoint` section in the README for how to generate speech samples. Please make sure that you are on the specific git commit noted above.

## Notes on hyper parameters

- Default hyper parameters, used during preprocessing/training/synthesis stages, are turned for English TTS using LJSpeech dataset. You will have to change some of parameters if you want to try other datasets. See `hparams.py` for details.
- `builder` specifies which model you want to use. `deepvoice3` [1] and `nyanko` [2] are surpprted.
- `presets` represents hyper parameters known to work well for LJSpeech dataset from my experiments. Before you try to find your best parameters, I would recommend you to try those presets by setting `use_preset=True`. E.g,
- `builder` specifies which model you want to use. `deepvoice3`, `deepvoice3_multispeaker` [1] and `nyanko` [2] are surpprted.
- `presets` represents hyper parameters known to work well for particular dataset/model from my experiments. Before you try to find your best parameters, I would recommend you to try those presets by setting `preset=${name}`. e.g., for LJSpeech, you can try either
```
python train.py --data-root=./data/ljspeech --checkpoint-dir=checkpoints_deepvoice3 \
--hparams="use_preset=True,builder=deepvoice3" \
--hparams="builder=deepvoice3,preset=deepvoice3_ljspeech" \
--log-event-path=log/deepvoice3_preset
```
or
```
python train.py --data-root=./data/ljspeech --checkpoint-dir=checkpoints_nyanko \
--hparams="use_preset=True,builder=nyanko" \
--hparams="builder=nyanko,preset=nyanko_ljspeech" \
--log-event-path=log/nyanko_preset
```
- Hyper parameters described in DeepVoice3 paper for single speaker didn't work for LJSpeech dataset, so I changed a few things. Add dilated convolution, more channels, more layers and add guided loss, etc. See code for details.

- Hyper parameters described in DeepVoice3 paper for single speaker didn't work for LJSpeech dataset, so I changed a few things. Add dilated convolution, more channels, more layers and add guided attention loss, etc. See code for details. The changes are also applied for multi-speaker model.
- Multiple attention layers are hard to learn. Empirically, one or two (first and last) attention layers seems enough.
- With guided attention (see https://arxiv.org/abs/1710.08969), alignments get monotonic more quickly and reliably if we use multiple attention layers. With guided attention, I can confirm five attention layers get monotonic, though I cannot get speech quality improvements.
- Binary divergence (described in https://arxiv.org/abs/1710.08969) seems stabilizes training particularly for deep (> 10 layers) networks.
- Adam with step lr decay works. However, for deeper networks, I find Adam + noam's lr scheduler is more stable.

## Requirements

Expand Down Expand Up @@ -79,6 +83,7 @@ pip install -e ".[jp]"
### 0. Download dataset

- LJSpeech (en): https://keithito.com/LJ-Speech-Dataset/
- VCTK (en): http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
- JSUT (jp): https://sites.google.com/site/shinnosuketakamichi/publication/jsut

### 1. Preprocessing
Expand All @@ -89,7 +94,13 @@ Preprocessing can be done by `preprocess.py`. Usage is:
python preprocess.py ${dataset_name} ${dataset_path} ${out_dir}
```

Supported `${dataset_name}`s for now are `ljspeech` and `jsut`. Suppose you will want to preprocess LJSpeech dataset and have it in `~/data/LJSpeech-1.0`, then you can preprocess data by:
Supported `${dataset_name}`s for now are

- `ljspeech` (en, single speaker)
- `vctk` (en, multi-speaker)
- `jsut` (jp, single speaker)

Suppose you will want to preprocess LJSpeech dataset and have it in `~/data/LJSpeech-1.0`, then you can preprocess data by:

```
python preprocess.py ljspeech ~/data/LJSpeech-1.0/ ./data/ljspeech
Expand All @@ -108,15 +119,15 @@ python train.py --data-root=${data-root} --hparams="parameters you want to overr
Suppose you will want to build a DeepVoice3-style model using LJSpeech dataset with default hyper parameters, then you can train your model by:

```
python train.py --data-root=./data/ljspeech/ --hparams="use_preset=True,builder=deepvoice3"
python train.py --data-root=./data/ljspeech/ --hparams="builder=deepvoice3,preset=deepvoice3_ljspeech"
```

Model checkpoints (.pth) and alignments (.png) are saved in `./checkpoints` directory per 5000 steps by default.

If you are building a Japaneses TTS model, then for example,

```
python train.py --data-root=./data/jsut --hparams="frontend=jp" --hparams="use_preset=True,builder=deepvoice3"
python train.py --data-root=./data/jsut --hparams="frontend=jp" --hparams="builder=deepvoice3,preset=deepvoice3_ljspeech"
```

`frontend=jp` tell the training script to use Japanese text processing frontend. Default is `en` and uses English text processing frontend.
Expand Down Expand Up @@ -148,6 +159,56 @@ Once upon a time there was a dear little girl who was loved by every one who loo
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```

## Advanced usage

### Multi-speaker model

Currently VCTK is the only supported dataset for building a multi-speaker model. Since some audio samples in VCTK have long silences that affect performance, it's recommended to do phoneme alignment and remove silences according to [vctk_preprocess/README.md](vctk_preprocess/README.md).

Once you have phoneme alignment for each utterance, you can extract features by:

```
python preprocess.py vctk ${your_vctk_root_path} ./data/vctk
```

Now that you have data prepared, then you can train a multi-speaker version of DeepVoice3 by:

```
python train.py --data-root=./data/vctk --checkpoint-dir=checkpoints_vctk \
--hparams="preset=deepvoice3_vctk,builder=deepvoice3_multispeaker" \
--log-event-path=log/deepvoice3_multispeaker_vctk_preset
```

If you want to reuse learned embedding from other dataset, then you can do this instead by:

```
python train.py --data-root=./data/vctk --checkpoint-dir=checkpoints_vctk \
--hparams="preset=deepvoice3_vctk,builder=deepvoice3_multispeaker" \
--log-event-path=log/deepvoice3_multispeaker_vctk_preset \
--load-embedding=20171213_deepvoice3_checkpoint_step000210000.pth
```

This may improve training speed a bit.

### Speaker adaptation

If you have very limited data, then you can consider to try fine-turn pre-trained model. For example, using pre-trained model on LJSpeech, you can adapt it to data from VCTK speaker `p225` (30 mins) by the following command:

```
python train.py --data-root=./data/vctk --checkpoint-dir=checkpoints_vctk_adaptation \
--hparams="builder=deepvoice3,preset=deepvoice3_ljspeech" \
--log-event-path=log/deepvoice3_vctk_adaptation \
--restore-parts="20171213_deepvoice3_checkpoint_step000210000.pth"
--speaker-id=0
```

From my experience, it can get reasonable speech quality very quickly rather than training the model from scratch.

There are two important options used above:

- `--restore-parts=<N>`: It specifies where to load model parameters. The differences from the option `--checkpoint=<N>` are 1) `--restore-parts=<N>` ignores all invalid parameters, while `--checkpoint=<N>` doesn't. 2) `--restore-parts=<N>` tell trainer to start from 0-step, while `--checkpoint=<N>` tell trainer to continue from last step. `--checkpoint=<N>` should be ok if you are using exactly same model and continue to train, but it would be useful if you want to customize your model architecture and take advantages of pre-trained model.
- `--speaker-id=<N>`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset.

## Acknowledgements

Part of code was adapted from the following projects:
Expand Down
71 changes: 39 additions & 32 deletions deepvoice3_pytorch/builder.py
Expand Up @@ -23,6 +23,8 @@ def deepvoice3(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
freeze_embedding=False,
window_ahead=3,
window_backward=1,
key_projection=False,
value_projection=False,
):
"""Build deepvoice3
"""
Expand Down Expand Up @@ -59,6 +61,8 @@ def deepvoice3(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
use_memory_mask=use_memory_mask,
window_ahead=window_ahead,
window_backward=window_backward,
key_projection=key_projection,
value_projection=value_projection,
)

seq2seq = AttentionSeq2Seq(encoder, decoder)
Expand Down Expand Up @@ -107,12 +111,16 @@ def nyanko(n_vocab, embed_dim=128, mel_dim=80, linear_dim=513, r=1,
freeze_embedding=False,
window_ahead=3,
window_backward=1,
key_projection=False,
value_projection=False,
):
from deepvoice3_pytorch.nyanko import Encoder, Decoder, Converter
assert encoder_channels == decoder_channels

if n_speakers != 1:
raise ValueError("Multi-speaker is not supported")
if not (downsample_step == 4 and r == 1):
raise RuntimeError("Not supported. You need to change hardcoded parameters")
raise ValueError("Not supported. You need to change hardcoded parameters")

# Seq2seq
encoder = Encoder(
Expand All @@ -133,6 +141,8 @@ def nyanko(n_vocab, embed_dim=128, mel_dim=80, linear_dim=513, r=1,
use_memory_mask=use_memory_mask,
window_ahead=window_ahead,
window_backward=window_backward,
key_projection=key_projection,
value_projection=value_projection,
)

seq2seq = AttentionSeq2Seq(encoder, decoder)
Expand All @@ -159,26 +169,28 @@ def nyanko(n_vocab, embed_dim=128, mel_dim=80, linear_dim=513, r=1,
return model


def deepvoice3_vctk(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
downsample_step=1,
n_speakers=1, speaker_embed_dim=16, padding_idx=0,
dropout=(1 - 0.95), kernel_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_ahead=3,
window_backward=1,
):
def deepvoice3_multispeaker(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
downsample_step=1,
n_speakers=1, speaker_embed_dim=16, padding_idx=0,
dropout=(1 - 0.95), kernel_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_ahead=3,
window_backward=1,
key_projection=True,
value_projection=True,
):
"""Build multi-speaker deepvoice3
"""
from deepvoice3_pytorch.deepvoice3 import Encoder, Decoder, Converter
Expand All @@ -194,8 +206,8 @@ def deepvoice3_vctk(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
dropout=dropout, max_positions=max_positions,
embedding_weight_std=embedding_weight_std,
# (channels, kernel_size, dilation)
convolutions=[(h, k, 1), (h, k, 3), (h, k, 9),
(h, k, 1), (h, k, 3), (h, k, 9),
convolutions=[(h, k, 1), (h, k, 3), (h, k, 9), (h, k, 27),
(h, k, 1), (h, k, 3), (h, k, 9), (h, k, 27),
(h, k, 1), (h, k, 3)],
)

Expand All @@ -205,18 +217,17 @@ def deepvoice3_vctk(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim,
dropout=dropout, max_positions=max_positions,
preattention=[(h, k, 1)],
convolutions=[(h, k, 1), (h, k, 3), (h, k, 9),
(h, k, 1), (h, k, 3), (h, k, 9),
convolutions=[(h, k, 1), (h, k, 3), (h, k, 9), (h, k, 27),
(h, k, 1)],
attention=[True, False, False,
False, False, False,
True],
attention=[True, False, False, False, False],
force_monotonic_attention=force_monotonic_attention,
query_position_rate=query_position_rate,
key_position_rate=key_position_rate,
use_memory_mask=use_memory_mask,
window_ahead=window_ahead,
window_backward=window_backward,
key_projection=key_projection,
value_projection=value_projection,
)

seq2seq = AttentionSeq2Seq(encoder, decoder)
Expand Down Expand Up @@ -245,7 +256,3 @@ def deepvoice3_vctk(n_vocab, embed_dim=256, mel_dim=80, linear_dim=513, r=4,
freeze_embedding=freeze_embedding)

return model


# TODO:
latest = nyanko

0 comments on commit 0421749

Please sign in to comment.