In [26]:
import os
import re

import torch
import torchaudio
from einops import rearrange
from vocos import Vocos
from datetime import datetime
from IPython.display import Audio, display

from model import CFM, UNetT, DiT, MMDiT
from model.utils import (
    load_checkpoint,
    get_tokenizer,
    convert_char_to_pinyin,
    save_spectrogram,
)


In [27]:
# Check for device availability
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


In [28]:
# Dataset Settings
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1

tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"


In [29]:
# Inference Settings
seed = None  # int | None

nfe_step = 32  # 16, 32
cfg_strength = 2.0
ode_method = 'euler'  # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
# fix_duration = 27  # fixed duration or linear estimate based on text
fix_duration = None
use_ema = True


In [30]:
# Output Settings
output_dir = "tests"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)


In [31]:
exp_name = "F5TTS_Base"  # F5TTS_Base | E2TTS_Base

# Define model type based on experiment
if exp_name == "F5TTS_Base":
    model_cls = DiT
    model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)

elif exp_name == "E2TTS_Base":
    model_cls = UNetT
    model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)

ckpt_step = 1200000
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"


In [32]:
# Vocoder model
local = False
if local:
    vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
    vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
    state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
    vocos.load_state_dict(state_dict)
    vocos.eval()
else:
    vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")


In [33]:
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)

In [34]:
# Model
model = CFM(
    transformer = model_cls(
        **model_cfg,
        text_num_embeds = vocab_size,
        mel_dim = n_mel_channels
    ),
    mel_spec_kwargs = dict(
        target_sample_rate = target_sample_rate,
        n_mel_channels = n_mel_channels,
        hop_length = hop_length,
    ),
    odeint_kwargs = dict(
        method = ode_method,
    ),
    vocab_char_map = vocab_char_map,
).to(device)

model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)


In [45]:
# Reference audio
ref_audio = "tests/ref_audio/ms_emily.wav"
ref_text = "The demon king, who was considered the pinnacle of all evil, lived in an out-of-place castle in the desolate land at the heart of the Ninth Hell. He had managed to unify Hell, which had been divided into seven factions."

# Audio
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
    audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms
if sr != target_sample_rate:
    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
    audio = resampler(audio)
audio = audio.to(device)

display(Audio(audio.cpu().numpy(), rate=target_sample_rate))


In [50]:
# Text Prompt
gen_text = "Let us hunt those who have fallen to darkness."

text_list = [ref_text +" "+ gen_text]
if tokenizer == "pinyin":
    final_text_list = convert_char_to_pinyin(text_list)
else:
    final_text_list = [text_list]

print(f"text  : {text_list}")
print(f"pinyin: {final_text_list}")


text  : ['The demon king, who was considered the pinnacle of all evil, lived in an out-of-place castle in the desolate land at the heart of the Ninth Hell. He had managed to unify Hell, which had been divided into seven factions. Let us hunt those who have fallen to darkness.']
pinyin: [['T', 'h', 'e', ' ', 'd', 'e', 'm', 'o', 'n', ' ', 'k', 'i', 'n', 'g', ',', ' ', 'w', 'h', 'o', ' ', 'w', 'a', 's', ' ', 'c', 'o', 'n', 's', 'i', 'd', 'e', 'r', 'e', 'd', ' ', 't', 'h', 'e', ' ', 'p', 'i', 'n', 'n', 'a', 'c', 'l', 'e', ' ', 'o', 'f', ' ', 'a', 'l', 'l', ' ', 'e', 'v', 'i', 'l', ',', ' ', 'l', 'i', 'v', 'e', 'd', ' ', 'i', 'n', ' ', 'a', 'n', ' ', 'o', 'u', 't', '-', ' ', 'o', 'f', '-', ' ', 'p', 'l', 'a', 'c', 'e', ' ', 'c', 'a', 's', 't', 'l', 'e', ' ', 'i', 'n', ' ', 't', 'h', 'e', ' ', 'd', 'e', 's', 'o', 'l', 'a', 't', 'e', ' ', 'l', 'a', 'n', 'd', ' ', 'a', 't', ' ', 't', 'h', 'e', ' ', 'h', 'e', 'a', 'r', 't', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 'N', 'i', 'n', 't', 'h', ' ', '

In [51]:
# Duration Calculation
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
    duration = int(fix_duration * target_sample_rate / hop_length)
else:  # simple linear scale calcul
    zh_pause_punc = r"。，、；：？！"
    ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
    gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
    duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)


In [52]:
# Audio Generation
with torch.inference_mode():
    generated, trajectory = model.sample(
        cond = audio,
        text = final_text_list,
        duration = duration,
        steps = nfe_step,
        cfg_strength = cfg_strength,
        sway_sampling_coef = sway_sampling_coef,
        seed = seed,
    )

# Final result
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
    generated_wave = generated_wave * rms / target_rms

display(Audio(generated_wave.cpu().numpy(), rate=target_sample_rate))


In [53]:
# Save results
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
file_name=f"test_single_{timestamp}.wav"
torchaudio.save(f"{output_dir}/{file_name}", generated_wave, target_sample_rate)

print(f"Saved as: {file_name}")


Saved as: test_single_2024-10-14_02-26-51.wav
