In [None]:
import random
import sys
from importlib.resources import files

import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf

from f5_tts.infer.utils_infer import (
    infer_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
    save_spectrogram,
    transcribe,
)
from f5_tts.model.utils import seed_everything


class F5TTS:
    def __init__(
        self,
        model="F5TTS_v1_Base",
        ckpt_file="",
        vocab_file="",
        ode_method="euler",
        use_ema=True,
        vocoder_local_path=None,
        device=None,
        hf_cache_dir=None,
    ):
        model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
        model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
        model_arc = model_cfg.model.arch

        self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
        self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate

        self.ode_method = ode_method
        self.use_ema = use_ema

        if device is not None:
            self.device = device
        else:
            import torch

            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "xpu"
                if torch.xpu.is_available()
                else "mps"
                if torch.backends.mps.is_available()
                else "cpu"
            )

        # Load models
        self.vocoder = load_vocoder(
            self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
        )

        repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"

        if not ckpt_file:
            ckpt_file = str(
                cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
            )
        
        self.ema_model = load_model(
            model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
        )

    def transcribe(self, ref_audio, language=None):
        return transcribe(ref_audio, language)

    def export_wav(self, wav, file_wave, remove_silence=False):
        sf.write(file_wave, wav, self.target_sample_rate)

        if remove_silence:
            remove_silence_for_generated_wav(file_wave)

    def export_spectrogram(self, spec, file_spec):
        save_spectrogram(spec, file_spec)

    def infer(
        self,
        ref_file,
        ref_text,
        gen_text,
        show_info=print,
        progress=tqdm,
        target_rms=0.1,
        cross_fade_duration=0.15,
        sway_sampling_coef=-1,
        cfg_strength=2,
        nfe_step=32,
        speed=1.0,
        fix_duration=None,
        remove_silence=False,
        file_wave=None,
        file_spec=None,
        seed=None,
    ):
        if seed is None:
            seed = random.randint(0, sys.maxsize)
        seed_everything(seed)
        self.seed = seed

        ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)

        wav, sr, spec = infer_process(
            ref_file,
            ref_text,
            gen_text,
            self.ema_model,
            self.vocoder,
            self.mel_spec_type,
            show_info=show_info,
            progress=progress,
            target_rms=target_rms,
            cross_fade_duration=cross_fade_duration,
            nfe_step=nfe_step,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
            speed=speed,
            fix_duration=fix_duration,
            device=self.device,
        )

        if file_wave is not None:
            self.export_wav(wav, file_wave, remove_silence)

        if file_spec is not None:
            self.export_spectrogram(spec, file_spec)

        return wav, sr, spec

In [None]:
f5tts = F5TTS()

