In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
from os import path, walk
import torch.nn as nn
from IPython.display import Audio, display
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.tensorboard import SummaryWriter

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/My Drive/all_samples_500ms.zip" -d "/content"
!unzip "/content/drive/My Drive/BRIRs_downsampled.zip" -d "/content"

In [None]:
def data_loader(random_sampling_file):
    data = list()
    f = open(random_sampling_file, "r")
    for line in f:
        file, brir_name = line.split(" ")
        data.append((brir_name.removesuffix("\n"), file))
    f.close()
    return data

In [None]:
class RNN(nn.Module):
    def __init__(self, n_observations, n_angles):
        super(RNN, self).__init__()
        self.n_hidden = 1000
        self.gru1 = nn.GRU(n_observations, 256, 1, batch_first=True, bidirectional=False)
        self.gru2 = nn.GRU(256, 128, 1, batch_first=True, bidirectional=False)
        self.gru3 = nn.GRU(128, 64, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(64*2, n_angles)
        self.dropout20 = nn.Dropout(p=0.2)
        self.dropout50 = nn.Dropout(p=0.5)

    def forward(self, x):
        x, _ = self.gru1(x)
        x = self.dropout20(x)
        x, _ = self.gru2(x)
        x = self.dropout20(x)
        x, _ = self.gru3(x)
        x = self.dropout50(x)
        x = torch.cat((x[:,0,:],x[:,1,:]), dim=1)
        x = self.fc(x)
        return x


In [None]:
data = data_loader("/content/drive/MyDrive/random_sampling_file_brir.txt")
train_data = data[:320000]
val_data = data[320000:384000]
test_data = data[384000:]
del data

In [None]:
fs = 8000
sample_length_secs = 0.5
n_observations = int(fs*sample_length_secs)
n_angles = 13*5
rnn = RNN(n_observations, n_angles).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(rnn.parameters(), lr=0.001)
az_angles = ["270", "285", "300", "315", "330", "345", "000", "015", "030", "045", "060", "075", "090"]
el_angles = ["-45", "-20", "000", "020", "045"]

In [None]:
def convolve_sound(brir, sample):
    """
    Convolves the BRIR with the window, and cuts result at the size of the window length.
    window: the window of the sample that should be convolved.

    """
    return torchaudio.functional.convolve(sample.repeat([2,1]).to(device), brir.to(device))[:,:sample.shape[1]]

def get_label(brir_name):
    az_angles = ["270", "285", "300", "315", "330", "345", "000", "015", "030", "045", "060", "075", "090"]
    el_angles = ["-45", "-20", "000", "020", "045"]
    az = brir_name.split("_")[12]
    el = brir_name.split("_")[14]
    return az_angles.index(az)*len(el_angles) + el_angles.index(el)

## Train:

In [None]:
batch_size = 256
writer = SummaryWriter()
total_test_loss = 100000
training_losses = list()
test_losses = list()
test_accuracies = list()

for epoch in range(50):  # maximum amount of iterations
    rnn.train()
    training_loss = 0.0
    for idx in tqdm(range(0, len(train_data), batch_size)):
        batch = train_data[idx:idx+batch_size]
        spatial_audios = torch.zeros(batch_size, 2, n_observations, dtype=torch.float32)
        labels = torch.zeros(batch_size, dtype=torch.long)

        for j, sample in enumerate(batch):
            brir_name, filename = sample                  # Split sample into label and filename
            brir, _ = torchaudio.load("/content/BRIRs_downsampled/"+brir_name, format="wav")
            audio, _ = torchaudio.load("/content/all_samples_500ms/"+filename, format='flac')  # Open audio sample
            audio = audio[:,::2].to(device)
            spatial_audio = convolve_sound(brir[:,::2], audio).to(device)  # Convolve HRTF and audio sample
            spatial_audios[j] = spatial_audio#.flatten()
            labels[j] = get_label(brir_name)

        spatial_audios = spatial_audios.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = rnn(spatial_audios)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
        training_loss += loss.item()

    total_training_loss = training_loss/(len(train_data)//batch_size)
    print("Epoch {}:\n\ttraining loss: {:.3f}".format(epoch, total_training_loss))
    training_losses.append(total_training_loss)
    writer.add_scalar('Loss/train', total_training_loss, epoch)

    rnn.eval()
    correct = 0
    total = 0
    true_labels = torch.zeros(64000)
    predicted_labels = torch.zeros(64000)

    with torch.no_grad():
        test_loss = 0.0
        for i, sample in tqdm(enumerate(val_data)):
            brir_name, filename = sample                  # Split sample into label and filename
            brir, _ = torchaudio.load("BRIRs_downsampled/"+brir_name, format="wav")
            audio, _ = torchaudio.load("all_samples_500ms/"+filename, format='flac')  # Open audio sample
            audio = audio[:,::2].to(device)
            spatial_audio = convolve_sound(brir[:,::2], audio).to(device)  # Convolve HRTF and audio sample
            # print(spatial_audio.shape)
            output = rnn(spatial_audio.unsqueeze(0))
            label = get_label(brir_name)
            loss = criterion(output, torch.tensor([label], device=device))
            test_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            # print(predicted.shape)
            total += 1
            correct += int(int(predicted) == label)
            true_labels[i] = label
            predicted_labels[i] = predicted

    test_accuracy = 100 * correct/total
    print("\ttest accuracy: {:,.2f}".format(test_accuracy))
    writer.add_scalar('Accuracy/test', test_accuracy, epoch)
    test_accuracies.append(test_accuracy)

    new_total_test_loss = test_loss/len(val_data)
    print("\ttest loss: {:,.3f}\n".format(new_total_test_loss))
    if new_total_test_loss > total_test_loss:
        test_losses.append(new_total_test_loss)
        break
    else:
        torch.save(rnn.state_dict(), '/content/drive/My Drive/BRIR_weights.pth')
    writer.add_scalar('Loss/test', new_total_test_loss, epoch)
    test_losses.append(new_total_test_loss)
    total_test_loss = new_total_test_loss


## Test once more on the best variables:

In [None]:
rnn.load_state_dict(torch.load('/content/drive/My Drive/BRIR_weights.pth'))

In [None]:
rnn.eval()
correct = 0
total = 0
true_labels = torch.zeros(624000)
true_az = torch.zeros(624000)
true_el = torch.zeros(624000)
predicted_labels = torch.zeros(624000)
predicted_az = torch.zeros(624000)
predicted_el = torch.zeros(624000)

with torch.no_grad():
    for i, sample in tqdm(enumerate(test_data)):
        brir_name, filename = sample                  # Split sample into label and filename
        brir, _ = torchaudio.load("BRIRs_16000Hz/"+brir_name, format="wav")
        audio, _ = torchaudio.load("samples_500ms/"+filename, format='flac')  # Open audio sample
        audio = audio[:,::2].to(device)
        spatial_audio = convolve_sound(brir[:,::2], audio).to(device)  # Convolve HRTF and audio sample
        # print(spatial_audio.shape)
        output = rnn(spatial_audio.unsqueeze(0))
        label = get_label(brir_name)
        az = brir_name.split("_")[12]
        el = brir_name.split("_")[14]
        az_label = az_angles.index(az)
        el_label = el_angles.index(el)
        _, predicted = torch.max(output.data, 1)
        # print(predicted.shape)
        total += 1
        correct += int(int(predicted) == label)
        true_labels[i] = label
        predicted_labels[i] = predicted
        true_az[i] = az_label
        predicted_az[i] = predicted // len(el_angles)
        true_el[i] = el_label
        predicted_el[i] = predicted % len(el_angles)

test_accuracy = 100* correct/total
print(test_accuracy)

## Statistics:

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(true_labels, predicted_labels)
cm = 100 * cm/cm.sum(axis=1)
fig, ax = plt.subplots(figsize=(50,50))
labels = ["(" + az + ",\n" + el + ")" for el in el_angles for az in az_angles]
matrix = ConfusionMatrixDisplay(cm, display_labels=labels)
matrix.plot(ax=ax)
plt.title("Confusion matrix for reverberant data")
plt.savefig("azelbrir.svg", format='svg')
plt.show()

In [None]:
cmaz = confusion_matrix(true_az, predicted_az)
cmaz = 100 * cmaz/cmaz.sum(axis=1)
fig, ax = plt.subplots(figsize=(10,10))
matrix = ConfusionMatrixDisplay(cmaz, display_labels=az_angles)
matrix.plot(ax=ax)
plt.title("Confusion matrix for reverberant data for azimuth angles combined")
plt.show()
plt.savefig("azbrir.svg")

In [None]:
cmel = confusion_matrix(true_el, predicted_el)
cmel = 100 * cmel/cmel.sum(axis=1)
fig, ax = plt.subplots(figsize=(5,5))
matrix = ConfusionMatrixDisplay(cmel, display_labels=el_angles)
matrix.plot(ax=ax)
plt.title("Confusion matrix for reverberant data for\nelevation angles combined")
plt.show()
plt.savefig("elbrir.svg")