In [4]:
import sys
sys.path.append('..')

In [5]:
from pflow.models.pflow_tts import pflowTTS
import torch
from dataclasses import dataclass

@dataclass
class DurationPredictorParams:
    filter_channels_dp: int
    kernel_size: int
    p_dropout: float

@dataclass
class EncoderParams:
    n_feats: int
    n_channels: int
    filter_channels: int
    filter_channels_dp: int
    n_heads: int
    n_layers: int
    kernel_size: int
    p_dropout: float
    spk_emb_dim: int
    n_spks: int
    prenet: bool

@dataclass
class CFMParams:
    name: str
    solver: str
    sigma_min: float

# Example usage
duration_predictor_params = DurationPredictorParams(
    filter_channels_dp=256,
    kernel_size=3,
    p_dropout=0.1
)

encoder_params = EncoderParams(
    n_feats=80,
    n_channels=192,
    filter_channels=768,
    filter_channels_dp=256,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    spk_emb_dim=64,
    n_spks=1,
    prenet=True
)

cfm_params = CFMParams(
    name='CFM',
    solver='euler',
    sigma_min=1e-4
)

@dataclass
class EncoderOverallParams:
    encoder_type: str
    encoder_params: EncoderParams
    duration_predictor_params: DurationPredictorParams

encoder_overall_params = EncoderOverallParams(
    encoder_type='RoPE Encoder',
    encoder_params=encoder_params,
    duration_predictor_params=duration_predictor_params
)

model = pflowTTS(
    n_vocab=100,
    n_feats=80,
    encoder=encoder_overall_params,
    decoder=None,
    cfm=cfm_params,
    data_statistics=None,
)

x = torch.randint(0, 100, (4, 20))
x_lengths = torch.randint(10, 20, (4,))
y = torch.randn(4, 80, 500)
y_lengths = torch.randint(300, 500, (4,))

dur_loss, prior_loss, diff_loss, attn = model(x, x_lengths, y, y_lengths)

print(dur_loss, prior_loss, diff_loss)

tensor(4.0884, grad_fn=<DivBackward0>) tensor(1.5378, grad_fn=<DivBackward0>) tensor(6.8176, grad_fn=<DivBackward0>)


In [6]:
x = torch.randint(0, 100, (1, 20))
x_lengths = torch.randint(10, 20, (1,))
y_slice = torch.randn(1, 80, 264)

model.synthesise(x, x_lengths, y_slice, n_timesteps=10)

{'encoder_outputs': tensor([[[ 0.5504,  0.1342,  0.1342,  ...,  1.2872,  1.2872,  0.5121],
          [ 0.2854, -0.0067, -0.0067,  ..., -0.2164, -0.2164,  0.1162],
          [ 1.0909,  0.0971,  0.0971,  ..., -0.4140, -0.4140,  0.0093],
          ...,
          [-0.6167,  0.0214,  0.0214,  ...,  0.1322,  0.1322,  0.0024],
          [ 0.7357,  0.7161,  0.7161,  ...,  0.0576,  0.0576,  0.1908],
          [-0.3782, -0.0351, -0.0351,  ...,  0.5459,  0.5459, -0.3888]]]),
 'decoder_outputs': tensor([[[ 0.2233,  0.6986, -0.4587,  ...,  1.7759, -1.5674, -0.4869],
          [ 0.3813,  0.3476,  0.1070,  ..., -1.4641, -0.0952,  1.0354],
          [ 1.4565,  1.2124, -0.3740,  ..., -0.8082,  0.4223, -1.3775],
          ...,
          [ 1.7223, -1.4008,  0.5498,  ...,  0.7512,  0.2925, -0.6928],
          [ 0.8185,  0.6916,  0.1859,  ..., -0.4052, -0.8805,  0.8896],
          [ 0.9668, -0.3577, -0.0522,  ..., -0.3391,  1.8045,  1.4378]]]),
 'attn': tensor([[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.