# StyleTTS 2 Demo (ClArTTS)


### Utils

In [1]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import random
random.seed(0)

import numpy as np
np.random.seed(0)

In [2]:
# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize

from models import *
from utils import *
from text_utils import TextCleaner
text_cleaner = TextCleaner()

49


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)

        with torch.no_grad():
            ref = model.style_encoder(mel_tensor.unsqueeze(1))
        reference_embeddings[key] = (ref.squeeze(1), audio)
    
    return reference_embeddings

### Load models

In [None]:
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='ar', preserve_punctuation=True,  with_stress=True)

In [8]:
def load_F0_models(path):
    F0_model = JDCNet(num_class=1, seq_len=192)
    all_params = torch.load(path, map_location='cpu', weights_only=False)
    if path.endswith('.t7'):
        params = all_params['net']
    else:
        params = all_params['model']
    F0_model.load_state_dict(params)
    _ = F0_model.train()
    
    return F0_model

In [9]:
config = yaml.safe_load(open("../models/ClArTTS/config_ClArTTS.yml"))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)



In [10]:
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

  WeightNorm.apply(module, name, dim)




In [11]:
params_whole = torch.load("../models/ClArTTS_ASR/epoch_2nd_00095.pth", map_location='cpu')
params = params_whole['net']

In [12]:
params.keys()

dict_keys(['bert', 'bert_encoder', 'predictor', 'decoder', 'text_encoder', 'predictor_encoder', 'style_encoder', 'diffusion', 'text_aligner', 'pitch_extractor', 'mpd', 'msd', 'wd'])

In [21]:
params['text_aligner'].keys()

odict_keys(['module.to_mfcc.dct_mat', 'module.init_cnn.conv.weight', 'module.init_cnn.conv.bias', 'module.cnns.0.0.blocks.0.0.conv.weight', 'module.cnns.0.0.blocks.0.0.conv.bias', 'module.cnns.0.0.blocks.0.2.weight', 'module.cnns.0.0.blocks.0.2.bias', 'module.cnns.0.0.blocks.0.4.conv.weight', 'module.cnns.0.0.blocks.0.4.conv.bias', 'module.cnns.0.0.blocks.1.0.conv.weight', 'module.cnns.0.0.blocks.1.0.conv.bias', 'module.cnns.0.0.blocks.1.2.weight', 'module.cnns.0.0.blocks.1.2.bias', 'module.cnns.0.0.blocks.1.4.conv.weight', 'module.cnns.0.0.blocks.1.4.conv.bias', 'module.cnns.0.0.blocks.2.0.conv.weight', 'module.cnns.0.0.blocks.2.0.conv.bias', 'module.cnns.0.0.blocks.2.2.weight', 'module.cnns.0.0.blocks.2.2.bias', 'module.cnns.0.0.blocks.2.4.conv.weight', 'module.cnns.0.0.blocks.2.4.conv.bias', 'module.cnns.0.1.weight', 'module.cnns.0.1.bias', 'module.cnns.1.0.blocks.0.0.conv.weight', 'module.cnns.1.0.blocks.0.0.conv.bias', 'module.cnns.1.0.blocks.0.2.weight', 'module.cnns.1.0.blocks.0

In [13]:
for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
_ = [model[key].eval() for key in model]

In [None]:
_ = [model[key].eval() for key in model]

In [15]:
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

### Synthesize speech

In [22]:
import subprocess

def phonemize_espeak(text, lang="ar"):
    cmd = ["/home/mabrouk/espeak-ng/src/espeak-ng", "-v", lang, "--ipa", text]
    return subprocess.check_output(cmd, text=True)

In [None]:
def inference(text, noise, diffusion_steps=5, embedding_scale=1):
    text = text.strip()
    text = text.replace('"', '')
    ps = global_phonemizer.phonemize([text])
    ps = word_tokenize(ps[0])
    ps = ' '.join(ps)

    tokens = text_cleaner(ps)
    # tokens = remap_arabic_to_english_indices(tokens)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    
    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
        text_mask = length_to_mask(input_lengths).to(tokens.device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2) 

        s_pred = sampler(noise, 
              embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
              embedding_scale=embedding_scale).squeeze(0)

        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_dur[-1] += 5

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))
        
    return out.squeeze().cpu().numpy()

