-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Have you test PL-BERT with FastSpeech 2 ? #3
Comments
No, I only tested it on StyleTTS, if you mean test for MOS comparison. However, I did train FS2 and VITS with PL-BERT, and it does sound better than without PL-BERT, though I don't see any point in testing them because they are not state-of-the-art models, and MOS comparison for American speakers costs a lot of money to do. Do you need the code for both MP-BERT and PL-BERT or just PL-BERT? It might be difficult to give you the pre-trained MP-BERT model because I have deleted them to save space. I still have the pre-trained BPE tokenizers and the code, but I don't have the pre-trained models. I can provide the pre-trained models for PL-BERT along with its training code. I will push them to this repo in the next week or so. |
Thanks @yl4579 |
Hi @yl4579, |
Yes I can share the MP-BERT BPE tokenizer here: https://drive.google.com/file/d/1h3WGT0Vb2x9ft4kDZtCWpYy9z35wMJ6q/view?usp=sharing. Example use case: from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("tokenizer-sup-30000.json")
#coding: utf-8
import os
import os.path as osp
import time
import random
import numpy as np
import random
import string
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
np.random.seed(1)
random.seed(1)
class FilePathDataset(torch.utils.data.Dataset):
def __init__(self, df):
self.data = df
self.max_mel_length = 512
self.word_mask_prob = 0.15
self.phoneme_mask_prob = 0.1
self.replace_prob = 0.2
self.text_cleaner = TextCleaner()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
phonemes = self.data[idx]['phonemes']
words = []
word_labels = []
labels = ""
phoneme = ""
phoneme_list = ' '.join(phonemes)
masked_index = []
encoded = tokenizer.encode(phoneme_list)
tokens = encoded.tokens
ids = encoded.ids
for (idx, token) in zip(ids, tokens):
z = [token, idx]
real_word = idx
labels += z[0] + " "
if np.random.rand() < self.word_mask_prob:
if np.random.rand() < self.replace_prob:
if np.random.rand() < 0.5:
phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))]) # randomized
fake_word = np.random.randint(5, 29999)
else:
phoneme += z[0]
fake_word = idx
else:
phoneme += 'M' * len(z[0]) # masked
fake_word = 4
masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
else:
phoneme += z[0]
fake_word = idx
word_labels.extend([real_word] * len(z[0]))
word_labels.append(2)
words.extend([fake_word] * len(z[0]))
words.append(2)
phoneme += " "
mel_length = len(phoneme)
masked_idx = np.array(masked_index)
masked_index = []
if mel_length > self.max_mel_length:
random_start = np.random.randint(0, mel_length - self.max_mel_length)
phoneme = phoneme[random_start:random_start + self.max_mel_length]
words = words[random_start:random_start + self.max_mel_length]
word_labels = word_labels[random_start:random_start + self.max_mel_length]
labels = labels[random_start:random_start + self.max_mel_length]
for m in masked_idx:
if m >= random_start and m < random_start + self.max_mel_length:
masked_index.append(m - random_start)
phoneme = self.text_cleaner(phoneme)
labels = self.text_cleaner(labels)
assert len(phoneme) == len(words)
assert len(phoneme) == len(labels)
assert len(phoneme) == len(word_labels)
phonemes = torch.LongTensor(phoneme)
labels = torch.LongTensor(labels)
words = torch.LongTensor(words)
word_labels = torch.LongTensor(word_labels)
return phonemes, words, word_labels, labels, masked_index
class Collater(object):
"""
Args:
adaptive_batch_size (bool): if true, decrease batch size when long data comes.
"""
def __init__(self, return_wave=False):
self.text_pad_index = 0
self.return_wave = return_wave
def __call__(self, batch):
# batch[0] = wave, mel, text, f0, speakerid
batch_size = len(batch)
# sort by mel length
lengths = [b[1].shape[0] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]
max_text_length = max([b[1].shape[0] for b in batch])
word_labels = torch.zeros((batch_size, max_text_length)).long()
words = torch.zeros((batch_size, max_text_length)).long()
labels = torch.zeros((batch_size, max_text_length)).long()
phonemes = torch.zeros((batch_size, max_text_length)).long()
input_lengths = []
masked_indices = []
for bid, (phoneme, word, word_label, label, masked_index) in enumerate(batch):
text_size = phoneme.size(0)
words[bid, :text_size] = word
word_labels[bid, :text_size] = word_label
labels[bid, :text_size] = label
phonemes[bid, :text_size] = phoneme
input_lengths.append(text_size)
masked_indices.append(masked_index)
return words, word_labels, labels, phonemes, input_lengths, masked_indices
def build_dataloader(df,
validation=False,
batch_size=4,
num_workers=1,
device='cpu',
collate_config={},
dataset_config={}):
dataset = FilePathDataset(df, **dataset_config)
collate_fn = Collater(**collate_config)
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=(not validation),
num_workers=num_workers,
drop_last=(not validation),
collate_fn=collate_fn,
pin_memory=(device != 'cpu'))
return data_loader MP-BERT architecture: class MultiTaskModel(nn.Module):
def __init__(self, model, dropout=0.1, num_tokens=178, num_vocab=30000, hidden_size=768):
super().__init__()
self.phoneme_emb = nn.Embedding(num_tokens, 128)
self.word_emb = nn.Embedding(num_vocab, 128)
self.encoder = model
self.mask_predictor = nn.Linear(hidden_size, num_tokens)
self.word_predictor = nn.Linear(hidden_size, num_vocab)
def forward(self, phonemes, words, attention_mask=None):
word_emb = self.word_emb(words)
phoneme_emb = self.phoneme_emb(phonemes)
embs = word_emb + phoneme_emb
output = self.encoder(inputs_embeds=embs, attention_mask=attention_mask)
tokens_pred = self.mask_predictor(output.last_hidden_state)
words_pred = self.word_predictor(output.last_hidden_state)
return tokens_pred, words_pred That's all I have about MP-BERT for now. I have deleted the pretrained models of MPBERT unfortunately. I can upload the entire jupyter notebook of training if it helps. The code to preprocess the dataset needs some time to be cleaned because I'm busy with my another project. I will try to upload it by this week. |
@yl4579 Thanks for the data loader code, I finished the model code but bit struggling with data management and pre-processing as I haven't worked on such large text data before as I am more of an Image processing guy. |
@yl4579 as per above code you separate each phoneme with separator token that is 2, is it good to put separator at phoneme level? |
@rishikksh20 They are separated at the word level. Note that each word (in this case, sup-phoneme unit, I didn't change the variable name because in PL-BERT implementation they are words instead of sup-phoneme units) is repeated |
@yl4579 |
How you tokenizer the words for PL-BERT for classification task. Are you simple create unique word dictionary of wikipedia corpus and assign unique integer id to each word? |
I just used a pre-trained tokenizer from XLNET that does not do BPE encoding, so there's no need to do any alignment for subword units. |
Thanks I will look at it. |
XLNet Tokenizer is also a subword units based tokenizer
outputs:
How you are using this for whole word tokenization for classification ? |
I used |
I've uploaded the full preprocessing and training code along with the pre-trained model. |
thanks |
@yl4579 One very interesting use case for these kind of phoneme bert that can we train these kind of model as multi-lingual way. |
Yes, I have tested with multilingual PLBERT on StyleTTS and it does work quite well. I tested on joint Chinese, English and Japanese and it does improve the clarity and accent if you use PLBERT for the first stage of training (not just the second stage). |
Thanks |
@yl4579 It sound really good, it has worked with ipa in multilingual situation. could you explain bit more? have you trained PL-BERT with Japanese, Chinese, English text corpus? would it be open-sourced too? really interesting topic |
@seastar105 Yes. I have trained a PL-BERT in Japanese, Chinese, and English, where Japanese and Chinese are at character level for simplicity, and English is the same as in this repo. The total vocabulary is around 10k, and the corpus was Wikipedia for these three languages. I will make a branch for Japanese preprocessing at the word level as the character-level preprocessing is much simpler, so I don't think you will need my code. The word-level is more complicated so I will open-source that one. Chinese should be very similar to Japanese. |
@rishikksh20 Are you planning on training such a multilingual PL-Bert? If so I would be happy to help on this task. |
Eventually I will work on that, but right now I am engaging on some other thing. Thanks for your support @lexkoro will let you know when I start to work on that. |
@yl4579 are you able share training and pre-processing (including BPE dictionary) notebook for MP-Bert with me? |
@rishikksh20 Yes. The training code is exactly the same as this one except the input now includes the phonemes and sup-phonemes, so I think you can figure it out yourself. Basically you only need to change the training code with the following lines: words, word_labels, labels, phonemes, input_lengths, masked_indices = batch
text_mask = length_to_mask(torch.Tensor(input_lengths))# .to(device)
tokens_pred, words_pred = bert(phonemes ,words, attention_mask=(~text_mask).int())
loss_token = 0
sizes = 0
for _s2s_pred, _text_input, _text_length, _masked_indices in zip(tokens_pred, labels, input_lengths, masked_indices):
if len(_masked_indices) > 0:
_text_input = _text_input[:_text_length][_masked_indices]
loss_tmp = criterion(_s2s_pred[:_text_length][_masked_indices],
_text_input[:_text_length])
loss_token += loss_tmp
sizes += 1
loss_token /= sizes
loss_vocab = 0
sizes = 0
for _s2s_pred, _text_input, _text_length, _masked_indices in zip(words_pred, word_labels, input_lengths, masked_indices):
if len(_masked_indices) > 0:
_text_input = _text_input[:_text_length][_masked_indices]
loss_tmp = criterion(_s2s_pred[:_text_length][_masked_indices],
_text_input[:_text_length])
loss_vocab += loss_tmp
sizes += 1
loss_vocab /= sizes Here "words" are the sup-phoneme units. The BPE tokenizers can be trained with the following code: from datasets import load_from_disk
dataset = load_from_disk("wikipedia_20220301.en.processed")
from tokenizers import Tokenizer
from tokenizers.models import BPE
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(vocab_size=30000, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
show_progress=True)
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
batch_size = 1000
def batch_iterator():
for i in range(0, len(dataset), batch_size):
yield [' '.join(p) for p in dataset[i : i + batch_size]['phonemes']]
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
tokenizer.save("tokenizer-sup-30000.json") |
Hi @yl4579
MP-BERT paper publish result with FS2 so to make that comparison have you integrated PL-BERT with FastSpeech 2.
If not, then I will test that from my side as I am interested to know MP-BERT behavior comparison to MP-BERT with FastSpeech 2.
Also let me know when the code is available if you are planning to open source it. Otherwise, I will try to implement the paper from my side.
Thanks
The text was updated successfully, but these errors were encountered: