In [1]:
!pip install datasets

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import librosa



In [2]:
# ---------------------
# 1. Vector Quantization Codec
# ---------------------
class ResidualVQ(nn.Module):
    def __init__(self, num_semantic, num_acoustic, dim, mel_dim):
        super().__init__()
        self.dim = dim
        self.project = nn.Linear(mel_dim, dim)
        self.deproject = nn.Linear(dim, mel_dim)
        self.quant1 = nn.Embedding(num_semantic, dim)
        self.quant2 = nn.Embedding(num_acoustic, dim)

    def forward(self, mel_frames):
        # mel_frames: (B, T, mel_dim)
        B, T, _ = mel_frames.size()
        x = self.project(mel_frames)                    # (B, T, dim)
        flat = x.view(-1, self.dim)                     # (B*T, dim)

        # Semantic quantization
        codebook1 = self.quant1.weight                   # (N1, dim)
        d1 = (flat.pow(2).sum(1, keepdim=True)
              + codebook1.pow(2).sum(1)
              - 2 * flat @ codebook1.t())              # (B*T, N1)
        idx1 = d1.argmin(1)                             # (B*T,)
        quant1 = self.quant1(idx1).view(B, T, self.dim)

        # Acoustic (residual) quantization
        residual = flat - quant1.view(-1, self.dim)     # (B*T, dim)
        codebook2 = self.quant2.weight                   # (N2, dim)
        d2 = (residual.pow(2).sum(1, keepdim=True)
              + codebook2.pow(2).sum(1)
              - 2 * residual @ codebook2.t())          # (B*T, N2)
        idx2 = d2.argmin(1)                             # (B*T,)
        quant2 = self.quant2(idx2).view(B, T, self.dim)

        return idx1.view(B, T), idx2.view(B, T)

    def decode(self, sem_idxs, ac_idxs):
        # sem_idxs, ac_idxs: (B, T)
        sem_emb = self.quant1(sem_idxs)                  # (B, T, dim)
        ac_emb = self.quant2(ac_idxs)                    # (B, T, dim)
        x = sem_emb + ac_emb                             # (B, T, dim)
        mel_recon = self.deproject(x)                    # (B, T, mel_dim)
        return mel_recon

# ---------------------
# 2. Transformer Modules
# ---------------------
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Linear(dim_ff, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (B, T, d_model)
        y, _ = self.attn(x, x, x)
        x = self.ln1(x + y)
        y = self.ff(x)
        return self.ln2(x + y)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=512, n_heads=8, n_layers=6):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])

    def forward(self, token_history):
        # token_history: (B, T, d_model)
        x = token_history
        for layer in self.layers:
            x = layer(x)
        return x  # (B, T, d_model)

class DepthTransformer(nn.Module):
    def __init__(self, d_model, num_semantic, num_acoustic):
        super().__init__()
        self.to_semantic = nn.Linear(d_model, num_semantic)
        self.to_acoustic = nn.Linear(d_model, num_acoustic)

    def forward(self, z_t):
        # z_t: (B, d_model)
        sem_logits = self.to_semantic(z_t)
        ac_logits = self.to_acoustic(z_t)
        return sem_logits, ac_logits

