In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import json
import torch
import librosa
import numpy as np
import webrtcvad
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# =========================
# Config
# =========================
SAMPLE_RATE = 22050
N_MELS = 80
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SPEAKER_EMB_DIM = 64

# =========================
# Audio Utils
# =========================
def trim_silence(wav, sr):
    vad = webrtcvad.Vad(2)
    frame = int(sr * 0.02)  # 20ms
    voiced = []

    for i in range(0, len(wav) - frame, frame):
        chunk = wav[i:i+frame]
        pcm = (chunk * 32768).astype(np.int16).tobytes()
        try:
            if vad.is_speech(pcm, sr):
                voiced.extend(chunk)
        except:
            continue

    return np.array(voiced) if len(voiced) > 0 else wav


def wav_to_mel(wav):
    mel = librosa.feature.melspectrogram(
        y=wav,
        sr=SAMPLE_RATE,
        n_mels=N_MELS
    )
    return torch.FloatTensor(np.log(mel + 1e-6))

# =========================
# Dataset
# =========================
class TTSDataset(Dataset):
    def __init__(self, manifest):
        self.data = json.load(open(manifest))
        speakers = sorted({d["speaker"] for d in self.data})
        self.spk2id = {s: i for i, s in enumerate(speakers)}

    def __getitem__(self, idx):
        item = self.data[idx]

        wav, _ = librosa.load(item["audio"], sr=SAMPLE_RATE)
        wav = trim_silence(wav, SAMPLE_RATE)

        mel = wav_to_mel(wav)

        text = torch.LongTensor(item["phonemes"])

        # pseudo duration
        dur_value = max(1, mel.shape[1] // len(text))
        duration = torch.ones(len(text)) * dur_value

        speaker_id = torch.LongTensor([self.spk2id[item["speaker"]]])

        return text, duration, mel, speaker_id

    def __len__(self):
        return len(self.data)

# =========================
# LoRA Linear
# =========================
class LoRALinear(nn.Module):
    def __init__(self, in_f, out_f, r=8, alpha=16):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(out_f, in_f),
            requires_grad=False
        )
        self.A = nn.Parameter(torch.randn(r, in_f) * 0.01)
        self.B = nn.Parameter(torch.randn(out_f, r) * 0.01)
        self.scale = alpha / r

    def forward(self, x):
        base = x @ self.weight.T
        lora = (x @ self.A.T @ self.B.T) * self.scale
        return base + lora

# =========================
# Model
# =========================
class ResearchTTS_LoRA(nn.Module):
    def __init__(self, vocab, dim=256, num_speakers=1):
        super().__init__()

        self.embed = nn.Embedding(vocab, dim)
        self.spk_embed = nn.Embedding(num_speakers, SPEAKER_EMB_DIM)

        self.encoder = nn.LSTM(
            dim + SPEAKER_EMB_DIM,
            dim,
            batch_first=True
        )

        self.duration = nn.Conv1d(dim, 1, 3, padding=1)

        self.decoder = nn.LSTM(dim, dim, batch_first=True)

        self.mel_proj = LoRALinear(dim, N_MELS)

        # Freeze backbone
        for p in self.embed.parameters():
            p.requires_grad = False
        for p in self.encoder.parameters():
            p.requires_grad = False
        for p in self.decoder.parameters():
            p.requires_grad = False

    def forward(self, text, speaker_id):
        x = self.embed(text)

        spk = self.spk_embed(speaker_id).squeeze(1)
        spk = spk.unsqueeze(1).expand(-1, x.size(1), -1)

        x = torch.cat([x, spk], dim=-1)

        x, _ = self.encoder(x)

        dur = self.duration(x.transpose(1, 2)).squeeze(1)

        # ===== Length Regulator =====
        expanded = []
        for b in range(x.size(0)):
            reps = torch.clamp(dur[b].round().long(), min=1)
            expanded_seq = torch.repeat_interleave(x[b], reps, dim=0)
            expanded.append(expanded_seq)

        x = torch.nn.utils.rnn.pad_sequence(expanded, batch_first=True)

        x, _ = self.decoder(x)
        mel = self.mel_proj(x).transpose(1, 2)

        return mel, dur

# =========================
# Loss
# =========================
def l1(a, b):
    return torch.mean(torch.abs(a - b))

