In [1]:
import sys
sys.path.append('..')

In [2]:
from pflow.models.pflow_tts import pflowTTS
import torch
from dataclasses import dataclass

@dataclass
class DurationPredictorParams:
    filter_channels_dp: int
    kernel_size: int
    p_dropout: float

@dataclass
class EncoderParams:
    n_feats: int
    n_channels: int
    filter_channels: int
    filter_channels_dp: int
    n_heads: int
    n_layers: int
    kernel_size: int
    p_dropout: float
    spk_emb_dim: int
    n_spks: int
    prenet: bool

@dataclass
class CFMParams:
    name: str
    solver: str
    sigma_min: float

# Example usage
duration_predictor_params = DurationPredictorParams(
    filter_channels_dp=256,
    kernel_size=3,
    p_dropout=0.1
)

encoder_params = EncoderParams(
    n_feats=80,
    n_channels=192,
    filter_channels=768,
    filter_channels_dp=256,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    spk_emb_dim=64,
    n_spks=1,
    prenet=True
)

cfm_params = CFMParams(
    name='CFM',
    solver='euler',
    sigma_min=1e-4
)

@dataclass
class EncoderOverallParams:
    encoder_type: str
    encoder_params: EncoderParams
    duration_predictor_params: DurationPredictorParams

encoder_overall_params = EncoderOverallParams(
    encoder_type='RoPE Encoder',
    encoder_params=encoder_params,
    duration_predictor_params=duration_predictor_params
)

@dataclass
class DecoderParams:
    channels: list
    dropout: float
    attention_head_dim: int
    n_blocks: int
    num_mid_blocks: int
    num_heads: int
    act_fn: str

decoder_params = DecoderParams(
    channels=[256, 256],
    dropout=0.05,
    attention_head_dim=64,
    n_blocks=1,
    num_mid_blocks=2,
    num_heads=2,
    act_fn='snakebeta'
)

model = pflowTTS(
    n_vocab=100,
    n_feats=80,
    encoder=encoder_overall_params,
    decoder=None,
    cfm=cfm_params,
    data_statistics=None,
)

x = torch.randint(0, 100, (4, 20))
x_lengths = torch.randint(10, 20, (4,))
y = torch.randn(4, 80, 500)
y_lengths = torch.randint(300, 500, (4,))

dur_loss, prior_loss, diff_loss, attn = model(x, x_lengths, y, y_lengths)

print(dur_loss, prior_loss, diff_loss)

tensor(3.8836, grad_fn=<DivBackward0>) tensor(1.4566, grad_fn=<DivBackward0>) tensor(6.6788, grad_fn=<DivBackward0>)


In [3]:
x = torch.randint(0, 100, (1, 20))
x_lengths = torch.randint(10, 20, (1,))
y_slice = torch.randn(1, 80, 264)

model.synthesise(x, x_lengths, y_slice, n_timesteps=10)

{'encoder_outputs': tensor([[[ 4.9688e-01, -1.2686e-01,  5.3712e-01,  ..., -8.9837e-04,
           -7.4237e-01, -7.4237e-01],
          [-6.8723e-01, -9.5133e-01, -3.3626e-03,  ..., -7.4941e-01,
           -9.8903e-01, -9.8903e-01],
          [ 2.0464e-01, -4.6147e-01, -2.9429e-01,  ..., -7.6495e-01,
           -8.0589e-01, -8.0589e-01],
          ...,
          [ 1.9040e-01,  1.0653e-01, -2.4789e-01,  ..., -7.0640e-01,
           -6.6554e-01, -6.6554e-01],
          [-7.0117e-01, -7.8523e-01, -9.4743e-02,  ..., -9.1728e-01,
           -1.1041e+00, -1.1041e+00],
          [-4.5640e-01, -1.1332e-01, -3.9241e-01,  ..., -5.6445e-02,
           -8.6922e-01, -8.6922e-01]]]),
 'decoder_outputs': tensor([[[-1.2112, -0.9451, -0.6372,  ...,  0.8843, -0.1819, -1.4297],
          [ 1.9485,  0.2319, -0.7402,  ..., -0.3386,  0.2191,  1.3389],
          [-0.7073,  0.0750, -1.2341,  ..., -1.7076, -0.9495, -1.5823],
          ...,
          [ 1.5402,  0.0086,  1.5628,  ..., -0.4434,  0.9534,  0.6925],