In [4]:
import sys
sys.path.append('..')

In [5]:
from pflow.models.pflow_tts import pflowTTS
import torch
from dataclasses import dataclass

@dataclass
class DurationPredictorParams:
    filter_channels_dp: int
    kernel_size: int
    p_dropout: float

@dataclass
class EncoderParams:
    n_feats: int
    n_channels: int
    filter_channels: int
    filter_channels_dp: int
    n_heads: int
    n_layers: int
    kernel_size: int
    p_dropout: float
    spk_emb_dim: int
    n_spks: int
    prenet: bool

@dataclass
class CFMParams:
    name: str
    solver: str
    sigma_min: float

# Example usage
duration_predictor_params = DurationPredictorParams(
    filter_channels_dp=256,
    kernel_size=3,
    p_dropout=0.1
)

encoder_params = EncoderParams(
    n_feats=80,
    n_channels=192,
    filter_channels=768,
    filter_channels_dp=256,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    spk_emb_dim=64,
    n_spks=1,
    prenet=True
)

cfm_params = CFMParams(
    name='CFM',
    solver='euler',
    sigma_min=1e-4
)

@dataclass
class EncoderOverallParams:
    encoder_type: str
    encoder_params: EncoderParams
    duration_predictor_params: DurationPredictorParams

encoder_overall_params = EncoderOverallParams(
    encoder_type='RoPE Encoder',
    encoder_params=encoder_params,
    duration_predictor_params=duration_predictor_params
)

@dataclass
class DecoderParams:
    channels: list
    dropout: float
    attention_head_dim: int
    n_blocks: int
    num_mid_blocks: int
    num_heads: int
    act_fn: str

decoder_params = DecoderParams(
    channels=[256, 256],
    dropout=0.05,
    attention_head_dim=64,
    n_blocks=1,
    num_mid_blocks=2,
    num_heads=2,
    act_fn='snakebeta'
)

model = pflowTTS(
    n_vocab=100,
    n_feats=80,
    encoder=encoder_overall_params,
    decoder=decoder_params,
    cfm=cfm_params,
    data_statistics=None,
)

x = torch.randint(0, 100, (4, 20))
x_lengths = torch.randint(10, 20, (4,))
y = torch.randn(4, 80, 500)
y_lengths = torch.randint(300, 500, (4,))

dur_loss, prior_loss, diff_loss, attn = model(x, x_lengths, y, y_lengths)

print(dur_loss, prior_loss, diff_loss)

tensor(4.4542, grad_fn=<DivBackward0>) tensor(1.5045, grad_fn=<DivBackward0>) tensor(7.3645, grad_fn=<DivBackward0>)


In [6]:
x = torch.randint(0, 100, (1, 20))
x_lengths = torch.randint(10, 20, (1,))
y_slice = torch.randn(1, 80, 264)

model.synthesise(x, x_lengths, y_slice, n_timesteps=10)

{'encoder_outputs': tensor([[[-0.4563, -0.4563, -0.4325,  ..., -0.1074, -0.2714, -0.2714],
          [ 0.6352,  0.6352,  0.1177,  ...,  0.5231, -0.8408, -0.8408],
          [-0.1301, -0.1301,  0.0195,  ...,  0.3655, -0.3479, -0.3479],
          ...,
          [ 0.0163,  0.0163,  0.5394,  ..., -0.4139, -0.4440, -0.4440],
          [ 0.5111,  0.5111,  0.4567,  ..., -0.6954, -0.0434, -0.0434],
          [ 0.3311,  0.3311, -0.4192,  ..., -0.6273,  0.1724,  0.1724]]]),
 'decoder_outputs': tensor([[[ 1.7638, -2.8382,  0.6564,  ...,  0.6531,  0.0379, -0.4501],
          [-0.0241,  1.2763, -0.8939,  ..., -1.9034,  0.5025, -1.7826],
          [ 0.4985, -0.7862,  0.5175,  ..., -0.5421, -0.1646,  0.6659],
          ...,
          [-0.4191,  1.0218,  2.2283,  ...,  0.9479, -0.4108,  0.1039],
          [ 0.0335,  0.4822,  0.2790,  ...,  0.3322,  1.0656, -0.8636],
          [-0.0633, -0.0229, -1.7531,  ..., -0.3215, -0.2620, -0.1053]]]),
 'attn': tensor([[[[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.