# Text To Speech Model

## Google Colab Integration

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Installations

In [None]:
!apt-get install sox
!apt-get update
!apt-get install espeak
!pip install unidecode
!pip install inflect
!pip install nltk
!pip install whisperx
!pip install torchaudio
!pip install phonemizer
!pip install nemo_toolkit[tts]

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
sox is already the newest version (14.4.2+git20190427-2+deb11u2ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 74 not upgraded.
Get:1 https://cli.github.com/packages stable InRelease [3,917 B]
Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubun

Collecting numpy>=1.22 (from nemo_toolkit[tts])
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.3.2
    Uninstalling numpy-2.3.2:
      Successfully uninstalled numpy-2.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
whisperx 3.4.2 requires numpy>=2.0.2, but you have numpy 1.26.4 which is incompatible.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.2 which is incompatible.
cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.2 which is incompatible.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is 

## Imports

In [None]:
import os
import subprocess
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from typing import List, Dict
import pickle
import gc
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2Processor
from phonemizer import phonemize
from phonemizer.backend import EspeakBackend
from nemo.collections.tts.models import HifiGanModel
import random

print("TensorFlow:", tf.__version__)
print("NumPy:", np.__version__)
print("GPUs:", tf.config.list_physical_devices('GPU'))


TensorFlow: 2.19.0
NumPy: 1.26.4
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## Download Dataset

In [None]:
# Define paths
dataset_path = "/content/drive/MyDrive/final_project/Dataset_3"
# dataset_path = "/content/drive/MyDrive/final_project/Dataset_1"
wav_folder = os.path.join(dataset_path, "wav")
stm_folder = os.path.join(dataset_path, "data/stm")
sph_folder = os.path.join(dataset_path, "data/sph")
clips_folder = os.path.join(dataset_path, "clips")
audio_outputs_folder = os.path.join(dataset_path, "audio_outputs")


# Create necessary directories
os.makedirs(wav_folder, exist_ok=True)
os.makedirs(clips_folder, exist_ok=True)
os.makedirs(audio_outputs_folder, exist_ok=True)


## Pre-Processing

### Convert .sph to .wav

In [None]:
def convert_sph_to_wav(sph_folder, wav_folder, max_files=None):
    os.makedirs(wav_folder, exist_ok=True)
    sph_files = glob.glob(os.path.join(sph_folder, "*.sph"))
    if max_files is not None:
        sph_files = sph_files[:max_files]
        print(f"Processing {max_files} files out of {len(sph_files)} total files")
    for sph_file in tqdm(sph_files, desc="Converting .sph to .wav"):
        basename = os.path.splitext(os.path.basename(sph_file))[0]
        wav_path = os.path.join(wav_folder, f"{basename}.wav")
        if not os.path.exists(wav_path):
            subprocess.run(["sox", sph_file, wav_path])

### Parse .stm files

In [None]:
def parse_stm_file(stm_path):
    segments = []
    segments_dict = {}
    with open(stm_path, 'r') as f:
        for line in f:
            parts = line.strip().split(' ')
            talk_id, speaker, start, end = parts[0], parts[1], float(parts[3]), float(parts[4])
            text = ' '.join(parts[6:])
            segments.append((talk_id, speaker, start, end, text))
            if talk_id not in segments_dict:
                segments_dict[talk_id] = dict()
            segments_dict[talk_id][f'{start:.2f}_{end:.2f}'] = text

    return segments, segments_dict


### Split wavs to small clips

In [None]:
def extract_audio_segments(wav_folder, segments, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    for (talk_id, speaker, start, end, text) in tqdm(segments, desc="Segmenting audio"):
        wav_path = os.path.join(wav_folder, f"{talk_id}.wav")
        if not os.path.exists(wav_path):
            print(f"Skipping {wav_path} because it does not exist")
            continue
        y, sr = librosa.load(wav_path, sr=None)
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        segment = y[start_sample:end_sample]
        output_path = os.path.join(output_folder, f"{talk_id}_{start:.2f}_{end:.2f}.wav")
        sf.write(output_path, segment, sr)


### Convert wav to mel spectrograms

In [None]:
# 1) Redefine wav_to_mel to produce dB-scaled mel (power=2.0 -> dB)
def wav_to_mel(wav_path, sr=22050, n_fft=1024, hop_length=256, win_length=1024, n_mels=80, fmin=0, fmax=8000):
    y, _ = librosa.load(wav_path, sr=sr, mono=True)
    S_power = librosa.feature.melspectrogram(
        y=y,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        power=2.0,
    )
    mel_db = librosa.power_to_db(S_power, ref=1.0)  # dB, typically in [-80, 0]
    return mel_db.astype(np.float32)  # [80, T]


In [None]:
from nemo.collections.tts.models import FastPitchModel
spec_generator = FastPitchModel.from_pretrained("nvidia/tts_en_fastpitch")

 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.
INFO:NeMo-text-processing:Creating ClassifyFst grammars.
[NeMo W 2025-08-21 18:00:36 nemo_logging:405] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.
[NeMo W 2025-08-21 18:00:36 nemo_logging:405] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.torch.data.TTSDataset
      manifest_filepath: /ws/LJSpeech/nvidia_ljspeech_train_clean_ngc.json
      sample_rate: 22050
      sup_data_path: /raid/LJSpeech/supplementary
      sup_data_types:
      - align_prior_matrix
      - pitch
      n_fft: 1024
      win_length: 10

[NeMo I 2025-08-21 18:00:36 nemo_logging:393] PADDING: 1
[NeMo I 2025-08-21 18:00:37 nemo_logging:393] Model FastPitchModel was successfully restored from /root/.cache/huggingface/hub/models--nvidia--tts_en_fastpitch/snapshots/2c8305b7b41b33fd6367f0635796dc3a7a33cbf9/tts_en_fastpitch.nemo.


In [None]:
def text_to_mel(text):
    parsed = spec_generator.parse(text)
    mel_spectrogram = spec_generator.generate_spectrogram(tokens=parsed)
    return mel_spectrogram

In [None]:
import numpy as np
import torch
from nemo.collections.tts.models import FastPitchModel

spec_generator = FastPitchModel.from_pretrained("nvidia/tts_en_fastpitch")
spec_generator.eval()

@torch.no_grad()
def text_to_mel(text):
    # Parse text -> token ids (shape [B, T_text])
    tokens = spec_generator.parse(text)

    # Generate mel (usually returns [B, n_mels(=80), T_frames])
    mel = spec_generator.generate_spectrogram(tokens=tokens)

    # Squeeze batch dim -> [80, T]
    if mel.dim() == 3:
        mel = mel[0]

    # Ensure mel-first layout [80, T]
    if mel.shape[0] != 80 and mel.shape[1] == 80:
        mel = mel.transpose(0, 1)

    # Convert to numpy float32 like wav_to_mel
    mel_np = mel.detach().cpu().numpy().astype(np.float32)  # [80, T]
    return mel_np


 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.
INFO:NeMo-text-processing:Creating ClassifyFst grammars.
[NeMo W 2025-08-21 18:01:15 nemo_logging:405] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.
[NeMo W 2025-08-21 18:01:15 nemo_logging:405] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.torch.data.TTSDataset
      manifest_filepath: /ws/LJSpeech/nvidia_ljspeech_train_clean_ngc.json
      sample_rate: 22050
      sup_data_path: /raid/LJSpeech/supplementary
      sup_data_types:
      - align_prior_matrix
      - pitch
      n_fft: 1024
      win_length: 10

[NeMo I 2025-08-21 18:01:15 nemo_logging:393] PADDING: 1
[NeMo I 2025-08-21 18:01:16 nemo_logging:393] Model FastPitchModel was successfully restored from /root/.cache/huggingface/hub/models--nvidia--tts_en_fastpitch/snapshots/2c8305b7b41b33fd6367f0635796dc3a7a33cbf9/tts_en_fastpitch.nemo.


In [None]:
import torch
@torch.no_grad()
def text_to_mel(text):
    spec_generator.eval()
    tokens = spec_generator.parse(text)                         # [B, T_text]
    mel = spec_generator.generate_spectrogram(tokens=tokens)    # [B, 80, T]
    if mel.dim() == 3:
        mel = mel[0]                                           # [80, T]
    mel = mel.transpose(0, 1).contiguous()                      # -> [T, 80]
    return mel.float()                                          # torch.FloatTensor [T, 80]


### Convert all wav files in folder to mel spectograms

In [None]:
def create_full_talks_mel(wav_folder):
    full_talks_mel = {}
    talks_paths = glob.glob(os.path.join(wav_folder, "*.wav"))
    print(talks_paths)
    for talk_path in tqdm(talks_paths, desc="Creating full talks mel"):
        mel = wav_to_mel(talk_path)
        talk_id = os.path.basename(talk_path).split(".")[0]
        full_talks_mel[talk_id] = mel
    return full_talks_mel


### Dataset Creation

In [None]:
def preprocess_tedlium(base_path, target_size=10000, max_files=None, output_path='custom_data_phoneme', _convert_sph_to_wav=True):
    sph_folder = os.path.join(base_path, "data/sph")
    stm_folder = os.path.join(base_path, "data/stm")
    custom_wav_folder = f"{output_path}/wavs"
    custom_clip_folder = f"{output_path}/clips"

    os.makedirs(custom_wav_folder, exist_ok=True)
    os.makedirs(custom_clip_folder, exist_ok=True)

    # Convert SPH to WAV
    if _convert_sph_to_wav:
        convert_sph_to_wav(sph_folder, custom_wav_folder, max_files=max_files)

    # Parse STM files
    stm_files = glob.glob(os.path.join(stm_folder, "*.stm"))
    custom_clips = glob.glob(os.path.join(custom_wav_folder, "*.wav"))
    custom_clips = [clip.split("/")[-1].replace(".wav", "") for clip in custom_clips]

    stm_files = [stm_file for stm_file in stm_files if stm_file.split("/")[-1].replace(".stm", "") in custom_clips]


    all_segments = []
    all_segments_dict = {}

    for stm_path in stm_files:
        segments, segments_dict = parse_stm_file(stm_path)
        all_segments.extend(segments)
        all_segments_dict.update(segments_dict)

    # Create full talks mel
    full_talks_mel = create_full_talks_mel(custom_wav_folder)

    # Limit dataset size
    all_segments = all_segments[:target_size]

    # Extract audio clips
    extract_audio_segments(custom_wav_folder, all_segments, custom_clip_folder)

    print(f"Preprocessing done. {len(all_segments)} clips created.")

    return custom_clip_folder, all_segments_dict, full_talks_mel


In [None]:
def create_dataset(clips_folder, all_segments, full_talks_mel: Dict[str, np.ndarray], full_talks_folder: str, processor: Wav2Vec2Processor, specific_clips: List[str]=None, shuffle=False):
    dataset = []

    clip_paths = glob.glob(os.path.join(clips_folder, "*.wav"))
    if specific_clips:
        clip_paths = [clip_path for clip_path in clip_paths if '_'.join(os.path.basename(clip_path).split("_")[:-2]) in specific_clips]

    for clip_path in tqdm(clip_paths, desc="Creating dataset"):
        mel = wav_to_mel(clip_path)
        basename = os.path.basename(clip_path).split("_")
        talk_id = '_'.join(basename[:-2])
        start = float(basename[-2])
        end = float(basename[-1].replace(".wav", ""))
        time_range = f'{start:.2f}_{end:.2f}'
        text = all_segments[talk_id][time_range]
        tokenized = processor.tokenizer(text, return_tensors="pt").input_ids.squeeze(0).numpy()

        full_talk_wav_path = os.path.join(full_talks_folder, f"{talk_id}.wav")
        dataset.append((full_talk_wav_path, clip_path, mel, tokenized))

    if shuffle:
        random.shuffle(dataset)

    return dataset




### Make a word-based dataset

In [None]:
import os, glob, json, re, csv, math
import soundfile as sf
import numpy as np
import torch
import torchaudio
from phonemizer import phonemize
from phonemizer.backend import EspeakBackend

# ---------- helpers ----------

_word_re = re.compile(r"[a-zA-Z']+")

def normalize_text_for_alignment(text: str) -> str:
    # Simple normalization; adjust to your preference
    text = text.strip().lower()
    # keep apostrophes in contractions; drop other punctuation
    text = re.sub(r"[^a-z0-9' ]+", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def slice_audio(audio, sr, t0, t1, pad=0.0):
    start = max(0, int((t0 - pad) * sr))
    end   = min(len(audio), int((t1 + pad) * sr))
    return audio[start:end]

def maybe_phonemes(word: str) -> str:
    # ARPAbet/IPA; here we use IPA with espeak (en-us)
    # For ARPAbet consider g2p-en or phonemizer with 'espeak-mbrola-...'
    try:
        return phonemize(word, backend=EspeakBackend(language='en-us'), strip=True, njobs=1)
    except Exception:
        return ""

# ---------- WhisperX alignment ----------

class WhisperXAligner:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", compute_type="float16"):
        import whisperx
        self.device = device
        self.model = whisperx.load_model("large-v2", device, compute_type=compute_type)
        self.alignment_model, self.metadata = whisperx.load_align_model(language_code="en", device=device)

    def align_sentence(self, wav_path: str, sentence_text: str):
        """
        Returns list of dicts: [{"word": str, "start": float, "end": float}, ...]
        Uses *given* sentence text to constrain ASR then align to words.
        """
        import whisperx

        audio = whisperx.load_audio(wav_path)
        # Do constrained decoding by injecting reference text into WhisperX pipeline.
        # Practical trick: transcribe then replace segments' text with our gold sentence, then align.
        # We create a "fake" segment covering full utterance.
        duration = len(audio) / 16000.0
        norm_text = normalize_text_for_alignment(sentence_text)

        # Build pseudo-segments using gold text
        segments = [{"start": 0.0, "end": duration, "text": norm_text}]
        result = {"segments": segments}

        # Align with CTC model for word times
        aligned = whisperx.align(result["segments"], self.alignment_model, self.metadata, audio, self.device)

        words = []
        for seg in aligned["segments"]:
            for w in seg.get("words", []):
                # Filter tokens that aren't real words
                token = w.get("word", "").strip()
                if _word_re.fullmatch(token):
                    words.append({"word": token, "start": w["start"], "end": w["end"]})
        return words

# ---------- main word dataset builder ----------

def build_word_dataset_from_sentence_clips(
    clips_folder: str,
    segments_dict: dict,
    out_folder: str = f"{dataset_path}/custom_data/words",
    pad_seconds: float = 0.0,
    min_word_dur: float = 0.05,
    max_word_dur: float = 1.5,
    skip_short_tokens=False
):
    """
    clips_folder: your sentence-level clips (wav) created by extract_audio_segments(...)
    segments_dict: from your parse_stm_file, maps clip_id -> {"text": ..., "talk_id": ..., "speaker": ...}
    """
    os.makedirs(out_folder, exist_ok=True)
    word_wavs = os.path.join(out_folder, "wavs")
    os.makedirs(word_wavs, exist_ok=True)

    aligner = WhisperXAligner()

    metadata_csv = os.path.join(out_folder, "metadata_words.csv")
    with open(metadata_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f, delimiter="|")
        writer.writerow(["path", "word", "phonemes", "speaker", "talk_id", "sentence_id", "start", "end", "dur"])

        clip_paths = sorted(glob.glob(os.path.join(clips_folder, "*.wav")))
        for i, clip_path in enumerate(clip_paths):
            # BEFORE
            # clip_id = os.path.splitext(os.path.basename(clip_path))[0]
            # meta = segments_dict.get(clip_id, {})
            # text = meta.get("text", "")
            # if not text:
            #     continue

            # AFTER
            base = os.path.splitext(os.path.basename(clip_path))[0]
            parts = base.split("_")
            if len(parts) < 3:
                print(f"[skip] bad filename: {base}")
                continue

            start = float(parts[-2])
            end = float(parts[-1])
            talk_id = "_".join(parts[:-2])
            time_range = f"{start:.2f}_{end:.2f}"   # <-- match your save format exactly

            text_entry = segments_dict.get(talk_id, {}).get(time_range)
            if isinstance(text_entry, dict):
                text = text_entry.get("text") or text_entry.get("sentence") or text_entry.get("transcript") or ""
            else:
                text = text_entry or ""

            if not text:
                print(f"[skip] no text for talk_id={talk_id} time={time_range}")
                continue

            # use 'base' as a stable sentence_id in your CSV, if you want:
            sentence_id = base


            # Align
            try:
                words = aligner.align_sentence(clip_path, text)
            except Exception as e:
                print(f"[align-error] {sentence_id}: {e}")
                continue

            # Load audio once
            audio, sr = sf.read(clip_path)
            if audio.ndim > 1:
                audio = np.mean(audio, axis=1)

            for j, w in enumerate(words):
                t0, t1 = w.get("start"), w.get("end")
                if t0 is None or t1 is None:
                    continue
                dur = float(t1 - t0)
                if skip_short_tokens and (dur < min_word_dur or dur > max_word_dur):
                    continue

                word_audio = slice_audio(audio, sr, t0, t1, pad=pad_seconds)
                if len(word_audio) < int(0.002 * sr):  # robust min length
                    continue

                token = w["word"]
                phone = maybe_phonemes(token)

                # Use talk_id + time_range for stable naming (not clip_id/meta)
                out_name = f"{talk_id}_{time_range}__{j:03d}_{token}.wav"
                out_path = os.path.join(word_wavs, out_name)
                sf.write(out_path, word_audio, sr)

                writer.writerow([
                    os.path.relpath(out_path, out_folder),  # path
                    token,                                  # word
                    phone,                                  # phonemes
                    "",                                     # speaker (fill if you have it)
                    talk_id,                                # talk_id
                    sentence_id,                            # sentence_id == "<talk_id>_<start>_<end>"
                    f"{t0:.3f}", f"{t1:.3f}", f"{dur:.3f}",
                ])
    print(f"Done. Wrote single-word clips + {metadata_csv}")



### Dataset Convertor

In [None]:
# Simplified TextMelDataset (same interface, minimal logic)
import os
import csv
import random
import numpy as np
import torch
from torch.utils.data import Dataset


class TextMelDataset(Dataset):
    def __init__(
        self,
        words_folder: str,
        metadata_csv: str,
        specific_items=None,  # list of talk_id or sentence_id strings
        shuffle: bool = False,
        use_phonemes: bool = True,
    ):
        self.mel_dim = 80
        self.words_folder = words_folder
        self.use_phonemes = use_phonemes

        # Load all rows once
        if not os.path.exists(metadata_csv):
            raise FileNotFoundError(f"metadata csv not found: {metadata_csv}")
        with open(metadata_csv, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f, delimiter="|")
            rows = list(reader)

        # Optional filter by talk_id or sentence_id
        if specific_items:
            keep = set(specific_items)
            filt = []
            for r in rows:
                if (r.get("talk_id") in keep) or (r.get("sentence_id") in keep):
                    filt.append(r)
            rows = filt

        self.rows = rows

        # Build index order
        self.indices = list(range(len(self.rows)))
        if shuffle:
            random.seed(1234)
            random.shuffle(self.indices)

    def __len__(self):
        return len(self.rows)

    def _load_mel(self, wav_path: str) -> torch.Tensor:
        """Create [T, mel_dim] FloatTensor using the global wav_to_mel."""
        mel = wav_to_mel(wav_path)         # [80, T] np.float32 (dB)
        mel = torch.from_numpy(mel.T).float()  # [T, 80]
        if mel.size(1) != self.mel_dim:
            raise RuntimeError(f"Mel dim mismatch: {mel.size(1)} != {self.mel_dim}")
        return mel

    def _get_text(self, row) -> torch.Tensor:
        word = (row.get("word") or "").strip()
        phon = (row.get("phonemes") or "").strip()
        text = phon if (self.use_phonemes and phon) else word
        if not text:
            text = word or phon or ""
        ids = text_to_sequence(text)
        return torch.IntTensor(ids)

    def __getitem__(self, index):
        row = self.rows[self.indices[index]]
        rel_path = (row.get("path") or "").strip()
        wav_path = os.path.join(self.words_folder, rel_path)
        if not os.path.exists(wav_path):
            raise FileNotFoundError(f"Missing wav: {wav_path}")
        # mel = self._load_mel(wav_path)
        word = (row.get("word") or "").strip()
        mel = text_to_mel(word)
        text_ids = self._get_text(row)
        return text_ids, mel


In [None]:
class TextMelCollate():
    """ Zero-pads model inputs and targets based on number of frames per step
    """
    def __init__(self, r):
        self.r = r

    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized]
        """
        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text

        # Right zero-pad mel-spec
        num_mels = batch[0][1].size(1)
        max_target_len = max([x[1].size(0) for x in batch])
        if max_target_len % self.r != 0:
            max_target_len += self.r - max_target_len % self.r
            assert max_target_len % self.r == 0

        # include mel padded and gate padded
        mel_padded = torch.FloatTensor(len(batch), max_target_len, num_mels)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :mel.size(0), :] = mel
            gate_padded[i, mel.size(0)-1:] = 1
            output_lengths[i] = mel.size(0)

        return text_padded, input_lengths, mel_padded, gate_padded, output_lengths


### Mel Spectogran Figure

In [None]:
%matplotlib inline

In [None]:
def show_or_save_mel_spectrogram(mel, title="Mel Spectrogram", save_path=None, cmap="magma"):
    """
    Display or save a Mel spectrogram (NumPy array) as an image.

    Args:
        mel (np.ndarray): 2D or 3D Mel spectrogram. Shape should be (time, mel) or (1, time, mel).
        title (str): Title of the plot.
        save_path (str): If provided, saves the image to this path.
        cmap (str): Colormap for visualization (e.g., "magma", "inferno", "viridis").
    """
    if mel.ndim == 3:
        mel = mel[0]  # Remove batch dimension if present

    plt.figure(figsize=(10, 4))
    plt.imshow(mel.T, aspect='auto', origin='lower', cmap=cmap)
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Mel Frequency Channels")
    plt.colorbar(format='%+2.0f dB')

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        print(f"Saved mel spectrogram to {save_path}")
        plt.close()
    else:
        plt.show()


## Model Architecture

### Set Logger

In [None]:
import copy

DEFAULT_T2_CFG = {
    "encoder": {
        "num_convs": 3,
        "conv_channels": 512,
        "conv_kernel_size": 5,
        "conv_dropout": 0.5,
        "blstm_units": 512,
    },
    "decoder": {
        "prenet_dims": [256, 256],
        "prenet_dropout": 0.5,
        "attention_dim": 128,
        "attention_rnn_units": 1024,
        "attention_dropout": 0.1,
        "attention_location_filters": 32,
        "attention_location_kernel_size": 31,
        "decoder_rnn_units": 1024,
        "decoder_rnn_layers": 2,
        "decoder_dropout": 0.1,
    },
    "postnet": {
        "num_convs": 5,
        "conv_channels": 512,
        "conv_kernel_size": 5,
        "conv_dropout": 0.5,
    },
}


def create_model():
    """
    Create TextToMelSpectrogramModel model.
    """
    model_cfg = copy.deepcopy(DEFAULT_T2_CFG)

    model = TextToMelSpectrogramModel(
        model_cfg=model_cfg,
        embed_dim=512,
        mel_dim=80,
        max_decoder_steps=1000,
        stop_threshold=0.5,
        r=3,
    )

    criterion = TextToMelSpectrogramLoss()
    return model, criterion


In [None]:
import re
from typing import List

_pad = "_"
_eos = "~"
_ascii_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'\"(),-.:;? %/"
_digits = "0123456789"

EN_SYMBOLS: List[str] = [_pad, _eos] + list(_ascii_chars) + list(_digits)
_symbol_to_id = {s: i for i, s in enumerate(EN_SYMBOLS)}
_id_to_symbol = {i: s for i, s in enumerate(EN_SYMBOLS)}

_whitespace_re = re.compile(r"\s+")

def basic_cleaners(text: str) -> str:
    text = text.lower()
    text = _whitespace_re.sub(" ", text)
    return text.strip()


def symbols(lang: str = "en") -> List[str]:
    return EN_SYMBOLS


def _symbols_to_sequence(chars: List[str]) -> List[int]:
    return [_symbol_to_id[c] for c in chars if c in _symbol_to_id]


def text_to_sequence(text: str) -> List[int]:
    text = basic_cleaners(text)
    seq = _symbols_to_sequence(list(text))
    seq.append(_symbol_to_id[_eos])
    return seq


In [None]:
from torch.utils.tensorboard import SummaryWriter

class TextToMelSpectrogramLogger(SummaryWriter):
    def __init__(self, logdir: str):
        super().__init__(logdir)

    def log_training(self, loss, grad_norm, learning_rate, duration, iteration: int):
        self.add_scalar("training.loss", float(loss), iteration)
        self.add_scalar("grad.norm", float(grad_norm), iteration)
        self.add_scalar("learning.rate", float(learning_rate), iteration)
        self.add_scalar("duration", float(duration), iteration)

    def log_validation(self, loss, model, targets, predicts, iteration: int):
        self.add_scalar("validation.loss", float(loss), iteration)
        return

### Create TextToMelSpectrogram Model

In [None]:
from math import sqrt
import torch
from torch import nn


class BahdanauAttention(nn.Module):
    """Additive attention that produces alignment over encoder memory."""
    def __init__(self, query_dim: int, attn_dim: int, score_mask_value: float = -float("inf")):
        super().__init__()
        self.query_layer = nn.Linear(query_dim, attn_dim, bias=False)
        self.tanh = nn.Tanh()
        self.v = nn.Linear(attn_dim, 1, bias=False)
        self.score_mask_value = score_mask_value

    def forward(self, query, processed_memory, mask=None):
        if query.dim() == 2:
            query = query.unsqueeze(1)
        energies = self.get_energies(query, processed_memory)
        if mask is not None:
            energies.data.masked_fill_(mask.view(query.size(0), -1), self.score_mask_value)
        return self.get_probabilities(energies)

    def init_attention(self, processed_memory):
        return

    def get_energies(self, query, processed_memory):
        processed_query = self.query_layer(query)
        alignment = self.v(self.tanh(processed_query + processed_memory))
        return alignment.squeeze(-1)

    def get_probabilities(self, energies):
        return nn.Softmax(dim=1)(energies)


class LocationSensitiveAttention(BahdanauAttention):
    """Location-sensitive attention that incorporates cumulative alignment history."""
    def __init__(self, query_dim, attn_dim, filters=32, kernel_size=31, score_mask_value=-float("inf")):
        super().__init__(query_dim, attn_dim, score_mask_value)
        self.conv = nn.Conv1d(1, filters, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=True)
        self.L = nn.Linear(filters, attn_dim, bias=False)
        self.cumulative = None

    def init_attention(self, processed_memory):
        b, t, _ = processed_memory.size()
        self.cumulative = processed_memory.data.new(b, t).zero_()

    def get_energies(self, query, processed_memory):
        processed_query = self.query_layer(query)
        processed_loc = self.L(self.conv(self.cumulative.unsqueeze(1)).transpose(1, 2))
        alignment = self.v(self.tanh(processed_query + processed_memory + processed_loc))
        return alignment.squeeze(-1)

    def get_probabilities(self, energies):
        alignment = nn.Softmax(dim=1)(energies)
        self.cumulative = self.cumulative + alignment
        return alignment


def get_mask_from_lengths(memory, memory_lengths):
    mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
    for idx, l in enumerate(memory_lengths):
        mask[idx][:l] = 1
    return mask == 0


class AttentionWrapper(nn.Module):
    """Wraps an RNN cell with attention over encoder memory."""
    def __init__(self, rnn_cell, attention_mechanism):
        super().__init__()
        self.rnn_cell = rnn_cell
        self.attention_mechanism = attention_mechanism

    def forward(self, query, attention, cell_state, memory, processed_memory=None, mask=None, memory_lengths=None):
        if processed_memory is None:
            processed_memory = memory
        if memory_lengths is not None and mask is None:
            mask = get_mask_from_lengths(memory, memory_lengths)

        cell_input = torch.cat((query, attention), -1)
        cell_output = self.rnn_cell(cell_input, cell_state)
        query = cell_output[0] if isinstance(self.rnn_cell, nn.LSTMCell) else cell_output

        alignment = self.attention_mechanism(query, processed_memory, mask)
        attention = torch.bmm(alignment.unsqueeze(1), memory).squeeze(1)
        return cell_output, attention, alignment


class Prenet(nn.Module):
    def __init__(self, in_dim, sizes=[256, 128], dropout=0.5):
        super().__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList([nn.Linear(i, o) for i, o in zip(in_sizes, sizes)])
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for linear in self.layers:
            x = self.dropout(self.relu(linear(x)))
        return x


class BatchNormConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm1d(out_channels)
        self.activation = activation

    def forward(self, x):
        x = self.conv1d(x)
        if self.activation is not None:
            x = self.activation(x)
        return self.bn(x)


class BatchNormConv1dStack(nn.Module):
    def __init__(self, in_channel, out_channels=[512, 512, 512], kernel_size=3, stride=1, padding=1, activations=None, dropout=0.5):
        super().__init__()
        if activations is None:
            activations = [None] * len(out_channels)
        in_sizes = [in_channel] + out_channels[:-1]
        self.convs = nn.ModuleList([
            BatchNormConv1d(i, o, kernel_size=kernel_size, stride=stride, padding=padding, activation=ac)
            for i, o, ac in zip(in_sizes, out_channels, activations)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for conv in self.convs:
            x = self.dropout(conv(x))
        return x


class Postnet(nn.Module):
    def __init__(self, mel_dim, num_convs=5, conv_channels=512, conv_kernel_size=5, conv_dropout=0.5):
        super().__init__()
        activations = [torch.tanh] * (num_convs - 1) + [None]
        channels = [conv_channels] * (num_convs - 1) + [mel_dim]
        self.convs = BatchNormConv1dStack(
            mel_dim,
            channels,
            kernel_size=conv_kernel_size,
            stride=1,
            padding=(conv_kernel_size - 1) // 2,
            activations=activations,
            dropout=conv_dropout,
        )

    def forward(self, x):
        return self.convs(x.transpose(1, 2)).transpose(1, 2)


class Encoder(nn.Module):
    def __init__(self, embed_dim, num_convs=3, conv_channels=512, conv_kernel_size=5, conv_dropout=0.5, blstm_units=512):
        super().__init__()
        activations = [nn.ReLU()] * num_convs
        channels = [conv_channels] * num_convs
        self.convs = BatchNormConv1dStack(
            embed_dim,
            channels,
            kernel_size=conv_kernel_size,
            stride=1,
            padding=(conv_kernel_size - 1) // 2,
            activations=activations,
            dropout=conv_dropout,
        )
        self.lstm = nn.LSTM(conv_channels, blstm_units // 2, 1, batch_first=True, bidirectional=True)

    def forward(self, x):
        x = self.convs(x.transpose(1, 2)).transpose(1, 2)
        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)
        return outputs


class Decoder(nn.Module):
    def __init__(self, mel_dim, r, encoder_output_dim, prenet_dims=[256, 256], prenet_dropout=0.5, attention_dim=128, attention_rnn_units=1024, attention_dropout=0.1, attention_location_filters=32, attention_location_kernel_size=31, decoder_rnn_units=1024, decoder_rnn_layers=2, decoder_dropout=0.1, max_decoder_steps=1000, stop_threshold=0.5):
        super().__init__()
        self.mel_dim = mel_dim
        self.r = r
        self.attention_context_dim = encoder_output_dim
        self.attention_rnn_units = attention_rnn_units
        self.decoder_rnn_units = decoder_rnn_units
        self.max_decoder_steps = max_decoder_steps
        self.stop_threshold = stop_threshold

        self.prenet = Prenet(mel_dim, prenet_dims, prenet_dropout)
        self.attention_rnn = AttentionWrapper(
            nn.LSTMCell(prenet_dims[-1] + encoder_output_dim, attention_rnn_units),
            LocationSensitiveAttention(attention_rnn_units, attention_dim, filters=attention_location_filters, kernel_size=attention_location_kernel_size),
        )
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.memory_layer = nn.Linear(encoder_output_dim, attention_dim, bias=False)

        self.decoder_rnn = nn.LSTMCell(attention_rnn_units + encoder_output_dim, decoder_rnn_units)
        self.decoder_dropout = nn.Dropout(decoder_dropout)

        self.mel_proj = nn.Linear(decoder_rnn_units + encoder_output_dim, mel_dim * self.r)
        self.stop_proj = nn.Linear(decoder_rnn_units + encoder_output_dim, 1)

    def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
        bsz = encoder_outputs.size(0)
        processed_memory = self.memory_layer(encoder_outputs)
        mask = get_mask_from_lengths(processed_memory, memory_lengths) if memory_lengths is not None else None
        greedy = inputs is None
        if inputs is not None:
            inputs = inputs.transpose(0, 1)
            T_decoder = inputs.size(0)

        go_frame = encoder_outputs.data.new(bsz, self.mel_dim).zero_()
        self.attention_rnn.attention_mechanism.init_attention(processed_memory)
        attn_h = encoder_outputs.data.new(bsz, self.attention_rnn_units).zero_()
        attn_c = encoder_outputs.data.new(bsz, self.attention_rnn_units).zero_()
        dec_h = encoder_outputs.data.new(bsz, self.decoder_rnn_units).zero_()
        dec_c = encoder_outputs.data.new(bsz, self.decoder_rnn_units).zero_()
        attn_ctx = encoder_outputs.data.new(bsz, self.attention_context_dim).zero_()

        mel_outputs, attn_scores, stop_tokens = [], [], []
        t = 0
        current = go_frame
        while True:
            if t > 0:
                current = mel_outputs[-1][:, -1, :] if greedy else inputs[t - 1]
            t += self.r

            current = self.prenet(current)
            (attn_h, attn_c), attn_ctx, attn_score = self.attention_rnn(
                current, attn_ctx, (attn_h, attn_c), encoder_outputs, processed_memory=processed_memory, mask=mask
            )
            attn_h = self.attention_dropout(attn_h)

            dec_input = torch.cat((attn_h, attn_ctx), -1)
            dec_h, dec_c = self.decoder_rnn(dec_input, (dec_h, dec_c))
            dec_h = self.decoder_dropout(dec_h)

            proj_in = torch.cat((dec_h, attn_ctx), -1)
            out = self.mel_proj(proj_in).view(bsz, -1, self.mel_dim)
            stop = torch.sigmoid(self.stop_proj(proj_in))

            mel_outputs.append(out)
            attn_scores.append(attn_score.unsqueeze(1))
            stop_tokens.extend([stop] * self.r)

            if greedy:
                if stop > self.stop_threshold or t > self.max_decoder_steps:
                    break
            else:
                if t >= T_decoder:
                    break

        mel_outputs = torch.cat(mel_outputs, dim=1)
        attn_scores = torch.cat(attn_scores, dim=1)
        stop_tokens = torch.cat(stop_tokens, dim=1)
        assert greedy or mel_outputs.size(1) == T_decoder
        return mel_outputs, stop_tokens, attn_scores


class TextToMelSpectrogramModel(nn.Module):
    def __init__(self, model_cfg, embed_dim=512, mel_dim=80, max_decoder_steps=1000, stop_threshold=0.5, r=3):
        super().__init__()
        self.mel_dim = mel_dim
        self.embedding = nn.Embedding(1, embed_dim)
        std = sqrt(2.0 / (1 + embed_dim))
        val = sqrt(3.0) * std
        self.embedding.weight.data.uniform_(-val, val)

        enc_cfg = model_cfg["encoder"]
        self.encoder = Encoder(embed_dim, **enc_cfg)
        encoder_out_dim = enc_cfg["blstm_units"]

        dec_cfg = model_cfg["decoder"]
        self.decoder = Decoder(mel_dim, r, encoder_out_dim, **dec_cfg, max_decoder_steps=max_decoder_steps, stop_threshold=stop_threshold)

        self.postnet = Postnet(mel_dim, **model_cfg["postnet"])

    def parse_data_batch(self, batch):
        device = next(self.parameters()).device
        text, text_length, mel, stop, _ = batch
        return (text.to(device).long(), text_length.to(device).long(), mel.to(device).float()), (mel.to(device).float(), stop.to(device).float())

    def forward(self, inputs):
        inputs, input_lengths, mels = inputs
        x = self.embedding(inputs)
        enc_out = self.encoder(x)
        mel_out, stop_tokens, alignments = self.decoder(enc_out, mels, memory_lengths=input_lengths)
        mel_post = self.postnet(mel_out)
        mel_post = mel_out + mel_post
        return mel_out, mel_post, stop_tokens, alignments

    def inference(self, inputs):
        return self.forward((inputs, None, None))


class TextToMelSpectrogramLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predicts, targets):
        mel_target, stop_target = targets
        mel_target.requires_grad = False
        stop_target.requires_grad = False
        mel_pred, mel_post_pred, stop_pred, _ = predicts
        mel_loss = nn.MSELoss()(mel_pred, mel_target)
        post_loss = nn.MSELoss()(mel_post_pred, mel_target)
        stop_loss = nn.BCELoss()(stop_pred, stop_target)
        return mel_loss + post_loss + stop_loss


### Build Model

In [None]:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

# Prepare device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)


# Instantiate TextToMelSpectrogram Model
print("\nInitialising TextToMelSpectrogram Model...\n")
model, criterion = create_model()
model = model.to(device)


Using device: cuda

Initialising Tacotron Model...



In [None]:
# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

In [None]:
# Prepare directory and logger
output_dir = ''
log_dir = os.path.join(audio_outputs_folder, "logs")

os.makedirs(audio_outputs_folder, exist_ok=True)
logger = TextToMelSpectrogramLogger(log_dir)


### Speaker Embeddings

In [None]:
MAX_CHUNKS = 5
CHUNK_DURATION_SEC = 5
CHUNK_LEN = 16000 * CHUNK_DURATION_SEC
HIDDEN_DIM = 1024  # from Wav2Vec2

def extract_wav2vec_features_fixed(audio_path, processor, model):
    waveform, sr = torchaudio.load(audio_path)
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)

    total_len = waveform.shape[1]
    embeddings = []

    for i in range(MAX_CHUNKS):
        start = i * CHUNK_LEN
        end = start + CHUNK_LEN

        if start >= total_len:
            break  # no more data

        chunk = waveform[:, start:end]
        if chunk.shape[1] < CHUNK_LEN:
            # Pad short chunk with zeros
            pad = torch.zeros((1, CHUNK_LEN - chunk.shape[1]))
            chunk = torch.cat((chunk, pad), dim=1)

        inputs = processor(chunk.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE)
        with torch.no_grad():
            outputs = model(inputs)
        emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  # shape: (1, HIDDEN_DIM)
        embeddings.append(emb)

        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    # Pad with zero-vectors if we didn't get enough chunks
    while len(embeddings) < MAX_CHUNKS:
        embeddings.append(np.zeros((1, HIDDEN_DIM)))

    final_embedding = np.concatenate(embeddings, axis=1)  # shape: (1, MAX_CHUNKS * HIDDEN_DIM)
    return final_embedding


In [None]:
def create_speaker_embeddings(dataset, wav2vec_processor, wav2vec_model):
    processor = wav2vec_processor
    model = wav2vec_model

    speaker_embeddings = dict()

    for i, (full_talk_wav_path, _, _, _) in enumerate(tqdm(dataset, desc="Generating speaker embeddings")):
      if full_talk_wav_path not in speaker_embeddings:
        speaker_embedding = extract_wav2vec_features_fixed(full_talk_wav_path, processor, model)  # shape (1, hidden_dim)
        speaker_embeddings[full_talk_wav_path] = speaker_embedding

    torch.cuda.empty_cache()
    gc.collect()

    return speaker_embeddings


### Build Models

In [None]:
DEVICE = 'cuda'
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h", use_safetensors=True).to(DEVICE)
hifigan_model = HifiGanModel.from_pretrained(model_name="nvidia/tts_hifigan")



Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[NeMo W 2025-08-21 18:02:11 nemo_logging:405] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.MelAudioDataset
      manifest_filepath: /home/fkreuk/data/train_finetune.txt
      min_duration: 0.75
      n_segments: 8192
    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 64
      num_workers: 4
    
[NeMo W 2025-08-21 18:02:11 nemo_logging:405] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a 

[NeMo I 2025-08-21 18:02:11 nemo_logging:393] PADDING: 0


[NeMo W 2025-08-21 18:02:11 nemo_logging:405] Using torch_stft is deprecated and has been removed. The values have been forcibly set to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True as needed.


[NeMo I 2025-08-21 18:02:11 nemo_logging:393] PADDING: 0
[NeMo I 2025-08-21 18:02:12 nemo_logging:393] Model HifiGanModel was successfully restored from /root/.cache/huggingface/hub/models--nvidia--tts_hifigan/snapshots/3ba1fed954276287015654bf4c78060ffc9a4772/tts_hifigan.nemo.


## Create Dataset

In [None]:
import_from_local = False
all_segments_dict = []
if import_from_local:
    clips_folder = f'{dataset_path}/custom_data/clips'
    # clips_folder = f'{dataset_path}/custom_data_phoneme/clips'
    print(clips_folder)
    with open(f"{dataset_path}/all_segments_dict.pkl", "rb") as f:
        all_segments_dict = pickle.load(f)
    with open(f"{dataset_path}/full_talks_mel.pkl", "rb") as f:
        full_talks_mel = pickle.load(f)

else:
    target_dataset_size = 10000
    clips_folder, all_segments_dict, full_talks_mel = preprocess_tedlium(dataset_path, target_dataset_size, max_files=100000, output_path=f'{dataset_path}/custom_data', _convert_sph_to_wav=False)
    with open(f"{dataset_path}/all_segments_dict.pkl", "wb") as f:
        pickle.dump(all_segments_dict, f)

    with open(f"{dataset_path}/full_talks_mel.pkl", "wb") as f:
        pickle.dump(full_talks_mel, f)


['/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AbeDavis_2015.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AdamOckelford_2013X.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AdamDavidson_2012S.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AJJacobs_2014A.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AJJacobs_2014A (1).wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AdamdelaZerda_2016X.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AlanSmith_2016X.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AdamGrant_2016S.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AditiGupta_2015X.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AdamFoss_2016.wav', '/content/drive/MyDrive/final_project/Dataset_3/custom_data/wavs/AlejandroAravena_2014G.wav', '/content/drive/MyDrive/f

Creating full talks mel: 100%|██████████| 184/184 [05:07<00:00,  1.67s/it]
Segmenting audio: 100%|██████████| 10000/10000 [1:25:09<00:00,  1.96it/s]


Preprocessing done. 10000 clips created.


### One-word based dataset

In [None]:
import_from_local = False

# NEW: build word dataset
# Step 2: Create Dataset
clips_folder = f'{dataset_path}/custom_data/clips'
full_talks_folder = f'{dataset_path}/custom_data/wavs'

output_words_path = f"{dataset_path}/custom_data/words"

if not import_from_local:
    build_word_dataset_from_sentence_clips(
        clips_folder=clips_folder,
        segments_dict=all_segments_dict,
        out_folder=output_words_path,
        pad_seconds=0.05
    )


DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _speechbrain_save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _speechbrain_load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _recover


config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.bin:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

vocabulary.txt: 0.00B [00:01, ?B/s]

INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.5.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../usr/local/lib/python3.12/dist-packages/whisperx/assets/pytorch_model.bin`


No language specified, language will be first be detected for each audio file (increases inference time).
>>Performing voice activity detection using Pyannote...
Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.8.0+cu126. Bad things might happen unless you revert torch to 1.x.
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth


100%|██████████| 360M/360M [00:00<00:00, 587MB/s]


[align-error] AalaElKhani_2016X_308.92_322.19: CUDA out of memory. Tried to allocate 84.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 32.52 GiB is free. Process 19480 has 7.03 GiB memory in use. Of the allocated memory 3.13 GiB is allocated by PyTorch, and 470.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[align-error] AalaElKhani_2016X_368.76_379.70: CUDA out of memory. Tried to allocate 70.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 32.52 GiB is free. Process 19480 has 7.03 GiB memory in use. Of the allocated memory 3.11 GiB is allocated by PyTorch, and 484.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragm

In [None]:
words_folder = os.path.join(output_words_path, "wavs")
words_folder = output_words_path
metadata_csv = os.path.join(output_words_path, "metadata_words.csv")

dataset = TextMelDataset(
    words_folder=words_folder,
    metadata_csv=metadata_csv,
    specific_items=None,   # or a list of sentence_ids / talk_ids to keep
    shuffle=False,
    use_phonemes=True      # switch to False to train on plain words
)

In [None]:
print(len(dataset))

274669


In [None]:
from torch.utils.data import DataLoader, random_split
import torch

# Split
val_ratio = 0.1
dataset_size = len(dataset)
val_size = max(1, int(val_ratio * dataset_size))
train_size = dataset_size - val_size
trainset, valset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

collate_fn = TextMelCollate(r=3)

# Train loader
train_loader = DataLoader(
    trainset,
    batch_size=32,
    shuffle=True,                 # shuffle the training set
    num_workers=0,                # start with 0 to avoid RAM spikes
    pin_memory=False,             # enable later if using GPU + enough RAM
    drop_last=True,
    collate_fn=collate_fn,
    prefetch_factor=None,         # IMPORTANT: None when num_workers=0
    persistent_workers=False
)

# Val loader
val_loader = DataLoader(
    valset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    drop_last=False,
    collate_fn=collate_fn,
    prefetch_factor=None,
    persistent_workers=False
)


## TextToMelSpectrogram Train

In [None]:
import os
import time
import torch


def validate(model, criterion, iteration, device, valset, batch_size, collate_fn, logger):
    """Simple validation loop that returns/prints average loss and logs scalar."""
    model.eval()
    avg_loss = 0.0
    last_targets, last_predicts = None, None
    with torch.no_grad():
        val_loader = torch.utils.data.DataLoader(
            valset, shuffle=False, batch_size=batch_size, num_workers=0,
            pin_memory=False, collate_fn=collate_fn
        )
        for i, batch in enumerate(val_loader):
            inputs, targets = model.parse_data_batch(batch)
            predicts = model(inputs)
            loss = criterion(predicts, targets)
            avg_loss += float(loss)
            last_targets, last_predicts = targets, predicts
    avg_loss /= max(1, i + 1)
    model.train()
    print(f"Validation loss {iteration}: {avg_loss:.6f}")
    if logger is not None:
        logger.log_validation(avg_loss, model, last_targets, last_predicts, iteration)
    return avg_loss


def train(model, logger, learning_rate, optimizer, train_loader, collate_fn, valset, epochs=10):
    """Training loop with periodic validation and checkpointing.
    Signature and external semantics preserved; parallel_run is ignored for simplicity.
    """
    # Optional checkpoint restore
    iteration = -1
    epoch_offset = 0

    model.train()
    for epoch in range(epoch_offset, epochs):
        print(f"Epoch: {epoch}")
        for batch in train_loader:
            start = time.perf_counter()
            iteration += 1

            # Update LR each step (keeps existing behavior)
            for group in optimizer.param_groups:
                group['lr'] = learning_rate

            # Forward + loss
            inputs, targets = model.parse_data_batch(batch)
            predicts = model(inputs)
            loss = criterion(predicts, targets)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # Logs
            duration = time.perf_counter() - start
            print(f"Train loss {iteration} {float(loss):.6f} Grad Norm {float(grad_norm):.6f} {duration:.2f}s/it")
            if logger is not None:
                logger.log_training(float(loss), float(grad_norm), float(learning_rate), float(duration), iteration)

            # Periodic validation
            if iteration % 1000 == 0:
                validate(model, criterion, iteration, None, valset, 32, collate_fn, logger)


## Train

In [None]:
learning_rate = 1e-3

train(model, logger, learning_rate, optimizer, train_loader, collate_fn, valset)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Train loss 22010 0.452703 Grad Norm 1.043250 0.15s/it
Train loss 22011 0.708000 Grad Norm 2.547370 0.13s/it
Train loss 22012 0.903015 Grad Norm 3.468312 0.11s/it
Train loss 22013 0.600642 Grad Norm 1.005891 0.13s/it
Train loss 22014 0.650044 Grad Norm 2.080561 0.12s/it
Train loss 22015 0.713655 Grad Norm 1.287138 0.11s/it
Train loss 22016 0.634973 Grad Norm 1.103881 0.13s/it
Train loss 22017 0.422491 Grad Norm 0.799460 0.16s/it
Train loss 22018 0.667909 Grad Norm 1.564371 0.15s/it
Train loss 22019 0.376232 Grad Norm 1.849930 0.13s/it
Train loss 22020 0.600515 Grad Norm 2.138355 0.11s/it
Train loss 22021 0.462078 Grad Norm 0.707953 0.14s/it
Train loss 22022 0.589383 Grad Norm 1.415639 0.12s/it
Train loss 22023 0.521892 Grad Norm 1.561337 0.11s/it
Train loss 22024 0.700345 Grad Norm 1.488307 0.13s/it
Train loss 22025 0.591218 Grad Norm 4.123748 0.11s/it
Train loss 22026 0.574837 Grad Norm 0.979283 0.14s/it
Train loss 22027 

In [None]:
torch.save(model.state_dict(), f"{dataset_path}/model_weights_21_08_25-4.pth")

## Voice Generation Model

In [None]:
# model.load_state_dict(torch.load(f"{dataset_path}/model_weights_21_08_25-2.pth"))
# model.eval()

In [None]:
import torch
import numpy as np

def prepare_mel_for_vocoder(
    mel_2d: torch.Tensor,
    *,
    mean: float | torch.Tensor | None = None,
    std: float  | torch.Tensor | None = None,
    hifigan_model=None,
) -> torch.Tensor:
    """
    Converts your current mel (as returned by text_to_mel) to the format HiFi-GAN expects.
    - Input:  [T, 80] torch.FloatTensor   (no batch)
    - Output: [1, 80, T] torch.FloatTensor (batched, on hifigan device)

    mean/std are OPTIONAL. Pass the SAME values you used during training normalization
    (if any). If you didn't normalize, leave them None.
    """
    assert isinstance(mel_2d, torch.Tensor) and mel_2d.ndim == 2, "expected [T, 80] torch tensor"

    x = mel_2d.detach().float()       # [T, 80]

    # 1) (Optional) de-normalize back to the original scale you trained on
    if (mean is not None) and (std is not None):
        # mean/std can be scalars or shape-[80]; broadcast handles both
        if not torch.is_tensor(mean): mean = torch.tensor(mean, dtype=x.dtype, device=x.device)
        if not torch.is_tensor(std):  std  = torch.tensor(std,  dtype=x.dtype, device=x.device)
        x = x * std + mean            # still [T, 80]

    # 2) Layout for HiFi-GAN: [T,80] -> [80,T] -> [1,80,T]
    x = x.transpose(0, 1).contiguous().unsqueeze(0)  # [1, 80, T]

    # 3) Put on same device as vocoder
    if hifigan_model is not None:
        device = next(hifigan_model.parameters()).device
        x = x.to(device)

    return x


@torch.no_grad()
def mel_to_audio_with_hifigan(
    mel_2d: torch.Tensor,
    *,
    mean: float | torch.Tensor | None = None,
    std: float  | torch.Tensor | None = None,
) -> np.ndarray:
    """
    Convenience wrapper: takes your [T,80] mel, adapts it, and runs the vocoder.
    Returns: 1D np.float32 waveform.
    """
    spec = prepare_mel_for_vocoder(mel_2d, mean=mean, std=std, hifigan_model=hifigan_model)  # [1,80,T]
    audio = hifigan_model.convert_spectrogram_to_audio(spec=spec)  # [1, samples]
    return audio[0].detach().cpu().float().numpy()


In [None]:
# Inference: generate mel from text, compare to dataset mel, and save both audios
import os
import re
import time
import numpy as np
import torch


def sanitize_filename(s: str, max_len: int = 32) -> str:
    s = re.sub(r"[^a-zA-Z0-9-_]+", "_", s).strip("_")
    if len(s) == 0:
        s = "text"
    return s[:max_len]


def infer_and_compare(sample_index: int = 10, input_text: str | None = None, save_plots: bool = False):
    """
    - If input_text is None: use dataset[sample_index] text ids for inference.
    - Compare predicted mel to the dataset target mel (L1/MSE, length-aligned).
    - Save two audios to audio_outputs_folder: target_*.wav and pred_*.wav
    """
    model_device = next(model.parameters()).device

    # Load dataset item
    text_ids, target_mel = dataset[sample_index]
    # Ensure tensor types and devices
    if input_text is None:
        seq = text_ids.long().unsqueeze(0).to(model_device)  # [1, T]
        name_tag = f"idx{sample_index}"
    else:
        seq_list = text_to_sequence(input_text)
        seq = torch.tensor(seq_list, dtype=torch.long, device=model_device).unsqueeze(0)
        name_tag = sanitize_filename(input_text)

    # Run model inference
    model.eval()
    with torch.no_grad():
        mel_pred, mel_pred_post, _, alignments = model.inference(seq)
    # Choose postnet output
    mel_pred_2d = mel_pred_post[0].detach().cpu()  # [T_pred, 80]

    # Prepare target mel [T, 80]
    target_mel_2d = target_mel.detach().cpu().float()

    # Length align for metrics
    t_min = min(target_mel_2d.size(0), mel_pred_2d.size(0))
    tgt_crop = target_mel_2d[:t_min]
    pred_crop = mel_pred_2d[:t_min]

    l1 = torch.mean(torch.abs(pred_crop - tgt_crop)).item()
    mse = torch.mean((pred_crop - tgt_crop) ** 2).item()
    print(f"Comparison ({name_tag}): L1={l1:.6f}, MSE={mse:.6f}, T_target={target_mel_2d.size(0)}, T_pred={mel_pred_2d.size(0)})")

    # Optional visualization
    if save_plots:
        show_or_save_mel_spectrogram(target_mel_2d.numpy(), title=f"Target Mel ({name_tag})")
        show_or_save_mel_spectrogram(mel_pred_2d.numpy(), title=f"Predicted Mel ({name_tag})")

    # Synthesize and save audios
    ts = time.strftime("%Y%m%d-%H%M%S")
    target_path = os.path.join(audio_outputs_folder, f"target_{name_tag}_{ts}.wav")
    pred_path = os.path.join(audio_outputs_folder, f"pred_{name_tag}_{ts}.wav")

    os.makedirs(audio_outputs_folder, exist_ok=True)

    # target_audio = tensor_to_hifigan_audio(target_mel_2d)
    # pred_audio = tensor_to_hifigan_audio(mel_pred_2d)
    target_audio = mel_to_audio_with_hifigan(target_mel_2d)
    pred_audio = mel_to_audio_with_hifigan(mel_pred_2d)

    sf.write(target_path, target_audio, 22050, format='WAV', subtype='PCM_16')
    sf.write(pred_path, pred_audio, 22050, format='WAV', subtype='PCM_16')

    print(f"Saved target audio -> {target_path}")
    print(f"Saved predicted audio -> {pred_path}")

    # Return artifacts for further use
    return {
        "sample_index": sample_index,
        "name_tag": name_tag,
        "metrics": {"l1": l1, "mse": mse},
        "mel_target": target_mel_2d,
        "mel_pred": mel_pred_2d,
        "alignment": alignments[0].detach().cpu() if alignments is not None else None,
        "paths": {"target": target_path, "pred": pred_path},
    }


In [None]:
# target_path = os.path.join(audio_outputs_folder, f"target_hello.wav")

# os.makedirs(audio_outputs_folder, exist_ok=True)

# target_mel_2d = text_to_mel("hello hello hello hello hello")

# target_audio = mel_to_audio_with_hifigan(target_mel_2d)

# sf.write(target_path, target_audio, 22050, format='WAV', subtype='PCM_16')

In [None]:
for i in range(5):
  sample_index = random.randint(1, len(dataset))
  infer_and_compare(sample_index=sample_index, input_text=None)