In [1]:
import torch
from torch import nn
from einops.layers.torch import Rearrange
from einops import rearrange
from model import createFreqBands, torch_mdct
import torch.nn.functional as F

x = torch.rand((2,2,44100*6))

In [2]:
x_mdct = torch_mdct(x[0],2048)

KBD Window: 2048
audio_shape: (2, 264600)
batch: 0 channel: 0
num_times: 260
batch: 0 channel: 1
num_times: 260


In [3]:
x_mdct.shape

torch.Size([2, 1023, 260])

In [4]:
class BandSplitModule(nn.Module):
    def __init__(self, sample_rate, n_fft, channels=2,
                 fc_dim=128, bands=[1000, 4000, 8000, 16000, 20000],
                 num_subBands=[10, 12, 8, 8, 2, 1]):
        super().__init__()

        self.bands = createFreqBands(sample_rate, n_fft, bands, num_subBands)
        self.band_intervals = self.bands[1:] - self.bands[:-1]
        self.channels = channels

        self.layer_list = nn.ModuleList([
            nn.Sequential(
                Rearrange('b f t c -> b t (f c)', c=self.channels),
                nn.LayerNorm(band_interval * channels),
                nn.Linear(channels * band_interval, channels * fc_dim),
                # Rearrange('b t n -> b n 1 t')
            )
            for band_interval in self.band_intervals
        ])

    def forward(self, spec):
        # spec format: (b, f, t, channel) #Mono or stereo channels
        spec_bands = [spec[:, self.bands[i]:self.bands[i + 1]] for i in range(len(self.bands) - 1)]
        outputs = []
        for spec_band, layer in zip(spec_bands, self.layer_list):
            output = layer(spec_band)
            outputs.append(output)
        outputs = torch.cat(outputs, dim=-2)
        return outputs

In [5]:
bands = [1000, 4000, 8000, 16000, 20000]
num_subBands = [10, 12, 8, 8, 2, 1]
sample_rate = 44100
n_fft = 2048
channels = 2
fc_dim = 128
bandSplitter = BandSplitModule(sample_rate=sample_rate, n_fft=n_fft, channels=channels,
                                            fc_dim=fc_dim, bands=bands, num_subBands=num_subBands)

In [6]:
freqBands = createFreqBands(sample_rate,n_fft,bands,num_subBands)
band_intervals = freqBands[1:]-freqBands[:-1]

In [7]:
x_mdct_rearrange = rearrange(x_mdct,'(b c) f t -> b f t c', c=channels)

In [9]:
spec_bands = [x_mdct_rearrange[:,freqBands[i]:freqBands[i+1]] for i in range(len(freqBands)-1)]

In [12]:
outputs = []
i = 0
for spec_band,band_interval in zip(spec_bands,band_intervals):
    print(i+1,spec_band.shape)
    i+=1
    f1 = rearrange(spec_band,'b f t c -> b t (f c)',c=channels)
    f2 = F.layer_norm(f1,[band_interval*channels])
    lin_layer = nn.Linear(channels*band_interval,channels*fc_dim)
    f3 = lin_layer(f2)
    f4 = rearrange(f3,'b t n -> b n 1 t')
    outputs.append(f4)

