# 0. Utterance

In [1]:
from pathlib import Path
import logging
from io import BytesIO
import librosa
import numpy as np
import random

class Utterance(object):

    def __init__(
        self,
        _id: str = None,
        raw_file: Path | BytesIO = None,
        processor=None,
    ):
        self._id = _id
        self.raw_file = raw_file
        self.processor = processor if processor is not None else AudioPreprocessor()
#         self.audio = self.raw()

    def raw(self):
        if isinstance(self.raw_file, Path) and self.raw_file.suffix == '.npy':
            return np.load(self.raw_file)
            
        audio, _ = librosa.load(
            self.raw_file, sr=self.processor.config.SAMPLE_RATE
        )
        
        if audio.size == 0:
            raise ValueError("Empty audio")

        audio = (
            self.processor.config.SCALING_FACTOR
            * librosa.util.normalize(audio)
        )
        return audio

    def mel_in_db(self):
        try:
            return self.processor.audio_to_mel_db(self.raw())
        except Exception:

            logging.debug(
                "Failed to load Mel spectrogram, raw file: %s", {self.raw_file}
            )
            raise
    
    def random_mel_in_db(self, seq_len):
        random_mel = self.mel_in_db()
        _, tempo_len = random_mel.shape
        if tempo_len < seq_len:
            pad_left = (seq_len - tempo_len) // 2
            pad_right = seq_len - tempo_len - pad_left
            random_mel = np.pad(random_mel, ((0, 0), (pad_left, pad_right)), mode="reflect")
        elif tempo_len > seq_len:
            max_seq_start = tempo_len - seq_len
            seq_start = np.random.randint(0, max_seq_start)
            seq_end = seq_start + seq_len
            random_mel = random_mel[:, seq_start:seq_end]
        return random_mel
    
    def magtitude(self):
        return self.processor.audio_to_magnitude_db(self.raw())

# 1. Processor

## 1.1 Audio Processor

In [2]:
class AudioConfig:
    pass

In [3]:
import librosa
import numpy as np


class AudioPreprocessor:
    def __init__(self, config):
        self.config = config

    def normalize(self, spectrogram_in_db):
        '''Normalize spectrogram in decibel values between 0 and 1.'''
        normalized_spectrogram_in_db = (
            spectrogram_in_db - self.config.REF_LEVEL_DB - self.config.MIN_LEVEL_DB
        ) / -self.config.MIN_LEVEL_DB

        return np.clip(normalized_spectrogram_in_db, self.config.ZERO_THRESHOLD, 1)

    def magnitude_to_mel(self, magnitude):
        '''Convert a magnitude spectrogram to a mel spectrogram.'''
        return librosa.feature.melspectrogram(
            S=magnitude,
            sr=self.config.SAMPLE_RATE,
            n_fft=self.config.N_FFT,
            n_mels=self.config.N_MELS,
            fmin=self.config.FMIN,
            fmax=self.config.FMAX,
        )

    def amp_to_db(self, mel_spectrogram):
        '''Convert amplitude spectrogram to decibel scale.'''
        return 20.0 * np.log10(
            np.maximum(self.config.ZERO_THRESHOLD, mel_spectrogram)
        )

    def audio_to_stft(self, audio):
        '''Generate Short-Time Fourier Transform (STFT) from the audio time series.'''
        return librosa.stft(
            y=audio,
            n_fft=self.config.N_FFT,
            hop_length=self.config.HOP_LENGTH,
            win_length=self.config.WIN_LENGTH,
        )

    def apply_pre_emphasis(self, y):
        '''Apply a pre-emphasis filter to the audio signal.'''
        return np.append(y[0], y[1:] - self.config.PRE_EMPHASIS * y[:-1])

    def stft_to_magnitude(self, linear):
        '''Compute the magnitude spectrogram from STFT.'''
        return np.abs(linear)

    def audio_to_mel_db(self, audio):
        '''Convert a given linear spectrogram to a log mel spectrogram (mel spectrogram in db) and return it.'''
        stft = self.audio_to_stft(audio)
        magnitude = self.stft_to_magnitude(stft)
        mel = self.magnitude_to_mel(magnitude)
        mel = self.amp_to_db(mel)
        return self.normalize(mel)
    
    def audio_to_magnitude_db(self, audio):
        '''Convert a given linear spectrogram to a magnitude spectrogram.'''
        stft = self.audio_to_stft(audio)
        magnitude_in_amp =  self.stft_to_magnitude(stft)
        magnitude_in_db = self.amp_to_db(magnitude_in_amp)
        return self.normalize(magnitude_in_db)

## 1.2 Text Processor

In [4]:
def sort_by_key(dictionary: dict):
    return dict(sorted(dictionary))

### 1.2.1 Acronym Normalizer

In [5]:
import json

ACRONYMS_FILEPATH = "./acronyms.json"

with open(ACRONYMS_FILEPATH, "r", encoding="utf-8") as file:
    ACRONYMS = sort_by_key(json.load(file).items())

In [6]:
import re

class AcronymNormalizer(object):

    pattern = re.compile(r"\b(" + "|".join(map(re.escape, ACRONYMS)) + r")\b")

    @classmethod
    def normalize(cls, text: str):
        def replace_unit(match):
            return ACRONYMS[match.group(0)]

        return cls.pattern.sub(replace_unit, text)


### 1.2.2 Breaker Normalizer

In [7]:
BREAKS = {
    ".": " chấm ",
    ",": " phẩy ",
}

In [8]:
import re

class BreakNormalizer(object):

    BREAKS = BREAKS

    duplicate_dot_comma_pattern = re.compile(r"([,.]){2,}")
    adjacent_symbols_pattern = re.compile(r"(\S)([,.])(\S)")
    left_symbol_pattern = re.compile(r"(\S)([,.])")
    right_symbol_pattern = re.compile(r"([,.])(\S)")

    @classmethod
    def normalize(cls, text):
        text = cls.duplicate_dot_comma_pattern.sub(lambda m: m.group(1), text)

        def replace_dot_and_comma(match):
            return match.group(1) + cls.BREAKS[match.group(2)] + match.group(3)

        text = cls.adjacent_symbols_pattern.sub(replace_dot_and_comma, text)
        text = cls.left_symbol_pattern.sub(r"\1 \2", text)
        text = cls.right_symbol_pattern.sub(r"\1 \2", text)

        return text

### 1.2.3 Character Normalizer

In [9]:
import re


class CharacterNormalizer(object):

    pattern = re.compile(r"[^a-zA-Z0-9\sđâăêôơư.,]")

    @classmethod
    def normalize(cls, text: str):

        return cls.pattern.sub("", text)


### 1.2.4 Date Normalizer

In [10]:
import json


DATE_PREFIXS_FILEPATH = "./date_prefixs.json"

with open(DATE_PREFIXS_FILEPATH, "r", encoding="utf-8") as file:
    DATE_PREFIXS = sorted(json.load(file), key=len, reverse=True)

In [11]:
import re