# ---------------------
# 3. RQ‑Tacotron Model
# ---------------------
class RQTacotron(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        encoder_hidden,
        d_model,
        num_semantic,
        num_acoustic,
        mel_dim=80
    ):
        super().__init__()
        # Text Encoder (LSTM)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, encoder_hidden, batch_first=True, bidirectional=True)
        self.enc_proj = nn.Linear(2*encoder_hidden, d_model)

        # VQ codec
        self.vq = ResidualVQ(num_semantic, num_acoustic, d_model, mel_dim)

        # Streams
        self.sem_emb = nn.Embedding(num_semantic, d_model)
        self.ac_emb = nn.Embedding(num_acoustic, d_model)

        # Transformers
        self.temporal = TemporalTransformer(d_model)
        self.depth = DepthTransformer(d_model, num_semantic, num_acoustic)

        # Vocoder
        bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
        self.vocoder = bundle.get_vocoder()

    def forward(self, text_ids, sem_targets=None, ac_targets=None, max_steps=None):
        # text_ids: (B, T_text)
        B = text_ids.size(0)
        device = text_ids.device

        # Encode text → context
        x = self.embedding(text_ids)
        enc_out, _ = self.encoder(x)
        ctx = self.enc_proj(enc_out)  # (B, T_text, d_model)

        # Prepare histories
        sem_history_emb = torch.zeros(B, 0, self.enc_proj.out_features, device=device)
        ac_history_emb = torch.zeros(B, 0, self.enc_proj.out_features, device=device)

        # Determine steps
        if sem_targets is not None:
            T = sem_targets.size(1)
        else:
            T = max_steps

        loss_sem = 0.0
        loss_ac = 0.0
        ce_loss = nn.CrossEntropyLoss()
        sem_logits_prev, ac_logits_prev = None, None

        sem_preds = []
        ac_preds = []

        for t in range(T):
            # Build full history sequence
            history = torch.cat([ctx, sem_history_emb, ac_history_emb], dim=1)
            # Compute context vector for this step
            z_seq = self.temporal(history)
            z_t = z_seq[:, -1, :]

            # Depth predictions
            sem_logits, ac_logits = self.depth(z_t)

            if sem_targets is not None:
                # Training: accumulate semantic loss
                loss_sem += ce_loss(sem_logits, sem_targets[:, t])
                # Acoustic loss delayed by one step
                if ac_logits_prev is not None:
                    loss_ac += ce_loss(ac_logits_prev, ac_targets[:, t-1])
            else:
                # Inference: sample or argmax
                sem_idx = torch.argmax(sem_logits, dim=-1)
                sem_preds.append(sem_idx)
                if ac_logits_prev is not None:
                    ac_idx = torch.argmax(ac_logits_prev, dim=-1)
                    ac_preds.append(ac_idx)

            # Update histories
            if sem_targets is not None:
                sem_idx = sem_targets[:, t]
            sem_emb_t = self.sem_emb(sem_idx).unsqueeze(1)  # (B,1,d_model)
            sem_history_emb = torch.cat([sem_history_emb, sem_emb_t], dim=1)

            ac_logits_prev = ac_logits
            if ac_targets is not None and t > 0:
                ac_idx = ac_targets[:, t-1]
                ac_emb_t = self.ac_emb(ac_idx).unsqueeze(1)
                ac_history_emb = torch.cat([ac_history_emb, ac_emb_t], dim=1)

        # Final acoustic for last step
        if sem_targets is None:
            # append last acoustic prediction
            ac_idx = torch.argmax(ac_logits_prev, dim=-1)
            ac_preds.append(ac_idx)

        if sem_targets is not None:
            return {
                'sem_loss': loss_sem / T,
                'ac_loss': loss_ac / max(1, T-1)
            }
        else:
            sem_seq = torch.stack(sem_preds, dim=1)
            ac_seq = torch.stack(ac_preds, dim=1)
            return sem_seq, ac_seq

    @torch.no_grad()
    def generate(self, text, max_steps=200):
        self.eval()
        text_ids = TextProcessor()(text).unsqueeze(0).to(next(self.parameters()).device)
        sem_seq, ac_seq = self.forward(text_ids, sem_targets=None, ac_targets=None, max_steps=max_steps)

        # Decode to mel
        mel = self.vq.decode(sem_seq, ac_seq)
        # Vocoder to waveform
        audio = self.vocoder(mel.transpose(1, 2))
        return audio[0]


In [3]:
# ---------------------
# 4. Dataset & Training Loop
# ---------------------
class TextProcessor():
    def __init__(self):
        self.char_to_id = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz '.,?!")}
    def __call__(self, text):
        ids = [self.char_to_id[c] for c in text.lower() if c in self.char_to_id]
        return torch.LongTensor(ids)

class LJSpeechDataset(Dataset):
    def __init__(self):
        self.dataset = load_dataset("lj_speech", split="train")
        self.text_processor = TextProcessor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        wav, sr = torchaudio.load(item['file'])
        wav = torchaudio.functional.resample(wav, sr, 22050)[0].numpy()
        mel = librosa.feature.melspectrogram(
            y=wav, sr=22050, n_fft=1024, hop_length=256, n_mels=80
        ).T  # (T,80)
        text_ids = self.text_processor(item['normalized_text'])
        return text_ids, torch.FloatTensor(mel)

def collate_fn(batch):
    texts, mels = zip(*batch)
    texts_padded = nn.utils.rnn.pad_sequence(texts, batch_first=True)
    mels_padded = nn.utils.rnn.pad_sequence(mels, batch_first=True)
    return texts_padded, mels_padded

In [4]:
dataset = LJSpeechDataset()
loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparams
vocab_size = len(TextProcessor().char_to_id)
model = RQTacotron(
    vocab_size=vocab_size,
    embed_dim=16,
    encoder_hidden=16,
    d_model=32,
    num_semantic=32,
    num_acoustic=32,
    mel_dim=80
).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [6]:
# Training
for epoch in range(1, 101):
    for texts, mels in loader:
        texts, mels = texts.to(device), mels.to(device)
        # Quantize targets
        sem_toks, ac_toks = model.vq(mels)

        # Forward + losses
        losses = model(texts, sem_targets=sem_toks, ac_targets=ac_toks)
        loss = 100 * losses['sem_loss'] + 1 * losses['ac_loss']

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        break


OutOfMemoryError: CUDA out of memory. Tried to allocate 14.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 6.12 MiB is free. Process 26568 has 14.73 GiB memory in use. Of the allocated memory 14.00 GiB is allocated by PyTorch, and 618.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Inference example
audio = model.generate("Hello world, streaming TTS!")
torchaudio.save("output_stream.wav", audio.cpu(), 22050)


In [None]:
from pydub import AudioSegment
import IPython

path = "output_stream.wav"
IPython.display.Audio(path)