In [44]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import json
import torch
import librosa
import numpy as np
import webrtcvad
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import soundfile as sf

In [45]:
SAMPLE_RATE = 22050
N_MELS = 80
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SPEAKER_EMB_DIM = 64

In [46]:
def inference(model, phoneme_ids, speaker_id):
    model.eval()

    with torch.no_grad():
        text = torch.LongTensor(phoneme_ids).unsqueeze(0).to(DEVICE)
        spk = torch.LongTensor([speaker_id]).unsqueeze(0).to(DEVICE)

        pred_mel, pred_dur = model(text, spk)

        mel = pred_mel.squeeze(0).cpu()  # (80, T)

    return mel


In [47]:
def mel_to_wav(mel):
    mel = torch.exp(mel)  # inverse log
    mel = mel.numpy()

    wav = librosa.feature.inverse.mel_to_audio(
        mel,
        sr=SAMPLE_RATE,
        n_iter=60
    )
    return wav


In [48]:
# =========================
# LoRA Linear
# =========================
class LoRALinear(nn.Module):
    def __init__(self, in_f, out_f, r=8, alpha=16):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(out_f, in_f),
            requires_grad=False
        )
        self.A = nn.Parameter(torch.randn(r, in_f) * 0.01)
        self.B = nn.Parameter(torch.randn(out_f, r) * 0.01)
        self.scale = alpha / r

    def forward(self, x):
        base = x @ self.weight.T
        lora = (x @ self.A.T @ self.B.T) * self.scale
        return base + lora

# =========================
# Model
# =========================
class ResearchTTS_LoRA(nn.Module):
    def __init__(self, vocab, dim=256, num_speakers=1):
        super().__init__()

        self.embed = nn.Embedding(vocab, dim)
        self.spk_embed = nn.Embedding(num_speakers, SPEAKER_EMB_DIM)

        self.encoder = nn.LSTM(
            dim + SPEAKER_EMB_DIM,
            dim,
            batch_first=True
        )

        self.duration = nn.Conv1d(dim, 1, 3, padding=1)

        self.decoder = nn.LSTM(dim, dim, batch_first=True)

        self.mel_proj = LoRALinear(dim, N_MELS)

        # Freeze backbone
        for p in self.embed.parameters():
            p.requires_grad = False
        for p in self.encoder.parameters():
            p.requires_grad = False
        for p in self.decoder.parameters():
            p.requires_grad = False

    def forward(self, text, speaker_id):
        x = self.embed(text)

        spk = self.spk_embed(speaker_id).squeeze(1)
        spk = spk.unsqueeze(1).expand(-1, x.size(1), -1)

        x = torch.cat([x, spk], dim=-1)

        x, _ = self.encoder(x)

        dur = self.duration(x.transpose(1, 2)).squeeze(1)

        # ===== Length Regulator =====
        expanded = []
        for b in range(x.size(0)):
            reps = torch.clamp(dur[b].round().long(), min=1)
            expanded_seq = torch.repeat_interleave(x[b], reps, dim=0)
            expanded.append(expanded_seq)

        x = torch.nn.utils.rnn.pad_sequence(expanded, batch_first=True)

        x, _ = self.decoder(x)
        mel = self.mel_proj(x).transpose(1, 2)

        return mel, dur

In [49]:
def trim_silence(wav, sr):
    vad = webrtcvad.Vad(2)
    frame = int(sr * 0.02)  # 20ms
    voiced = []

    for i in range(0, len(wav) - frame, frame):
        chunk = wav[i:i+frame]
        pcm = (chunk * 32768).astype(np.int16).tobytes()
        try:
            if vad.is_speech(pcm, sr):
                voiced.extend(chunk)
        except:
            continue

    return np.array(voiced) if len(voiced) > 0 else wav


def wav_to_mel(wav):
    mel = librosa.feature.melspectrogram(
        y=wav,
        sr=SAMPLE_RATE,
        n_mels=N_MELS
    )
    return torch.FloatTensor(np.log(mel + 1e-6))

class TTSDataset(Dataset):
    def __init__(self, manifest):
        self.data = json.load(open(manifest))
        speakers = sorted({d["speaker"] for d in self.data})
        self.spk2id = {s: i for i, s in enumerate(speakers)}

    def __getitem__(self, idx):
        item = self.data[idx]

        wav, _ = librosa.load(item["audio"], sr=SAMPLE_RATE)
        wav = trim_silence(wav, SAMPLE_RATE)

        mel = wav_to_mel(wav)

        text = torch.LongTensor(item["phonemes"])

        # pseudo duration
        dur_value = max(1, mel.shape[1] // len(text))
        duration = torch.ones(len(text)) * dur_value

        speaker_id = torch.LongTensor([self.spk2id[item["speaker"]]])

        return text, duration, mel, speaker_id

    def __len__(self):
        return len(self.data)

In [50]:
with open("inference_whisper.json", "r", encoding="utf-8") as f:
    inference_whisper = json.load(f)

for data_dict in inference_whisper:
    data_dict["speaker"] = "Lisa"

with open("inference_whisper_Lisa.json", "w", encoding="utf-8") as f:
    json.dump(inference_whisper, f, ensure_ascii=False, indent=2)

In [51]:
# โหลดโมเดล
dataset = TTSDataset("inference_whisper_Lisa.json")

max_token = max(max(d["phonemes"]) for d in dataset.data)
vocab_size = max_token + 10

model = ResearchTTS_LoRA(
    vocab=vocab_size,
    num_speakers=len(dataset.spk2id)
).to(DEVICE)

model = ResearchTTS_LoRA(
    vocab=5643,
    num_speakers=4
).to(DEVICE)

model.load_state_dict(torch.load("research_tts_lora_multispk_whisper.pt"))
model.eval()

# phoneme
phonemes = dataset.data[0]["phonemes"]
speaker_name = dataset.data[0]["speaker"]
speaker_id = dataset.spk2id[speaker_name]

mel = inference(model, phonemes, speaker_id)

wav = mel_to_wav(mel)

wav = wav / np.max(np.abs(wav))
sf.write("inference_whisper_Lisa_output.wav", wav, SAMPLE_RATE)
