In [80]:
import os

import torch
from torch import nn
import torch.nn.functional as F
import torchaudio

from tqdm import tqdm
import matplotlib.pyplot as plt

## Define model architecture

In [68]:
# Causal convolution modules from https://github.com/wesbz/SoundStream/blob/main/net.py
# Because I could do the math but don't want to

class CausalConv1d(nn.Conv1d):
    """
    1D convolution with padding at start only
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)

    def forward(self, x):
        return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)

class CausalConvTranspose1d(nn.ConvTranspose1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0]
    
    def forward(self, x, output_size=None):
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')

        assert isinstance(self.padding, tuple)
        output_padding = self._output_padding(
            x, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
        return F.conv_transpose1d(
            x, self.weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)[...,:-self.causal_padding]

In [70]:
class ResidualUnit(nn.Module):
    def __init__(self, N, dilation):
        """
        N is the number of channels (stays constant)
        dilation is Conv1d dilation for first convolution
        """
        super().__init__()
        self.conv1 = CausalConv1d(in_channels=N, out_channels=N, kernel_size=7,
                                  dilation=dilation)
        self.conv2 = CausalConv1d(in_channels=N, out_channels=N, kernel_size=1)

    def forward(self, x):
        out = F.elu(self.conv1(x))
        out = F.elu(x + self.conv2(x))
        return out

In [72]:
# Encoder architecture
class EncoderBlock(nn.Module):
    def __init__(self, N, S):
        """
        N is the number of output channels (stays constant)
            We assume the number of input channels is N/2
        S is the stride size for downsampling
        """
        super().__init__()
        self.resunits = nn.Sequential(ResidualUnit(N//2, dilation=1),
                                      ResidualUnit(N//2, dilation=3),
                                      ResidualUnit(N//2, dilation=9))
        self.conv1 = CausalConv1d(in_channels=N//2, out_channels=N, kernel_size=(2*S), stride=S)

    def forward(self, x):
        out = self.resunits(x)
        out = F.elu(self.conv1(out))
        return out

class Encoder(nn.Module):
    def __init__(self, C, D):
        """
        C is the number of channels initially
        D is the dimensionality of the encoded vectors
        """
        super().__init__()
        self.conv1 = CausalConv1d(in_channels=1, out_channels=C, kernel_size=7)
        self.encblocks = nn.Sequential(EncoderBlock(N=2*C, S=2),
                                       EncoderBlock(N=4*C, S=4),
                                       EncoderBlock(N=8*C, S=5),
                                       EncoderBlock(N=16*C, S=8))
        self.conv2 = CausalConv1d(in_channels=16*C, out_channels=D, kernel_size=3)

    def forward(self, x):
        out = F.elu(self.conv1(x))
        out = self.encblocks(out)
        out = F.elu(self.conv2(out))
        return out

In [73]:
# Decoder architecture
class DecoderBlock(nn.Module):
    def __init__(self, N, S):
        """
        N is number of channels (assumed to stay constant)
        S is stride size
        """
        super().__init__()
        self.convT1 = CausalConvTranspose1d(in_channels=N, out_channels=N//2, kernel_size=2*S,
                                            stride=S)
        self.resunits = nn.Sequential(
            ResidualUnit(N//2, dilation=1),
            ResidualUnit(N//2, dilation=3),
            ResidualUnit(N//2, dilation=9)
        )

    def forward(self, x):
        out = F.elu(self.convT1(x))
        out = self.resunits(out)
        return out

class Decoder(nn.Module):
    def __init__(self, C, D):
        """
        C is "channel scale"
        D is dimensionality of input embeddings
        """
        super().__init__()
        self.conv1 = CausalConv1d(in_channels=D, out_channels=16*C, kernel_size=7)
        self.decblocks = nn.Sequential(
            DecoderBlock(N=16*C, S=8),
            DecoderBlock(N=8*C, S=5),
            DecoderBlock(N=4*C, S=4),
            DecoderBlock(N=2*C, S=2),
        )
        self.conv2 = CausalConv1d(in_channels=C, out_channels=1, kernel_size=7)

    def forward(self, x):
        """x is the embeddings"""
        out = F.elu(self.conv1(x))
        out = self.decblocks(out)
        out = F.elu(self.conv2(out))
        return out

In [74]:
class SoundStream(nn.Module):
    def __init__(self, C_enc, C_dec, D):
        """
        C_enc: "channel scale" for encoder
        C_dec: "channel scale" for decoder
        D:     dimensionality of embeddings
        """
        self.encoder = Encoder(C_enc, D)
        self.decoder = Decoder(C_dec, D)

    def forward(self, x):
        """
        x is a waveform of shape (batch_size, 1, sequence_length)
        """
        # This will be more complicated once we get the RVQ working
        return self.decoder(self.encoder(x))

## Define data loader

In [92]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self):
        # https://stackoverflow.com/questions/42720627/python-os-walk-to-certain-level
        # Get a list of wav file paths
        self.wav_files = []
        data_dir = "data/LibriTTS/dev-clean"
        for root, dirs, files in tqdm(os.walk("data/LibriTTS/dev-clean")):
            if root[len(data_dir):].count(os.sep) != 2:
                continue
            for file in files:
                if file.endswith(".wav"):
                    self.wav_files.append(os.path.join(root, file))

    def __len__(self):
        return len(self.wav_files)
    
    def __getitem__(self, idx):
        return torchaudio.load(self.wav_files[idx])[0]

In [93]:
train_data = AudioDataset()
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)

137it [00:00, 2408.81it/s]
