In [14]:
import torch
import torch.utils.data

from wav2avatar.streaming.hifigan.wav_ema_dataset import WavEMADataset
from wav2avatar.streaming.hifigan.layers import PastFCEncoder, HiFiGANResidualBlock

import typing


In [2]:
# Input: (B, 512, T)
# Output: (B, 512, ws), where ws is window size

# E.g.: T = 615, ws = 100

In [4]:
def collate_features(features: typing.List[torch.FloatTensor], window_size=50):
    """
    Performs the autoregressive collating from the car_collate function.
    """
    feats_collated = []
    for feature in features:
        collates = [feature[:, :, i:i+window_size] for i in range(0, feature.shape[2] // window_size * window_size, window_size)]
        feature_collated = torch.cat(collates, dim=0)
        feats_collated.append(feature_collated)
    return torch.concatenate(feats_collated, dim=0)

#if __name__ == "__main__":
#    wav_ema_dataset = WavEMADataset()
#
#    data_loader = torch.utils.data.DataLoader(dataset=wav_ema_dataset, batch_size=1, shuffle=True, collate_fn=car_collate)
#
#    print(next(iter(data_loader)))

In [5]:
def car_collate(batch):
    """
    Collates a batch of audio_features and corresponding EMA features in the
    following way:

    For time interval [curr, curr + window_size], we use the audio features
    and pseudolabeled EMA from the same time interval 
    [curr, curr + window_size]. As we are autoregressive w.r.t to the EMA 
    predictions, we use the EMA_ar features from the previous time interval 
    [curr - window_size, curr].

    We then concatenate the audio features, EMA features, EMA_ar features
    independently through all time steps, jumping by window_size.

    Args:
        batch: Batch of tuple - (audio_features, ema_features)

    Returns:
        audio_feats_collated: Collated audio features
        ema_collated: Collated EMA features
        ema_collated_ar: Collated EMA_ar features
    """
    audio_feats, ema = zip(*batch)

    audio_feats = list(audio_feats)
    ema = list(ema)

    for i in range(len(audio_feats)):
        audio_feats[i] = audio_feats[i].detach().cpu()
        ema[i] = ema[i].detach().cpu()

    window_size = 50
    audio_feats_collated = collate_features(audio_feats, window_size)

    ema_collated = collate_features(ema, window_size).float()
    ema_collated_ar = ema_collated[:len(ema_collated) - 1]
    first_ema_batch = torch.zeros(1, 12, window_size)

    ema_collated_ar = torch.concatenate([first_ema_batch, ema_collated_ar], dim=0).float()

    return audio_feats_collated, ema_collated, ema_collated_ar


In [6]:
ar_model = PastFCEncoder(input_len=600, hidden_dim=256, output_dim=128)

In [15]:
ar = torch.randn(1, 12, 50)
x = torch.randn(1, 512, 615)
print(f"Initial x: {x.shape}, ar: {ar.shape}")

ar_feats = ar_model(ar).unsqueeze(2).repeat(1, 1, 615)
print(f"ar feats: {ar_feats.shape}")

x = torch.cat([x, ar_feats], dim=1)
print(f"Concatenated x: {x.shape}")

Initial x: torch.Size([1, 512, 615]), ar: torch.Size([1, 12, 50])


NameError: name 'ar_model' is not defined

In [8]:
input_conv = torch.nn.Conv1d(
    512 + 128,
    512 // 2,
    3,
    3,
    padding=(3 - 1) // 2,
)
x = input_conv(x)
print(f"Input conv: {x.shape}")

Input conv: torch.Size([1, 256, 205])


In [9]:
blocks = []
resblock_kernel_sizes=(3, 7, 11, 15)
resblock_dilations=[(1, 3, 5), (3, 5, 7), (1, 3, 5), (1, 3, 5)]
for j in range(len(resblock_kernel_sizes)):
    blocks += [
        HiFiGANResidualBlock(
            kernel_size=resblock_kernel_sizes[j],
            channels=512 // (2 ** (0 + 1)),
            dilations=resblock_dilations[j],
            bias=True,
            use_additional_convs=True,
            nonlinear_activation="LeakyReLU",
            nonlinear_activation_params={"negative_slope": 0.1},
        )
    ]

cs = 0.0
for block in blocks:
    cs += block(x)

x = cs / len(blocks)

In [10]:
output_conv = torch.nn.Sequential(
    # NOTE(kan-bayashi): follow official implementation but why
    #   using different slope parameter here? (0.1 vs. 0.01)
    torch.nn.LeakyReLU(),
    torch.nn.Conv1d(
        512 // (2 ** (1)),
        1,
        3,
        1,
        padding=(3 - 1) // 2,
    ),
)

In [11]:
x = output_conv(x)
print(x.shape)

torch.Size([1, 1, 205])


In [21]:
# Input AR: (B, 12, 50)
# Output feat: (B, 128)

ar_conv = torch.nn.Sequential(
    torch.nn.LeakyReLU(),
    torch.nn.Conv1d(
        12,
        128,
        3,
        1,
        padding=1
    ),
    torch.nn.Conv1d(
        128,
        128,
        3,
        5,
        padding=1
    ),
    torch.nn.Conv1d(
        128,
        128,
        3,
        5,
        padding=1
    )
)
ar_linear = torch.nn.Linear(256, 128)

ar = torch.randn(1, 12, 50)
ar = ar_conv(ar)
print(ar.shape)
ar = ar.reshape(ar.shape[0], -1)
print(ar.shape)
ar = ar_linear(ar)
ar = ar.unsqueeze(2).repeat(1, 1, 615)
print(ar.shape)

torch.Size([1, 128, 2])
torch.Size([1, 256])
torch.Size([1, 128, 615])


In [17]:
ar_linear = torch.nn.Linear(50, 128)

ar = torch.randn(1, 12, 50)
ar = ar_conv(ar)

ar = ar_linear(ar)

print(ar.shape)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x10 and 50x128)

In [8]:
128 * 50

6400