In [1]:
import argparse
import codecs
import os
import re
from importlib.resources import files
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path

from model import DiT, CFM
import torch
from f5_tts.model.utils import (
    get_tokenizer,
    convert_char_to_pinyin,
)
import torchaudio
from f5_tts.model.modules import MelSpec

from infer.utils_infer import (
    infer_process,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
)

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")

cfg_strength = 3.0
scale_phi = 0.75

target_sample_rate = 24000
mel_spec_type  = "vocos"
vocoder_name   = "vocos"
n_mel_channels = 100
win_length     = 1024
hop_length     = int(win_length//4)
n_fft          = 1024
target_rms     = 0.1
ode_method     = "euler"
nfe_step       = 32  # 16, 32
speed          = 1.0
fix_duration   = None
vocab_file     = "./infer/examples/vocab.txt"
tokenizer      = "custom"
ode_method     = "euler"
cross_fade_duration = 0.15
sway_sampling_coef = -1.0

In [3]:
# load model
model_cfg = dict(
    dim=1024, 
    depth=22, 
    heads=16, 
    ff_mult=2, 
    text_dim=512, 
    conv_layers=4
)

vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
vocoder     = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path="./vocoder")
transformer = DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)

Load vocos from local path ./vocoder


In [4]:
mel_spec_kwargs=dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

odeint_kwargs=dict(
    method=ode_method,
)

mel_spec = MelSpec(**mel_spec_kwargs)

model = CFM(
    transformer=transformer,
    mel_spec_kwargs=mel_spec_kwargs,
    odeint_kwargs=odeint_kwargs,
    vocab_char_map=vocab_char_map,
).to(device)
model.eval()
print("-")

-


In [5]:
dtype = torch.float32
ckpt_path = '/workspace/f5tts_clone_qwen_filter_7.pt'
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

# 키 이름에서 "emewe_model." prefix 제거
if "ema_model_state_dict" in checkpoint:
    new_state_dict = {
        k.replace("ema_model.", ""): v
        for k, v in checkpoint['ema_model_state_dict'].items()
        if k.startswith("ema_model.")
    }

model.load_state_dict(checkpoint, strict=False)
del checkpoint
torch.cuda.empty_cache()

In [6]:
import time
import librosa
import torchaudio

resampler = torchaudio.transforms.Resample(
    orig_freq=44100,
    new_freq=24000
).to(device)

tt = 0

for _ in range(100):
    st = time.time()
    reference_audio_path = "/workspace/tts/src/zombie-or-monster-says-i-sound-effect-079567563_nw_prev.mp3"
    audio, sr = torchaudio.load(reference_audio_path)
    audio = resampler(audio.to(device))
    tt += time.time() - st
print(tt)

2.1242799758911133


In [7]:
# print(reference_audio.ndim)

# if reference_audio.ndim == 2:
#     reference_audio_mel = model.mel_spec(reference_audio)
#     reference_audio_mel = reference_audio_mel.permute(0, 2, 1)
#     assert reference_audio_mel.shape[-1] == model.num_channels
# reference_audio_mel = reference_audio_mel.to(next(model.parameters()).dtype)
# print(reference_audio_mel.shape)

# ref_text_len = len(reference_script.encode("utf-8"))
# gen_text_len = len("I love her not bad.".encode("utf-8"))
# duration = reference_audio_len + int(reference_audio_len / ref_text_len * gen_text_len)
# print(duration)

In [8]:
from f5_tts.model.utils import (
    list_str_to_idx,
    list_str_to_tensor,
)

text = ["I thought I saw you in my dream."]
final_text_list = convert_char_to_pinyin(text)

list_str_to_idx(text, model.vocab_char_map)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.798 seconds.
Prefix dict has been built successfully.


tensor([[  40,    0, 1082,  441,  827, 1147,  369,  441, 1082,    0,   40,    0,
          973,   62, 1149,    0, 1236,  827, 1147,    0,  507,  765,    0,  704,
         1236,    0,  250,  940,  325,   62,  704,   14]])

