In [1]:
from bm_features import XTTSFeatureExtractor, XTTSProcessor, XTTSTokenizer
from datasets import load_dataset
import datasets
import torch

In [2]:
DEVICE = torch.device('mps')
dataset = load_dataset("oza75/bambara-tts", "denoised")
dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=22050)).rename_column('bambara', 'text')
dataset

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'text', 'french', 'duration', 'speaker_embeddings', 'speaker_id'],
        num_rows: 30765
    })
})

In [3]:
feature_extractor = XTTSFeatureExtractor("../mel_stats.pth", sampling_rate=22050, max_samples=221000)
tokenizer = XTTSTokenizer.from_pretrained("openai-community/gpt2-medium")
processor = XTTSProcessor(feature_extractor, tokenizer)

batch = processor(dataset['train'][:2], device=DEVICE)
batch

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'XTTSTokenizer'.


{'audio': [{'path': None,
   'array': array([ 0.00099032,  0.00157957,  0.00133549, ..., -0.00630357,
          -0.00620323, -0.00616569]),
   'sampling_rate': 22050},
  {'path': None,
   'array': array([ 0.00011191,  0.00036082,  0.00023281, ..., -0.00214735,
          -0.0011257 , -0.00175181]),
   'sampling_rate': 22050}],
 'text': ['Jigi, i bolo degunnen don wa ?', 'Dɔɔnin !'],
 'french': ['Jigi, es-tu occupé ?', 'Un peu! '],
 'duration': [2.645986394557823, 0.8740136054421769],
 'speaker_embeddings': tensor([[ -2.5645, -20.9284,  69.9060,  ...,  22.0783,  42.1153,  19.8596],
         [ -2.5297,  57.3677, -24.4568,  ...,  72.7989, 108.3561,  71.1548]],
        device='mps:0'),
 'speaker_id': [22, 27],
 'wav_mels': tensor([[[0.8004, 0.9520, 1.2643,  ..., 1.6425, 1.6425, 1.6425],
          [1.0748, 1.2971, 1.7636,  ..., 1.8926, 1.8926, 1.8926],
          [1.6542, 1.9314, 2.3495,  ..., 2.4791, 2.4791, 2.4791],
          ...,
          [1.2285, 1.2285, 1.2285,  ..., 1.2285, 1.2285, 1.2

In [4]:
from bm_models import XttsConfig, Xtts

model = Xtts.from_pretrained("./bm_xtts")
model.to(DEVICE)
model

Some weights of the model checkpoint at ./bm_xtts were not used when initializing Xtts: ['gpt_wrapper.conditioning_perceiver.norm.weight']
- This IS expected if you are initializing Xtts from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Xtts from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Xtts were not initialized from the model checkpoint at ./bm_xtts and are newly initialized: ['gpt_wrapper.conditioning_perceiver.norm.gamma']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Xtts(
  (gpt_wrapper): GPTWrapper(
    (text_emb): Embedding(50257, 1024)
    (mel_emb): Embedding(605, 1024)
    (gpt): GPT2Model(
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-29): 30 x GPT2Block(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (text_pos_emb): LearnedPositionEmbeddings(
      (emb): Embedding(402, 1024)
    )
    (mel_pos_emb): LearnedPos

In [26]:
model(
    input_ids=batch['text_tokens'],
    label_ids=batch['text_tokens'],  # same as text_tokens
    text_lengths=batch['text_lengths'],
    cond_mels=batch['cond_mels'],
    cond_idxs=batch['cond_idxs'],
    cond_lens=batch['cond_len'],
    wav_mels=batch['wav_mels'],
    wav_mel_attention_masks=batch['wav_mel_attention_masks'],
    wav_lengths=batch['wav_lengths'],
)

Xtts(
  (gpt_wrapper): GPTWrapper(
    (text_emb): Embedding(50257, 1024)
    (mel_emb): Embedding(605, 1024)
    (gpt): GPT2Model(
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-29): 30 x GPT2Block(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (text_pos_emb): LearnedPositionEmbeddings(
      (emb): Embedding(402, 1024)
    )
    (mel_pos_emb): LearnedPos