In [None]:
import argparse
import codecs
import os
import re
from importlib.resources import files
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path

from f5_tts.infer.utils_infer import (
    infer_process,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
)
from f5_tts.model import DiT, UNetT, CFM
from f5_tts.train.utils import draw_plot
from f5_tts.train.validation import validate
import torch
from f5_tts.model.utils import (
    get_tokenizer,
    convert_char_to_pinyin,
    list_str_to_idx,
    lens_to_mask,
    mask_from_frac_lengths
)
from audiotools import AudioSignal
from transformers import T5EncoderModel, AutoTokenizer
from torch.cuda.amp import autocast
from accelerate import Accelerator, DistributedDataParallelKwargs
from f5_tts.model.cfm import T5Conditioner

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# -----------------------------------------

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32  # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
preset = "/workspace/tts_sfx"

# -----------------------------------------

t5_model_name = "t5-base"
text_conditioner = T5Conditioner(t5_model_name="t5-base", max_length=128).to(device)

accelerator = Accelerator()

In [None]:
model_cls = DiT
model_cfg = dict(
    dim=1024, 
    depth=22, 
    heads=16, 
    ff_mult=2, 
    text_dim=512, 
    conv_layers=4
)

vocab_file = "./f5_tts/infer/examples/vocab.txt"
tokenizer = "custom"
vocoder_name = "vocos"
ode_method = "euler"

vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
print(vocab_size)
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=f"{preset}/src/f5_tts/vocoder")

transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)

mel_spec_kwargs=dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

odeint_kwargs=dict(
    method=ode_method,
)

model = CFM(
    transformer=transformer,
    mel_spec_kwargs=mel_spec_kwargs,
    odeint_kwargs=odeint_kwargs,
    vocab_char_map=vocab_char_map,
    frac_lengths_mask=(0.7, 1.0),
    audio_drop_prob=0.3,
    cond_drop_prob=0.2,
    caption_drop_prob=0.2
).to(device)

In [None]:
# load pre-trained weights
dtype = (
    torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
)
ckpt_path = f"{preset}/ckpts/model_1200000.pt"
print(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

checkpoint["model_state_dict"] = {
    k.replace("ema_model.", ""): v
    for k, v in checkpoint["ema_model_state_dict"].items()
    if k not in ["initted", "step"]
}

# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
    if key in checkpoint["model_state_dict"]:
        del checkpoint["model_state_dict"][key]

model.load_state_dict(checkpoint["model_state_dict"], strict=False)

del checkpoint
torch.cuda.empty_cache()

In [None]:
batch_size   = 8
lr           = 0.00001
weight_decay = 0.001
betas        = (0.9, 0.999)
sample_rate  = 24000
train_duration  = 30.0

num_workers = 8
num_epochs = 100

output_dir   = 'weights_1125'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [None]:
from f5_tts.custom.dataset import CustomDataset, collate_fn
from torch.utils.data import Dataset, DataLoader, RandomSampler

train_dataset = CustomDataset(
    f"{preset}/datas/train_expresso.csv",
    target_sample_rate = target_sample_rate,
    mode           = "train",
    hop_length     = hop_length,
    n_mel_channels = n_mel_channels,
    win_length     = win_length,
    n_fft          = n_fft,
    mel_spec_type  = "vocos",
    preprocessed_mel = False,
    mel_spec_module = None
)
valid_dataset = CustomDataset(
    f"{preset}/datas/valid_expresso.csv",
    target_sample_rate = target_sample_rate,
    mode           = "validation",
    hop_length     = hop_length,
    n_mel_channels = n_mel_channels,
    win_length     = win_length,
    n_fft          = n_fft,
    mel_spec_type  = "vocos",
    preprocessed_mel = False,
    mel_spec_module = None
)

# Define Train Sampler
steps_per_epoch = 3001
num_samples_per_epoch = steps_per_epoch * batch_size  # 총 샘플 수 = 스텝 수 * 배치 크기 # train 320,000 samples per epoch
train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=num_samples_per_epoch)  # Train Sampler with replacement

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    sampler=train_sampler,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=2,
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    shuffle=True,  # Shuffle validation data without sampler
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=2,
)

libri_valid_dataset = CustomDataset(
    f"{preset}/datas/libri_test_clean.csv",
    target_sample_rate = target_sample_rate,
    mode           = "validation",
    hop_length     = hop_length,
    n_mel_channels = n_mel_channels,
    win_length     = win_length,
    n_fft          = n_fft,
    mel_spec_type  = "vocos",
    preprocessed_mel = False,
    mel_spec_module  = None
)
libri_valid_loader = DataLoader(libri_valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True, prefetch_factor=2)

In [None]:
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

trainer = {
    'train_losses': [],
    'valid_losses': [],
    'libri_valid_losses': [],
    'wer': [],
    'cer': [],
    'lrs': [],
}

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps=21, # 20 epoch마다 1/10이 된다.
    cycle_mult=1.0,
    max_lr=lr,
    min_lr=lr/10,
    warmup_steps=20,
    gamma=0.8 # 한 사이클 돌 때마다 max lr이 80%가 된다.
)