class DateNormalizer(object):

    DATE_PREFIXS = DATE_PREFIXS

    date_pattern1 = re.compile(r"(\b\w{0,4}\b)\s*([12][0-9]|3[01]|0?[1-9])\/(1[0-2]|0?[1-9])\/(\d{1,4})")
    date_pattern2 = re.compile(r"(\b\w{0,4}\b)\s*([12][0-9]|3[01]|0?[1-9])\-(1[0-2]|0?[1-9])\-(\d{1,4})")
    date_pattern3 = re.compile(r"(\b\w{0,5}\b)\s*(0?[1-9]|1[0,1,2])[\/|\-](\d{4})")
    prefixs = "|".join(DATE_PREFIXS)
    date_pattern4 = re.compile(r"(" + prefixs + r")\s([12][0-9]|3[01]|0?[1-9])[\-|\/](1[0-2]|0?[1-9])")

    @classmethod
    def normalize_date_pattern1(cls, text: str):
        # Date pattern 1
        # Example: 11/12/2002

        def replace(match):
            prefix = match.group(1).strip()
            day = match.group(2)
            month = match.group(3)
            year = match.group(4)

            if prefix == "ngày":
                return f"{prefix} {day} tháng {month} năm {year}"
            else:
                return f'{prefix + " " if prefix != "" else ""}ngày {day} tháng {month} năm {year}'

        return cls.date_pattern1.sub(replace, text)

    @classmethod
    def normalize_date_pattern2(cls, text: str):
        # Date pattern 1
        # Example: 11-12-2002

        def replace(match):
            prefix = match.group(1).strip()
            day = match.group(2)
            month = match.group(3)
            year = match.group(4)

            if prefix == "ngày":
                return f"{prefix} {day} tháng {month} năm {year}"
            else:
                return f'{prefix + " " if prefix != "" else ""}ngày {day} tháng {month} năm {year}'

        return cls.date_pattern2.sub(replace, text)

    @classmethod
    def normalize_date_pattern3(cls, text: str):
        # Date pattern 3
        # Example: 12/2022 -> tháng 12 năm 2002

        def replace(match):
            prefix = match.group(1)
            month = match.group(2)
            year = match.group(3)

            if prefix == "tháng":
                return f"tháng {month} năm {year}"
            else:
                return f'{prefix + " " if prefix != "" else ""}tháng {month} năm {year}'

        return cls.date_pattern3.sub(replace, text)

    @classmethod
    def normalize_date_pattern4(cls, text: str):
        # Date pattern 4
        # Example: ngày 11/12

        def replace(match):
            prefix = match.group(1)
            day = match.group(2)
            month = match.group(3)

            if prefix == "ngày":
                return f"ngày {day} tháng {month}"
            else:
                return f'{prefix + " " if prefix != "" else ""}ngày {day} tháng {month}'

        return cls.date_pattern4.sub(replace, text)

    @classmethod
    def normalize(cls, text: str):
        text = cls.normalize_date_pattern4(text)
        text = cls.normalize_date_pattern1(text)
        text = cls.normalize_date_pattern2(text)
        text = cls.normalize_date_pattern3(text)
        return text


### 1.2.5 Letter Normalizer

In [12]:
import json

LETTERS_FILEPATH = "./letters.json"

with open(LETTERS_FILEPATH, "r", encoding="utf-8") as file:
    LETTERS = sort_by_key(json.load(file).items())

In [13]:
import re

class LetterNormalizer(object):

    pattern = re.compile(r"\b(" + "|".join(map(re.escape, LETTERS)) + r")\b")

    @classmethod
    def normalize(cls, text: str):

        def replace_unit(match):
            return LETTERS[match.group(0)]

        return cls.pattern.sub(replace_unit, text)


### 1.2.6 Number Normalizer

In [14]:
import json

BASE_NUMBERS_FILEPATH = "./base_numbers.json"

with open(BASE_NUMBERS_FILEPATH, "r") as file:
    BASE_NUMBERS = {int(key): value for key, value in json.load(file).items()}

In [15]:
import json

NUMBER_LEVEL_FILEPATH = "./number_levels.json"

with open(NUMBER_LEVEL_FILEPATH, "r") as file:
    NUMBER_LEVELS = {int(key): value for key, value in json.load(file).items()}

In [16]:
import re

class NumberNomalizer(object):

    pattern = re.compile(r"\d+")

    @classmethod
    def _convert_number_2_digits(cls, number: int):
        if number in BASE_NUMBERS:
            return BASE_NUMBERS[number]

        tens = number // 10
        base = number % 10
        if base > 0:
            return f"{BASE_NUMBERS[tens]} mươi {BASE_NUMBERS[base]}"

        return f"{BASE_NUMBERS[tens]} mươi"

    @classmethod
    def _convert_number_3_digits(cls, number: int):
        if number == 0:
            return ""

        remainder = number % 100
        hundred = number // 100
        if remainder == 0:
            return f"{BASE_NUMBERS[hundred]} trăm"

        if remainder < 10:
            return f"{BASE_NUMBERS[number // 100]} trăm linh {BASE_NUMBERS[remainder]}"

        return f"{BASE_NUMBERS[hundred]} trăm {cls._convert_number_2_digits(remainder)}"

    @classmethod
    def number_to_vietnamese(cls, number: int):
        if number == 0:
            return "không"

        if number in BASE_NUMBERS:
            return BASE_NUMBERS[number]

        if number < 100:
            return cls._convert_number_2_digits(number)

        result = cls._convert_number_3_digits(number % 1000)
        current_level = None

        for current_level in NUMBER_LEVELS:
            next_level = current_level * 1000
            if number // (next_level) == 0:
                break
            level_base = number % (next_level) // current_level
            result = f"{cls._convert_number_3_digits(level_base)} {NUMBER_LEVELS[current_level]} {result}"

        level_base = number // current_level

        if level_base == 0:
            return result

        if level_base in BASE_NUMBERS:
            return f"{BASE_NUMBERS[level_base]} {NUMBER_LEVELS[current_level]} {result}"

        if level_base > 99:
            return f"{cls._convert_number_3_digits(level_base)} {NUMBER_LEVELS[current_level]} {result}"

        if level_base > 11:
            return f"{cls._convert_number_2_digits(level_base)} {NUMBER_LEVELS[current_level]} {result}"

    @classmethod
    def normalize(cls, text: str) -> str:

        replaced_text = cls.pattern.sub(lambda x: cls.number_to_vietnamese(int(x.group())), text)

        return replaced_text


### 1.2.7 Phoneme Normalizer

In [17]:
import json

SAME_PHONEMES_FILEPATH = "./same_phonemes.json"

with open(SAME_PHONEMES_FILEPATH, "r", encoding="utf-8") as file:
    SAME_PHONEMES = sort_by_key(json.load(file).items())

In [18]:
import re

class PhonemeNormalizer(object):

    pattern = re.compile(r"(" + "|".join(map(re.escape, SAME_PHONEMES)) + r")")

    @classmethod
    def normalize(cls, text: str):

        def replace_symbol(match):
            return SAME_PHONEMES[match.group(0)]

        return cls.pattern.sub(replace_symbol, text)


### 1.2.8 Symbol Normalizer

In [19]:
import json

SYMBOLS_FILEPATH = "./symbols.json"

with open(SYMBOLS_FILEPATH, "r", encoding="utf-8") as file:
    SYMBOLS = sort_by_key(json.load(file).items())

In [20]:
import re

class SymbolNormalizer(object):

    pattern = re.compile(r"([\s\S])(" + "|".join(map(re.escape, SYMBOLS)) + r")([\s\S])")

    @classmethod
    def normalize(cls, text: str):

        def replace_symbol(match):
            return (
                (match.group(1) if match.group(1) == " " else match.group(1) + " ")
                + SYMBOLS[match.group(2)]
                + (match.group(3) if match.group(3) == " " else match.group(3) + " ")
            )

        return cls.pattern.sub(replace_symbol, text)


### 1.2.9 Tone Normalizer

In [21]:
import json

TONES_FILEPATH = "./tones.json"

with open(TONES_FILEPATH, "r") as file:
    TONES = json.load(file)

In [22]:
import re

class ToneNormalizer(object):
    pattern = re.compile(r"(\w*)([áàảãạấầẩẫậắằẳẵặéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựýỳỷỹỵ])(\w*)")

    @classmethod
    def normalize(cls, text):

        def replace(match):
            accented = match.group(2)
            base, tone = TONES[accented]
            return f"{match.group(1)}{base}{match.group(3)}{tone}"

        text = cls.pattern.sub(replace, text)
        return text


### 1.2.10 Unit Normalizer

In [23]:
import json

UNITS_FILEPATH = "./units.json"

with open(UNITS_FILEPATH, "r", encoding="utf-8") as file:
    UNITS = sort_by_key(json.load(file).items())

