In [1]:
import os
import random as rnd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchaudio
from pytorch_model_summary import summary
import torch.nn as nn
from torch.nn.functional import normalize
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchaudio.prototype.models
import torchaudio.prototype.pipelines

In [2]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NVIDIA GeForce RTX 3060


In [3]:
class SpikerboxRecordings(Dataset):

    def __init__(self, annotations_file, audio_dir, transformation, target_sample_rate, num_samples, device, already_dataframe):
        if already_dataframe:
            self.annotations = annotations_file
        else:
            self.annotations = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.device = device
        self.transformation = transformation.to(self.device)
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
    
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        label = self._get_audio_sample_label(index)
        signal, sample_rate = torchaudio.load(audio_sample_path)
        signal = signal.to(self.device)
        signal = self._cut_if_necessary(signal)
        signal = self._right_pad_if_necessary(signal)
        signal = self.transformation(signal)
        signal = self._make_log_mels(signal)
        signal = self._adjust_mel_width_if_necessary(signal, 96)
        return signal, label
    
    def _get_audio_sample_path(self, index):
        path = os.path.join(self.audio_dir, self.annotations.iloc[index, 0])
        return path
    def _resample_if_necessary(self, signal, sample_rate):
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate).to(self.device)
            signal = resampler(signal)
        return signal
    def _get_audio_sample_label(self, index):
        return self.annotations.iloc[index, 3]
    def _cut_if_necessary(self, signal):
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        return signal
    def _right_pad_if_necessary(self, signal):
        length_signal = signal.shape[1]
        if length_signal < self.num_samples:
            num_missing_samples = self.num_samples - length_signal
            last_dim_padding = (0, num_missing_samples)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
        return signal
    def _make_log_mels(self, signal):
        return torchaudio.transforms.AmplitudeToDB().to(self.device)(signal)
    def _adjust_mel_width_if_necessary(self, log_mel_spectrogram, width):
        if log_mel_spectrogram.shape[-1] < width:
            pad_width = width - log_mel_spectrogram.shape[-1]
            log_mel_spectrogram = torch.nn.functional.pad(log_mel_spectrogram, (0, pad_width))
        elif log_mel_spectrogram.shape[-1] > width:
            log_mel_spectrogram = log_mel_spectrogram[:, :, :width]
        return log_mel_spectrogram

In [4]:
input_dir = 'test_data/files'
df_metadata_test = pd.DataFrame(pd.read_csv('test_data/metadata/file_labels.csv'))
exam_names = np.unique([element.split('_')[0] for element in df_metadata_test.filename.values])

In [5]:
exams = list()

for exam_name in exam_names:
    ANNOTATIONS_FILE = df_metadata_test[df_metadata_test.filename.str.startswith(exam_name)]
    AUDIO_DIR = input_dir
    SAMPLE_RATE = 10000
    NUM_SAMPLES = 9600

    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate = SAMPLE_RATE,
        n_fft = 400,
        hop_length = 160,
        n_mels = 64
    )

    spr = SpikerboxRecordings(
        ANNOTATIONS_FILE,
        AUDIO_DIR,
        mel_spectrogram,
        SAMPLE_RATE,
        NUM_SAMPLES,
        device,
        already_dataframe = True
    )
    exams.append(spr)

In [6]:
ANNOTATIONS_FILE = 'test_data/metadata/file_labels.csv'
AUDIO_DIR = input_dir
SAMPLE_RATE = 10000
NUM_SAMPLES = 9600

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate = SAMPLE_RATE,
    n_fft = 400,
    hop_length = 160,
    n_mels = 64
)
spr = SpikerboxRecordings(
    ANNOTATIONS_FILE,
    AUDIO_DIR,
    mel_spectrogram,
    SAMPLE_RATE,
    NUM_SAMPLES,
    device,
    already_dataframe = False
)

In [7]:
class VGGishNetwork(nn.Module):
    def __init__(self):
        super(VGGishNetwork, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.embeddings = nn.Sequential(
            nn.Linear(12288, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.embeddings(x)
        return x

In [8]:
def evaluate_exam(model, test_dl, device):
        model.eval()
        with torch.no_grad():
            test_inputs, test_targets = next(iter(test_dl))
            test_inputs = test_inputs.to(device)
            test_outputs = model(test_inputs)
            _, predicted_val = torch.max(test_outputs, 1)
            return round(torch.mean(predicted_val.float()).item())

In [9]:
def scaled_accuracy(output, target, max_distance, device):
    output, target = output.to(device), target.to(device)
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        target = target.view(-1)
        distance = torch.abs(pred - target)
        scaled_acc = torch.clamp(1 - (distance.float() / max_distance), min=0.0)
        return scaled_acc.mean().item()

In [10]:
def test_model(model, test_dl, loss_func, device):
    model.eval()
    running_test_loss = 0.0
    running_scaled_accuracy_test = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for test_inputs, test_targets in test_dl:
            test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
            test_outputs = model(test_inputs)
            running_test_loss += loss_func(test_outputs, test_targets).item()
            running_scaled_accuracy_test += scaled_accuracy(test_outputs, test_targets, 9, device)
            _, predicted_test = torch.max(test_outputs, 1)
            total_test += test_targets.size(0)
            correct_test += (predicted_test == test_targets).sum().item()
        
    test_loss = running_test_loss / len(test_dl)
    test_scaled_accuracy = (running_scaled_accuracy_test / len(test_dl)) * 100
    test_accuracy = correct_test / total_test
    print(f'[Test Results]\nTest Loss: {test_loss:.7f}, Test Accuracy: {100 * test_accuracy:.2f}%, Scaled Test Accuracy: {test_scaled_accuracy:.2f}%')

In [11]:
VGGish_Stress = VGGishNetwork().to(device)
VGGish_Stress.load_state_dict(torch.load('trained_models/VGGish_Stress.pth', weights_only = False))

<All keys matched successfully>

In [12]:
generator = torch.Generator()
generator.manual_seed(569567390)
test_dl = DataLoader(spr, batch_size = 128, shuffle = True, generator = generator)
loss_fn = nn.CrossEntropyLoss().to(device)
test_model(VGGish_Stress, test_dl, loss_fn, device)

[Test Results]
Test Loss: 1.4868198, Test Accuracy: 48.30%, Scaled Test Accuracy: 80.18%


In [13]:
for exam in exams:
    test_dl = DataLoader(exam, batch_size = len(exam), shuffle = True)
    prediction = evaluate_exam(VGGish_Stress, test_dl, device)
    print("Predicted Stress:", prediction)
    print("Actual Stress:", exam[0][1])

Predicted Stress: 6
Actual Stress: 3
Predicted Stress: 5
Actual Stress: 5
Predicted Stress: 6
Actual Stress: 4
