In [41]:
import torch
from torch import nn
import torch.nn.functional as F

from x_transformers.x_transformers import RotaryEmbedding

from f5_tts.model.modules import (
    TimestepEmbedding,
    ConvNeXtV2Block,
    ConvPositionEmbedding,
    DiTBlock,
    AdaLayerNormZero_Final,
    precompute_freqs_cis,
    get_pos_embed_indices,
)
from torch import nn
import torch
import logging, warnings
import string
import typing as tp
import gc
import random

# Text embedding
class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, text_dim, conv_layers=0, conv_mult=2):
        super().__init__()
        self.text_embed = nn.Embedding(vocab_size+1, text_dim)  # use 0 as filler token

        if conv_layers > 0:
            self.extra_modeling = True
            self.precompute_max_pos = 4096  # ~44s of 24khz audio
            # self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
            self.freqs_cis = precompute_freqs_cis(text_dim, self.precompute_max_pos) # (precompute_max_pos, text_dim)
            self.text_blocks = nn.Sequential(
                *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
            )
        else:
            self.extra_modeling = False

    def forward(self, text, seq_len, drop_text=False):  # noqa: F722
        # text is tokenized by custom vocab.
        text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
        text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
        batch, text_len = text.shape[0], text.shape[1]
        text = F.pad(text, (0, seq_len - text_len), value=0)
        print("padded text ", text.shape)

        if drop_text and random.random()<drop_text:  # cfg for text
            text = torch.zeros_like(text)

        text = self.text_embed(text)  # b n -> b n d

        # possible extra modeling
        if self.extra_modeling:
            # sinus pos emb
            batch_start = torch.zeros((batch,), dtype=torch.long)
            pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
            text_pos_embed = self.freqs_cis[pos_idx]
            text = text + text_pos_embed

            # convnextv2 blocks
            text = self.text_blocks(text)

        return text

In [42]:
from f5_tts.model.utils import get_tokenizer, list_str_to_idx

text_num_embeds=2546
text_dim=512
conv_layers=4

vocab_file = "./f5_tts/infer/examples/vocab.txt"
tokenizer = "custom"
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer) # len(vocab_char_map) = 2545 = vocab_size

In [43]:
text = ["hello", "it's raining"]
text = list_str_to_idx(text, vocab_char_map)
print(text.shape)

text_module = TextEmbedding(vocab_size=vocab_size, text_dim=text_dim, conv_layers=conv_layers)

torch.Size([2, 12])


In [44]:
text_embed = text_module(text, seq_len=1000, drop_text=0.0) # seq_len if for setting max_len

text_embed.shape

padded text  torch.Size([2, 1000])


torch.Size([2, 1000, 512])

In [57]:
text_embed[0][20:25]

tensor([[ 0.6813,  0.6497,  0.2855,  ..., -1.3252, -0.0537,  0.0632],
        [-0.3184, -0.1016,  0.0722,  ..., -1.2688, -0.0505,  0.0849],
        [-0.7829, -1.0163, -0.7877,  ..., -1.2776, -0.0253,  0.1682],
        [-0.2993, -1.3094, -1.6249,  ..., -1.3394,  0.0140,  0.2614],
        [ 0.6841, -0.7223, -1.7785,  ..., -1.4046,  0.0351,  0.2931]],
       grad_fn=<SliceBackward0>)

In [58]:
text_embed[1][20:25]

tensor([[ 0.6797,  0.6491,  0.2860,  ..., -1.3269, -0.0527,  0.0625],
        [-0.3184, -0.1016,  0.0722,  ..., -1.2688, -0.0505,  0.0849],
        [-0.7829, -1.0163, -0.7878,  ..., -1.2776, -0.0253,  0.1683],
        [-0.2993, -1.3094, -1.6249,  ..., -1.3394,  0.0140,  0.2614],
        [ 0.6841, -0.7223, -1.7785,  ..., -1.4046,  0.0351,  0.2931]],
       grad_fn=<SliceBackward0>)

In [4]:
seq_len = 1000

# text is tokenized by custom vocab.
text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
print("text : ", text)