In [24]:
import re

class UnitNormalizer(object):

    pattern = re.compile(r"\b(" + "|".join(map(re.escape, UNITS)) + r")\b")

    @classmethod
    def normalize(cls, text):

        def replace_unit(match):
            return UNITS[match.group(0)]

        return cls.pattern.sub(replace_unit, text)


### 1.2.11 Text Normalizer

In [25]:
DEFAULT_PIPELINE = [
    DateNormalizer,
    NumberNomalizer,
    LetterNormalizer,
    AcronymNormalizer,
    SymbolNormalizer,
    UnitNormalizer,
    PhonemeNormalizer,
    ToneNormalizer,
    CharacterNormalizer,
    BreakNormalizer,
]

In [26]:
class TextNormalizer(object):

    def __init__(self, pipeline=DEFAULT_PIPELINE, lower=True):
        self.pipeline = pipeline
        self.lower = lower

    def normalize(self, text):
        if self.lower:
            text = text.lower()

        for processor in self.pipeline:
            text = processor.normalize(text)

        return text

    def __call__(self, text):
        return self.normalize(text)

### 1.2.12 Text To Sequence

In [27]:
ACCENTS = ['1', '2', '3', '4', '5']

In [28]:
VOWELS_FILEPATH = "./vowels.json"

with open(VOWELS_FILEPATH, "r", encoding="utf-8") as file:
    VOWELS = sorted(json.load(file), key=len, reverse=True)

In [29]:
HEAD_CONSONANTS_FILEPATH = "./head_consonants.json"

with open(HEAD_CONSONANTS_FILEPATH, "r", encoding="utf-8") as file:
    HEAD_CONSONANTS = sorted(json.load(file), key=len, reverse=True)

In [30]:
FINAL_CONSONANTS_FILEPATH = "./final_consonants.json"
    
with open(FINAL_CONSONANTS_FILEPATH, "r", encoding="utf-8") as file:
    FINAL_CONSONANTS = sorted(json.load(file), key=len, reverse=True)


In [31]:
PHONEMES = sorted(VOWELS + HEAD_CONSONANTS + FINAL_CONSONANTS + ACCENTS + list(BREAKS.keys()) + [" "], key=len, reverse=True)

In [32]:
len(PHONEMES)

92

In [33]:
phoneme_to_ids = {s: i for i, s in enumerate(PHONEMES)}

In [34]:
import re

class WordByPhonemesEmbedding(object):

    def __init__(self, phonemes=PHONEMES, normalizer=TextNormalizer(), spliter=" "):
        self.phonemes = phonemes
        self.normalizer = normalizer
        self.spliter = spliter

    def _parse_head_constants(self, word):
        pattern = r'^(' + '|'.join(HEAD_CONSONANTS) + ')'
        match = re.match(pattern, word)
        head_consonant = None
        if match:
            head_consonant = r'\b' + match.group(1)
        return re.sub(pattern, '', word), head_consonant
    
    def _parse_vowels(self, word):
        pattern = r'^(' + '|'.join(VOWELS) + ')'
        match = re.match(pattern, word)
        vowel = None
        if match:
            vowel =  match.group(1)
        return re.sub(pattern, '', word), vowel

    def _parse_final_constants(self, word):
        pattern = r'^(' + '|'.join(FINAL_CONSONANTS) + ')'
        match = re.match(pattern, word)
        final_consonant = None
        if match:
            final_consonant =  match.group(1)
        return re.sub(pattern, '', word), final_consonant

    def word2vec(self, word:str):
        embedding_vector = []
        
        word, head_consonant = self._parse_head_constants(word)
        word, vowel = self._parse_vowels(word)
        word, final_consonant = self._parse_final_constants(word)

        if head_consonant is not None:
            embedding_vector.append(phoneme_to_ids[head_consonant])
            
        if vowel is not None:
            embedding_vector.append(phoneme_to_ids[vowel])
            
        if final_consonant is not None:
            embedding_vector.append(phoneme_to_ids[final_consonant])
        
        if len(word) > 0 and word[-1] in PHONEMES:
            accent_or_break = word[-1]
            embedding_vector.append(phoneme_to_ids[accent_or_break])

        return {
            "head_consonant": head_consonant,
            "final_consonant": final_consonant,
            "vowel": vowel,
            "emmbedding_vector": embedding_vector
        }

    def embedding(self, text):
        text = self.normalizer.normalize(text)
        words = text.split(self.spliter)
        sequence = []
        for word in words:
            sequence.extend(self.word2vec(word)["emmbedding_vector"])
            sequence.append(phoneme_to_ids[self.spliter])
        return sequence[:-1]
        # return [self.word2vec(word)["emmbedding_vector"] for word in words]

    def __call__(self, text):
        return self.embedding(text)


## English Normalizer

In [None]:
import re


class ENAcronymNormalizer(object):
    
    _acronyms = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
        ("mrs", "misess"),
        ("mr", "mister"),
        ("dr", "doctor"),
        ("st", "saint"),
        ("co", "company"),
        ("jr", "junior"),
        ("maj", "major"),
        ("gen", "general"),
        ("drs", "doctors"),
        ("rev", "reverend"),
        ("lt", "lieutenant"),
        ("hon", "honorable"),
        ("sgt", "sergeant"),
        ("capt", "captain"),
        ("esq", "esquire"),
        ("ltd", "limited"),
        ("col", "colonel"),
        ("ft", "fort"),
    ]]


    def normalize_acronym(self, text):
        for regex, replacement in self._acronyms:
            text = re.sub(regex, replacement, text)
            
        return text

    def __call__(self, text):
        return self.normalize_acronym(text)

In [None]:
import re
import inflect


