In [None]:
import argparse
import codecs
import os
import re
from importlib.resources import files
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path

from infer.utils_infer import (
    infer_process,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
)
from model import DiT, UNetT
import torch

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# -----------------------------------------

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32  # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None

# -----------------------------------------

In [None]:
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)

In [None]:
from model.utils import (
    get_tokenizer,
    convert_char_to_pinyin,
)

vocab_file = "./infer/examples/vocab.txt"
tokenizer = "custom"

vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
print(vocab_size)

In [None]:
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)

In [None]:
ode_method = "euler"

mel_spec_kwargs=dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

odeint_kwargs=dict(
    method=ode_method,
)

vocoder_name = "vocos"

In [None]:
from model import CFM

model = CFM(
    transformer=transformer,
    mel_spec_kwargs=mel_spec_kwargs,
    odeint_kwargs=odeint_kwargs,
    vocab_char_map=vocab_char_map,
).to(device)

In [None]:
dtype = (
    torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
)
ckpt_path = "/workspace/tts/ckpts/model_1200000.pt"
print(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
use_ema = True

if use_ema:
    checkpoint["model_state_dict"] = {
        k.replace("ema_model.", ""): v
        for k, v in checkpoint["ema_model_state_dict"].items()
        if k not in ["initted", "step"]
    }

    # patch for backward compatibility, 305e3ea
    for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
        if key in checkpoint["model_state_dict"]:
            del checkpoint["model_state_dict"][key]

    model.load_state_dict(checkpoint["model_state_dict"], strict=False)

del checkpoint
torch.cuda.empty_cache()

In [None]:
mel_spec_type = vocoder_name
if vocoder_name == "vocos":
    vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
    vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"

vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=True, local_path="/workspace/tts/src/f5_tts/vocoder")

In [None]:
import torchaudio

adu, sr = torchaudio.load(ref_audio)
print(adu.shape)

In [None]:
from audiotools import AudioSignal

ref_audio = "/workspace/dit_audio/valid_data/monster-saying-i-love-you-sound-effect-234404303_nw_prev.mp3"
ref_text = "I love you."

main_voice = {
    "ref_audio": ref_audio,
    "ref_text": ref_text,
}

voices = {
    "main": main_voice
}

for voice in voices:
    voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
        voices[voice]["ref_audio"], voices[voice]["ref_text"]
    )
    print("Voice:", voice)
    print("Ref_audio:", voices[voice]["ref_audio"])
    print("Ref_text:", voices[voice]["ref_text"])

AudioSignal(ref_audio).widget()

In [None]:
import re

gen_text = "Go kill them all. Get the fuck out of here!"

generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, gen_text)
reg2 = r"\[(\w+)\]"

audio, final_sample_rate, spectragram = infer_process(
    voices[voice]["ref_audio"],
    voices[voice]["ref_text"], 
    gen_text, 
    model, 
    vocoder, 
    mel_spec_type=mel_spec_type, 
    speed=speed
)

In [None]:
import re

gen_text = "Go kill them all. Get the fuck out of here!"

generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, gen_text)
reg2 = r"\[(\w+)\]"

audio, final_sample_rate, spectragram = infer_process(
    None,
    " ",
    gen_text, 
    model, 
    vocoder, 
    mel_spec_type=mel_spec_type, 
    speed=speed
)

In [None]:
from audiotools import AudioSignal

AudioSignal(audio, sample_rate=24000).widget()