In [1]:
# Using the spectrograms previously made, we are going to create a Autoencoder which will create a 
# prediction mask that can be used to calculate the spatial audio

In [1]:
# Load the data
import sys
sys.path.append('/workspace/fourth_year_project/Freqency Domain/')
from SpectroDataset import SpectroDataset


In [2]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [13]:
class AutoEncoder(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(AutoEncoder, self).__init__()
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True)
            )   

        self.encoder = nn.Sequential(
            conv_block(1, 64),
            nn.MaxPool2d(2, 2),  # 276 x 50
            conv_block(64, 128),
            nn.MaxPool2d(2, 2),  # 138 x 25
        )

        self.bottleneck = conv_block(128, 256)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),  # 276 x 50
            conv_block(128, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),  # 552 x 101
            conv_block(64, 64),
        )

        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder[0](x)
        enc2 = self.encoder[2](self.encoder[1](enc1))
        bottleneck = self.bottleneck(self.encoder[3](enc2))
        dec1 = self.decoder[1](self.decoder[0](bottleneck))
        dec2 = self.decoder[3](self.decoder[2](dec1))
        return self.final_conv(dec2)


    def train_loop(self, train_dataset, val_dataset, batch_size, epochs, lr):
        dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
        #optimizer = optim.Adam(self.parameters(), lr=lr)
        # Add momentum to the optimizer
        optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        criterion = nn.MSELoss()
        #print("Starting training loop")

        for epoch in range(epochs):
            #print("Epoch", epoch)
            for i, (left_stft, targets, _, inputs, angle, sr) in enumerate(dataloader):
                inputs, targets = inputs.cuda(), targets.cuda()

                optimizer.zero_grad()
                outputs = self(inputs)

                # The output from the model is a mask that represents the difference between the input and the target
                # We can use the difference to get both left and right channels
                output_stft = inputs * outputs
                # Compare this to the target stft
                

                # This is the stft, so we take the istft
                XD = torch.istft(x_d, n_fft=1024, hop_length=256, normalized=True, return_complex=False)

                # 
                
                


                # Calculate the loss
                # We are comparing the stft

                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

             # Validation
            self.eval()  # Set the model to evaluation mode
            with torch.no_grad():  # No need to track gradients in validation
                val_loss = 0
                for i, (_, targets, _, inputs, angle, sr) in enumerate(val_loader):
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = self(inputs)
                    val_loss += criterion(outputs, targets).item()
                val_loss /= len(val_loader)  # Calculate average validation loss

            scheduler.step()
            if optimizer.param_groups[0]['lr'] < 0.0001:
                optimizer.param_groups[0]['lr'] = 0.0001
            print(f'Epoch: {epoch}, Training Loss: {loss.item()}, Validation Loss: {val_loss}, LR: {scheduler.get_last_lr()[0]}')
            

In [4]:
dataset = SpectroDataset()

In [5]:
len(dataset)

17100

In [9]:
print(dataset[1][0].shape)
print(dataset[3][1].shape)
print(dataset[5][2].shape)

(552, 101)
(552, 101)
(552, 101)


In [10]:
from torch.utils.data import random_split

# Define the proportions for the split
train_proportion = 0.8 
val_proportion = 0.1
test_proportion = 0.1 

# Calculate the number of samples for train, validation and test
train_size = int(train_proportion * len(dataset))
val_size = int(val_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


In [19]:
autoencoder = AutoEncoder().cuda()

In [20]:
def weights_init(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)

autoencoder.apply(weights_init)
''

''

In [22]:
autoencoder.train()
''

''

In [None]:
autoencoder.train_loop(train_dataset, val_dataset, batch_size=32, epochs=100, lr=0.001)