class ENNumberNormalizer(object):
    
    _inflect = inflect.engine()
    _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
    _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
    _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
    _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
    _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
    _number_re = re.compile(r"[0-9]+")

    def _remove_commas(self, m):
        return m.group(1).replace(",", "")

    def _expand_decimal_point(self, m):
        return m.group(1).replace(".", " point ")

    def _expand_dollars(self, m):
        match = m.group(1)
        parts = match.split(".")
        
        if len(parts) > 2:
            return match + " dollars"
        
        dollars = int(parts[0]) if parts[0] else 0
        cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
        
        if dollars and cents:
            dollar_unit = "dollar" if dollars == 1 else "dollars"
            cent_unit = "cent" if cents == 1 else "cents"
            
            return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
        
        elif dollars:
            dollar_unit = "dollar" if dollars == 1 else "dollars"
            return "%s %s" % (dollars, dollar_unit)
        
        elif cents:
            cent_unit = "cent" if cents == 1 else "cents"
            
            return "%s %s" % (cents, cent_unit)
        
        else:
            return "zero dollars"

    def _expand_ordinal(self, m):
        return self._inflect.number_to_words(m.group(0))

    def _expand_number(self, m):
        num = int(m.group(0))
        if num > 1000 and num < 3000:
            if num == 2000:
                return "two thousand"
            elif num > 2000 and num < 2010:
                return "two thousand " + self._inflect.number_to_words(num % 100)
            elif num % 100 == 0:
                return self._inflect.number_to_words(num // 100) + " hundred"
            else:
                return self._inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
        else:
            return self._inflect.number_to_words(num, andword="")

    def normalize_numbers(self, text):
        text = re.sub(self._comma_number_re, self._remove_commas, text)
        text = re.sub(self._pounds_re, r"\1 pounds", text)
        text = re.sub(self._dollars_re, self._expand_dollars, text)
        text = re.sub(self._decimal_number_re, self._expand_decimal_point, text)
        text = re.sub(self._ordinal_re, self._expand_ordinal, text)
        text = re.sub(self._number_re, self._expand_number, text)
        
        return text
    
    def __call__(self, text):
        return self.normalize_numbers(text)

In [None]:
import re


class WhiteSpaceNormalizer(object):

    _whitespace_re = re.compile(r"\s+")

    def collapse_whitespace(self, text):
        return re.sub(self._whitespace_re, " ", text)
    
    def __call__(self, text):
        return self.collapse_whitespace(text)


In [None]:
import re
from unidecode import unidecode

class EnglishText2Sequence(object):
    
    _pad        = "_"
    _eos        = "~"
    _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "

    symbols = [_pad, _eos] + list(_characters)

    _symbol_to_id = {s: i for i, s in enumerate(symbols)}
    _id_to_symbol = {i: s for i, s in enumerate(symbols)}
    _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")

    _acronym_normalizer = ENAcronymNormalizer()
    _number_normalizer = ENNumberNormalizer()
    _whitespace_normalizer = WhiteSpaceNormalizer()

    def lowercase(self, text):
        return text.lower()


    def convert_to_ascii(self, text):
        return unidecode(text)


    def normalize(self, text):
        text = self.convert_to_ascii(text)
        text = self.lowercase(text)
        text = self._number_normalizer(text)
        text = self._acronym_normalizer(text)
        text = self._whitespace_normalizer(text)
        
        return text
    
    def _symbols_to_sequence(self, symbols):
        return [self._symbol_to_id[s] for s in symbols if s in self._symbol_to_id and s not in ("_", "~")]

    
    def text_to_sequence(self, text):
        sequence = []

        while len(text):
            m = self._curly_re.match(text)
            if not m:
                sequence += self._symbols_to_sequence(self.normalize(text))
                break
            
            sequence += self._symbols_to_sequence(self.normalize(m.group(1)))
            sequence += self._arpabet_to_sequence(m.group(2))
            text = m.group(3)

        sequence.append(self._symbol_to_id["~"])
        return sequence
    
    def __call__(self, text):
        return self.text_to_sequence(text)

# 2. Speech Encoder

## 2.1 Speech Encoder Audio Config

In [35]:
class SpeakerEncoderAudioConfig(AudioConfig):
    N_MELS = 80
    SAMPLE_RATE = 16000
    FRAME_SHIFT = 0.01
    FRAME_LENGTH = 0.025
    HOP_LENGTH = int(SAMPLE_RATE * FRAME_SHIFT)
    WIN_LENGTH = int(SAMPLE_RATE * FRAME_LENGTH)
    N_FFT = 1024
    FMIN = 90
    FMAX = 7600
    ZERO_THRESHOLD = 1e-5
    MIN_AMPLITUDE = 0.3
    MAX_AMPLITUDE = 1.0
    MIN_LEVEL_DB = -100
    REF_LEVEL_DB = 0
    NUM_FRAMES = 160 * 30
    SCALING_FACTOR = 0.95

## 2.2. Speech Encoder

### 2.2.1 Define Model

In [36]:
import math
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from torch.nn.utils import clip_grad_norm_
from scipy.optimize import brentq
import torch
import torch.nn as nn
import numpy as np

class SpeechTransformerEncoder(nn.Module):
    def __init__(self, input_size=80, hidden_size=786, num_layers=12, num_heads=8, device='cpu', loss_device='cpu'):
        super().__init__()
        self.device = device
        self.loss_device = loss_device
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_size,
            nhead=num_heads,
            dim_feedforward=hidden_size,
            dropout=0.05,
            activation='relu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)

        self.linear = nn.Linear(in_features=input_size, out_features=256).to(device)
        self.relu = nn.ReLU().to(device)

        # Cosine similarity scaling (with fixed initial parameter values)
        self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
        self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)

        # Loss
        self.loss_fn = nn.CrossEntropyLoss().to(loss_device)

    def do_gradient_ops(self):
        # Gradient scale
        self.similarity_weight.grad *= 0.01
        self.similarity_bias.grad *= 0.01

        # Gradient clipping
        clip_grad_norm_(self.parameters(), 3, norm_type=2)

    def forward(self, utterances, hidden_init=None):
        """
        Computes the embeddings of a batch of utterance spectrograms.

        :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
        (batch_size, n_frames, n_channels)
        :param hidden_init: not used in the Transformer version
        :return: the embeddings as a tensor of shape (batch_size, embedding_size)
        """
        utterances = utterances.to(self.device)
        # Pass the input through the Transformer Encoder
        out = self.transformer_encoder(utterances)

        # We take the mean of all time steps (similar to a global pooling)
        embeds_raw = self.relu(self.linear(out.mean(dim=1)))

        # L2-normalize it
        embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)

        return embeds

    def similarity_matrix(self, embeds):
        """
        Computes the similarity matrix according of GE2E.

        :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, embedding_size)
        :return: the similarity matrix as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, speakers_per_batch)
        """
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

        # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
        centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
        centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)

        # Exclusive centroids (1 per utterance)
        centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
        centroids_excl /= (utterances_per_speaker - 1)
        centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)

        # Similarity matrix computation
        sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                 speakers_per_batch).to(self.loss_device)
        mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int64)
        for j in range(speakers_per_batch):
            mask = np.where(mask_matrix[j])[0]
            sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
            sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)

        sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
        return sim_matrix

    def loss(self, embeds):
        """
        Computes the softmax loss according the section 2.1 of GE2E.

        :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, embedding_size)
        :return: the loss and the EER for this batch of embeddings.
        """
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

        # Loss
        sim_matrix = self.similarity_matrix(embeds)
        sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
                                         speakers_per_batch))
        ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
        target = torch.from_numpy(ground_truth).long().to(self.loss_device)
        loss = self.loss_fn(sim_matrix, target)

        # EER (not backpropagated)
        with torch.no_grad():
            inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int64)[0]
            labels = np.array([inv_argmax(i) for i in ground_truth])
            preds = sim_matrix.detach().cpu().numpy()

            # Snippet from https://yangcha.github.io/EER-ROC/
            fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
            eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

        return loss, eer


### 2.2.2 Speech Encoder Model Config

In [37]:
class SpeechrTransformerEncoderModelConfigs:
    DEVICE = "cuda:0"
    LOSS_DEVICE = "cpu"
    MODEL_PATH = "./000000036000.pt"

### 2.2.3 Load Speech Encoder Model

In [38]:
import torch
import torch.nn as nn

def load_speaker_transformer_encoder(model_settings):
    model = SpeechTransformerEncoder(
        device=model_settings.DEVICE, loss_device=model_settings.LOSS_DEVICE
    )
    ckpt = torch.load(model_settings.MODEL_PATH, weights_only=False,
                      map_location=torch.device('cuda:0'))

    if ckpt:
        model.load_state_dict(ckpt["model_state_dict"], strict=False)

    model.eval()
    model.to(model_settings.DEVICE)
    return nn.DataParallel(model)

In [39]:
# SPEECH_TRANSFORMER_ENCODER = load_speaker_transformer_encoder(
#     SpeechrTransformerEncoderModelConfigs
# )

SPEECH_TRANSFORMER_ENCODER = None

# 3. Text to Speech

## 3.1 Text to Speech Audio Config

In [40]:
class Text2SpeechAudioConfig(AudioConfig):
    N_MELS = 80
    SAMPLE_RATE = 16000
    N_FFT = 800
    FRAME_SHIFT = 0.0125
    FRAME_LENGTH = 0.05
    REF_LEVEL_DB = 20
    HOP_LENGTH = int(SAMPLE_RATE * FRAME_SHIFT)
    WIN_LENGTH = int(SAMPLE_RATE * FRAME_LENGTH)
    PRE_EMPHASIS = 0.97
    POWER = 1.2
    FMIN = 90
    FMAX = 7600
    ZERO_THRESHOLD = 1e-5
    MIN_LEVEL_DB = -100

## 3.2 Text to Speech Model

### 3.2.1 Positional Encoding

### 3.2.2 Model

In [42]:
import ast
import pprint

