In [26]:
# Using the spectrograms previously made, we are going to create a Autoencoder which will create a 
# prediction mask that can be used to calculate the spatial audio

In [27]:
# Load the data
import sys
sys.path.append('/workspace/fourth_year_project/Freqency Domain/')
from SpectroDataset import SpectroDataset


In [28]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [410]:
class AutoEncoder(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(AutoEncoder, self).__init__()
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True)
            )   

        self.encoder = nn.Sequential(
            conv_block(552, 128),
            nn.MaxPool2d(2,1),  # 276 x 50
            conv_block(128, 128),
            nn.MaxPool2d(2,1),  # 138 x 25
        )

        self.bottleneck = conv_block(128, 256)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=1, stride=1),  # 276 x 50
            conv_block(128, 128),
            nn.ConvTranspose2d(128, 256, kernel_size=1, stride=1),  # 552 x 100
            conv_block(256, 256),
            nn.ConvTranspose2d(256, 552, kernel_size=1, stride=1),  # 552 x 100
            #nn.ConstantPad2d((0, 1, 0, 0), 0),  
        )

        self.final_conv = nn.Conv2d(552, 552, kernel_size=1)

    '''
torch.Size([32, 552, 101, 2])
In:  torch.Size([32, 552, 101, 2])
En:  torch.Size([32, 128, 25, 2])
bt:  torch.Size([32, 256, 25, 2])
d1:  torch.Size([32, 128, 50, 2])
d2:  torch.Size([32, 128, 50, 2])
d3:  torch.Size([32, 256, 100, 2])
d4:  torch.Size([32, 256, 100, 2])
d5:  torch.Size([32, 552, 100, 2])
Out:  torch.Size([32, 552, 100, 2])


    '''
    def forward(self, x):
        # Pass the input through the encoder
        print("In: ", x.shape)
        x = self.encoder(x)
        print("En: ", x.shape)

        # Pass the result through the bottleneck
        x = self.bottleneck(x)
        print("bt: ", x.shape)

        # Pass the result through the decoder
        # Break down decoder into its 7 steps
        d1 = self.decoder[0](x)
        print("d1: ", d1.shape)
        d2 = self.decoder[1](d1)
        print("d2: ", d2.shape)
        d3 = self.decoder[2](d2)
        print("d3: ", d3.shape)
        d4 = self.decoder[3](d3)
        print("d4: ", d4.shape)
        d5 = self.decoder[4](d4)
        print("d5: ", d5.shape)

        # d6 = self.decoder[5](d5)
        #x = self.decoder(d5)
        # print("de: ", d6.shape)
        


        # Pass the result through the final convolution
        x = self.final_conv(d5)
        print("Out: ", x.shape)

        return x


    def train_loop(self, train_dataset, val_dataset, batch_size, epochs, lr):
        dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
        #optimizer = optim.Adam(self.parameters(), lr=lr)
        # Add momentum to the optimizer
        optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        criterion = nn.MSELoss()
        #print("Starting training loop")

        for epoch in range(epochs):
            #print("Epoch", epoch)
            for i, (targets, inputs, angle) in enumerate(dataloader):
                inputs, targets = inputs.cuda(), targets.cuda()

                optimizer.zero_grad()
                # Convert to real
                inputs = torch.view_as_real(inputs)
                print(inputs.shape)
                
                outputs = self(inputs)
                # Convert back to complex
                outputs = torch.view_as_complex(outputs)

                # The output from the model is a mask that represents the difference between the input and the target
                # We can use the difference to get both left and right channels
                print("Output shape: ", outputs.shape)
                print("Target shape: ", targets.shape)
                print("Input shape: ", inputs.shape)
                output_stft = inputs * outputs
                # Compare this to the target stft

                # We compare the istft to the target
                # XD --> output stft
                # targets --> stft
                loss = criterion(output_stft, targets)

                loss.backward()
                optimizer.step()

             # Validation
            self.eval()  # Set the model to evaluation mode
            with torch.no_grad():  # No need to track gradients in validation
                val_loss = 0
                for i, (targets, inputs, angle) in enumerate(val_loader):
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = self(inputs)
                    outputs = inputs * outputs
                    val_loss += criterion(outputs, targets).item()
                val_loss /= len(val_loader)  # Calculate average validation loss

            scheduler.step()
            if optimizer.param_groups[0]['lr'] < 0.0001:
                optimizer.param_groups[0]['lr'] = 0.0001
            print(f'Epoch: {epoch}, Training Loss: {loss.item()}, Validation Loss: {val_loss}, LR: {scheduler.get_last_lr()[0]}')
            

In [411]:
#dataset = SpectroDataset()

In [412]:
# import pickle
# with open('/workspace/extension/train_dataset.pkl', 'wb') as f:
#     pickle.dump(dataset, f)

In [413]:
# def my_get_item(self, index):
#     temp = self.data_map[index]
#     return temp["target_spec"], temp["orig_spec"], temp["label"]

In [414]:
# dataset.__class__.__getitem__ = my_get_item

In [415]:
# len(dataset)