1 torch.Size([1, 4, 260, 2])
2 torch.Size([1, 5, 260, 2])
3 torch.Size([1, 4, 260, 2])
4 torch.Size([1, 5, 260, 2])
5 torch.Size([1, 5, 260, 2])
6 torch.Size([1, 4, 260, 2])
7 torch.Size([1, 5, 260, 2])
8 torch.Size([1, 4, 260, 2])
9 torch.Size([1, 5, 260, 2])
10 torch.Size([1, 5, 260, 2])
11 torch.Size([1, 11, 260, 2])
12 torch.Size([1, 12, 260, 2])
13 torch.Size([1, 11, 260, 2])
14 torch.Size([1, 12, 260, 2])
15 torch.Size([1, 11, 260, 2])
16 torch.Size([1, 12, 260, 2])
17 torch.Size([1, 12, 260, 2])
18 torch.Size([1, 11, 260, 2])
19 torch.Size([1, 12, 260, 2])
20 torch.Size([1, 11, 260, 2])
21 torch.Size([1, 12, 260, 2])
22 torch.Size([1, 12, 260, 2])
23 torch.Size([1, 23, 260, 2])
24 torch.Size([1, 23, 260, 2])
25 torch.Size([1, 23, 260, 2])
26 torch.Size([1, 24, 260, 2])
27 torch.Size([1, 23, 260, 2])
28 torch.Size([1, 23, 260, 2])
29 torch.Size([1, 23, 260, 2])
30 torch.Size([1, 24, 260, 2])
31 torch.Size([1, 46, 260, 2])
32 torch.Size([1, 47, 260, 2])
33 torch.Size([1, 46, 260, 

Mask Estimation Fix

In [1]:
import torch
from torch import nn
from einops.layers.torch import Rearrange
from einops import rearrange
from model import createFreqBands, torch_mdct
import torch.nn.functional as F

x = torch.rand([2,256,41,260])

In [2]:
bands = [1000, 4000, 8000, 16000, 20000]
num_subBands = [10, 12, 8, 8, 2, 1]
sample_rate = 44100
n_fft = 2048
channels = 2
fc_dim = 128
freqBands = createFreqBands(sample_rate,n_fft,bands,num_subBands)
band_intervals = freqBands[1:]-freqBands[:-1]

In [23]:
outputs = []
for i in range(len(band_intervals)):
    f1 = rearrange(x[:,:,i,:],'b n t -> b t n')
    # print(f1.shape)
    f2 = F.layer_norm(f1,[2*fc_dim])
    lin_layer1 = nn.Linear(2*fc_dim,4*fc_dim)
    lin_layer2 = nn.Linear(4*fc_dim,band_intervals[i]*channels)
    f3 = lin_layer1(f2)
    f4 = torch.tanh(f3)
    f5 = lin_layer2(f4)
    f6 = rearrange(f5, 'b t (f c) -> (b c) f t',c=channels)
    print(f6.shape)
    outputs.append(f6)

torch.Size([4, 4, 260])
torch.Size([4, 5, 260])
torch.Size([4, 4, 260])
torch.Size([4, 5, 260])
torch.Size([4, 5, 260])
torch.Size([4, 4, 260])
torch.Size([4, 5, 260])
torch.Size([4, 4, 260])
torch.Size([4, 5, 260])
torch.Size([4, 5, 260])
torch.Size([4, 11, 260])
torch.Size([4, 12, 260])
torch.Size([4, 11, 260])
torch.Size([4, 12, 260])
torch.Size([4, 11, 260])
torch.Size([4, 12, 260])
torch.Size([4, 12, 260])
torch.Size([4, 11, 260])
torch.Size([4, 12, 260])
torch.Size([4, 11, 260])
torch.Size([4, 12, 260])
torch.Size([4, 12, 260])
torch.Size([4, 23, 260])
torch.Size([4, 23, 260])
torch.Size([4, 23, 260])
torch.Size([4, 24, 260])
torch.Size([4, 23, 260])
torch.Size([4, 23, 260])
torch.Size([4, 23, 260])
torch.Size([4, 24, 260])
torch.Size([4, 46, 260])
torch.Size([4, 47, 260])
torch.Size([4, 46, 260])
torch.Size([4, 47, 260])
torch.Size([4, 46, 260])
torch.Size([4, 47, 260])
torch.Size([4, 46, 260])
torch.Size([4, 47, 260])
torch.Size([4, 92, 260])
torch.Size([4, 93, 260])
torch.Size

In [29]:
mask = torch.cat(outputs,dim=-2)

In [1]:
mask.shape

NameError: name 'mask' is not defined