class HParams(object):
    def __init__(self, **kwargs): self.__dict__.update(kwargs)
    def __setitem__(self, key, value): setattr(self, key, value)
    def __getitem__(self, key): return getattr(self, key)
    def __repr__(self): return pprint.pformat(self.__dict__)

    def parse(self, string):
        # Overrides hparams from a comma-separated string of name=value pairs
        if len(string) > 0:
            overrides = [s.split("=") for s in string.split(",")]
            keys, values = zip(*overrides)
            keys = list(map(str.strip, keys))
            values = list(map(str.strip, values))
            for k in keys:
                self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
        return self

hparams = HParams(
        ### Signal Processing (used in both synthesizer and vocoder)
        sample_rate = 16000,
        n_fft = 800,
        num_mels = 80,
        hop_size = 200,                             # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
        win_size = 800,                             # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
        fmin = 55,
        min_level_db = -100,
        ref_level_db = 20,
        max_abs_value = 4.,                         # Gradient explodes if too big, premature convergence if too small.
        preemphasis = 0.97,                         # Filter coefficient to use if preemphasize is True
        preemphasize = True,

        ### Tacotron Text-to-Speech (TTS)
        tts_embed_dims = 512,                       # Embedding dimension for the graphemes/phoneme inputs
        tts_encoder_dims = 256,
        tts_decoder_dims = 128,
        tts_postnet_dims = 512,
        tts_encoder_K = 5,
        tts_lstm_dims = 1024,
        tts_postnet_K = 5,
        tts_num_highways = 4,
        tts_dropout = 0.5,
        tts_cleaner_names = ["english_cleaners"],
        tts_stop_threshold = -3.4,                  # Value below which audio generation ends.
                                                    # For example, for a range of [-4, 4], this
                                                    # will terminate the sequence at the first
                                                    # frame that has all values < -3.4

        ### Tacotron Training
        tts_schedule = [(2,  1e-3,  20_000,  2),   # Progressive training schedule
                        (2,  5e-4,  40_000,  2),   # (r, lr, step, batch_size)
                        (2,  2e-4,  80_000,  2),   #
                        (2,  1e-4, 160_000,  2),   # r = reduction factor (# of mel frames
                        (2,  3e-5, 320_000,  2),   #     synthesized for each decoder iteration)
                        (2,  1e-5, 640_000,  2)],  # lr = learning rate

        tts_clip_grad_norm = 1.0,                   # clips the gradient norm to prevent explosion - set to None if not needed
        tts_eval_interval = 500,                    # Number of steps between model evaluation (sample generation)
                                                    # Set to -1 to generate after completing epoch, or 0 to disable

        tts_eval_num_samples = 1,                   # Makes this number of samples

        ### Data Preprocessing
        max_mel_frames = 900,
        rescale = True,
        rescaling_max = 0.9,
        synthesis_batch_size = 16,                  # For vocoder preprocessing and inference.

        ### Mel Visualization and Griffin-Lim
        signal_normalization = True,
        power = 1.5,
        griffin_lim_iters = 60,

        ### Audio processing options
        fmax = 7600,                                # Should not exceed (sample_rate // 2)
        allow_clipping_in_normalization = True,     # Used when signal_normalization = True
        clip_mels_length = True,                    # If true, discards samples exceeding max_mel_frames
        use_lws = False,                            # "Fast spectrogram phase recovery using local weighted sums"
        symmetric_mels = True,                      # Sets mel range to [-max_abs_value, max_abs_value] if True,
                                                    #               and [0, max_abs_value] if False
        trim_silence = True,                        # Use with sample_rate of 16000 for best results

        ### SV2TTS
        speaker_embedding_size = 256,               # Dimension for the speaker embedding
        silence_min_duration_split = 0.4,           # Duration in seconds of a silence for an utterance to be split
        utterance_min_duration = 1.6,               # Duration in seconds below which utterances are discarded
        )

def hparams_debug_string():
    return str(hparams)

In [43]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Union


class HighwayNetwork(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.W1 = nn.Linear(size, size)
        self.W2 = nn.Linear(size, size)
        self.W1.bias.data.fill_(0.)

    def forward(self, x):
        x1 = self.W1(x)
        x2 = self.W2(x)
        g = torch.sigmoid(x2)
        y = g * F.relu(x1) + (1. - g) * x
        return y


class Encoder(nn.Module):
    def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
        super().__init__()
        prenet_dims = (encoder_dims, encoder_dims)
        cbhg_channels = encoder_dims
        self.embedding = nn.Embedding(num_chars, embed_dims)
        self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
                              dropout=dropout)
        self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
                         proj_channels=[cbhg_channels, cbhg_channels],
                         num_highways=num_highways)

    def forward(self, x, speaker_embedding=None):
        x = self.embedding(x)
        x = self.pre_net(x)
        x.transpose_(1, 2)
        x = self.cbhg(x)
        if speaker_embedding is not None:
            x = self.add_speaker_embedding(x, speaker_embedding)
        return x

    def add_speaker_embedding(self, x, speaker_embedding):
        # SV2TTS
        # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
        # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
        #     (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
        # This concats the speaker embedding for each char in the encoder output

        # Save the dimensions as human-readable names
        batch_size = x.size()[0]
        num_chars = x.size()[1]

        if speaker_embedding.dim() == 1:
            idx = 0
        else:
            idx = 1

        # Start by making a copy of each speaker embedding to match the input text length
        # The output of this has size (batch_size, num_chars * tts_embed_dims)
        speaker_embedding_size = speaker_embedding.size()[idx]
        e = speaker_embedding.repeat_interleave(num_chars, dim=idx)

        # Reshape it and transpose
        e = e.reshape(batch_size, speaker_embedding_size, num_chars)
        e = e.transpose(1, 2)

        # Concatenate the tiled speaker embedding with the encoder output
        x = torch.cat((x, e), 2)
        return x


class BatchNormConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, relu=True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
        self.bnorm = nn.BatchNorm1d(out_channels)
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x) if self.relu is True else x
        return self.bnorm(x)


class CBHG(nn.Module):
    def __init__(self, K, in_channels, channels, proj_channels, num_highways):
        super().__init__()

        # List of all rnns to call `flatten_parameters()` on
        self._to_flatten = []

        self.bank_kernels = [i for i in range(1, K + 1)]
        self.conv1d_bank = nn.ModuleList()
        for k in self.bank_kernels:
            conv = BatchNormConv(in_channels, channels, k)
            self.conv1d_bank.append(conv)

        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)

        self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
        self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)

        # Fix the highway input if necessary
        if proj_channels[-1] != channels:
            self.highway_mismatch = True
            self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
        else:
            self.highway_mismatch = False

        self.highways = nn.ModuleList()
        for i in range(num_highways):
            hn = HighwayNetwork(channels)
            self.highways.append(hn)

        self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
        self._to_flatten.append(self.rnn)

        # Avoid fragmentation of RNN parameters and associated warning
        self._flatten_parameters()

    def forward(self, x):
        # Although we `_flatten_parameters()` on init, when using DataParallel
        # the model gets replicated, making it no longer guaranteed that the
        # weights are contiguous in GPU memory. Hence, we must call it again
        self._flatten_parameters()

        # Save these for later
        residual = x
        seq_len = x.size(-1)
        conv_bank = []

        # Convolution Bank
        for conv in self.conv1d_bank:
            c = conv(x) # Convolution
            conv_bank.append(c[:, :, :seq_len])

        # Stack along the channel axis
        conv_bank = torch.cat(conv_bank, dim=1)

        # dump the last padding to fit residual
        x = self.maxpool(conv_bank)[:, :, :seq_len]

        # Conv1d projections
        x = self.conv_project1(x)
        x = self.conv_project2(x)

        # Residual Connect
        x = x + residual

        # Through the highways
        x = x.transpose(1, 2)
        if self.highway_mismatch is True:
            x = self.pre_highway(x)
        for h in self.highways: x = h(x)

        # And then the RNN
        x, _ = self.rnn(x)
        return x

    def _flatten_parameters(self):
        """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
        to improve efficiency and avoid PyTorch yelling at us."""
        [m.flatten_parameters() for m in self._to_flatten]