In [None]:
# Reset everything
torch.manual_seed(99)
np.random.seed(99)
random.seed(99)

# Test 1: Load normal weights
params_whole = torch.load("../models/ClArTTS_ASR/epoch_2nd_00095.pth", map_location='cpu')
params = params_whole['net']

for key in model:
    if key in params:
        model[key].load_state_dict(params[key], strict=False)

torch.manual_seed(99)
noise = torch.randn(1,1,256).to(device)
audio_normal = inference(text, noise, diffusion_steps=3, embedding_scale=1)

torch.manual_seed(99)
params_whole = torch.load("../models/ClArTTS_ASR/epoch_2nd_00095.pth", map_location='cpu')
params = params_whole['net']

tmp = 0.9
for name, tensor in params['text_aligner'].items():
    if any(k in name for k in ['running_mean', 'running_var', 'num_batches_tracked']):
        continue
    rand = torch.empty_like(tensor)
    if tensor.ndim >= 2:
        nn.init.xavier_uniform_(rand)
    else:
        rand.normal_(mean=0.0, std=1e-3)
    tensor.mul_(1 - tmp).add_(tmp * rand)

for key in model:
    if key in params:
        model[key].load_state_dict(params[key], strict=False)

torch.manual_seed(99)
noise = torch.randn(1,1,256).to(device)
audio_random = inference(text, noise, diffusion_steps=3, embedding_scale=1)

print("Are they Differet:", np.allclose(audio_normal, audio_random, atol=1e-5))
print(f"Normal audio length: {len(audio_normal)}")
print(f"Random audio length: {len(audio_random)}")

[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] x is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError]

In [None]:
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
_ = [model[key].to(device) for key in model]

params_whole = torch.load("../models/ClArTTS_ASR/epoch_2nd_00095.pth", map_location='cpu')
params = params_whole['net']

tmp = 0.9
for name, tensor in params['text_aligner'].items():
    if any(k in name for k in ['running_mean', 'running_var', 'num_batches_tracked']):
        continue
    rand = torch.empty_like(tensor)
    if tensor.ndim >= 2:
        nn.init.xavier_uniform_(rand)
    else:
        rand.normal_(mean=0.0, std=1e-3)
    tensor.mul_(1 - tmp).add_(tmp * rand)

for key in model:
    if key in params:
        model[key].load_state_dict(params[key], strict=False)

noise = torch.randn(1,1,256).to(device)
audio_noisy = inference(text, noise, diffusion_steps=3, embedding_scale=1)

model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
_ = [model[key].to(device) for key in model]

params_whole = torch.load("../models/ClArTTS_ASR/epoch_2nd_00095.pth", map_location='cpu')
params = params_whole['net']

for key in model:
    if key in params:
        model[key].load_state_dict(params[key], strict=False)

noise = torch.randn(1,1,256).to(device)
audio_normal = inference(text, noise, diffusion_steps=3, embedding_scale=1)

print(f"Normal length: {len(audio_normal)}, Noisy length: {len(audio_noisy)}")
print(f"Are they Different: {not np.allclose(audio_normal, audio_noisy, atol=0.1)}")

  WeightNorm.apply(module, name, dim)




[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] x is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError]

ValueError: operands could not be broadcast together with shapes (2930400,) (2889600,) 

In [17]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/mabrouk/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

#### Basic synthesis (5 diffusion steps)

In [18]:
text = '''
مرحبًا بكم في اختبار نموذج تحويل النص إلى كلام! 
اليوم هو 3 ديسمبر 2025، والساعة الآن 1:30 ظهرًا. 
شكرًا لكم، وإلى اللقاء.
'''

In [131]:
from camel_tools.utils.charmap import CharMapper
from camel_tools.utils.transliterate import Transliterator

ar2bw = CharMapper.builtin_mapper('ar2safebw')
ar2bw_translit = Transliterator(ar2bw)

