In [2]:
from models.codec import SwinCrossScaleCodec

model = SwinCrossScaleCodec(patch_size = [3,2],
                 swin_depth = 2,
                 swin_heads = [3, 3, 6, 12, 24],
                 window_size = 4,
                 mlp_ratio = 4.,
                 in_dim = 2, 
                 in_freq = 192, 
                 h_dims = [45, 45, 72, 96, 192, 384], 
                 max_streams = 6, 
                 proj = [4,4,2,2,2,2], 
                 overlap = 2, 
                 num_vqs = 6, 
                 codebook_size = 1024, )

Audio Codec 18.0kbps Initialized
Quantization Vis: 
     Freq dims:  [2, 2, 4, 8, 16, 32]
     Channel(hidden) dims:  [384, 384, 192, 96, 72, 45]
     projections from:  [768, 768, 768, 768, 1152, 1440]
     projections to:  [192, 192, 384, 384, 576, 720]
     group_vq_dims:  [384, 384, 768, 768, 1152, 1440]


In [6]:
import torchaudio
import torch
audio_path = "../swin-debug-vis/test/spanish_instance1.wav"
model.eval()
x, sr = torchaudio.load(audio_path)
x = x[:, :-80]

In [11]:
x = torch.rand(1, 160000-80)

In [12]:
model.ft(x).shape

torch.Size([1, 192, 2000])

In [13]:
model.train()
outputs = model.train_one_step(x, None, streams=6)

model.eval()
with torch.inference_mode():
   test_outputs =  model.test_one_step(x, None, streams=6)

In [16]:
test_outputs["recon_feat"].shape

torch.Size([1, 2, 192, 2000])

In [1]:
from models.losses import TimeLoss, FreqLoss
recon_loss = TimeLoss()
mel_loss = FreqLoss()



In [2]:
import torch
x = torch.randn(1,47920)
x_ = torch.randn(1,47920)

In [3]:
recon_loss(x, x_)

tensor(2.0037)

In [4]:
mel_loss(x, x_)

