In [1]:
from model.vf import VFEstimator
from model.refencoder import ReferenceEncoder
from model.textencoder import TextEncoder
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import torch
from utils.feature import TorchAudioFbank, TorchAudioFbankConfig
from tokenizerown import LibriTTSTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sample_rate = 24000
n_mels = 100
d_model = 256

ref_encoder = ReferenceEncoder(in_dim=n_mels, num_heads=4).to(device)
text_encoder = TextEncoder(vocab_size=160, emb_dim=128).to(device)
vf_estimator = VFEstimator(dim_in=n_mels, dim_model=d_model, conv_hidden=1024, num_heads=4, Nm=4).to(device)

ref_audio_path = './test.wav'
script = "Hello, world!"

tokenizer = LibriTTSTokenizer(
    special_tokens=["<filler>"],
    token_file="./vocab_small.txt",
    lowercase=True,
    oov_policy="skip",        # OOV은 버림 (또는 "use_unk", "error")
    unk_token="[UNK]",        # oov_policy="use_unk"일 때만 필요
)
fbank = TorchAudioFbank(config=TorchAudioFbankConfig(sampling_rate=sample_rate, n_mels=n_mels, n_fft=1024, hop_length=256))

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

audio, sr = librosa.load(ref_audio_path, sr=24000, mono=True)
audio = torch.from_numpy(audio)
logmel = fbank.extract(audio, sr).unsqueeze(0)
print("logmel : ", logmel.shape)
B, T_ref, _ = logmel.shape

token_ids = tokenizer.texts_to_token_ids([script])
token_ids = torch.tensor(token_ids, device=device)
print("token_ids : ", token_ids)

logmel :  torch.Size([1, 301, 100])
token_ids :  tensor([[20, 18, 24, 24, 27,  8,  3, 35, 27, 30, 24, 17,  4]], device='cuda:0')


In [4]:
ref = ref_encoder(logmel.to(device))
print("ref : ", ref.shape)

text_emb_128, learned_keys_50 = text_encoder(token_ids, ref, ref_key=None)
print("text_emb_128 : ", text_emb_128.shape)
print("learned_keys_50 : ", learned_keys_50.shape)

ref :  torch.Size([1, 50, 128])
text_emb_128 :  torch.Size([1, 13, 128])
learned_keys_50 :  torch.Size([50, 128])


In [5]:
noisy_latents = torch.randn(B, 100*4, n_mels, device=device)
time_t = torch.rand((B,), device=device)
print("time_t : ", time_t, time_t.shape)

output = vf_estimator(noisy_latents, time_t, text_emb_128, ref, learned_keys_50)
print("output : ", output.shape) # B, secs*100, n_mels

time_t :  tensor([0.2730], device='cuda:0') torch.Size([1])
output :  torch.Size([1, 400, 100])


In [6]:
total_params = sum(p.numel() for p in vf_estimator.parameters())
print(f"전체 파라미터 수: {total_params:,}")

total_params = sum(p.numel() for p in text_encoder.parameters())
print(f"전체 파라미터 수: {total_params:,}")

total_params = sum(p.numel() for p in ref_encoder.parameters())
print(f"전체 파라미터 수: {total_params:,}")

전체 파라미터 수: 17,062,756
전체 파라미터 수: 1,748,480
전체 파라미터 수: 948,352


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