In [9]:
import time
import librosa
import torchaudio

model.eval()

st = time.time()

resampler = torchaudio.transforms.Resample(
    orig_freq=44100,
    new_freq=24000
).to(device)
reference_audio_path = "/workspace/tts/src/zombie-or-monster-says-i-sound-effect-079567563_nw_prev.mp3"
reference_script = "I hate you so much."
audio, sr = torchaudio.load(reference_audio_path)
reference_audio = resampler(audio.to(device))[0].unsqueeze(dim=0)
print(reference_audio.shape)

reference_audio_len = reference_audio.shape[-1]/24000
print("reference_audio_len : ", reference_audio_len, "s")

script = "I thought I saw you in my dream."
target_second = 3.0
target_duration = int(target_second * 100)
# final_text_list = convert_char_to_pinyin(text)

with torch.no_grad():
    generated, sec = model.sample(
        reference_audio=reference_audio,
        text=script,
        duration=target_duration,
        steps=16,
        cfg_strength=0.0,
        sway_sampling_coef=sway_sampling_coef,
        
        duplicate_test=False,
        t_inter=None,
        no_ref_audio=False,
        batch_size=1
    )
print(time.time() - st)

torch.Size([1, 74676])
reference_audio_len :  3.1115 s
1 tensor([292], device='cuda:0')


TypeError: maximum(): argument 'input' (position 1) must be Tensor, not int

In [None]:
from audiotools import AudioSignal

generated = generated.to(torch.float32)
print(generated.shape)
generated_mel_spectrogram = generated.permute(0, 2, 1)
generated_wave_form = vocoder.decode(generated_mel_spectrogram)
print(generated_wave_form.shape)

AudioSignal(generated_wave_form.cpu().numpy()[0], sample_rate=24000).widget()

In [None]:
# 전체 파라미터 수
total_params = sum(p.numel() for p in model.parameters())

# 학습 가능한(=requires_grad=True) 파라미터 수
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
import librosa
import time
from tqdm import tqdm

ref_audio = "/workspace/tts/valid_data/monster-saying-i-love-you-sound-effect-234404303_nw_prev.mp3"
ref_text = "I love you."

audio, sr = librosa.load(ref_audio, sr=24000, mono=True)
audio = torch.tensor(audio).unsqueeze(dim=0)
print(audio.shape) # 2.18초 = 약 202 frames

text = [ref_text+"I thought I saw you in my dream. model will be optimized by back propagation with tons of data.", ref_text+"Hey how are you?"]
duration = 920
final_text_list = convert_char_to_pinyin(text)

total = 0
for _ in range(10):
    start = time.time()
    generated, sec = model.sample(
        cond           = audio.to(device),
        text           = final_text_list,
        duration       = duration,
        steps          = 32,
        cfg_strength   = cfg_strength,
        sway_sampling_coef=-1,
        
        duplicate_test = False,
        t_inter        = None,
        no_ref_audio   = False,
        batch_size     = 1
    )
    
    generated = generated.to(torch.float32)
    generated_mel_spectrogram = generated.permute(0, 2, 1)
    generated_wave_form = vocoder.decode(generated_mel_spectrogram)
    # print("taken : ", time.time() - start, " voice duration : ", duration, generated_wave_form.shape[-1]/24000)
    total += time.time() - start
print('10번 반복 ', total)

In [None]:
from audiotools import AudioSignal

AudioSignal(audio, sample_rate=24000).widget()
AudioSignal(generated_wave_form[0].cpu().numpy(), sample_rate=24000).widget()
AudioSignal(generated_wave_form[1].cpu().numpy(), sample_rate=24000).widget()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# Compute t values using numpy to avoid heavy torch import
sway_sampling_coef = -1
t = np.linspace(0, 1, 32)
t_transformed = t + sway_sampling_coef * (np.cos(np.pi / 2 * t) - 1 + t)

# Plot the result
plt.plot(t_transformed)
plt.xlabel("Step Index")
plt.ylabel("Transformed t Value")
plt.title("Sway Sampling Transformed t")
plt.show()