In [416]:
# print(dataset[1][0].shape)
# print(dataset[3][1].shape)
# print(dataset[5][2])

In [417]:
# dataset[1][0]

In [418]:
# from torch.utils.data import random_split

# # Define the proportions for the split
# train_proportion = 0.8 
# val_proportion = 0.1
# test_proportion = 0.1 

# # Calculate the number of samples for train, validation and test
# train_size = int(train_proportion * len(dataset))
# val_size = int(val_proportion * len(dataset))
# test_size = len(dataset) - train_size - val_size

# train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


In [419]:
autoencoder = AutoEncoder().cuda()

In [420]:
def weights_init(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)

autoencoder.apply(weights_init)
''

''

In [421]:
autoencoder.train()
''

''

In [422]:
autoencoder.train_loop(train_dataset, val_dataset, batch_size=32, epochs=100, lr=0.01)

torch.Size([32, 552, 101, 2])
In:  torch.Size([32, 552, 101, 2])


RuntimeError: Given input size: (128x100x1). Calculated output size: (128x99x0). Output size is too small

In [400]:
# NOTE THAT THE MODEL GENERATES STFTs not the actual audio
# MAKE SURE THIS IS THE SAME AS ConvertData.ipynb
sr = 44100  # Sample rate in Hz (change this to match your audio's sample rate)
window_length_ms = 25  # Window length in ms
hop_length_ms = 10  # Hop length in ms

# Convert window length and hop length from ms to samples
window_length_samples = int(sr * window_length_ms / 1000)
hop_length_samples = int(sr * hop_length_ms / 1000)

n_fft = window_length_samples

# This is the stft, so we take the istft
# Use the same parameters as the stft
# return_complex since we want phase information
# Don't normalize since that is the amplitude information
XD = torch.istft(output_stft, n_fft=n_fft, hop_length=hop_length_samples, normalize=False, return_complex=True) # , normalized=True, return_complex=False

NameError: name 'output_stft' is not defined

In [404]:
dataset[1][0].shape

(552, 101)

In [408]:
dataset[1][0][:,0]

array([ 5.02988875e-01+0.00000000e+00j, -5.30593574e-01+1.64057776e-01j,
        3.47780168e-01-3.93227488e-01j, -8.54631141e-02+3.25779945e-01j,
        6.24152794e-02-1.84804529e-01j, -8.60450119e-02+1.62846655e-01j,
        8.65100101e-02-1.64212421e-01j, -6.56432435e-02+1.88583001e-01j,
        3.80171137e-03-1.80625960e-01j,  2.13879067e-02+1.24031000e-01j,
       -2.28486024e-03-9.55331177e-02j, -5.54099970e-04+9.37462971e-02j,
       -3.94957606e-03-8.19811001e-02j,  7.14307185e-04+7.21289963e-02j,
        1.07573939e-03-6.65651262e-02j, -3.90548399e-03+6.20640144e-02j,
        4.82445862e-03-6.31022155e-02j,  2.50251358e-03+6.14278987e-02j,
       -5.29625593e-03-4.95583564e-02j, -3.42971180e-03+4.36155759e-02j,
        6.41115708e-03-4.89205308e-02j, -2.44011080e-05+4.98691574e-02j,
       -4.19482961e-03-4.37827036e-02j,  3.49496654e-03+3.68963890e-02j,
        1.00282789e-03-3.28809544e-02j, -4.34094621e-03+3.43883708e-02j,
        2.03237310e-03-3.51826176e-02j, -1.54861400

In [409]:
dataset[1][0][0,:]

array([ 0.5029889 +0.j, -0.4718467 +0.j, -0.6029801 +0.j,  1.176311  +0.j,
        0.648138  +0.j, -0.9083143 +0.j,  0.37553746+0.j,  1.668564  +0.j,
       -1.2395344 +0.j, -0.63539296+0.j,  0.10222008+0.j, -0.5448244 +0.j,
       -1.7610011 +0.j, -0.8694553 +0.j,  0.963181  +0.j, -0.5163584 +0.j,
       -1.0425301 +0.j, -0.67282873+0.j,  0.17496312+0.j,  1.1337073 +0.j,
        1.526577  +0.j,  1.9941056 +0.j,  0.71489406+0.j, -0.5192214 +0.j,
        1.2931379 +0.j,  0.63454354+0.j, -1.012174  +0.j,  0.5408937 +0.j,
        0.03543916+0.j, -0.43480355+0.j,  1.174363  +0.j,  1.1227303 +0.j,
       -1.7879429 +0.j, -2.5298352 +0.j,  0.06283827+0.j, -0.16407955+0.j,
       -0.05243171+0.j,  0.17424308+0.j,  1.2525527 +0.j,  0.47287044+0.j,
       -0.77015984+0.j,  0.8128915 +0.j,  0.6842846 +0.j, -2.1751873 +0.j,
       -2.0081465 +0.j, -1.1924818 +0.j, -1.2335767 +0.j, -0.05667074+0.j,
       -1.0939306 +0.j, -1.7757728 +0.j, -0.46993595+0.j,  1.8803167 +0.j,
        1.3843821 +0.j, -