In [1]:
import sys
sys.path.append('..')

from pflow.models.pflow_tts import pflowTTS
import torch
from dataclasses import dataclass

@dataclass
class DurationPredictorParams:
    filter_channels_dp: int
    kernel_size: int
    p_dropout: float

@dataclass
class EncoderParams:
    n_feats: int
    n_channels: int
    filter_channels: int
    filter_channels_dp: int
    n_heads: int
    n_layers: int
    kernel_size: int
    p_dropout: float
    spk_emb_dim: int
    n_spks: int
    prenet: bool

@dataclass
class CFMParams:
    name: str
    solver: str
    sigma_min: float

# Example usage
duration_predictor_params = DurationPredictorParams(
    filter_channels_dp=256,
    kernel_size=3,
    p_dropout=0.1
)

encoder_params = EncoderParams(
    n_feats=80,
    n_channels=192,
    filter_channels=768,
    filter_channels_dp=256,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    spk_emb_dim=64,
    n_spks=1,
    prenet=True
)

cfm_params = CFMParams(
    name='CFM',
    solver='euler',
    sigma_min=1e-4
)

@dataclass
class EncoderOverallParams:
    encoder_type: str
    encoder_params: EncoderParams
    duration_predictor_params: DurationPredictorParams

encoder_overall_params = EncoderOverallParams(
    encoder_type='RoPE Encoder',
    encoder_params=encoder_params,
    duration_predictor_params=duration_predictor_params
)

model = pflowTTS(
    n_vocab=100,
    n_feats=80,
    encoder=encoder_overall_params,
    decoder=None,
    cfm=cfm_params,
    data_statistics=None,
)

x = torch.randint(0, 100, (4, 20))
x_lengths = torch.randint(10, 20, (4,))
y = torch.randn(4, 80, 500)
y_lengths = torch.randint(300, 500, (4,))

model(x, x_lengths, y, y_lengths)

(tensor(0.3590, grad_fn=<DivBackward0>),
 tensor(1.5666, grad_fn=<DivBackward0>),
 tensor(8.3302, grad_fn=<DivBackward0>))

In [4]:
x = torch.randint(0, 100, (1, 20))
x_lengths = torch.randint(10, 20, (1,))
y_slice = torch.randn(1, 80, 264)

model.synthesise(x, x_lengths, y_slice, n_timesteps=10)

{'encoder_outputs': tensor([[[ 0.0571,  0.2162, -0.0617,  ..., -0.0474,  0.9882, -0.2352],
          [-0.1764,  0.4822,  0.5068,  ...,  0.4122,  0.0049,  0.1976],
          [-0.0915,  0.5326, -0.6715,  ...,  0.0528, -0.5645,  0.2096],
          ...,
          [ 0.3362,  0.4159,  0.1211,  ...,  0.1333, -0.7746,  0.0222],
          [-0.5805, -0.2404,  0.5159,  ...,  0.2905, -0.6415,  0.5865],
          [ 0.2180,  0.1045, -0.6131,  ..., -0.4079, -0.3824, -0.8299]]]),
 'decoder_outputs': tensor([[[-0.6374, -1.2095, -2.0675,  ...,  1.0466, -0.5316, -0.9766],
          [-0.0597, -1.1466,  0.7942,  ...,  0.7352,  0.8900, -0.3178],
          [-0.5340, -0.3991,  0.5070,  ...,  0.4689,  0.4287,  1.8043],
          ...,
          [-1.4877, -0.7701,  1.5185,  ..., -0.9086,  0.5233, -0.9295],
          [-1.3017,  1.3422,  0.5179,  ...,  0.2207,  0.8652, -0.4455],
          [ 1.4924, -0.6081, -1.2654,  ...,  0.5243,  0.0148, -0.5286]]]),
 'attn': tensor([[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.