In [None]:
wav, sr, spec = f5tts.infer(
    ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
    ref_text="some call me nature, others call me mother nature.",
    gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall.""",
    file_wave=None,
    file_spec=None,
    seed=None,
)

print("seed :", f5tts.seed)

In [None]:
wav.shape[-1]/24000

In [None]:
from audiotools import AudioSignal

AudioSignal(wav, sample_rate=24000).widget()

In [None]:
from f5_tts.model.modules import ConvNeXtV2Block

conv = ConvNeXtV2Block(256, 256 * 4)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# GRN (Global Response Normalization) Layer as a prerequisite for ConvNeXtV2Block
# This is a common implementation of GRN.
class GRN(nn.Module):
    """ GRN (Global Response Normalization) layer """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=1, keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

# The user-provided ConvNeXtV2Block
class ConvNeXtV2Block(nn.Module):
    """
    A single block of ConvNeXtV2.
    """
    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        dilation: int = 1,
    ):
        super().__init__()
        # Calculate padding based on kernel size and dilation to maintain sequence length
        padding = (dilation * (7 - 1)) // 2
        # Depthwise convolution
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
        )
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        # Pointwise convolutions implemented as linear layers
        self.pwconv1 = nn.Linear(dim, intermediate_dim)
        self.act = nn.GELU()
        self.grn = GRN(intermediate_dim)
        self.pwconv2 = nn.Linear(intermediate_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        # Transpose for Conv1d: (batch, sequence, dim) -> (batch, dim, sequence)
        x = x.transpose(1, 2)
        x = self.dwconv(x)
        # Transpose back: (batch, dim, sequence) -> (batch, sequence, dim)
        x = x.transpose(1, 2)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        # Add residual connection
        return residual + x

# Time Conditioning Block as described in the paper (Appendix A.2.3)
class TimeCondBlock(nn.Module):
    """ Conditions the input on a time embedding. """
    def __init__(self, hidden_dim, time_emb_dim):
        super().__init__()
        self.linear = nn.Linear(time_emb_dim, hidden_dim)

    def forward(self, x, t):
        # Project time embedding to match the hidden dimension
        time_proj = self.linear(t)
        # Add globally to the sequence (unsqueezing for broadcasting)
        return x + time_proj.unsqueeze(1)

# Text and Reference Conditioning Blocks using Cross-Attention
class CrossAttentionCondBlock(nn.Module):
    """
    Conditions the input on text or reference speech using cross-attention.
    'q' is the input sequence, 'k' and 'v' are the conditioning variables.
    This version uses F.scaled_dot_product_attention for efficiency.
    """
    def __init__(self, hidden_dim, num_heads=4):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        # Linear projections for query, key, value
        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Output projection
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, key, value):
        batch_size, seq_len_q, _ = x.shape
        _, seq_len_kv, _ = key.shape

        # 1. Project Q, K, V
        q = self.q_proj(x)
        k = self.k_proj(key)
        v = self.v_proj(value)

        # 2. Reshape for multi-head attention
        # (batch, seq_len, hidden_dim) -> (batch, num_heads, seq_len, head_dim)
        q = q.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)

        # 3. Apply scaled dot-product attention
        # The output of this function will have the same shape as the query
        attn_output = F.scaled_dot_product_attention(q, k, v)

        # 4. Reshape back and project
        # (batch, num_heads, seq_len, head_dim) -> (batch, seq_len, hidden_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)
        attn_output = self.out_proj(attn_output)

        # 5. Add residual connection
        return x + attn_output

# The main VF Estimator model
class VFEstimator(nn.Module):
    """
    Vector Field (VF) Estimator from the Supertonic-TTS paper.
    """
    def __init__(self, input_dim=144, hidden_dim=256, time_emb_dim=64, text_cond_dim=128, ref_cond_dim=128, num_main_blocks=4):
        super().__init__()
        # Initial projection layer
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Main repeating blocks
        self.main_blocks = nn.ModuleList()
        for _ in range(num_main_blocks):
            block = nn.Sequential(
                # Dilated ConvNeXt blocks
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=2),
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=4),
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=8),
                # Standard ConvNeXt blocks
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
                ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
            )
            self.main_blocks.append(block)

        # Conditioning modules
        self.time_cond_blocks = nn.ModuleList([TimeCondBlock(hidden_dim, time_emb_dim) for _ in range(num_main_blocks)])
        self.text_cond_blocks = nn.ModuleList([CrossAttentionCondBlock(hidden_dim) for _ in range(num_main_blocks)])
        self.ref_cond_blocks = nn.ModuleList([CrossAttentionCondBlock(hidden_dim) for _ in range(num_main_blocks)])

        # Final processing blocks
        self.final_blocks = nn.Sequential(
            ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
            ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
            ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
            ConvNeXtV2Block(hidden_dim, hidden_dim * 4, dilation=1),
        )

        # Output projection layer
        self.output_proj = nn.Linear(hidden_dim, input_dim)


    def forward(self, noisy_latents, time_embedding, text_cond, ref_cond_key, ref_cond_value):
        # 1. Initial Projection
        x = self.input_proj(noisy_latents)

        # 2. Main Blocks with Conditioning
        for i, block in enumerate(self.main_blocks):
            x = block(x)
            x = self.time_cond_blocks[i](x, time_embedding)
            x = self.text_cond_blocks[i](x, key=text_cond, value=text_cond)
            x = self.ref_cond_blocks[i](x, key=ref_cond_key, value=ref_cond_value)

        tt = time.time()
        x = self.final_blocks(x)
        print(f"[{time.time() - tt}] Final blocks")
        output = self.output_proj(x)
        return output

In [None]:
# These dimensions are based on the paper's appendix
vf_estimator = VFEstimator(
    input_dim=144,       # Dimension of compressed latents
    hidden_dim=256,      # Hidden dimension inside the estimator
    time_emb_dim=64,     # Dimension of the time embedding
    text_cond_dim=256,   # Dimension of the text encoder output
    ref_cond_dim=256,    # Dimension of the reference encoder output
    num_main_blocks=4    # Number of repeating main blocks
).to('cuda')

# --- Dummy Input Creation ---
batch_size = 1
sequence_length = 200 # Example sequence length for the latents
text_seq_len = 50     # Example sequence length for text condition
ref_seq_len = 50      # Example sequence length for reference condition

# Noisy latents (output from forward process)
# Shape: (batch_size, sequence_length, input_dim)
dummy_noisy_latents = torch.randn(batch_size, sequence_length, 144)

# Time embedding for the current diffusion step
# Shape: (batch_size, time_emb_dim)
dummy_time_embedding = torch.randn(batch_size, 64)

# Text conditioning variable (output from Text Encoder)
# Shape: (batch_size, text_seq_len, text_cond_dim)
dummy_text_cond = torch.randn(batch_size, text_seq_len, 256)

# Reference conditioning variables (output from Reference Encoder)
# Shape: (batch_size, ref_seq_len, ref_cond_dim)
dummy_ref_key = torch.randn(batch_size, ref_seq_len, 256)
dummy_ref_value = torch.randn(batch_size, ref_seq_len, 256)


# --- Model Forward Pass ---
print("--- Running VF Estimator ---")
print(f"Input noisy_latents shape: {dummy_noisy_latents.shape}")

# The model might be large, so let's put it in eval mode for inference
vf_estimator.eval()
import time

with torch.no_grad():
    st = time.time()
    output_vector_field = vf_estimator(
        dummy_noisy_latents.to('cuda'),
        dummy_time_embedding.to('cuda'),
        dummy_text_cond.to('cuda'),
        dummy_ref_key.to('cuda'),
        dummy_ref_value.to('cuda')
    )
    print(time.time() - st)

print(f"Output vector_field shape: {output_vector_field.shape}")
print("Model ran successfully!")

In [None]:
# vf_estimator.py
import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# Speed knobs (optional)
# ----------------------------
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# GRN (for ConvNeXt V2 block)
# ----------------------------
class GRN(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # (B, N, C)
        # Global Response Normalization (ConvNeXt V2)
        Gx = torch.norm(x, p=2, dim=-1, keepdim=True)                 # (B, N, 1)
        Nx = Gx / (Gx.mean(dim=-2, keepdim=True) + self.eps)          # (B, N, 1)
        return self.gamma * (x * Nx) + self.beta + x

# ----------------------------
# Your ConvNeXtV2Block (as given)
# ----------------------------
class ConvNeXtV2Block(nn.Module):
    def __init__(self, dim: int, intermediate_dim: int, dilation: int = 1):
        super().__init__()
        padding = (dilation * (7 - 1)) // 2
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
        )  # depthwise conv
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs
        self.act = nn.GELU()
        self.grn = GRN(intermediate_dim)
        self.pwconv2 = nn.Linear(intermediate_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # (B, N, D)
        residual = x
        x = x.transpose(1, 2)             # (B, D, N)
        x = self.dwconv(x)
        x = x.transpose(1, 2)             # (B, N, D)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        return residual + x

# ----------------------------
# Time embedding (Grad-TTS style)
# ----------------------------
def sinusoidal_t_embed(t: torch.Tensor, dim: int = 64) -> torch.Tensor:
    """
    t: (B,) in [0,1], returns (B, dim) sinusoidal time embedding.
    """
    device = t.device
    half = dim // 2
    # 2*pi for better coverage; consistent with many diffusion impls
    freqs = torch.exp(
        torch.linspace(0, math.log(10000), half, device=device)
        * (-1)
    )
    # shape: (B, half)
    args = 2.0 * math.pi * t[:, None] * freqs[None, :]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb  # (B, dim)

# ----------------------------
# Cross-Attention Block (uses SDPA)
# ----------------------------
class CrossAttention(nn.Module):
    def __init__(self, dim_q: int, dim_kv: int, num_heads: int = 4, head_dim: int = 64, attn_drop: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.inner_dim = num_heads * head_dim

        self.q_proj = nn.Linear(dim_q, self.inner_dim, bias=True)
        self.k_proj = nn.Linear(dim_kv, self.inner_dim, bias=True)
        self.v_proj = nn.Linear(dim_kv, self.inner_dim, bias=True)
        self.out_proj = nn.Linear(self.inner_dim, dim_q, bias=True)
        self.attn_drop = attn_drop

    def forward(
        self,
        x_q: torch.Tensor,      # (B, Tq, Dq)
        x_kv: torch.Tensor,     # (B, Tk, Dkv)
        attn_mask: Optional[torch.Tensor] = None  # (B, Tq, Tk) or None
    ) -> torch.Tensor:
        B, Tq, _ = x_q.shape
        _, Tk, _ = x_kv.shape

        q = self.q_proj(x_q)  # (B, Tq, inner)
        k = self.k_proj(x_kv) # (B, Tk, inner)
        v = self.v_proj(x_kv) # (B, Tk, inner)

        # reshape to (B, heads, T, head_dim)
        q = q.view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, Tk, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, Tk, self.num_heads, self.head_dim).transpose(1, 2)

        # scaled dot-product attention (uses FlashAttention kernels when available)
        out = F.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0.0, is_causal=False
        )  # (B, heads, Tq, head_dim)

        out = out.transpose(1, 2).contiguous().view(B, Tq, self.inner_dim)
        return self.out_proj(out)  # (B, Tq, Dq)

# ----------------------------
# VF Estimator
# ----------------------------
@dataclass
class VFConfig:
    d_in: int = 144        # compressed latent dim (C*Kc)
    d_model: int = 256     # hidden channels inside VF
    d_time: int = 64       # time embedding dim (before projection)
    n_repeats: int = 4     # Nm in the paper
    n_post: int = 4        # post ConvNeXt blocks count
    inter_dim: int = 1024  # ConvNeXt MLP hidden
    kernel_size: int = 5
    num_heads: int = 4
    head_dim: int = 64

class VFEstimator(nn.Module):
    """
    Implements the VF Estimator (Fig. 4(c) & Appx A.2.3) of Supertonic-TTS.
    - Input: noisy compressed latents z_t (B, T, 144)
    - Conditions: time t (B,), text_kv (B, Nt, Dt), ref_kv (B, Nr, Dr)
    - Output: vector field v(z_t, text, ref, t) with shape (B, T, 144)
    """
    def __init__(self, cfg: VFConfig = VFConfig()):
        super().__init__()
        self.cfg = cfg

        # input/output projections
        self.proj_in  = nn.Linear(cfg.d_in, cfg.d_model)
        self.proj_out = nn.Linear(cfg.d_model, cfg.d_in)

        # time conditioning: 64-d sinusoidal -> project to d_model, then global add
        self.time_proj = nn.Linear(cfg.d_time, cfg.d_model)

        # text & ref cross-attention (reused many times)
        self.text_attn = CrossAttention(cfg.d_model, dim_kv=128, num_heads=cfg.num_heads, head_dim=cfg.head_dim)
        self.ref_attn  = CrossAttention(cfg.d_model, dim_kv=128, num_heads=cfg.num_heads, head_dim=cfg.head_dim)

        # a helper to build blocks
        def convnext_block(dilation: int = 1):
            return ConvNeXtV2Block(dim=cfg.d_model, intermediate_dim=cfg.inter_dim, dilation=dilation)

        # main repeated stages
        self.stages = nn.ModuleList()
        for _ in range(cfg.n_repeats):
            stage = nn.ModuleDict(
                dict(
                    dilated_1=convnext_block(dilation=1),
                    dilated_2=convnext_block(dilation=2),
                    dilated_3=convnext_block(dilation=4),
                    dilated_4=convnext_block(dilation=8),
                    # after time add, two standard ConvNeXt blocks around cross-attn
                    pre_text = convnext_block(dilation=1),
                    post_text= convnext_block(dilation=1),
                )
            )
            self.stages.append(stage)

        # post blocks
        self.post_blocks = nn.Sequential(*[convnext_block(dilation=1) for _ in range(cfg.n_post)])

        # lightweight norms on the path
        self.pre_text_ln = nn.LayerNorm(cfg.d_model)
        self.post_text_ln = nn.LayerNorm(cfg.d_model)
        self.pre_ref_ln = nn.LayerNorm(cfg.d_model)
        self.post_ref_ln = nn.LayerNorm(cfg.d_model)

    def forward(
        self,
        zt: torch.Tensor,               # (B, T, 144)
        t: torch.Tensor,                # (B,)
        text_kv: torch.Tensor,          # (B, Nt, 128)
        ref_kv: torch.Tensor,           # (B, Nr, 128)
    ) -> torch.Tensor:
        x = self.proj_in(zt)  # (B, T, d_model)

        # precompute time embedding once
        t_emb = sinusoidal_t_embed(t, dim=self.cfg.d_time)
        t_add = self.time_proj(t_emb)  # (B, d_model)

        for stage in self.stages:
            # 4 dilated ConvNeXt blocks
            x = stage["dilated_1"](x)
            x = stage["dilated_2"](x)
            x = stage["dilated_3"](x)
            x = stage["dilated_4"](x)

            # time conditioning: global add (broadcast over time)
            x = x + t_add[:, None, :]

            # ConvNeXt -> Text cross-attn -> ConvNeXt
            x = stage["pre_text"](x)
            x = self.pre_text_ln(x)
            x = x + self.text_attn(x, text_kv)  # residual
            x = stage["post_text"](x)

            # LayerNorm -> Ref cross-attn (residual)
            x = self.pre_ref_ln(x)
            x = x + self.ref_attn(x, ref_kv)
            x = self.post_ref_ln(x)

        # tail ConvNeXt stack + projection to 144-dim
        x = self.post_blocks(x)
        out = self.proj_out(x)  # (B, T, 144)
        return out

In [None]:
# ----------------------------
# Quick demo (random inputs)
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

B, T = 2, 860     # batch, time steps of compressed latents
Nt, Nr = 256, 50  # text tokens, reference tokens

torch.manual_seed(0)

zt = torch.randn(B, T, 144, device=device, dtype=dtype)
t  = torch.rand(B, device=device, dtype=torch.float32)  # keep time in fp32 for embedding stability
text_kv = torch.randn(B, Nt, 128, device=device, dtype=dtype)
ref_kv  = torch.randn(B, Nr, 128, device=device, dtype=dtype)

model = VFEstimator().to(device=device, dtype=dtype)
model.eval()

# Optional: compile for speed (PyTorch 2.0+)
try:
    model = torch.compile(model)
except Exception:
    pass

# Inference
with torch.inference_mode(), torch.autocast(device_type=device if device == "cuda" else "cpu", enabled=(device=="cuda")):
    y = model(zt, t, text_kv, ref_kv)

print("Output shape:", y.shape)  # (B, T, 144)

# Small timing test
iters = 10
torch.cuda.synchronize() if device == "cuda" else None
import time
st = time.time()
for _ in range(iters):
    with torch.inference_mode(), torch.autocast(device_type=device if device=="cuda" else "cpu", enabled=(device=="cuda")):
        _ = model(zt, t, text_kv, ref_kv)
torch.cuda.synchronize() if device == "cuda" else None
print(f"{iters} iters avg latency: {(time.time()-st)*1000:.2f} ms")

# 길이 200 -> 120, latency 60ms -> 64?

In [None]:
import torch

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 파라미터 수 계산
total_params = count_parameters(model)
print(f"Total parameters: {total_params}")