class PreNet(nn.Module):
    def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
        super().__init__()
        self.fc1 = nn.Linear(in_dims, fc1_dims)
        self.fc2 = nn.Linear(fc1_dims, fc2_dims)
        self.p = dropout

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = F.dropout(x, self.p, training=True)
        x = self.fc2(x)
        x = F.relu(x)
        x = F.dropout(x, self.p, training=True)
        return x


class Attention(nn.Module):
    def __init__(self, attn_dims):
        super().__init__()
        self.W = nn.Linear(attn_dims, attn_dims, bias=False)
        self.v = nn.Linear(attn_dims, 1, bias=False)

    def forward(self, encoder_seq_proj, query, t):

        # print(encoder_seq_proj.shape)
        # Transform the query vector
        query_proj = self.W(query).unsqueeze(1)

        # Compute the scores
        u = self.v(torch.tanh(encoder_seq_proj + query_proj))
        scores = F.softmax(u, dim=1)

        return scores.transpose(1, 2)


class LSA(nn.Module):
    def __init__(self, attn_dim, kernel_size=31, filters=32):
        super().__init__()
        self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
        self.L = nn.Linear(filters, attn_dim, bias=False)
        self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
        self.v = nn.Linear(attn_dim, 1, bias=False)
        self.cumulative = None
        self.attention = None

    def init_attention(self, encoder_seq_proj):
        device = next(self.parameters()).device  # use same device as parameters
        b, t, c = encoder_seq_proj.size()
        self.cumulative = torch.zeros(b, t, device=device)
        self.attention = torch.zeros(b, t, device=device)

    def forward(self, encoder_seq_proj, query, t, chars):

        if t == 0: self.init_attention(encoder_seq_proj)

        processed_query = self.W(query).unsqueeze(1)

        location = self.cumulative.unsqueeze(1)
        processed_loc = self.L(self.conv(location).transpose(1, 2))

        u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
        u = u.squeeze(-1)

        # Mask zero padding chars
        u = u * (chars != 0).float()

        # Smooth Attention
        # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
        scores = F.softmax(u, dim=1)
        self.attention = scores
        self.cumulative = self.cumulative + self.attention

        return scores.unsqueeze(-1).transpose(1, 2)


class Decoder(nn.Module):
    # Class variable because its value doesn't change between classes
    # yet ought to be scoped by class because its a property of a Decoder
    max_r = 20
    def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
                 dropout, speaker_embedding_size):
        super().__init__()
        self.register_buffer("r", torch.tensor(1, dtype=torch.int))
        self.n_mels = n_mels
        prenet_dims = (decoder_dims * 2, decoder_dims * 2)
        self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
                             dropout=dropout)
        self.attn_net = LSA(decoder_dims)
        self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
        self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
        self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
        self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
        self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
        self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)

    def zoneout(self, prev, current, p=0.1):
        device = next(self.parameters()).device  # Use same device as parameters
        mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
        return prev * mask + current * (1 - mask)

    def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
                hidden_states, cell_states, context_vec, t, chars):

        # Need this for reshaping mels
        batch_size = encoder_seq.size(0)

        # Unpack the hidden and cell states
        attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
        rnn1_cell, rnn2_cell = cell_states

        # PreNet for the Attention RNN
        prenet_out = self.prenet(prenet_in)

        # Compute the Attention RNN hidden state
        attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
        attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)

        # Compute the attention scores
        scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)

        # Dot product to create the context vector
        context_vec = scores @ encoder_seq
        context_vec = context_vec.squeeze(1)

        # Concat Attention RNN output w. Context Vector & project
        x = torch.cat([context_vec, attn_hidden], dim=1)
        x = self.rnn_input(x)

        # Compute first Residual RNN
        rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
        if self.training:
            rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
        else:
            rnn1_hidden = rnn1_hidden_next
        x = x + rnn1_hidden

        # Compute second Residual RNN
        rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
        if self.training:
            rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
        else:
            rnn2_hidden = rnn2_hidden_next
        x = x + rnn2_hidden

        # Project Mels
        mels = self.mel_proj(x)
        mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
        hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
        cell_states = (rnn1_cell, rnn2_cell)

        # Stop token prediction
        s = torch.cat((x, context_vec), dim=1)
        s = self.stop_proj(s)
        stop_tokens = torch.sigmoid(s)

        return mels, scores, hidden_states, cell_states, context_vec, stop_tokens


class Tacotron(nn.Module):
    def __init__(self, embed_dims=512, num_chars=81, encoder_dims=256, decoder_dims=128, n_mels=80, 
                 fft_bins=80, postnet_dims=512, encoder_K=5, lstm_dims=1024, postnet_K=5, num_highways=4,
                 dropout=0.5, stop_threshold=-3.4, speaker_embedding_size=256):
        super().__init__()
        self.n_mels = n_mels
        self.lstm_dims = lstm_dims
        self.encoder_dims = encoder_dims
        self.decoder_dims = decoder_dims
        self.speaker_embedding_size = speaker_embedding_size
        self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
                               encoder_K, num_highways, dropout)
        self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
        self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
                               dropout, speaker_embedding_size)
        self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
                            [postnet_dims, fft_bins], num_highways)
        self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)

        self.init_model()
        self.num_params()

        self.register_buffer("step", torch.zeros(1, dtype=torch.long))
        self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))

    @property
    def r(self):
        return self.decoder.r.item()

    @r.setter
    def r(self, value):
        self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)

    def forward(self, x, m, speaker_embedding):
        device = next(self.parameters()).device  # use same device as parameters

        self.step += 1
        batch_size, _, steps  = m.size()

        # Initialise all hidden states and pack into tuple
        attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
        rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
        rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
        hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)

        # Initialise all lstm cell states and pack into tuple
        rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
        rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
        cell_states = (rnn1_cell, rnn2_cell)

        # <GO> Frame for start of decoder loop
        go_frame = torch.zeros(batch_size, self.n_mels, device=device)

        # Need an initial context vector
        context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)

        # SV2TTS: Run the encoder with the speaker embedding
        # The projection avoids unnecessary matmuls in the decoder loop
        encoder_seq = self.encoder(x, speaker_embedding)
        encoder_seq_proj = self.encoder_proj(encoder_seq)

        # Need a couple of lists for outputs
        mel_outputs, attn_scores, stop_outputs = [], [], []

        # Run the decoder loop
        for t in range(0, steps, self.r):
            prenet_in = m[:, :, t - 1] if t > 0 else go_frame
            mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
                self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
                             hidden_states, cell_states, context_vec, t, x)
            mel_outputs.append(mel_frames)
            attn_scores.append(scores)
            stop_outputs.extend([stop_tokens] * self.r)

        # Concat the mel outputs into sequence
        mel_outputs = torch.cat(mel_outputs, dim=2)

        # Post-Process for Linear Spectrograms
        postnet_out = self.postnet(mel_outputs)
        linear = self.post_proj(postnet_out)
        linear = linear.transpose(1, 2)

        # For easy visualisation
        attn_scores = torch.cat(attn_scores, 1)
        # attn_scores = attn_scores.cpu().data.numpy()
        stop_outputs = torch.cat(stop_outputs, 1)

        return mel_outputs, linear, attn_scores, stop_outputs

    def generate(self, x, speaker_embedding=None, steps=2000):
        self.eval()
        device = next(self.parameters()).device  # use same device as parameters

        batch_size, _  = x.size()

        # Need to initialise all hidden states and pack into tuple for tidyness
        attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
        rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
        rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
        hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)

        # Need to initialise all lstm cell states and pack into tuple for tidyness
        rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
        rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
        cell_states = (rnn1_cell, rnn2_cell)

        # Need a <GO> Frame for start of decoder loop
        go_frame = torch.zeros(batch_size, self.n_mels, device=device)

        # Need an initial context vector
        context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)

        # SV2TTS: Run the encoder with the speaker embedding
        # The projection avoids unnecessary matmuls in the decoder loop
        encoder_seq = self.encoder(x, speaker_embedding)
        encoder_seq_proj = self.encoder_proj(encoder_seq)

        # Need a couple of lists for outputs
        mel_outputs, attn_scores, stop_outputs = [], [], []

        # Run the decoder loop
        for t in range(0, steps, self.r):
            prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
            mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
            self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
                         hidden_states, cell_states, context_vec, t, x)
            mel_outputs.append(mel_frames)
            attn_scores.append(scores)
            stop_outputs.extend([stop_tokens] * self.r)
            # Stop the loop when all stop tokens in batch exceed threshold
            if (stop_tokens > 0.5).all() and t > 10: break

        # Concat the mel outputs into sequence
        mel_outputs = torch.cat(mel_outputs, dim=2)

        # Post-Process for Linear Spectrograms
        postnet_out = self.postnet(mel_outputs)
        linear = self.post_proj(postnet_out)


        linear = linear.transpose(1, 2)

        # For easy visualisation
        attn_scores = torch.cat(attn_scores, 1)
        stop_outputs = torch.cat(stop_outputs, 1)

        self.train()

        return mel_outputs, linear, attn_scores

    def init_model(self):
        for p in self.parameters():
            if p.dim() > 1: nn.init.xavier_uniform_(p)

    def get_step(self):
        return self.step.data.item()

    def reset_step(self):
        # assignment to parameters or buffers is overloaded, updates internal dict entry
        self.step = self.step.data.new_tensor(1)

    def log(self, path, msg):
        with open(path, "a") as f:
            print(msg, file=f)

    def load(self, path, optimizer=None):
        # Use device of model params as location for loaded state
        device = next(self.parameters()).device
        checkpoint = torch.load(str(path), map_location=device)
        self.load_state_dict(checkpoint["model_state"])

        if "optimizer_state" in checkpoint and optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state"])

    def save(self, path, optimizer=None):
        if optimizer is not None:
            torch.save({
                "model_state": self.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }, str(path))
        else:
            torch.save({
                "model_state": self.state_dict(),
            }, str(path))


    def num_params(self, print_out=True):
        parameters = filter(lambda p: p.requires_grad, self.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        if print_out:
            print("Trainable Parameters: %.3fM" % parameters)
        return parameters

In [44]:
model = Tacotron()

Trainable Parameters: 30.878M


In [45]:
checkpoint = torch.load("synthesizer.pt")
model_dict = model.state_dict()

# Filter out the embedding layer from the checkpoint
filtered_dict = {k: v for k, v in checkpoint['model_state'].items() if "encoder.embedding" not in k}

# Load the filtered state dict
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)

  checkpoint = torch.load("synthesizer.pt")