tensor(12.3666)

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Convolution1D(nn.Conv1d):
    """1D Convolution (dilated-causal convolution)"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 bias=True,
                 causal=True):
        super(Convolution1D, self).__init__(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=0 if causal else padding,
            dilation=dilation, groups=groups, bias=bias)

        self.left_pad = dilation * (kernel_size - 1) if causal else 0

    def forward(self, input):
        x = F.pad(input, (self.left_pad, 0))

        return super(Convolution1D, self).forward(x)

In [15]:
x = torch.randn(1, 2, 300)

In [22]:
causalconv1d = Convolution1D(2, 4, 2, causal=True)
causalconv1d(x).shape

torch.Size([1, 4, 300])

In [26]:
from torch.nn.modules.utils import _pair

class Convolution2D(nn.Conv2d):
    """2D Convolution (dilated-causal convolution)"""

    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size, 
                 stride=1, 
                 padding=None, 
                 dilation=1, 
                 groups=1, 
                 bias=True,
                 causal=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        dilation = _pair(dilation)
        if causal:
            padding = [int((kernel_size[i]-1) * dilation[i]) for i in range(len(kernel_size))]
    
        super(Convolution2D, self).__init__(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=0 if causal else padding, 
            dilation=dilation, groups=groups, bias=bias)
        
        self.left_pad = _pair(padding) if causal else _pair(0)

    def forward(self, inputs):
        x = F.pad(inputs, (self.left_pad[1], 0, self.left_pad[0], 0))

        return super(Convolution2D, self).forward(x)


In [74]:
x = torch.randn(1, 2, 192, 600)
causalconv2d = Convolution2D(2, 1, 3, 1, causal=True)

In [75]:
causalconv2d(x).shape

torch.Size([1, 1, 192, 600])

In [69]:
import torch.nn as nn

class ConvolutionTranspose2D(nn.ConvTranspose2d):
    """2D Transposed Convolution (dilated-causal convolution)"""

    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size, 
                 stride=1, 
                 padding=None,
                 dilation=1, 
                 groups=1, 
                 bias=True, 
                 causal=False):
        super(ConvolutionTranspose2D, self).__init__(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=0 if causal else padding,
            dilation=dilation, groups=groups, bias=bias)
        self.causal = causal
        
        kh, kw = _pair(kernel_size)
        dh, dw = _pair(dilation)
        
        self.crop_h = (kh - 1) * dh if causal else 0
        self.crop_w = (kw - 1) * dw if causal else 0
    
    def forward(self, x):

        x_ = super(ConvolutionTranspose2D, self).forward(x)
        
        if self.causal:
            x_ = x_[:, :, self.crop_h:, self.crop_w:]

        h_pad, w_pad = x.shape[2] * self.stride[0] - x_.shape[2], x.shape[3] * self.stride[1] - x_.shape[3]
        x_ = F.pad(x_, (0, w_pad, 0, h_pad))
        
        return x_


In [72]:
x = torch.randn(1, 4, 48, 600)
causalconvtranspose2d = ConvolutionTranspose2D(4, 2, (5,2), (1,1), causal=True)

In [73]:
causalconvtranspose2d(x).shape

torch.Size([1, 2, 48, 600])

In [76]:
import torch.nn as nn

class RNNFilter(nn.GRU):

    def __init__(self, 
                 input_size, 
                 hidden_size,
                 num_layers,
                 bias=True,
                 batch_first=True,
                 dropout=0.,
                 bidirectional=False,):
        super().__init__(input_size, 
                 hidden_size,
                 num_layers,
                 bias=bias,
                 batch_first=batch_first,
                 dropout=dropout,
                 bidirectional=bidirectional,)
        
        # self.num_layers = num_layers
        # self.D = 2 if bidirectional else 1
    
    def forward(self, x):

        output, _ = super().forward(x)
        return output

In [77]:
x = torch.randn(1, 600, 8)
rnn = RNNFilter(8, 8, 1)

In [78]:
rnn(x).shape

torch.Size([1, 600, 8])

In [79]:
from models.tfs import TCM
tcm = TCM(8, 16, (1,2,4,8))

In [81]:
x = torch.randn(1, 8, 600)
tcm(x).shape

torch.Size([1, 8, 600])

In [82]:
from models.codec import ConvCrossScaleCodec

model = ConvCrossScaleCodec(fuse_net=True, scalable=True, use_tf=True)

Use Fuse Merge Net
Audio Codec 18.0kbps Initialized
Quantization Vis: 
     Freq dims:  [6, 6, 12, 24, 48, 96]
     Channel(hidden) dims:  [64, 64, 32, 24, 24, 16]
     projections from:  [384, 384, 384, 576, 1152, 1536]
     projections to:  [192, 192, 192, 288, 576, 768]
     group_vq_dims:  [768, 768, 768, 1152, 2304, 3072]


In [83]:
model

ConvCrossScaleCodec(
  (ft): Spectrogram()
  (ift): InverseSpectrogram()
  (recon_loss): MSELoss()
  (mel_loss): MELLoss(
    (mel_transf1): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (mel_transf2): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (mel_transf3): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (mel_transf4): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (mel_transf5): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (mel_transf6): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
  )
  (encoder): ConvEncoder(
    (blocks): ModuleList(
      (0): ConvEncoderLayer(
        (conv): Convolution2D(
          (conv): Conv2d(2, 16, kernel_size=(5, 2), stride=(1, 1))
        )
        (norm): BatchNorm2d(16, eps=1e-05, moment

In [138]:
from timm.models.layers import trunc_normal_, to_2tuple
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x, mask=None):
        """ Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        print("q k v shape: ", q.shape, k.shape, v.shape)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        print("attn q@k shape: ", attn.shape)
        
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            print("adding mask: ", attn)
            attn = self.softmax(attn)
            print("after softmax: ", attn)

        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        print("attn_dist: ", attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        print("attn_out: ", x)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [139]:
d_model = 3
window_size = 2
num_heads = 1

attn = WindowAttention(d_model, window_size=to_2tuple(window_size), num_heads=num_heads)

In [140]:
def create_causal_mask_for_windows(num_windows, window_size):
    """
    Create a causal mask for a given number of windows and window size.
    
    Args:
        num_windows (int): The number of windows for which the mask is to be created.
        window_size (int): The size of each window.
        
    Returns:
        torch.Tensor: A causal mask of shape (num_windows, window_size * window_size, window_size * window_size).
    """
    single_window_mask = torch.triu(torch.full((window_size * window_size, window_size * window_size), float('-inf')), diagonal=1)
    mask = single_window_mask.unsqueeze(0).repeat(num_windows, 1, 1)
    return mask

num_windows = 1
causal_mask = create_causal_mask_for_windows(num_windows, window_size)
causal_mask

tensor([[[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., 0., -inf],
         [0., 0., 0., 0.]]])

In [141]:
bs = 1

x = torch.ones(bs*num_windows, window_size*window_size, d_model)
x = torch.Tensor([[[1,2,3],
                   [4,3,1],
                   [2,1,1],
                   [3,4,2]]])
print(x.shape)
attn_out = attn(x)
attn_out

torch.Size([1, 4, 3])
q k v shape:  torch.Size([1, 1, 4, 3]) torch.Size([1, 1, 4, 3]) torch.Size([1, 1, 4, 3])
attn q@k shape:  torch.Size([1, 1, 4, 4])
attn_dist:  tensor([[[[0.4157, 0.0918, 0.0954, 0.3972],
          [0.3914, 0.0856, 0.0759, 0.4471],
          [0.3792, 0.1426, 0.1625, 0.3156],
          [0.3711, 0.0644, 0.0543, 0.5101]]]], grad_fn=<SoftmaxBackward0>)
attn_out:  tensor([[[-1.6874,  0.9687,  1.0449],
         [-1.7257,  0.9954,  1.0435],
         [-1.6834,  1.0428,  1.0451],
         [-1.7572,  1.0056,  1.0389]]], grad_fn=<ReshapeAliasBackward0>)


tensor([[[-0.5366,  1.9465,  0.9556],
         [-0.5383,  1.9671,  0.9580],
         [-0.5156,  1.9583,  0.9194],
         [-0.5417,  1.9801,  0.9638]]], grad_fn=<ViewBackward0>)

In [143]:
torch.ones(1,3)

tensor([[1., 1., 1.]])

In [142]:
attn_out = attn(x, causal_mask)
attn_out

q k v shape:  torch.Size([1, 1, 4, 3]) torch.Size([1, 1, 4, 3]) torch.Size([1, 1, 4, 3])
attn q@k shape:  torch.Size([1, 1, 4, 4])
adding mask:  tensor([[[[ 1.4206,    -inf,    -inf,    -inf],
          [ 1.4375, -0.0829,    -inf,    -inf],
          [ 0.9014, -0.0769,  0.0541,    -inf],
          [ 1.6335, -0.1172, -0.2888,  1.9517]]]], grad_fn=<ViewBackward0>)
after softmax:  tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.8206, 0.1794, 0.0000, 0.0000],
          [0.5542, 0.2083, 0.2375, 0.0000],
          [0.3711, 0.0644, 0.0543, 0.5101]]]], grad_fn=<SoftmaxBackward0>)
attn_dist:  tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.8206, 0.1794, 0.0000, 0.0000],
          [0.5542, 0.2083, 0.2375, 0.0000],
          [0.3711, 0.0644, 0.0543, 0.5101]]]], grad_fn=<SoftmaxBackward0>)
attn_out:  tensor([[[-1.1582,  0.2155,  1.0944],
         [-1.3923,  0.5742,  1.1232],
         [-1.4680,  0.8772,  1.0700],
         [-1.7572,  1.0056,  1.0389]]], grad_fn=<ReshapeAliasBackwar

tensor([[[-0.6273,  1.6082,  1.1172],
         [-0.5991,  1.7889,  1.0587],
         [-0.5160,  1.8484,  0.9215],
         [-0.5417,  1.9801,  0.9638]]], grad_fn=<ViewBackward0>)