class Adapter128to256(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(128, 256)

    def forward(self, x):
        # x: (B, T, 128)
        return self.lin(x)

# -------------------------------------------------
# 4) 전체 파이프라인 예시
# -------------------------------------------------
class TTSVectorFieldPipeline(nn.Module):
    def __init__(self, vocab_size: int, device="cuda" if torch.cuda.is_available() else "cpu"):
        super().__init__()
        self.device = device

        # TextEncoder: 앞서 구현한 사양(임베딩 128, ConvNeXt x6, SA x4, Cross-Attn x2)
        self.text_encoder = TextEncoder(vocab_size=vocab_size).to(device)

        # ReferenceValueEncoder: logmel(100) → 128
        self.ref_encoder = ReferenceValueEncoder(in_dim=100, out_dim=128).to(device)

        # 128→256 어댑터
        self.text_adapter = Adapter128to256().to(device)
        self.ref_adapter  = Adapter128to256().to(device)

        # VFEstimator: 144→256→(main blocks)→256→144
        self.vf = VFEstimator(dim_in=144, dim_model=256, conv_hidden=1024, num_heads=4, Nm=4).to(device)

    @torch.no_grad()
    def forward_once(
        self,
        text_ids: torch.Tensor,     # (B, T_text), long
        logmel: torch.Tensor,       # (B, T_ref, 100), float
        noisy_latents: torch.Tensor,# (B, T_ref, 144)
        t_scalar: float = 0.5       # 예시 time t
    ):
        """
        1) ref_value = RefEncoder(logmel)  -> (B, T_ref, 128)
        2) text_emb, learned_keys = TextEncoder(text_ids, ref_value)
        3) text/ref 128→256 어댑터
        4) VFEstimator(noisy_latents, t, text256, ref256, learned_keys_50)
        """
        B, T_ref, _ = logmel.shape

        # 1) Reference value
        ref_value_128 = self.ref_encoder(logmel)  # (B, T_ref, 128)

        # 2) Text encoder (speaker-adaptive text embedding + learned keys(50,128))
        text_emb_128, learned_keys_50 = self.text_encoder(text_ids, ref_value_128, ref_key=None)  # (B, T_text, 128), (50,128)

        # 3) 128 → 256
        text_emb_256 = self.text_adapter(text_emb_128)  # (B, T_text, 256)
        ref_value_256 = self.ref_adapter(ref_value_128) # (B, T_ref, 256)

        # 4) VFEstimator forward
        time_t = torch.full((B,), float(t_scalar), device=text_ids.device)
        vf_pred = self.vf(
            noisy_latents=noisy_latents,         # (B, T_ref, 144)
            time_t=time_t,                      # (B,)
            text_embed=text_emb_256,            # (B, T_text, 256)
            ref_value=ref_value_256,            # (B, T_ref, 256)
            learned_keys_50=learned_keys_50,    # (50, 128) → 내부에서 256로 proj
            proj_text_to_dim=None,              # 이미 256로 변환했으므로 불필요
            proj_ref_to_dim=None,
        )
        return vf_pred  # (B, T_ref, 144)


# -------------------------------------------------
# 5) 실제 사용 예시
#    - 당신이 이미 계산한 logmel을 사용 (T, 100)
#    - 텍스트 문자열 -> 배치 1로 토크나이즈
#    - noisy_latents는 예시로 동일 T_ref 길이의 랜덤 텐서(144채널)
# -------------------------------------------------
def run_forward_example(text: str, logmel: torch.Tensor):
    """
    text: 입력 텍스트
    logmel: (T, 100)  # fbank.extract(audio, sr) 결과
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # (a) 토크나이저 & 배치화
    tokenizer = SimpleCharTokenizer()
    text_ids = tokenizer.batchify([text], device=device)  # (1, T_text)

    # (b) logmel 배치화 및 dtype 맞춤
    logmel = logmel.unsqueeze(0).to(device=device, dtype=torch.float32)  # (1, T_ref, 100)
    B, T_ref, _ = logmel.shape

    # (c) VF 입력용 noisy_latents 예시 (보통은 외부 노이즈 스케줄러/샘플러에서 옴)
    noisy_latents = torch.randn(B, T_ref, 144, device=device)

    # (d) 파이프라인 구성 & forward
    pipe = TTSVectorFieldPipeline(vocab_size=tokenizer.vocab_size, device=device)
    pipe.eval()

    vf_pred = pipe.forward_once(
        text_ids=text_ids,
        logmel=logmel,
        noisy_latents=noisy_latents,
        t_scalar=0.5,   # 필요에 맞게 조정
    )
    print("vf_pred:", vf_pred.shape)  # (1, T_ref, 144)
    return vf_pred


# -------------------------------------------------
# 6) 예시 실행 (당신이 이미 계산한 logmel을 그대로 사용)
# -------------------------------------------------
# 예:
# from feature import TorchAudioFbank, TorchAudioFbankConfig
# fbank = TorchAudioFbank(config=TorchAudioFbankConfig(sampling_rate=24000, n_mels=100, n_fft=1024, hop_length=256))
# audio, sr = librosa.load("./test.wav", sr=24000)
# audio = torch.from_numpy(audio)
# logmel = fbank.extract(audio, sr)  # (T, 100)

# 이제 forward:
# vf = run_forward_example("오늘 회의 주제 뭐였지?", logmel)