<All keys matched successfully>

# 4. Dataset

## 4.1 Dataset Model

In [46]:
import numpy as np
from pathlib import Path

def load_mel(file_path: Path):
    """
    Load a Mel spectrogram saved as a .npy file.

    :param file_path: Path to the .npy file containing the Mel spectrogram
    :return: Numpy array of the Mel spectrogram
    """
    if not file_path.exists():
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        mel_spectrogram = np.load(file_path)
    except EOFError:
        print(f"File could not be loaded due to EOFError: {random_file}")
        return False
    return mel_spectrogram


In [47]:
import json
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
import os

class TTSDataset(Dataset):
    
    def __init__(self, root: Path=None, text_map: dict=None, text_map_file: Path=None, tokenizer=None):
        """
        :param root: Path to the audio files directory
        :param text_map: A dictionary mapping audio file names to their corresponding texts
        :param text_map_file: Path to the JSON file containing the text map
        :param tokenizer: A custom tokenizer for text processing
        """
        self.root = root
        self.text_map = text_map if text_map is not None else self.load_text_map(text_map_file)
        # self.audios = os.listdir("./encoded_speechs")
        self.audios = [f[:-4] for f in os.listdir("./trimmed_mels")]
        self.texts = self.text_map
        self.tokenizer = tokenizer if tokenizer is not None else self.default_tokenizer
        self.encoded_speech_dir = Path("./encoded_speechs")
        self.mel_dir = Path("./trimmed_mels")

        
    def load_text_map(self, text_map_file: Path):
        with text_map_file.open('r') as file:
            text_map = json.load(file)
        return text_map
    
    def __len__(self):
        return len(self.audios)

    def _get_audio_input(self, speech_output):
        """
        Method to get input audio for the decoder, used for inference, like teacher-forcing mode
        """
        speech_input = np.concatenate([np.zeros([1, Text2SpeechAudioConfig.N_MELS], np.float32), 
                                       speech_output[:-1, :]], axis=0)
        return speech_input

    def _get_audio_output(self, idx):
        """
        Get ground truth audio (target output), loading from .npy file if it exists.
        """
        mel_file = self.mel_dir / f"{self.audios[idx]}.npy"
        
        if mel_file.exists():
            return load_mel(mel_file)
        else:
            print(f"not found: {mel_file}")
            utterance = Utterance(
                raw_file=self.root / f"{self.audios[idx]}.npy",
                processor=AudioPreprocessor(Text2SpeechAudioConfig)
            )
            return utterance.mel_in_db().T
    
    def default_tokenizer(self, text):
        """
        Default tokenizer method if no custom tokenizer is provided.
        Tokenizes text to a sequence of integers or phonemes.
        """
        # Example: basic character tokenizer, can be replaced with more complex tokenizers
        return [ord(char) for char in text]
    
    def _get_encoded_speech(self, idx):
        """
        Get encoded speech randomly from a folder.
        """
        audio_name = self.audios[idx]  # Extract the audio name without extension
        encoded_speech_subfolder = self.encoded_speech_dir / audio_name  # Path to subfolder
        encoded_speech_files = list(encoded_speech_subfolder.glob("*.npy"))  # List of all `.pt` files
        
        if not encoded_speech_files:
            utterance = Utterance(raw_file=self.root / self.audios[idx], 
                              processor=AudioPreprocessor(SpeakerEncoderAudioConfig))
        
            random_mel = torch.tensor(np.array([utterance.random_mel_in_db(4800)])).transpose(1, 2)
            
            with torch.no_grad():
                encoded_speech = SPEECH_TRANSFORMER_ENCODER(random_mel)
                
            return encoded_speech
        
        random_file = random.choice(encoded_speech_files)  # Select a random `.pt` file
        try:
            encoded_speech = np.load(random_file)
        except EOFError:
            utterance = Utterance(raw_file=self.root / self.audios[idx], 
                              processor=AudioPreprocessor(SpeakerEncoderAudioConfig))
        
            random_mel = torch.tensor(np.array([utterance.random_mel_in_db(4800)])).transpose(1, 2)
            
            with torch.no_grad():
                encoded_speech = SPEECH_TRANSFORMER_ENCODER(random_mel)
                
            return encoded_speech.cpu().numpy()
        
        return encoded_speech
    
    def __getitem__(self, idx):
        """
        Get text sequence, input audio (for teacher-forcing), and output audio (ground truth)
        """
        # Text to sequence using custom tokenizer
        text_sequence = self.tokenizer(self.texts[self.audios[idx]])

        # Get Encoding Speech
        encoded_speech = self._get_encoded_speech(idx)

        # Get the audio output (target output)
        output_audio = self._get_audio_output(idx)
        
        return text_sequence, output_audio, encoded_speech, idx, output_audio.shape[1]