# =========================
# Train
# =========================
def train(manifest, saved_model_name):
    dataset = TTSDataset(manifest)

    # Auto vocab
    max_token = max(max(d["phonemes"]) for d in dataset.data)
    vocab_size = max_token + 10
    print("Auto vocab size:", vocab_size)

    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    model = ResearchTTS_LoRA(
        vocab=vocab_size,
        num_speakers=len(dataset.spk2id)
    ).to(DEVICE)

    optim = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4
    )

    for epoch in range(30):
        total = 0

        for text, dur, mel, spk in tqdm(loader):

            text = text.to(DEVICE)
            mel = mel.to(DEVICE)
            spk = spk.to(DEVICE)
            dur = dur.to(DEVICE)

            pred_mel, pred_dur = model(text, spk)

            # Crop mel to match
            min_len = min(pred_mel.size(2), mel.size(2))
            pred_mel = pred_mel[:, :, :min_len]
            mel = mel[:, :, :min_len]

            loss = (
                l1(pred_mel, mel) +
                l1(pred_dur, dur)
            )

            optim.zero_grad()
            loss.backward()
            optim.step()

            total += loss.item()

        print(f"[LoRA][Multi-Spk] Epoch {epoch}: {total/len(loader):.4f}")

    torch.save(model.state_dict(), saved_model_name)


  import pkg_resources


In [2]:
train("manifest_whisper_rename_path.json", "research_tts_lora_multispk_whisper.pt")

Auto vocab size: 5643


100%|██████████| 2558/2558 [00:59<00:00, 42.76it/s]


[LoRA][Multi-Spk] Epoch 0: 22.8745


100%|██████████| 2558/2558 [01:32<00:00, 27.78it/s]


[LoRA][Multi-Spk] Epoch 1: 16.1821


100%|██████████| 2558/2558 [01:37<00:00, 26.33it/s]


[LoRA][Multi-Spk] Epoch 2: 14.6404


100%|██████████| 2558/2558 [01:39<00:00, 25.79it/s]


[LoRA][Multi-Spk] Epoch 3: 14.1161


100%|██████████| 2558/2558 [01:43<00:00, 24.63it/s]


[LoRA][Multi-Spk] Epoch 4: 13.8483


100%|██████████| 2558/2558 [01:46<00:00, 23.93it/s]


[LoRA][Multi-Spk] Epoch 5: 13.6743


100%|██████████| 2558/2558 [01:46<00:00, 23.98it/s]


[LoRA][Multi-Spk] Epoch 6: 13.5486


100%|██████████| 2558/2558 [01:46<00:00, 24.06it/s]


[LoRA][Multi-Spk] Epoch 7: 13.4524


100%|██████████| 2558/2558 [01:45<00:00, 24.15it/s]


[LoRA][Multi-Spk] Epoch 8: 13.3754


100%|██████████| 2558/2558 [01:55<00:00, 22.24it/s]


[LoRA][Multi-Spk] Epoch 9: 13.3140


100%|██████████| 2558/2558 [01:58<00:00, 21.61it/s]


[LoRA][Multi-Spk] Epoch 10: 13.2602


100%|██████████| 2558/2558 [01:45<00:00, 24.20it/s]


[LoRA][Multi-Spk] Epoch 11: 13.2129


100%|██████████| 2558/2558 [01:55<00:00, 22.10it/s]


[LoRA][Multi-Spk] Epoch 12: 13.1708


100%|██████████| 2558/2558 [01:51<00:00, 22.91it/s]


[LoRA][Multi-Spk] Epoch 13: 13.1339


100%|██████████| 2558/2558 [01:46<00:00, 23.95it/s]


[LoRA][Multi-Spk] Epoch 14: 13.1008


100%|██████████| 2558/2558 [01:57<00:00, 21.86it/s]


[LoRA][Multi-Spk] Epoch 15: 13.0698


100%|██████████| 2558/2558 [01:52<00:00, 22.69it/s]


[LoRA][Multi-Spk] Epoch 16: 13.0416


100%|██████████| 2558/2558 [01:54<00:00, 22.25it/s]


[LoRA][Multi-Spk] Epoch 17: 13.0135


100%|██████████| 2558/2558 [01:56<00:00, 21.96it/s]


[LoRA][Multi-Spk] Epoch 18: 12.9861


100%|██████████| 2558/2558 [01:55<00:00, 22.14it/s]


[LoRA][Multi-Spk] Epoch 19: 12.9591


100%|██████████| 2558/2558 [01:57<00:00, 21.80it/s]


[LoRA][Multi-Spk] Epoch 20: 12.9321


100%|██████████| 2558/2558 [01:57<00:00, 21.83it/s]


[LoRA][Multi-Spk] Epoch 21: 12.9047


100%|██████████| 2558/2558 [01:56<00:00, 21.98it/s]


[LoRA][Multi-Spk] Epoch 22: 12.8785


100%|██████████| 2558/2558 [01:55<00:00, 22.21it/s]


[LoRA][Multi-Spk] Epoch 23: 12.8550


100%|██████████| 2558/2558 [01:55<00:00, 22.18it/s]


[LoRA][Multi-Spk] Epoch 24: 12.8327


100%|██████████| 2558/2558 [01:56<00:00, 21.88it/s]


[LoRA][Multi-Spk] Epoch 25: 12.8134


100%|██████████| 2558/2558 [01:50<00:00, 23.12it/s]


[LoRA][Multi-Spk] Epoch 26: 12.7965


100%|██████████| 2558/2558 [01:58<00:00, 21.64it/s]


