In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

num_frames = 16384
num_channels = 1
num_layers = 12
num_initial_filters = 24
filter_size = 15
merge_filter_size = 5
input_filter_size = 15
output_filter_size = 1
padding = 'same'  
upsampling = 'linear'  # or 'learned'
output_type = 'direct'  # or 'difference'
source_names = ["accompaniment", "vocals"]
activation = 'tanh'
training = True


In [3]:

def crop(tensor, target_shape):
    """
    Crop the tensor to match the target shape. 
    It is assumed that the tensor shape and target shape only differ in the time dimension.
    """
    shape = tensor.shape
    target_length = target_shape[2]
    tensor_length = shape[2]
    
    start = (tensor_length - target_length) // 2
    end = start + target_length

    return tensor[:, :, start:end]


def AudioClip(x, training):
    if training:
        return x
    else:
        return torch.clamp(x, -1.0, 1.0)

def difference_output(input_mix, featuremap, source_names, num_channels, filter_width, padding, activation, training):
    outputs = {}
    sum_source = 0
    
    # Convert padding type from 'same' or 'valid' to PyTorch's 'same' or 'valid' equivalent
    padding_val = (filter_width - 1) // 2 if padding == 'same' else 0
    
    for name in source_names[:-1]:
        conv_layer = nn.Conv1d(featuremap.shape[1], num_channels, filter_width, padding=padding_val)
        out = conv_layer(featuremap)
        
        if activation == 'tanh':
            out = torch.tanh(out)
        # Add other activation checks here if needed
        
        outputs[name] = out
        sum_source += out

    # Use the crop function to ensure the shapes match
    last_source = crop(input_mix, sum_source.shape) - sum_source
    last_source = AudioClip(last_source, training)
    outputs[source_names[-1]] = last_source
    
    return outputs


In [4]:

class LearnedInterpolationLayer(nn.Module):
    def __init__(self, features, padding, level):
        super(LearnedInterpolationLayer, self).__init__()
        self.weights = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(features)))
        self.padding = padding
        self.level = level

    def forward(self, x):
        weights_scaled = torch.sigmoid(self.weights)
        counter_weights = 1.0 - weights_scaled

        # Create diagonal matrices for weights and counter-weights
        conv_weights = torch.diag(weights_scaled).unsqueeze(0)
        counter_conv_weights = torch.diag(counter_weights).unsqueeze(0)

        # Perform matrix multiplication
        intermediate_vals = torch.bmm(x, conv_weights)
        counter_intermediate_vals = torch.bmm(x, counter_conv_weights)

        # Concatenate along the channel dimension
        output = torch.cat([intermediate_vals, counter_intermediate_vals], dim=1)

        if self.padding == "valid":
            output = output[:, :, :-1]
        return output


In [16]:
class Encoder(nn.Module):
    def __init__(self, num_initial_filters, num_layers, filter_size, input_filter_size, padding, dropout_rate=0.3):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        # Input convolution layer
        self.layers.append(nn.Conv1d(1, num_initial_filters, input_filter_size, stride=1, padding=self.get_padding(padding, input_filter_size)))
        self.layers.append(nn.LeakyReLU(0.01))
        self.layers.append(nn.Dropout(dropout_rate))
      
        # Loop for adding convolution layers
        for i in range(num_layers - 1):
            in_channels = num_initial_filters * (i + 1)
            out_channels = in_channels + num_initial_filters
            self.layers.append(nn.Conv1d(in_channels, out_channels, filter_size, stride=1, padding=self.get_padding(padding, filter_size)))
            self.layers.append(nn.LeakyReLU(0.01))
            self.layers.append(nn.Dropout(dropout_rate))

    def forward(self, x):
        enc_outputs = []
        for i in range(0, 3 * self.num_layers, 3):
            x = self.layers[i](x)
            x = self.layers[i + 1](x)
            x = self.layers[i + 2](x)
            if i < 3 * (self.num_layers - 1):
                x = x[:, :, ::2]  # Decimate by factor of 2
            enc_outputs.append(x)
        return enc_outputs
    
    def get_padding(self, padding_type, kernel_size):
        return (kernel_size - 1) // 2 if padding_type == 'same' else 0

