In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm
import numpy as np

from typing import Union, List, Tuple

In [85]:
class SSResidualUnit(nn.Module):
    
    def __init__(self, channels:int, dilation:int):
        
        super(SSResidualUnit, self).__init__()
        
        self.channels = channels
        self.dilation = dilation
        
        self.sequence = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=7, dilation=dilation, padding="same"),
            nn.ELU(),
            nn.Conv1d(channels, channels, kernel_size=1, padding="same")
        )
        
    def forward(self, x):
        h = self.sequence(x)
        return x + h
    
s = SSResidualUnit(2, 1)
x = torch.randn((1, 2, 64))
s(x).shape

torch.Size([1, 2, 64])

In [86]:
class SSEncoderBlock(nn.Module):
    
    def __init__(self, channels:int, stride:int):
        
        super(SSEncoderBlock, self).__init__()
        
        self.channels = channels
        self.stride = stride
        
        self.resid_0 = SSResidualUnit(self.channels // 2, dilation = 1)
        self.resid_1 = SSResidualUnit(self.channels // 2, dilation = 3)
        self.resid_2 = SSResidualUnit(self.channels // 2, dilation = 9)
        
        self.out_conv = nn.Conv1d(self.channels // 2, self.channels, kernel_size=2*self.stride, stride=self.stride, padding=self.stride//2)
        
    def forward(self, x):
        x0 = self.resid_0(x)
        x1 = self.resid_0(x0)
        x2 = self.resid_0(x1)
        out = self.out_conv(x2)
        
        return out
    
se = SSEncoderBlock(4, 2)
se(x).shape

torch.Size([1, 4, 32])

In [87]:
class SSEncoder(nn.Module):
    
    def __init__(self, in_channels:int, channels:int, embedding_dims:int):
        
        super(SSEncoder, self).__init__()
        
        self.in_channels = in_channels
        self.channels = channels
        self.embedding_dims = embedding_dims
        
        self.in_conv = nn.Conv1d(self.in_channels, self.channels, kernel_size=7, padding=3)
        
        self.encoder_block_0 = SSEncoderBlock(self.channels * 2, 2)
        self.encoder_block_1 = SSEncoderBlock(self.channels * 4, 4)
        self.encoder_block_2 = SSEncoderBlock(self.channels * 8, 5)
        self.encoder_block_3 = SSEncoderBlock(self.channels * 16, 8)
        
        self.out_conv = nn.Conv1d(self.channels * 16, self.embedding_dims, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        
        h = self.in_conv(x)
        h0 = self.encoder_block_0(h)
        h1 = self.encoder_block_1(h0)
        h2 = self.encoder_block_2(h1)
        h3 = self.encoder_block_3(h2)
        
        out = self.out_conv(h3)
        
        return out
    
sse = SSEncoder(1, 4, 128)
sse(torch.randn(1, 1, 128*128)).shape

torch.Size([1, 128, 51])

In [88]:
class SSDecoderBlock(nn.Module):
    
    def __init__(self, channels:int, stride:int):
        
        super(SSDecoderBlock, self).__init__()
        
        self.channels = channels
        self.stride = stride
        
        self.in_conv = nn.ConvTranspose1d(self.channels, self.channels // 2, kernel_size=2*self.stride, stride=self.stride, padding=self.stride//2)
        
        self.resid_0 = SSResidualUnit(self.channels // 2, dilation = 1)
        self.resid_1 = SSResidualUnit(self.channels // 2, dilation = 3)
        self.resid_2 = SSResidualUnit(self.channels // 2, dilation = 9)
        
    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.resid_0(x0)
        x2 = self.resid_1(x1)
        out = self.resid_2(x2)
        
        return out
    
se = SSDecoderBlock(8, 2)
x = torch.randn((1, 8, 64))
se(x).shape

torch.Size([1, 4, 128])

In [None]:
class SSDecoder(nn.Module):
    
    def __init__(self, in_channels:int, channels:int, embedding_dims:int):
        
        super(SSDecoder, self).__init__()
        
        self.in_channels = in_channels
        self.channels = channels
        self.embedding_dims = embedding_dims
        
        self.in_conv = nn.Conv1d(self.in_channels, self.channels, kernel_size=7, padding=3)
        
        self.encoder_block_0 = SSEncoderBlock(self.channels * 2, 2)
        self.encoder_block_1 = SSEncoderBlock(self.channels * 4, 4)
        self.encoder_block_2 = SSEncoderBlock(self.channels * 8, 5)
        self.encoder_block_3 = SSEncoderBlock(self.channels * 16, 8)
        
        self.out_conv = nn.Conv1d(self.channels * 16, self.embedding_dims, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        
        h = self.in_conv(x)
        h0 = self.encoder_block_0(h)
        h1 = self.encoder_block_1(h0)
        h2 = self.encoder_block_2(h1)
        h3 = self.encoder_block_3(h2)
        
        out = self.out_conv(h3)
        
        return out

In [54]:
# this is not the quantizer they used, just a simple one for experimenting
class EmbeddingQuantizer(nn.Module):

    def __init__(self, codebook_size:int, embedding_dims:int) -> None:
        
        super(EmbeddingQuantizer, self).__init__()

        self.embedding_dims = embedding_dims
        self.codebook_size = codebook_size

        self.embeddings = nn.Embedding(self.codebook_size, self.embedding_dims)

    def forward(self, x):

        B, C, T = x.shape

        reshape_inputs = x.permute(0, 2, 1).contiguous() # embed by channel values (BTC)
        reshape_inputs = reshape_inputs.view(-1, self.embedding_dims) # reshape to embedding dimensions (B, E)

        # calculate distances between all inputs and embeddings
        xs = (reshape_inputs**2).sum(dim=1, keepdim=True)
        ys = (self.embeddings.weight**2).sum(dim=1)
        dots = reshape_inputs @ self.embeddings.weight.t()
        distances = (xs + ys) - (2 * dots)

        # get embedding indices and quantize
        embedding_indexes = torch.argmin(distances, dim=1).unsqueeze(1)
        quantized_embeddings = self.embeddings(embedding_indexes).view(B, T, C).permute(0, 2, 1).contiguous()

        loss = F.mse_loss(x, quantized_embeddings)

        return quantized_embeddings, embedding_indexes.squeeze(-1), loss

eq = EmbeddingQuantizer(1024, 128)
x = torch.randn((1, 128, 51))
embs, codes, loss = eq(x)
embs.shape

torch.Size([1, 128, 51])

In [57]:
class SoundStream(nn.Module):
    
    def __init__(self, in_channels:int, channels:int, embedding_dims:int, codebook_size:int):
        
        super(SoundStream, self).__init__()
        
        self.encoder = SSEncoder(in_channels, channels, embedding_dims)
        self.quantizer = EmbeddingQuantizer(codebook_size, embedding_dims)
        
    def forward(self, x):
        
        h = self.encoder(x)
        embs, codes, q_loss = self.quantizer(h)
        
        return embs, codes, q_loss
    
model = SoundStream(1, 4, 128, 512)
embs, codes, q_loss = model(torch.randn(1, 1, 128*128))
embs.shape, codes, q_loss

(torch.Size([1, 128, 51]),
 tensor([254, 254, 388, 388, 254, 388, 254, 388, 254, 388, 388, 388, 254, 388,
         254, 254, 388, 254, 388, 388, 254, 254, 388, 254, 388, 388, 388, 254,
         254, 388, 254, 388, 254, 254, 388, 132, 254, 254, 254, 254, 254, 254,
         254, 254, 388, 254, 254, 254, 388, 254, 254]),
 tensor(0.6970, grad_fn=<MseLossBackward0>))