In [3]:
import torch
from torchvision.models import resnet50, resnet18

import torch.nn as nn

from torchinfo import summary
"""

SubBand : 6 - 15 Extract
Shape: 288 X Frames

Input Data Range : about -1 to 1
Input Data Shape : B x 320 x Frames

"""

# ResNet18: 11M Model
# ResNet50: 25M Model
model = resnet18()
print(model)

data = torch.rand(1,3,28,28)
# print(summary(model, input_data=data))

# class ResNet18Encoder(nn.Module):
#     def __init__(self):
#         super(ResNet18Encoder, self).__init__()
#         # Load a pre-trained ResNet18 model
#         resnet = resnet18(pretrained=True)
        
#         # Modify the first convolutional layer to accept 1-channel input
#         resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
#         # Remove the fully connected layer and pooling layer to keep only convolutional features
#         self.encoder = nn.Sequential(*list(resnet.children())[:-2])
        
#     def forward(self, x):
#         x = self.encoder(x)
#         return x
    


  from .autonotebook import tqdm as notebook_tqdm


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
from FeatureExtractor.dataset_FE import FeatureExtractorDataset
from torch.utils.data import DataLoader

path_dir_wb = ["/mnt/hdd/Dataset/FSD50K_48kHz", "/mnt/hdd/Dataset/MUSDB18_HQ_mono_48kHz"]
dataset = FeatureExtractorDataset(path_dir_wb, seg_len=4)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in dataloader:
    wave, spec, spec_m, spec_e, spec_me, name = batch
    print(len(batch))
    print(spec_me.shape)
    print(2048 / 48000 * 1000, "ms window")
    print(42 * 94)
    break


GT 51947 file numbers loaded!
51947 files loaded
6
torch.Size([1, 1, 320, 94])
42.666666666666664 ms window
3948


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

""" Causal Convolutions """

class CausalConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + 1 - self.stride[0]
 
    def forward(self, x):
        return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)

class CausalConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Calculate padding for temporal dimension (T)
        self.temporal_padding = self.dilation[1] * (self.kernel_size[1] - 1) - (self.stride[1] - 1)
        
        # Calculate total padding for frequency dimension (F)
        total_f_padding = self.dilation[0] * (self.kernel_size[0] - 1) - (self.stride[0] - 1)
        
        # Split total padding into top and bottom (asymmetrical padding if needed)
        self.frequency_padding_top = math.ceil(total_f_padding / 2)
        self.frequency_padding_bottom = math.floor(total_f_padding / 2)
        
    def forward(self, x):
        # Apply padding: F (top and bottom), T (only to the left)
        print(f"Temporal Padding (T): {self.temporal_padding}")
        print(f"Frequency Padding (F): top={self.frequency_padding_top}, bottom={self.frequency_padding_bottom}")
        x = F.pad(x, [self.temporal_padding, 0, self.frequency_padding_top, self.frequency_padding_bottom])
        return self._conv_forward(x, self.weight, self.bias)

if __name__ == "__main__":
    in_channels = 3  # Number of input channels
    out_channels = 64  # Number of output channels
    kernel_size = (7, 7)  # Kernel size for (F, T) dimensions
    stride = (2,1)  # Stride for convolution (F, T)
    dilation = (1,1)  # Dilation for convolution (F, T)

    causal_conv2d = CausalConv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation)

    # Example input tensor with shape: B x C x F x T
    input_tensor = torch.randn(8, in_channels, 320, 100) 
    output = causal_conv2d(input_tensor)
    
    # Output shape
    print("Output shape:", output.shape)


Temporal Padding (T): 6
Frequency Padding (F): top=3, bottom=2
Output shape: torch.Size([8, 64, 160, 100])


In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = weight_norm(CausalConv2d(in_channels, out_channels, kernel_size=3, stride=stride, bias=False))
        # self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 =  weight_norm(CausalConv2d(out_channels, out_channels, kernel_size=3, stride=1, bias=False))
        # self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.downsample = downsample  # Used for downsampling


    def forward(self, x):
        print("input shape", x.shape)
        identity = x

        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        print(out.shape)
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

model = BasicBlock(1,64, stride=1, downsample=None)
data = torch.rand(8,1,320,80)
# print(summary(model, input_data = data))
print(model(data).shape)
# 8 x 1 x 320 x 80 -> 8 x 64 x 320 x 80


a = torch.rand(1,1,320,80)
b = torch.rand(1,64,320,80)
c = a+b
print(c.shape)

input shape torch.Size([8, 1, 320, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
torch.Size([8, 64, 320, 80])
torch.Size([8, 64, 320, 80])
torch.Size([1, 64, 320, 80])


In [64]:
class ResNet(nn.Module):
    def __init__(self, block, layers):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = weight_norm(CausalConv2d(1, 64, kernel_size=(7,7), stride=(2,1), bias=False))
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2,1), padding=1)

        """ Below Must be Modified """
        # ResNet layers
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=(2,1))
        self.layer3 = self._make_layer(block, 256, layers[2], stride=(2,1))
        self.layer4 = self._make_layer(block, 512, layers[3], stride=(2,1))

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 : # Downsampling layer needs channel modification
            downsample = CausalConv2d(self.in_channels, out_channels,
                             kernel_size=1, stride=stride, bias=False)
                        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x
    
def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

# Example usage
if __name__ == "__main__":
    model = ResNet18()  # 10 class
    input_tensor = torch.randn(8, 1, 320, 80)  
    output = model(input_tensor)
    # print(output.shape)  
    print(summary(model, input_data=input_tensor))


Temporal Padding (T): 6
Frequency Padding (F): top=3, bottom=2
input shape torch.Size([8, 64, 80, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
torch.Size([8, 64, 80, 80])
input shape torch.Size([8, 64, 80, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
torch.Size([8, 64, 80, 80])
input shape torch.Size([8, 64, 80, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=0
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
torch.Size([8, 128, 40, 80])
Temporal Padding (T): 0
Frequency Padding (F): top=0, bottom=-1
input shape torch.Size([8, 128, 40, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
Temporal Padding (T): 2
Frequency Padding (F): top=1, bottom=1
torch.Size([8, 128, 40, 80])
input shape torch.Size([8, 128, 40, 80])
Temporal Padding (T): 2
Frequency Padding (F): top=1,

In [35]:
# from torchutils
model = resnet18()
print(model.layer2)

data = torch.rand(8,64,228,228)
print(model.layer2(data).shape)

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-