In [1]:
from datasets import load_dataset
import datasets
from vq_vae.features import VQVAEProcessor, VQVAEFeatureExtractor
import torch

# Instantiate the feature extractor and processor
feature_extractor = VQVAEFeatureExtractor(sampling_rate=22050, mel_norm_file='../mel_stats.pth', max_samples=221000)
processor = VQVAEProcessor(feature_extractor)

# Load your Bambara TTS dataset
dataset = load_dataset("oza75/bambara-tts", "denoised")
dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=22050))
dataset

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'bambara', 'french', 'duration', 'speaker_embeddings', 'speaker_id'],
        num_rows: 30765
    })
})

In [32]:
221000 / 22050

10.022675736961451

In [7]:
batch = processor(dataset['train'][:2])
inputs, attention_masks, speaker_embeddings = batch["mel_spectrogram"], batch["attention_masks"], batch['speaker_embeddings']

In [3]:
from vq_vae.models import BMSpeechVQVAE, BMSpeechVQVAEConfig

config = BMSpeechVQVAEConfig(in_channels=1, out_channels=1, num_layers=4, latent_channels=512, speaker_embed_dim=512, act_fn='relu')
model = BMSpeechVQVAE(config)
model

in_channels:  1
in_channels:  64
in_channels:  128
in_channels:  256


BMSpeechVQVAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): ReLU()
      (5): GroupNorm(32, 128, eps=1e-05, affine=True)
      (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU()
      (8): GroupNorm(32, 256, eps=1e-05, affine=True)
      (9): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): ReLU()
    )
  )
  (quant_conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
  (quantize): VectorQuantizer(
    (embedding): Embedding(512, 64)
  )
  (post_quant_conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1))
  (speaker_latents_fc): Linear(in_features=1024, out_features=512, bias=True)
  (decoder): Decoder(
    (decoder): Sequential(
      (0): GroupNorm(32, 512, eps=1e-05, af

In [9]:
speaker_embeddings.unsqueeze(2).unsqueeze(3).expand(-1, -1, 5, 54)[0].shape

torch.Size([512, 5, 54])

In [8]:
bsz, n_mels, time_steps = inputs.shape
inputs = inputs.view(bsz, 1, n_mels, time_steps)
attention_masks = attention_masks.view(bsz, 1, n_mels, time_steps)
speaker_embeddings = torch.stack([torch.tensor(item) for item in speaker_embeddings])
inputs.shape

torch.Size([2, 1, 80, 864])

In [9]:
outputs = model(inputs, attention_masks, speaker_embeddings=speaker_embeddings)

z_e shape: torch.Size([2, 512, 5, 54])
z_e quant conved shape: torch.Size([2, 64, 5, 54])
z_q shape: torch.Size([2, 64, 5, 54])
z_q post shape: torch.Size([2, 512, 5, 54])
z_q concat shape: torch.Size([2, 1024, 5, 54])
z_q fc shape: torch.Size([2, 512, 5, 54])
z_recon shape: torch.Size([2, 1, 80, 864])


In [10]:
outputs

(tensor(4.3154, grad_fn=<AddBackward0>),
 tensor([[[[0.1472, 0.3314, 0.1975,  ..., 0.2223, 0.2452, 0.1016],
           [0.0570, 0.2349, 0.0000,  ..., 0.2809, 0.0000, 0.2425],
           [0.2627, 0.0885, 0.2913,  ..., 0.3492, 0.2087, 0.1468],
           ...,
           [0.0092, 0.4624, 0.0000,  ..., 0.0000, 0.0000, 0.3139],
           [0.1110, 0.0383, 0.0590,  ..., 0.0000, 0.0000, 0.3534],
           [0.0869, 0.3015, 0.0000,  ..., 0.3252, 0.2253, 0.0530]]],
 
 
         [[[0.2451, 0.3609, 0.4276,  ..., 0.2718, 0.2316, 0.1087],
           [0.0013, 0.2102, 0.0000,  ..., 1.2279, 0.0000, 0.5147],
           [0.2772, 0.0000, 0.4362,  ..., 0.4660, 0.3872, 0.2619],
           ...,
           [0.0000, 0.8945, 0.0000,  ..., 0.7992, 0.0000, 0.0175],
           [0.1648, 0.0000, 0.5626,  ..., 0.0000, 0.0000, 0.0564],
           [0.0753, 0.4211, 0.0000,  ..., 0.6858, 0.1625, 0.0405]]]],
        grad_fn=<ReluBackward0>))

In [5]:
from vq_vae.models import Encoder

encoder = Encoder(input_channels=80)
encoder

Encoder(
  (encoder): Sequential(
    (0): Conv1d(80, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(256, 512, kernel_size=(4,), stride=(2,), padding=(1,))
    (7): ReLU()
    (8): Conv1d(512, 1024, kernel_size=(4,), stride=(2,), padding=(1,))
    (9): ReLU()
    (10): Conv1d(1024, 2048, kernel_size=(4,), stride=(2,), padding=(1,))
    (11): ReLU()
  )
)

In [6]:
import numpy as np
def count_trainable_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [7]:
count_trainable_parameters(model)

6132161

In [6]:
from vq_vae.models import SpeechVQConfig, SpeechVQVAE

speech_config = SpeechVQConfig()
speech_model = SpeechVQVAE(speech_config)
speech_model

SpeechVQVAE(
  (model): VQVAE(
    (encoder): Encoder(
      (conv_in): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (down_blocks): ModuleList(
        (0-4): 5 x DownEncoderBlock2D(
          (resnets): ModuleList(
            (0): ResnetBlock2D(
              (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)
              (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (nonlinearity): SiLU()
            )
          )
          (downsamplers): ModuleList(
            (0): Downsample2D(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
            )
          )
        )
        (5): DownEncoderBlock2D(
          (resnets): ModuleList(
            (0): ResnetBlock2D(
              (norm1): Gr

In [7]:
dataset['train'][0]

{'audio': {'path': None,
  'array': array([ 0.00099032,  0.00157957,  0.00133549, ..., -0.00630357,
         -0.00620323, -0.0061657 ]),
  'sampling_rate': 22050},
 'bambara': 'Jigi, i bolo degunnen don wa ?',
 'french': 'Jigi, es-tu occupé ?',
 'duration': 2.645986394557823,
 'speaker_embeddings': [-2.564516305923462,
  -20.928388595581055,
  69.90596008300781,
  8.361804962158203,
  14.13325309753418,
  50.45071792602539,
  80.53385162353516,
  20.306468963623047,
  -35.76181411743164,
  -18.653125762939453,
  -4.586198329925537,
  -88.45294952392578,
  14.038538932800293,
  -1.9949610233306885,
  29.295623779296875,
  35.923561096191406,
  -4.508488655090332,
  22.126203536987305,
  -20.97467803955078,
  39.27812194824219,
  15.961697578430176,
  35.7476806640625,
  26.484188079833984,
  -12.542716979980469,
  -35.30205154418945,
  92.43451690673828,
  -11.966684341430664,
  -48.78108596801758,
  -42.39558792114258,
  -20.03965187072754,
  21.1246395111084,
  -3.3788418769836426,
  