# Instantiate and test
encoder = Encoder(num_initial_filters, num_layers, filter_size, input_filter_size, padding)
dummy_input = torch.randn(1, 1, 16384)  # Batch size of 1, 1 channel, 16384 frames
enc_out = encoder(dummy_input)
for out in enc_out:
    print(out.shape)


torch.Size([1, 24, 8192])
torch.Size([1, 48, 4096])
torch.Size([1, 72, 2048])
torch.Size([1, 96, 1024])
torch.Size([1, 120, 512])
torch.Size([1, 144, 256])
torch.Size([1, 168, 128])
torch.Size([1, 192, 64])
torch.Size([1, 216, 32])
torch.Size([1, 240, 16])
torch.Size([1, 264, 8])
torch.Size([1, 288, 8])


In [None]:


class UNet(nn.Module):
    def __init__(self, num_frames, num_channels, num_layers, num_initial_filters, filter_size, merge_filter_size, 
                 input_filter_size, output_filter_size, padding, upsampling, output_type, source_names, activation, training):
        super(UNet, self).__init__()

        self.num_layers = num_layers
        self.upsampling = upsampling
        self.output_type = output_type
        self.source_names = source_names
        self.training = training
        
        # Define the encoder layers
        self.encoders = nn.ModuleList()
        self.encoders.append(nn.Conv1d(num_channels, num_initial_filters, input_filter_size, stride=1, padding=self.get_padding(padding, input_filter_size)))
        self.encoders.append(nn.Dropout(0.3))
        for i in range(num_layers - 1):
            self.encoders.append(nn.Conv1d(num_initial_filters + (num_initial_filters * i), 
                                           num_initial_filters + (num_initial_filters * (i + 1)), 
                                           filter_size, stride=1, padding=self.get_padding(padding, filter_size)))
            self.encoders.append(nn.Dropout(0.3))
        
        # Define the decoder layers
        self.decoders = nn.ModuleList()
        for i in range(num_layers - 1):
            self.decoders.append(nn.Conv1d(num_initial_filters * (num_layers - i + 1), num_initial_filters * (num_layers - i), merge_filter_size, stride=1, padding=self.get_padding(padding, merge_filter_size)))
            self.decoders.append(nn.Dropout(0.3))
        
        # Define the final output layer
        self.outputs = nn.ModuleList()
        for _ in source_names:
            self.outputs.append(nn.Conv1d(num_initial_filters, num_channels, output_filter_size, stride=1, padding=self.get_padding(padding, output_filter_size)))

    def forward(self, x):
        # Encoder
        enc_outputs = []
        for i in range(0, 2 * self.num_layers, 2):
            x = self.encoders[i](x)
            x = F.leaky_relu(x, 0.01)
            x = self.encoders[i + 1](x)
            enc_outputs.append(x)
            if i < 2 * (self.num_layers - 1):
                x = x[:, :, ::2]  # Decimate by factor of 2
        
        # Decoder
        for i in range(0, 2 * (self.num_layers - 1), 2):
            if self.upsampling == 'linear':
                x = F.interpolate(x, scale_factor=2)
            elif self.upsampling == 'learned':
                # Implement learned interpolation if needed
                pass
            x = torch.cat([x, enc_outputs[-(i//2 + 1)]], dim=1)
            x = self.decoders[i](x)
            x = F.leaky_relu(x, 0.01)
            x = self.decoders[i + 1](x)
        
        # Output Layer
        outputs = {}
        for i, name in enumerate(self.source_names):
            outputs[name] = torch.tanh(self.outputs[i](x))
        
        return outputs
    
    def get_padding(self, padding_type, kernel_size):
        return (kernel_size - 1) // 2 if padding_type == 'same' else 0


In [None]:

num_frames = 16384
num_channels = 1
num_layers = 12
num_initial_filters = 24
filter_size = 15

merge_filter_size = 5
input_filter_size = 15
output_filter_size = 1
padding = 'same'  
upsampling = 'linear'  # or 'learned'
output_type = 'direct'  # or 'difference'
source_names = ["accompaniment", "vocals"]
activation = 'tanh'
training = True

model = UNet(num_frames, num_channels, num_layers, num_initial_filters, filter_size, merge_filter_size, 
             input_filter_size, output_filter_size, padding, upsampling, output_type, source_names, activation, training)

# If needed, define the loss and optimizer for training
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
class DebugUNet(UNet):
    def forward(self, x):
        print("Input shape:", x.shape)
        
        # Encoder
        enc_outputs = []
        for i in range(0, 2 * self.num_layers, 2):
            x = self.encoders[i](x)
            print(f"After encoder Conv[{i//2}]:", x.shape)
            x = F.leaky_relu(x, 0.01)
            x = self.encoders[i + 1](x)
            enc_outputs.append(x)
            if i < 2 * (self.num_layers - 1):
                x = x[:, :, ::2]  # Decimate by factor of 2
                print(f"After decimation[{i//2}]:", x.shape)
        
        # Decoder
        for i in range(0, 2 * (self.num_layers - 1), 2):
            if self.upsampling == 'linear':
                x = F.interpolate(x, scale_factor=2)
                print(f"After upsampling[{i//2}]:", x.shape)
            elif self.upsampling == 'learned':
                # Implement learned interpolation if needed
                pass
            
            cropped_enc_output = crop(enc_outputs[-(i//2 + 1)], x.shape)
            x = torch.cat([x, cropped_enc_output], dim=1)
            print(f"After concatenation[{i//2}]:", x.shape)
            x = self.decoders[i](x)
            print(f"After decoder Conv[{i//2}]:", x.shape)
            x = F.leaky_relu(x, 0.01)
            x = self.decoders[i + 1](x)
        
        # Output Layer
        outputs = {}
        for i, name in enumerate(self.source_names):
            outputs[name] = torch.tanh(self.outputs[i](x))
            print(f"Output shape of {name}:", outputs[name].shape)
        
        return outputs
    

debug_model = DebugUNet(num_frames, num_channels, num_layers, num_initial_filters, filter_size, merge_filter_size, 
                        input_filter_size, output_filter_size, padding, upsampling, output_type, source_names, activation, training)

dummy_input = torch.randn(1, num_channels, num_frames)  # Batch size of 1
output = debug_model(dummy_input)


Input shape: torch.Size([1, 1, 16384])
After encoder Conv[0]: torch.Size([1, 24, 16384])
After decimation[0]: torch.Size([1, 24, 8192])
After encoder Conv[1]: torch.Size([1, 48, 8192])
After decimation[1]: torch.Size([1, 48, 4096])
After encoder Conv[2]: torch.Size([1, 72, 4096])
After decimation[2]: torch.Size([1, 72, 2048])
After encoder Conv[3]: torch.Size([1, 96, 2048])
After decimation[3]: torch.Size([1, 96, 1024])
After encoder Conv[4]: torch.Size([1, 120, 1024])
After decimation[4]: torch.Size([1, 120, 512])
After encoder Conv[5]: torch.Size([1, 144, 512])
After decimation[5]: torch.Size([1, 144, 256])
After encoder Conv[6]: torch.Size([1, 168, 256])
After decimation[6]: torch.Size([1, 168, 128])
After encoder Conv[7]: torch.Size([1, 192, 128])
After decimation[7]: torch.Size([1, 192, 64])
After encoder Conv[8]: torch.Size([1, 216, 64])
After decimation[8]: torch.Size([1, 216, 32])
After encoder Conv[9]: torch.Size([1, 240, 32])
After decimation[9]: torch.Size([1, 240, 16])
Afte

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 4 for tensor number 1 in the list.

In [None]:
from torchsummary import summary

# Assuming your model is designed for a specific input size
input_size = (num_channels, num_frames)
summary(model, input_size)


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 8 for tensor number 1 in the list.

In [None]:
#!pip install torchsummary


Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