text = text_embed.text_embed(text)  # b n -> b n d
print("text : ", text.shape)

text :  tensor([[ 442,  326,  616,  ...,    0,    0,    0],
        [ 508, 1083,    8,  ...,    0,    0,    0]], device='cuda:0')
text :  torch.Size([2, 1000, 512])


In [5]:
precompute_max_pos = 4096

batch_start = torch.zeros((batch,), dtype=torch.long)
print(batch_start)

pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=precompute_max_pos)

tensor([0, 0])


In [6]:
pos_idx

tensor([[  0,   1,   2,  ..., 997, 998, 999],
        [  0,   1,   2,  ..., 997, 998, 999]])

In [13]:
text = text + text_embed.freqs_cis[pos_idx].to('cuda')

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [32]:
import argparse
import codecs
import os
import re
from importlib.resources import files
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path

from f5_tts.infer.utils_infer import (
    infer_process,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
)
from f5_tts.model import DiT, UNetT, CFM
import torch
from f5_tts.model.utils import (
    get_tokenizer,
    convert_char_to_pinyin,
    list_str_to_idx,
    lens_to_mask,
    mask_from_frac_lengths
)
from audiotools import AudioSignal
from transformers import T5EncoderModel, AutoTokenizer
from torch.cuda.amp import autocast
from accelerate import Accelerator, DistributedDataParallelKwargs
from f5_tts.model.cfm import T5Conditioner

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# -----------------------------------------

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32  # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None

# -----------------------------------------

t5_model_name = "t5-base"
text_conditioner = T5Conditioner(t5_model_name="t5-base", max_length=128).to(device)

accelerator = Accelerator(
    mixed_precision = "fp16",
)

  warn(


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

In [34]:
model_cls = DiT
model_cfg = dict(
    dim=1024, 
    depth=22, 
    heads=16, 
    ff_mult=2, 
    text_dim=512, 
    conv_layers=4
)

vocab_file = "./f5_tts/infer/examples/vocab.txt"
tokenizer = "custom"
vocoder_name = "vocos"
ode_method = "euler"

vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
print(vocab_size)
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path="../src/f5_tts/vocoder")

transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)

mel_spec_kwargs=dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

odeint_kwargs=dict(
    method=ode_method,
)

model = CFM(
    transformer=transformer,
    mel_spec_kwargs=mel_spec_kwargs,
    odeint_kwargs=odeint_kwargs,
    vocab_char_map=vocab_char_map,
    frac_lengths_mask=(0.7, 1.0),
    audio_drop_prob=0.3,
    cond_drop_prob=0.2,
    caption_drop_prob=0.2
).to(device)

num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_trainable_params)

2545
Load vocos from local path ../src/f5_tts/vocoder
532889700


In [36]:
transformer.__dict__

{'training': True,
 '_parameters': {},
 '_buffers': {},
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': {'time_embed': TimestepEmbedding(
    (time_embed): SinusPositionEmbedding()
    (time_mlp): Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
    )
  ),
  'text_embed': TextEmbedding(
    (text_embed): Embedding(2546, 512)
    (text_blocks): Sequential(
      (0): ConvNeXtV2Block(
    

In [3]:
text = ["a man saying"]

caption_embed, attention_mask = text_conditioner(text, device=device)

caption_embed.shape

mel = torch.zeros((1, 100, 100)).to(device)
mel_lengths = torch.tensor([100]).to(device)
scripts = ["HHIH"]

In [4]:
loss, cond, pred = model(
    mel, text=scripts, lens=mel_lengths, noise_scheduler=None, caption_embed=caption_embed, attention_mask=attention_mask
)

torch.Size([1, 100, 100]) torch.Size([1, 100, 100]) torch.Size([1, 100, 512])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t :  torch.Size([1, 1024])
caption_embed :  torch.Size([1024])
t : 

In [4]:
import torch.nn as nn

layer = nn.Linear(768, 1024)

num_trainable_params = sum(p.numel() for p in layer.parameters() if p.requires_grad)
print(num_trainable_params * 22)

17324032


In [5]:
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_trainable_params)

337096804