[LoRA][Multi-Spk] Epoch 27: 12.7800


100%|██████████| 2558/2558 [01:55<00:00, 22.09it/s]


[LoRA][Multi-Spk] Epoch 28: 12.7676


100%|██████████| 2558/2558 [01:56<00:00, 21.88it/s]

[LoRA][Multi-Spk] Epoch 29: 12.7559





In [3]:
train("manifest_pathumma_rename_path.json", "research_tts_lora_multispk_pathumma.pt")

Auto vocab size: 5643


100%|██████████| 2590/2590 [01:10<00:00, 36.86it/s]


[LoRA][Multi-Spk] Epoch 0: 34.4155


100%|██████████| 2590/2590 [01:14<00:00, 34.90it/s]


[LoRA][Multi-Spk] Epoch 1: 27.1116


100%|██████████| 2590/2590 [01:23<00:00, 31.14it/s]


[LoRA][Multi-Spk] Epoch 2: 24.7394


100%|██████████| 2590/2590 [01:34<00:00, 27.27it/s]


[LoRA][Multi-Spk] Epoch 3: 23.7159


100%|██████████| 2590/2590 [01:39<00:00, 26.16it/s]


[LoRA][Multi-Spk] Epoch 4: 23.2205


100%|██████████| 2590/2590 [01:43<00:00, 24.95it/s]


[LoRA][Multi-Spk] Epoch 5: 22.9487


100%|██████████| 2590/2590 [01:46<00:00, 24.22it/s]


[LoRA][Multi-Spk] Epoch 6: 22.7712


100%|██████████| 2590/2590 [01:48<00:00, 23.97it/s]


[LoRA][Multi-Spk] Epoch 7: 22.6448


100%|██████████| 2590/2590 [01:48<00:00, 23.91it/s]


[LoRA][Multi-Spk] Epoch 8: 22.5514


100%|██████████| 2590/2590 [01:47<00:00, 24.00it/s]


[LoRA][Multi-Spk] Epoch 9: 22.4759


100%|██████████| 2590/2590 [01:47<00:00, 24.10it/s]


[LoRA][Multi-Spk] Epoch 10: 22.4114


100%|██████████| 2590/2590 [01:49<00:00, 23.58it/s]


[LoRA][Multi-Spk] Epoch 11: 22.3550


100%|██████████| 2590/2590 [01:47<00:00, 24.05it/s]


[LoRA][Multi-Spk] Epoch 12: 22.3077


100%|██████████| 2590/2590 [01:47<00:00, 24.19it/s]


[LoRA][Multi-Spk] Epoch 13: 22.2658


100%|██████████| 2590/2590 [01:48<00:00, 23.81it/s]


[LoRA][Multi-Spk] Epoch 14: 22.2275


100%|██████████| 2590/2590 [01:47<00:00, 24.08it/s]


[LoRA][Multi-Spk] Epoch 15: 22.1919


100%|██████████| 2590/2590 [01:49<00:00, 23.72it/s]


[LoRA][Multi-Spk] Epoch 16: 22.1575


100%|██████████| 2590/2590 [01:49<00:00, 23.56it/s]


[LoRA][Multi-Spk] Epoch 17: 22.1244


100%|██████████| 2590/2590 [01:47<00:00, 24.16it/s]


[LoRA][Multi-Spk] Epoch 18: 22.0919


100%|██████████| 2590/2590 [01:48<00:00, 23.81it/s]


[LoRA][Multi-Spk] Epoch 19: 22.0585


100%|██████████| 2590/2590 [01:49<00:00, 23.69it/s]


[LoRA][Multi-Spk] Epoch 20: 22.0251


100%|██████████| 2590/2590 [01:49<00:00, 23.62it/s]


[LoRA][Multi-Spk] Epoch 21: 21.9948


100%|██████████| 2590/2590 [01:48<00:00, 23.79it/s]


[LoRA][Multi-Spk] Epoch 22: 21.9667


100%|██████████| 2590/2590 [01:48<00:00, 23.84it/s]


[LoRA][Multi-Spk] Epoch 23: 21.9442


100%|██████████| 2590/2590 [01:49<00:00, 23.61it/s]


[LoRA][Multi-Spk] Epoch 24: 21.9259


100%|██████████| 2590/2590 [01:44<00:00, 24.73it/s]


[LoRA][Multi-Spk] Epoch 25: 21.9110


100%|██████████| 2590/2590 [01:52<00:00, 23.08it/s]


[LoRA][Multi-Spk] Epoch 26: 21.9000


100%|██████████| 2590/2590 [01:50<00:00, 23.37it/s]


[LoRA][Multi-Spk] Epoch 27: 21.8887


100%|██████████| 2590/2590 [01:51<00:00, 23.22it/s]


[LoRA][Multi-Spk] Epoch 28: 21.8792


100%|██████████| 2590/2590 [01:52<00:00, 23.10it/s]

[LoRA][Multi-Spk] Epoch 29: 21.8701