noise_scheduler = None # 지금 세팅 안되어있음
max_grad_norm = 1.0

# model, train_loader, text_conditioner, valid_loader, libri_valid_loader = accelerator.prepare(model, train_loader, text_conditioner, valid_loader, libri_valid_loader)

In [None]:
from tqdm import tqdm
import random

print("start training")

for epoch in range(num_epochs):
    model.train()
    text_conditioner.eval()
    epoch_loss = 0
    tqdm_bar = tqdm(total=len(train_loader), desc="F5-TTS Training")
    
    for idx, batch in enumerate(train_loader):
        # with accelerator.accumulate(model):
        mel = batch["mel"].to(device)
        mel_lengths = batch["mel_lengths"].to(device)
        scripts = batch["script"]
        caption = batch["caption"]
        mel_spec = mel.permute(0, 2, 1)
        
        with torch.no_grad():
            # if random.random()<0.2:
            #     caption_embed, attention_mask = None, None
            # else:
            caption_embed, attention_mask = text_conditioner(caption, device=device)

        with autocast():
            loss = model(
                mel_spec, text=scripts, lens=mel_lengths, noise_scheduler=noise_scheduler, caption_embed=caption_embed, attention_mask=attention_mask
            )
        loss.backward()
        # accelerator.backward(loss)
        
        # if max_grad_norm > 0 and accelerator.sync_gradients:
        #     accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        tqdm_bar.update()
        # epoch_loss += loss
        epoch_loss += loss.cpu().detach().item()
        del loss
        del caption_embed
        del attention_mask
        
        if idx%1000 == 999:
            scheduler.step()
            trainer['train_losses'].append(epoch_loss/idx)
            trainer['lrs'].append(optimizer.param_groups[0]['lr'])
            
            draw_plot('train_losses', trainer, output_dir=output_dir)
            draw_plot('lrs', trainer, output_dir=output_dir)
            # 텍스트 파일에 쓰기
            with open(f'./{output_dir}/middle_logs.txt', 'a') as file:
                file.write(f"\nEpoch - {epoch} : {epoch_loss/idx}\n")
            torch.cuda.empty_cache()

    print(epoch_loss/idx)
    with open(f'./{output_dir}/middle_logs.txt', 'a') as file:
        file.write(f"\nEpoch - {epoch} : {epoch_loss/idx}\n")
    scheduler.step()
    trainer['train_losses'].append(epoch_loss/idx)
    trainer['lrs'].append(optimizer.param_groups[0]['lr'])
    
    draw_plot('train_losses', trainer, output_dir=output_dir)
    draw_plot('lrs', trainer, output_dir=output_dir)

    model.eval()
    unwrapped_model = model
    # unwrapped_model = accelerator.unwrap_model(model)
    torch.save(unwrapped_model.state_dict(), f'./{output_dir}/model_{epoch}.pt')

    valid_expresso_loss, wer, cer = validate(
        unwrapped_model,
        valid_loader,
        vocoder,
        text_conditioner,
        output_dir=output_dir,
        epoch=epoch,
        noise_scheduler=None,
        is_caption=True,
        is_make_samples=True,
        is_whisper=True,
        device='cuda'
    )
    valid_libri_loss, _, _ = validate(
        unwrapped_model,
        libri_valid_loader,
        vocoder,
        text_conditioner,
        output_dir=output_dir,
        epoch=epoch,
        noise_scheduler=None,
        is_caption=False,
        is_make_samples=False,
        is_whisper=False,
        device='cuda'
    )
    
    trainer['valid_losses'].append(valid_expresso_loss)
    trainer['libri_valid_losses'].append(valid_libri_loss)
    
    draw_plot('valid_losses', trainer, output_dir=output_dir)
    draw_plot('libri_valid_losses', trainer, output_dir=output_dir)
    
    if is_make_samples:
        trainer['clap_score'].append(clap_score)
        trainer['wer'].append(wer)
        trainer['cer'].append(cer)
        
        draw_plot('clap_score', trainer, output_dir=output_dir)
        draw_plot('wer', trainer, output_dir=output_dir)
        draw_plot('cer', trainer, output_dir=output_dir)
    
    # 텍스트 파일에 쓰기
    with open(f'./{output_dir}/logs.txt', 'a') as file:
        file.write(f"\nEpoch - {epoch}\n")
        file.write(f"Train loss : {epoch_loss/len(train_loader)}\n")
        file.write(f"valid_expresso_loss : {valid_expresso_loss}\n")
        file.write(f"valid_libri_loss : {valid_libri_loss}\n")
        
        if is_make_samples:
            file.write(f"Clap score : {clap_score}\n")
            file.write(f"Wer : {wer}\n")
            file.write(f"Cer : {cer}\n")

    del unwrapped_model
    torch.cuda.empty_cache()

In [None]:

import torch
from f5_tts.infer.utils_infer import (
    infer_process,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
)
from f5_tts.train.custom_prompts import custom_prompts, creature_prompts
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from f5_tts.train.utils import make_html
import torchaudio
import numpy as np
from jiwer import wer
from difflib import SequenceMatcher
from torch.cuda.amp import autocast

def calculate_wer(reference, hypothesis):
    # Word Error Rate 계산
    return wer(reference, hypothesis)

def calculate_cer(reference, hypothesis):
    # Character Error Rate 계산
    # 레벤슈타인 거리 기반으로 일치 비율 계산
    matcher = SequenceMatcher(None, reference, hypothesis)
    cer = 1 - matcher.ratio()
    return cer

def validate(model, valid_loader, vocoder, text_conditioner, output_dir, epoch, noise_scheduler, is_caption=False, is_make_samples=False, is_whisper=False, device='cuda'):
    model.eval()
    valid_loss=0
    with torch.no_grad():
        for idx, batch in enumerate(valid_loader):
            mel = batch["mel"].to(device)
            mel_lengths = batch["mel_lengths"].to(device)
            scripts = batch["script"]
            caption = batch["caption"]
            mel_spec = mel.permute(0, 2, 1).to(device)
            
            with autocast():
                if is_caption:
                    caption_embed, attention_mask = text_conditioner(caption, device=device)
                else:
                    caption_embed, attention_mask = None, None
                loss, cond, pred = model(
                    mel_spec, text=scripts, lens=mel_lengths, noise_scheduler=noise_scheduler, caption_embed=caption_embed, attention_mask=attention_mask
                )
            valid_loss += loss
        print(valid_loss)

    wer = 0
    cer = 0
    if is_make_samples:
        speed = 1.0
        mel_spec_type = "vocos"

        output_list = []
        for data in custom_prompts:
            caption = data['text']
            script = data['script']
            caption_embed, attention_mask = text_conditioner(caption, device=device)
            audio, final_sample_rate, spectragram = infer_process(
                None,
                "kill all. ", # 1초 짜리 zeros가 들어가니까
                script, 
                model, 
                vocoder, 
                mel_spec_type=mel_spec_type, 
                speed=speed,
                no_ref_audio=True,
                caption_embed=caption_embed,
                attention_mask=attention_mask
            )
            output_list.append({
                'array': torch.stack((torch.tensor(audio).unsqueeze(dim=0), torch.tensor(audio).unsqueeze(dim=0)), dim=0).squeeze(),
                'caption': caption + " - " + script,
                'script': script
            })
        
        for data in creature_prompts:
            prefix_path = data['prefix_path']
            prefix_script = data['prefix_script']
            caption = data['caption']
            script = data['script']

            main_voice = {
                "ref_audio": prefix_path,
                "ref_text": prefix_script,
            }
            
            voices = {
                "main": main_voice
            }
            for voice in voices:
                voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
                    voices[voice]["ref_audio"], voices[voice]["ref_text"]
                )
            
            caption_embed, attention_mask = text_conditioner(caption, device=device)
            
            audio, final_sample_rate, spectragram = infer_process(
                voices[voice]["ref_audio"],
                voices[voice]["ref_text"],
                script, 
                model, 
                vocoder, 
                mel_spec_type=mel_spec_type, 
                speed=speed,
                no_ref_audio=True,
                caption_embed=caption_embed,
                attention_mask=attention_mask
            )
            print(audio.shape)
            output_list.append({
                'array': torch.stack((torch.tensor(audio).unsqueeze(dim=0), torch.tensor(audio).unsqueeze(dim=0)), dim=0).squeeze(),
                'caption': "prefix + " + caption + " - " + script,
                'script': script
            })

        make_html(epoch, output_dir, output_list)
        
        if is_whisper:
            # load model and processor
            whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
            whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2").to(device)
            whisper_model.config.forced_decoder_ids = None
        
            resampler = torchaudio.transforms.Resample(orig_freq=24000, new_freq=16000)
            for i in range(len(output_list)):
                audio = output_list[i]['array'].squeeze()[0]
                script = output_list[i]['script']
                resampled_audio = resampler(audio)
                
                input_features = whisper_processor(resampled_audio, sampling_rate=16000, return_tensors="pt").input_features.to(device)
                predicted_ids = whisper_model.generate(input_features)
                transcription = whisper_processor.batch_decode(predicted_ids.cpu(), skip_special_tokens=True)[0]
                
                wer_result = calculate_wer(script, transcription)
                cer_result = calculate_cer(script, transcription)
                wer += wer_result
                cer += cer_result
    
            del whisper_processor
            del whisper_model
            

    return valid_loss.cpu().detach().item()/len(valid_loader), wer, cer

In [None]:
ol = validate(
    model,
    valid_loader,
    vocoder,
    text_conditioner,
    output_dir="weights_1125",
    epoch=0,
    noise_scheduler=None,
    is_caption=False,
    is_make_samples=True,
    is_whisper=True,
    device='cuda'
)