In [1]:
%load_ext autoreload
%autoreload 2
import gc

import pathlib

import torch
import numpy as np
import soundfile as sf

from rt_vocaltract import datasets
from rt_vocaltract import models

import IPython

# LibriTTS Pseudolabeled Data Exploration

EMA is in MNGU0 format, then pitch, periodicity, then wavlm_cnn

In [2]:
device = 0

In [3]:
libri_root = pathlib.Path("/data/common/LibriTTS_R/")

wav_root = libri_root / "wavs"
feat_root = libri_root / "features"

In [4]:
first_wav = next(wav_root.glob("*.wav"))
first_feat_dict = np.load(feat_root / first_wav.with_suffix(".npy").name, allow_pickle=True).item()
first_feat_ema = first_feat_dict["ema"]
first_feat_wlm = first_feat_dict["wavlm_cnn"]
IPython.display.Audio(first_wav)

## Using our Dataset, try to make SoundStream skeleton

In [5]:
data_path = "/data/common/LibriTTS_R"
lttsr = datasets.LibriTTSRDataset(data_path)

--- loading LibriTTS_R dataset ---
-- wav root: /data/common/LibriTTS_R/wavs ---
-- feature root: /data/common/LibriTTS_R/features ---
--- loaded LibriTTS_R dataset ---


In [11]:
wav, feat = lttsr[4]
print(wav.shape, feat.shape)
wav = wav.unsqueeze(0)

torch.Size([1, 1052164]) torch.Size([2191, 14])


In [21]:
conv1 = torch.nn.Conv1d(in_channels=1, out_channels=512, kernel_size=7, stride=1, padding=0)

x = conv1(wav)
print(x.shape)

torch.Size([1, 886084]) torch.Size([1845, 14])
torch.Size([1, 512, 886078])


In [54]:
class ResidualUnit(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 dilation):
        super(ResidualUnit, self).__init__()
        self.dilated_conv = torch.nn.Conv1d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=stride,
            dilation=dilation)
        self.single_conv = torch.nn.Conv1d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=1, 
            stride=stride)
    
    def forward(self, x):
        conv_x = self.dilated_conv(x)
        conv_x = self.single_conv(conv_x)
        x = x[:, :, :conv_x.shape[2]]
        return x + conv_x

class EncoderBlock(torch.nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 stride, 
                 dilations,
                 kernel_size):
        super(EncoderBlock, self).__init__()
        self.res_units = torch.nn.ModuleList([
            ResidualUnit(
                in_channels=in_channels, 
                out_channels=out_channels // 2, 
                kernel_size=kernel_size, 
                stride=1, 
                dilation=dilation
                )
            for dilation in dilations
        ])
        self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride)
    
    def forward(self, x):
        for res_unit in self.res_units:
            x = res_unit(x)
        x = self.conv(x)
        return x

In [45]:
res_unit1 = ResidualUnit(in_channels=512,
                         out_channels=512,
                         kernel_size=7,
                         stride=1,
                         dilation=1).to(device)
res_unit2 = ResidualUnit(in_channels=512,
                         out_channels=512,
                         kernel_size=7,
                         stride=1,
                         dilation=3).to(device)
res_unit3 = ResidualUnit(in_channels=512,
                         out_channels=512,
                         kernel_size=7,
                         stride=1,
                         dilation=9).to(device)
conv2 = torch.nn.Conv1d(in_channels=512,
                        out_channels=512,
                        kernel_size=4,
                        stride=2).to(device)
x = conv1(wav).to(0)
x = res_unit1(x)
x = res_unit2(x)
x = res_unit3(x)
x = conv2(x)
print(x.shape)

torch.Size([1, 512, 442999])


In [84]:
C = 14
with torch.no_grad():
    conv1 = torch.nn.Conv1d(in_channels=1, out_channels=C, kernel_size=7, stride=1, padding=0)
    enc1 = EncoderBlock(in_channels=C, out_channels=C * 2, stride=2, kernel_size=7, dilations=[1, 3, 9]).to(device)
    enc2 = EncoderBlock(in_channels=C * 2, out_channels=C * 4, stride=5, kernel_size=7, dilations=[1, 3, 9]).to(device)
    enc3 = EncoderBlock(in_channels=C * 4, out_channels=C * 8, stride=6, kernel_size=7, dilations=[1, 3, 9]).to(device)
    enc4 = EncoderBlock(in_channels=C * 8, out_channels=C * 16, stride=8, kernel_size=7, dilations=[1, 3, 9]).to(device)
    conv_end = torch.nn.Conv1d(in_channels=C * 16, out_channels=C, kernel_size=3, stride=1).to(device)

    wav = torch.zeros(1, 1, 1052164)
    x = conv1(wav).to(device)
    print(x.shape)
    x = enc1(x)
    print(x.shape)
    x = enc2(x)
    print(x.shape)
    x = enc3(x)
    print(x.shape)
    x = enc4(x)
    print(x.shape)
    x = conv_end(x)

    print(x.shape)

torch.Size([1, 14, 1052158])
torch.Size([1, 28, 526039])
torch.Size([1, 56, 105191])
torch.Size([1, 112, 17517])
torch.Size([1, 224, 2178])
torch.Size([1, 14, 2176])


## SoundStream Model Dimension Sanity Check

In [6]:
inv_model = models.SoundStreamInversion(in_channels=1, out_channels=14, kernel_size=7, strides=[2, 5, 6, 8], dilations=[1, 3, 9]).to(device)

In [8]:
dataloader = torch.utils.data.DataLoader(lttsr, batch_size=8, shuffle=True)

In [10]:
wav, feat = next(iter(dataloader))

inv_model(wav.to(device)).shape

torch.Size([8, 14, 2176])