In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm

from pflow_encodec.modules.spk_enc import SpeakerEncoder
from pflow_encodec.modules.transformer import Transformer

In [None]:
text_encoder = Transformer(
    depth=6,
    dim=192,
    dim_head=96,
    heads=2,
    ff_mult=4.0,
    attn_dropout=0.1,
    ff_dropout=0.0,
    norm_type="ada_embed",
    ff_type="conv",
    ff_kernel_size=9,
    ff_groups=4,
    scale_type="ada_embed",
    dim_cond=192,
)
cond_linear = nn.Linear(192, 192 * 6)

In [None]:
x = torch.randn(1, 64, 192)
cond = torch.randn(1, 1, 192)
attn_norm_scale, attn_norm_bias, attn_scale, ff_norm_scale, ff_norm_bias, ff_scale = cond_linear(cond).chunk(6, dim=-1)
cond_input = {
    "attn_norm_cond": torch.cat([attn_norm_scale, attn_norm_bias], dim=-1),
    "attn_scale_cond": attn_scale,
    "ff_norm_cond": torch.cat([ff_norm_scale, ff_norm_bias], dim=-1),
    "ff_scale_cond": ff_scale,
    "final_norm_cond": cond,
}
out = text_encoder(x, cond_input=cond_input)

In [None]:
sum(p.numel() for p in text_encoder.parameters()) / 1e6

In [None]:
decoder = Transformer(
    depth=12,
    dim=512,
    dim_head=64,
    heads=8,
    ff_mult=4.0,
    attn_dropout=0.1,
    ff_dropout=0.0,
    norm_type="ada_embed",
    ff_type="conv",
    ff_kernel_size=3,
    ff_groups=4,
    scale_type="ada_embed",
    dim_cond=512,
)

In [None]:
sum(p.numel() for p in decoder.parameters()) / 1e6

In [None]:
spk_encoder = SpeakerEncoder(
    dim_input=128,
    depth=2,
    dim=192,
    dim_head=96,
    heads=2,
    ff_mult=4.0,
    attn_dropout=0.1,
    ff_dropout=0.0,
    norm_type="layer",
    ff_type="conv",
    ff_kernel_size=9,
    ff_groups=4,
    scale_type="none",
)

In [None]:
sum(p.numel() for p in spk_encoder.parameters()) / 1e6

In [None]:
prompt = torch.randn(1, 225, 128)
spk_encoder(prompt).shape

In [None]:
from pflow_encodec.data.datamodule import TextLatentLightningDataModule

dm = TextLatentLightningDataModule(
    train_tsv_path="/home/seastar105/datasets/libritts_r/train_duration.tsv",
    val_tsv_path="/home/seastar105/datasets/libritts_r/dev_duration.tsv",
    num_workers=8,
    return_upsampled=False,
)
dm.setup("fit")
dl = dm.train_dataloader()

In [None]:
from pflow_encodec.models.pflow import PFlow

model = PFlow()

In [None]:
batch = next(iter(dl))

In [None]:
text_tokens, text_token_lens, durations, duration_lens, latents, latent_lens = batch

In [None]:
import torch


def slice_segments(x, ids_str, segment_size=4):
    ret = torch.zeros_like(x[:, :segment_size, :])
    for i in range(x.size(0)):
        idx_str = ids_str[i]
        idx_end = idx_str + segment_size
        ret[i] = x[i, idx_str:idx_end, :]
    return ret


def rand_slice_segments(x, x_lengths=None, segment_size=4):
    b, t, d = x.size()
    if x_lengths is None:
        x_lengths = t
    ids_str_max = x_lengths - segment_size + 1
    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
    ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to(dtype=torch.long)
    ret = slice_segments(x, ids_str, segment_size)
    mask = torch.arange(t, device=x.device).expand(b, t) >= ids_str.unsqueeze(1)
    mask &= torch.arange(t, device=x.device).expand(b, t) < (ids_str + segment_size).unsqueeze(1)
    return ret, mask

In [None]:
prompts, prompt_masks = rand_slice_segments(latents, latent_lens, segment_size=225)

In [None]:
prompts.shape

In [None]:
model(text_tokens, text_token_lens, durations, duration_lens, latents, latent_lens, prompts, prompt_masks)