sentence = '''
مرحبًا بكم في اختبار نموذج تحويل النص إلى كلام! 
اليوم هو 3 ديسمبر 2025، والساعة الآن 1:30 ظهرًا. 
شكرًا لكم، وإلى اللقاء.
'''

sentence_bw = ar2bw_translit.transliterate(sentence)
print("\n" + "=" * 60)
print("SAMPLE ARABIC SENTENCE IN BUCKWALTER TRANSCRIPTION:")
print("=" * 60)
print(f"Original: {sentence}")
print(f"Buckwalter: {sentence_bw}")


SAMPLE ARABIC SENTENCE IN BUCKWALTER TRANSCRIPTION:
Original: 
مرحبًا بكم في اختبار نموذج تحويل النص إلى كلام! 
اليوم هو 3 ديسمبر 2025، والساعة الآن 1:30 ظهرًا. 
شكرًا لكم، وإلى اللقاء.

Buckwalter: 
mrHbFA bkm fy AxtbAr nmwVj tHwyl AlnS IlY klAm! 
Alywm hw 3 dysmbr 2025، wAlsAEp AlMn 1:30 ZhrFA. 
ckrFA lkm، wIlY AllqAC.



In [19]:
start = time.time()
noise = torch.randn(1,1,256).to(device)
wav = inference(text, noise, diffusion_steps=3, embedding_scale=1)
rtf = (time.time() - start) / (len(wav) / 24000)
print(f"RTF = {rtf:5f}")

[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError]

  return torch.abs(forward_transform), torch.angle(forward_transform)


RTF = 0.137226


In [20]:
import IPython.display as ipd
display(ipd.Audio(wav, rate=24000))

In [21]:
# save
torchaudio.save("Output/ClArTTS_random_pitch_extractor_weights.wav", torch.from_numpy(wav).unsqueeze(0), 24000)

In [21]:
sentences = [
    "ولد محمد بن سلمان في 31 أغسطس 1985 في مدينة جدة بالمملكة العربية السعودية.",
    "يبلغ عدد سكان دول مجلس التعاون الخليجي حوالي 57845000 نسمة حسب إحصائيات 2023.",
    "اجتمع الشيخ محمد بن زايد مع الرئيس الفرنسي إيمانويل ماكرون يوم الثلاثاء 18 سبتمبر في أبوظبي.",
    "اخترع ألبرت أينشتاين نظرية النسبية عام 1905، وحصل على جائزة نوبل عام 1921.",
    "درجة الحرارة اليوم الأحد 7 ديسمبر ستصل إلى 35.5 درجة مئوية مع رطوبة 78%.",
    "فاز منتخب الأرجنتين بكأس العالم 2022 بعد فوزه على فرنسا بركلات الترجيح 4-2.",
    "سعر برميل النفط اليوم هو 89.25 دولاراً أمريكياً، بينما الذهب يسجل 2034.50 دولاراً للأونصة.",
    "قال الشاعر أحمد شوقي: 'قُم للمعلم وفِّه التبجيلا، كاد المعلم أن يكون رسولا'.",
    "موعد صلاة الفجر يوم الخميس 25 رمضان هو الساعة 4:32 صباحاً والمغرب 6:18 مساءً.",
    "شركة أرامكو السعودية سجلت أرباحاً قدرها 161.1 مليار دولار في 2022."
    "الاستثمارات الأجنبية المباشرة والتحول الرقمي والطاقة المتجددة تمثل أولويات استراتيجية لرؤية 2030.",
    "صرح معالي الوزير د. عبدالله العثيم بتاريخ 15/04/2024 أن الناتج المحلي الإجمالي نما بنسبة 8.7%.",
    "للتواصل اتصل على +971501234567 أو أرسل بريداً إلى contact@gcc-company.ae قبل 28/02/2025.",
    "أسعار العملات: الريال السعودي 3.75 مقابل الدولار، الدرهم الإماراتي 3.67، الدينار الكويتي 0.31.",
    "قال صاحب السمو الشيخ محمد بن راشد آل مكتوم: 'المستقبل ملك لمن يستطيع تخيله وتصميمه وتنفيذه'.",
    "درجات الحرارة المتوقعة: الرياض 42°، دبي 39°، الدوحة 41°، المنامة 38°، مسقط 37°."
]

