In [1]:
from newmodel.vf import VFEstimator
from newmodel.textencoder import TextEncoder
import torch
import librosa
from utils.feature import TorchAudioFbank, TorchAudioFbankConfig
from tokenizerown import LibriTTSTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sample_rate = 24000
n_mels      = 100
d_model     = 256
depth       = 5
num_heads   = 4
downsample_factors = [1, 2, 4, 2, 1]

text_encoder = TextEncoder(vocab_size=160, emb_dim=128).to(device)
vf_estimator = VFEstimator(dim_in=n_mels, dim_model=d_model, conv_hidden=1024, num_heads=num_heads, Nm=depth, downsample_factors=downsample_factors).to(device)

ref_audio_path = './test.wav'
script = "Hello, world!"

tokenizer = LibriTTSTokenizer(
    special_tokens=["<filler>"],
    token_file="./vocab_small.txt",
    lowercase=True,
    oov_policy="skip",        # OOV은 버림 (또는 "use_unk", "error")
    unk_token="[UNK]",        # oov_policy="use_unk"일 때만 필요
)
fbank = TorchAudioFbank(config=TorchAudioFbankConfig(sampling_rate=sample_rate, n_mels=n_mels, n_fft=1024, hop_length=256))

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
audio, sr = librosa.load(ref_audio_path, sr=24000, mono=True)
audio = torch.from_numpy(audio)
logmel = fbank.extract(audio, sr).unsqueeze(0)
print("logmel : ", logmel.shape)
B, T_ref, _ = logmel.shape

token_ids = tokenizer.texts_to_token_ids([script])
token_ids = torch.tensor(token_ids, device=device)
print("token_ids : ", token_ids)

In [None]:
text_emb_128 = text_encoder(token_ids)
print("text_emb_128 : ", text_emb_128.shape)

In [None]:
noisy_latents = torch.randn(B, 100*4, n_mels, device=device)
time_t = torch.rand((B,), device=device)
print("time_t : ", time_t, time_t.shape)

output = vf_estimator(noisy_latents, time_t, text_emb_128)
print("output : ", output.shape) # B, secs*100, n_mels

In [None]:
total_params = sum(p.numel() for p in vf_estimator.parameters())
print(f"전체 파라미터 수: {total_params:,}")

total_params = sum(p.numel() for p in text_encoder.parameters())
print(f"전체 파라미터 수: {total_params:,}")

In [4]:
from datasets import load_dataset

ds = load_dataset("atmansingh/ljspeech")
print(len(ds['train']))
print(len(ds['validation']))

11775
1309


In [5]:
from data.text_mel_datamodule import TextMelDataset, TextMelDataModule

dm = TextMelDataModule(
    name="ljspeech",
    dataset=ds,
    batch_size=16,
    num_workers=4,
    pin_memory=True,
    n_spks=1,             # LJSpeech = 단일 화자 → 1
    n_fft=1024,
    n_feats=100,           # mel bins
    sample_rate=22050,
    hop_length=256,
    f_min=0,
    f_max=8000,
    data_statistics={"mel_mean": 0.0, "mel_std": 1.0},
    seed=42,
    load_durations=False, # alignment 정보 필요 없으면 False
)


In [6]:
dm.setup()

In [None]:
data = next(iter(dm.train_dataloader()))

print(data['x'].shape, data['y'].shape)
print(''.join(tokenizer.token_ids_to_tokens(data['x'][0].unsqueeze(0).tolist())[0]))

  return {"x": torch.tensor(token_ids), "y": torch.tensor(mel).transpose(1, 0), "filepath": "filepath", "durations":None}
  return {"x": torch.tensor(token_ids), "y": torch.tensor(mel).transpose(1, 0), "filepath": "filepath", "durations":None}
  return {"x": torch.tensor(token_ids), "y": torch.tensor(mel).transpose(1, 0), "filepath": "filepath", "durations":None}
  return {"x": torch.tensor(token_ids), "y": torch.tensor(mel).transpose(1, 0), "filepath": "filepath", "durations":None}


torch.Size([16, 163]) torch.Size([16, 100, 846])