In [48]:
def collate_synthesizer(batch, r, hparams):
    # Text
    x_lens = [len(x[0]) for x in batch]
    max_x_len = max(x_lens)

    texts = [pad1d(x[0], max_x_len) for x in batch]
    texts = np.stack(texts)

    # Mel spectrogram
    spec_lens = [x[1].shape[-1] for x in batch]
    max_spec_len = max(spec_lens) + 1 
    # if max_spec_len % r != 0:
    #     max_spec_len += r - max_spec_len % r 

    # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
    # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
    # if hparams.symmetric_mels:
    #     mel_pad_value = -1 * hparams.max_abs_value
    # else:
    #     mel_pad_value = 0

    # mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
    mel = [pad2d(x[1], max_spec_len, pad_value=0) for x in batch]
    mel = np.stack(mel)

    # Speaker embedding (SV2TTS)
    embeds = np.array([x[2] for x in batch])

    # Index (for vocoder preprocessing)
    indices = [x[3] for x in batch]
    mel_len = [x[4] for x in batch]

    # Convert all to tensor
    texts = torch.tensor(texts).long()
    mel = torch.tensor(mel)
    embeds = torch.tensor(embeds)
    mel_len = torch.tensor(mel_len)

    return texts, mel, embeds, indices, mel_len

def pad1d(x, max_len, pad_value=0):
    return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)

def pad2d(x, max_len, pad_value=0):
    return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)

## 4.2 Load Dataset

In [49]:
from pathlib import Path

DATA_PATH = Path(r".")
TEXT_MAP_PATH = Path(r"./transcripts.json")

In [50]:
dataset = TTSDataset(root=DATA_PATH, text_map_file=TEXT_MAP_PATH, tokenizer=WordByPhonemesEmbedding())

In [51]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

In [52]:
len(train_dataset[0])

5

# 5. Trainning 

In [53]:
def save_checkpoint(
    model, optimizer, epoch,train_losses, eval_losses, save_path
):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "eval_losses": eval_losses,
    }
    model_path = f"{save_path}/model_epoch_{epoch}_subset_{current_subset}.pt"
    torch.save(checkpoint, f"{model_path}")
    print(f"Checkpoint saved at {model_path}")


In [54]:
import os
def load_checkpoint(save_path, model, optimizer, device):
    if os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        train_losses = checkpoint["train_losses"]
        eval_losses = checkpoint["eval_losses"]
        print(f"Checkpoint loaded from {save_path}")
        return model, optimizer, epoch, train_losses, eval_losses
    else:
        print(f"No checkpoint found at {save_path}, starting fresh.")
        return model, optimizer, 0, [], []


In [55]:
from datetime import datetime
from functools import partial
from pathlib import Path
import os
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader

def np_now(x: torch.Tensor): return x.detach().cpu().numpy()

def time_string():
    return datetime.now().strftime("%Y-%m-%d %H:%M")

def save_checkpoint(model, optimizer, epoch, train_losses, eval_losses, save_path, subset="full"):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "eval_losses": eval_losses,
    }
    model_path = f"{save_path}/model_epoch_{epoch}_subset_{subset}.pt"
    torch.save(checkpoint, model_path)
    print(f"Checkpoint saved at {model_path}")

def load_checkpoint(save_path, model, optimizer, device):
    if os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        train_losses = checkpoint["train_losses"]
        eval_losses = checkpoint["eval_losses"]
        print(f"Checkpoint loaded from {save_path}")
        return model, optimizer, epoch, train_losses, eval_losses
    else:
        print(f"No checkpoint found at {save_path}, starting fresh.")
        return model, optimizer, 0, [], []

def train(model, train_dataset, eval_dataset, hparams):
    weights_fpath = Path("saved_models/synthesizer.pt")
    save_dir = Path("saved_models")
    save_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model = model.to(device)
    optimizer = optim.Adam(model.parameters())

    model, optimizer, start_epoch, train_losses, eval_losses = load_checkpoint(weights_fpath, model, optimizer, device)

    for session in hparams.tts_schedule:
        r, lr, max_step, batch_size = session
        model.r = r

        collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
        eval_loader = DataLoader(eval_dataset, batch_size, shuffle=False, num_workers=2, collate_fn=collate_fn)

        for p in optimizer.param_groups:
            p['lr'] = lr

        for epoch in range(start_epoch + 1, 100):  # Continue from last epoch
            total_loss = 0
            for texts, mels, embeds, idx, mel_lens in train_loader:
                texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)

                stop = torch.ones(mels.shape[0], mels.shape[2]).to(device)
                for j, k in enumerate(idx):
                    stop[j, :int(mel_lens[j]) - 1] = 0

                m1_hat, m2_hat, _, stop_pred = model(texts, mels, embeds)

                m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
                m2_loss = F.mse_loss(m2_hat, mels)
                stop_loss = F.binary_cross_entropy(stop_pred, stop)

                loss = m1_loss + m2_loss + stop_loss
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch} - Loss: {avg_loss:.4f}")

            train_losses.append(avg_loss)

            if model.get_step() >= max_step:
                break

            if epoch % 10 == 0:  # Evaluate every 10 epochs
                eval_loss = evaluate(model, eval_loader, device, hparams)
                eval_losses.append(eval_loss)
                save_checkpoint(model, optimizer, epoch, train_losses, eval_losses, save_dir)

def evaluate(model, data_loader, device, hparams):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for texts, mels, embeds, idx in data_loader:
            texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)

            stop = torch.ones(mels.shape[0], mels.shape[2]).to(device)
            m1_hat, m2_hat, _, stop_pred = model(texts, mels, embeds)

            m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
            m2_loss = F.mse_loss(m2_hat, mels)
            stop_loss = F.binary_cross_entropy(stop_pred, stop)

            loss = m1_loss + m2_loss + stop_loss
            total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    print(f"Evaluation Loss: {avg_loss:.4f}")
    model.train()
    return avg_loss


In [56]:
train(model, train_dataset, test_dataset, hparams)

Using device: cuda
No checkpoint found at saved_models/synthesizer.pt, starting fresh.


RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`

In [None]:
import time

import torch
import torch.nn as nn
import torch.optim as optim

# Adjust learning rate function
def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
    lr = LEARNING_RATE * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
torch.cuda.device_count()

In [None]:
for i in range(torch.cuda.device_count()):
   print(torch.cuda.get_device_properties(i).name)

In [None]:
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

def split_train_dataset(dataset, num_splits=4):
    """Splits the dataset into `num_splits` equal parts."""
    indices = list(range(len(dataset)))
    split_size = len(dataset) // num_splits
    subsets = []
    
    for i in range(num_splits):
        start_idx = i * split_size
        end_idx = len(dataset) if i == num_splits - 1 else (i + 1) * split_size
        subsets.append(Subset(dataset, indices[start_idx:end_idx]))
    return subsets

In [None]:
train_splits = split_train_dataset(train_dataset, num_splits=1)

In [None]:
# model_save_path ='./saved_models/model_epoch_59_subset_0.pt'
# model, optimizer, start_epoch, global_step, current_subset, train_losses, eval_losses = load_checkpoint(
#     save_path=model_save_path, model=model, optimizer=optimizer, device=device
# )
# current_subset = 0
# start_epoch += 1
# if current_subset == 3:
#     current_subset = 0
#     start_epoch += 1
# else:
#     current_subset += 1

In [None]:
torch.cuda.is_available()

In [None]:
# Split dataset into subsets
train_splits = split_train_dataset(train_dataset, num_splits=1)
EPOCHS = 10000
LEARNING_RATE = 1e-4
# Default Model, Optimizer and params
device = "cuda:0"
model = nn.DataParallel(model.to(device))
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
global_step = 0
current_subset = 0
start_epoch = 0
train_losses = []
eval_losses = []
save_path = "./saved_models2"

In [None]:
train_and_evaluate_splits(
    model, optimizer, criterion, train_splits, eval_dataloader,
    device, num_epochs=EPOCHS, start_epoch=start_epoch, save_path=save_path,
    global_step=global_step, current_subset=current_subset,
    train_losses=train_losses, eval_losses=eval_losses
)