In [22]:
from tqdm.notebook import tqdm

In [23]:
outpath = "Output/ClArTTS_base/"
os.makedirs(outpath, exist_ok=True)

In [24]:
for i, text in tqdm(enumerate(sentences)):
    start = time.time()
    noise = torch.randn(1,1,256).to(device)
    wav = inference(text, noise, diffusion_steps=3, embedding_scale=1)
    rtf = (time.time() - start) / (len(wav) / 24000)
    print(f"Sentence {i+1}/{len(sentences)} RTF = {rtf:5f}")

    import soundfile as sf
    sf.write(os.path.join(outpath, f"{i}.wav"), wav, 24000)

0it [00:00, ?it/s]

[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError]

#### With higher diffusion steps (more diverse)
Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed.

In [28]:
start = time.time()
noise = torch.randn(1,1,256).to(device)
wav = inference(text, noise, diffusion_steps=10, embedding_scale=2)
rtf = (time.time() - start) / (len(wav) / 24000)
print(f"RTF = {rtf:5f}")
import IPython.display as ipd
display(ipd.Audio(wav, rate=24000))

[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError]

RTF = 0.037657


### Speech expressiveness
The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page.

#### With embedding_scale=1
This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional. 

In [33]:
texts = {}
# convert each sentence to arabic
# texts['Happy'] = "We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands."
texts['Happy'] = "نحن سعداء بدعوتكم للانضمام إلينا في رحلة إلى الماضي، حيث سنزور أعظم المعالم التي بناها البشر على الإطلاق."
# texts['Sad'] = "I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence."
texts['Sad'] = "أنا آسف أن أقول إننا عانينا من نكسة شديدة في جهودنا لاستعادة الازدهار والثقة."
# texts['Angry'] = "The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!"
texts['Angry'] = "مجال علم الفلك هو مزحة! نظرياته تستند إلى ملاحظات معيبة وتفسيرات متحيزة!"
# texts['Surprised'] = "I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?"
texts['Surprised'] = "لا أستطيع أن أصدق ذلك! هل تعني أن تخبرني أنك اكتشفت نوعًا جديدًا من البكتيريا في هذه البركة؟"

for k,v in texts.items():
    noise = torch.randn(1,1,256).to(device)
    wav = inference(v, noise, diffusion_steps=10, embedding_scale=1)
    print(k + ": ")
    display(ipd.Audio(wav, rate=24000, normalize=False))

[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
Happy: 


[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
Sad: 


[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
Angry: 


[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
[KeyError] ˈ is not found in the dictionary.
[KeyError] ˌ is not found in the dictionary.
Surprised:

#### With embedding_scale=2

In [None]:
texts = {}
texts['Happy'] = "We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands."
texts['Sad'] = "I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence."
texts['Angry'] = "The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!"
texts['Surprised'] = "I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?"

for k,v in texts.items():
    noise = torch.randn(1,1,256).to(device)
    wav = inference(v, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion
    print(k + ": ")
    display(ipd.Audio(wav, rate=24000, normalize=False))

### Long-form generation
This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page. 

In [None]:
passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word "Homemade," when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first-class homemade products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''

In [None]:
def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
    text = text.strip()
    text = text.replace('"', '')
    ps = global_phonemizer.phonemize([text])
    ps = word_tokenize(ps[0])
    ps = ' '.join(ps)

    tokens = text_cleaner(ps)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    
    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
        text_mask = length_to_mask(input_lengths).to(tokens.device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2) 

        s_pred = sampler(noise, 
              embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
              embedding_scale=embedding_scale).squeeze(0)
        
        if s_prev is not None:
            # convex combination of previous and current style
            s_pred = alpha * s_prev + (1 - alpha) * s_pred
        
        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))
        
    return out.squeeze().cpu().numpy(), s_pred

In [None]:
sentences = passage.split('.') # simple split by comma
wavs = []
s_prev = None
for text in sentences:
    if text.strip() == "": continue
    text += '.' # add it back
    noise = torch.randn(1,1,256).to(device)
    wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)
    wavs.append(wav)
display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))

## TMP