# Assignment for Probabilistic Machine Learning

Tom OVISTE

___

## Description

In order to apply the knowledge and skills acquired in the Probabilistic Machine Learning course, we will be building a bayesian (probabilistic) neural network for a classification problem.  


Moreover, our field of research is speech enhancement, i.e., the task of removing background noise, reverberation and distortion from a target speech signal. Therefore, we will be focusing here on a related, albeit simpler, problem: the task of classification of different noise environments.


For this task, we will be using audio signals from the [DEMAND dataset](https://www.kaggle.com/datasets/chrisfilo/demand), which contains several recordings taken in different noisy environments (e.g., a restaurant, a laundromat).


In this Jupyter notebook, we apply some pre-processing to the data, then build a classification model using the [Pyro framework](https://pyro.ai/).

A GitHub repository for this assignment is available [here](https://github.com/ovistetom/probabilistic-machine-learning-assignment).

___

## Import Dependencies

Remember to install the required libraries using the `requirements.sh` script.

In [None]:
import torch
import torchaudio
import librosa
import os
import shutil
import random
import kagglehub
import pyro
import matplotlib.pyplot as plt

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

___

## Download and Pre-Process Database (DEMAND)

We will use the [DEMAND dataset](https://www.kaggle.com/datasets/chrisfilo/demand) for this assignment.  
Some pre-processing is required before we can use the dataset for training, such as splitting the audio files into chunks of a few seconds.

In [None]:
root = r".\data\demand\versions\1"
root_preprocessed = r".\data\demand\versions\preprocessed"
os.makedirs(root, exist_ok=True)
os.makedirs(root_preprocessed, exist_ok=True)

In [None]:
# TODO: Change this to download dataset.
if False: 
    path = kagglehub.dataset_download("chrisfilo/demand")
    print(f"Initial path to dataset: {path}")
    shutil.move(src=path, dst=root)
print(f"Path to raw dataset: {root}")    

In [None]:
# TODO: Change this to clean dataset (remove 48k files and move 16k files to root).
if False:
    for dir_name in os.listdir(root):
        dir_path = os.path.join(root, dir_name)
        if dir_name.endswith('48k'):
            shutil.rmtree(dir_path)
        elif dir_name.endswith('16k'):
            dir_path_full = os.path.join(dir_path, dir_name[:-4])
            shutil.move(src=dir_path_full, dst=root)
            shutil.rmtree(dir_path)
print(f"Path to clean dataset: {root}")

In [None]:
# TODO: Change this to preprocess dataset.
if False:
    for dir_name in os.listdir(root):
        c = 0
        dir_path = os.path.join(root, dir_name)
        if os.path.isdir(dir_path):
            for file_name in os.listdir(dir_path):
                # Open audio file.
                file_path = os.path.join(dir_path, file_name)
                waveform, sr = torchaudio.load(file_path, channels_first=True)
                # Split waveform into chunks of length 5 seconds.
                target_len = 5 * sr
                fade_in_len = 512
                fade_out_len = 512
                fader = torchaudio.transforms.Fade(fade_in_len, fade_out_len)               
                waveform_segments = torch.split(waveform, target_len, dim=-1)
                waveform_segments = [s for s in waveform_segments if s.size(1)==target_len]
                for segment in waveform_segments:
                    # Apply fading and normalize.
                    segment = fader(segment)
                    segment = segment / segment.abs().max()
                    # Save short audio segment.
                    segment_name = f'{dir_name}_{c:0>4}.flac'
                    segment_path = os.path.join(root_preprocessed, segment_name)
                    torchaudio.save(uri=segment_path, src=segment, sample_rate=sr, channels_first=True)
                    c += 1

In [None]:
# TODO: Change this to split pre-processed dataset into Train, Validation and Test subsets.
if False:
    list_files = [f for f in os.listdir(root_preprocessed) if f.endswith('.flac')]
    random.shuffle(list_files)
    num_files = len(list_files)
    trn_size = round(0.8*num_files)
    tst_size = round(0.1*num_files)
    val_size = round(0.1*num_files)
    trn_files = list_files[:trn_size]
    tst_files = list_files[trn_size:trn_size+tst_size]
    val_files = list_files[-val_size:]
    for subset, subset_name in [(trn_files, 'trn'), (tst_files, 'tst'), (val_files, 'val')]:
        subset_path = os.path.join(root_preprocessed, subset_name)
        os.makedirs(subset_path, exist_ok=True)
        for file_name in subset:
            file_path = os.path.join(root_preprocessed, file_name)
            shutil.move(src=file_path, dst=subset_path)
print(f"Path to preprocessed dataset: {root_preprocessed}")            

In [None]:
NOISE_CLASSES = {
    'DKITCHEN': 0,
    'DLIVING': 1,
    'DWASHING': 2,
    'NFIELD': 3,
    'NPARK': 4,
    'NRIVER': 5,
    'OHALLWAY': 6,
    'OMEETING': 7,
    'OOFFICE': 8,
    'PCAFETER': 9,
    'PRESTO': 10,
    'PSTATION': 11,
    'SPSQUARE': 12,
    'STRAFFIC': 13,
    'TBUS': 14,
    'TCAR': 15,
    'TMETRO': 16,    
}

___

## Define DataLoader

Here, we define a PyTorch `DemandDataset` class which will allow us to parse through the (pre-processed) database more naturally.

In [None]:
class DemandDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 root: str, 
                 subset: str, 
                 num_samples: int | None = None,
    ):
        """Dataset class to handle the DEMAND database.

        Args:
            root (str): Path to the root directory of the dataset.
            subset (str): Name of the subset to load.
            num_samples (int, optional): Number of samples to load. Defaults to `None` (all samples). 
        """
        super().__init__()
        self.subset_name = subset
        self.subset_path = os.path.join(root, subset)
        self.num_samples = num_samples
        self.sample_names = self._collect_samples()
    
    def _collect_samples(self):
        list_samples = [f for f in os.listdir(self.subset_path) if f.endswith('.flac')]
        random.shuffle(list_samples)
        return sorted(list_samples[:self.num_samples])

    def _get_file_path(self, sample_name):
        return os.path.join(self.subset_path, f'{sample_name}')

    def _load_sample(self, n: int):
        sample_name = self.sample_names[n]
        sample_path = self._get_file_path(sample_name)
        waveform, _ = torchaudio.load(sample_path, channels_first=True)
        noise_type = self._get_noise_type(sample_name)
        return waveform.squeeze(0), noise_type
    
    def _get_noise_type(self, sample_name):
        return NOISE_CLASSES[sample_name.split('_')[0]]

    def __getitem__(self, n: int):
        """Load the n-th sample from the dataset.

        Args:
            n (int): The index of the sample to be loaded.

        Returns:
            tuple: A tuple containing the waveform and the corresponding noise type.
        """
        return self._load_sample(n)  

    def __len__(self):
        return len(self.sample_names)

___

## Define Audio Features

For simplification, and to reduce model complexity, we will not be using directly the waveforms from the database as input to our neural network.  
Instead, we define a function that computes various audio features from the waveform using the [Librosa library](https://librosa.org/), and use these features as input to the model.

In [None]:
def compute_audio_features(batch_waveform, sr=16000):
    batch_waveform_numpy = batch_waveform.numpy()
    mfcc = torch.from_numpy(librosa.feature.mfcc(y=batch_waveform_numpy, sr=sr).mean(2))
    spectral_kwargs = {'n_fft': 2**15, 'hop_length': 2**13}
    temporal_kwargs = {'frame_length': 2**15, 'hop_length': 2**13}
    spectral_centroid = torch.from_numpy(librosa.feature.spectral_centroid(y=batch_waveform_numpy, sr=sr, **spectral_kwargs))
    spectral_bandwidth = torch.from_numpy(librosa.feature.spectral_bandwidth(y=batch_waveform_numpy, **spectral_kwargs))
    spectral_contrast = torch.from_numpy(librosa.feature.spectral_contrast(y=batch_waveform_numpy, sr=sr).mean(2))
    spectral_flatness = torch.from_numpy(librosa.feature.spectral_flatness(y=batch_waveform_numpy, **spectral_kwargs))
    spectral_rolloff = torch.from_numpy(librosa.feature.spectral_rolloff(y=batch_waveform_numpy, sr=sr, **spectral_kwargs))
    zero_crossing_rate = torch.from_numpy(librosa.feature.zero_crossing_rate(y=batch_waveform_numpy, **temporal_kwargs))
    root_mean_square = torch.from_numpy(librosa.feature.rms(y=batch_waveform_numpy, **temporal_kwargs))
    tempo = torch.from_numpy(librosa.feature.tempo(y=batch_waveform_numpy, sr=sr))
    audio_features = torch.cat(
        [mfcc,
        spectral_centroid.squeeze(1),
        spectral_bandwidth.squeeze(1),
        spectral_contrast,
        spectral_flatness.squeeze(1),
        spectral_rolloff.squeeze(1),
        zero_crossing_rate.squeeze(1),
        root_mean_square.squeeze(1),
        tempo],
        dim=1,
    )
    return audio_features


Let's ensure that these audio features are appropriate to discriminate the different noise types.

In [None]:
NUM_SAMPLES = 500
dataset_tst = DemandDataset(root=root_preprocessed, subset='tst', num_samples=NUM_SAMPLES)
dataloader_tst = torch.utils.data.DataLoader(dataset_tst, batch_size=NUM_SAMPLES, shuffle=False)

In [None]:
for batch_waveform, labels in dataloader_tst:
    spectral_features = compute_audio_features(batch_waveform, sr=16000)
    spectral_rolloff = torch.from_numpy(librosa.feature.spectral_rolloff(y=batch_waveform.numpy(), sr=16000).mean(2)).squeeze(1)
    zero_crossing_rate = torch.from_numpy(librosa.feature.zero_crossing_rate(y=batch_waveform.numpy()).mean(2)).squeeze(1)
    break
print(f"{spectral_features.shape = }")
print(f"{spectral_rolloff.shape = }")
print(f"{zero_crossing_rate.shape = }")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(8, 6))

axs.scatter(spectral_rolloff, zero_crossing_rate, c=labels, marker='.', cmap='tab20')
axs.set_xlabel("Spectral Rolloff")
axs.set_ylabel("Zero Crossing Rate")
axs.grid()
axs.set_title("Audio Features w/ Different Colors for Different Noise Types")

fig.tight_layout()
plt.show()

Although not sufficient, the two plotted features are somewhat discriminative between the different classes.
We can reasonably believe that with the additional features, our model will have sufficient information to separate the different noise types.

___

## Define Bayesian Neural Network

In this section, we define our bayesian neural network by leveraging `PyroModule` tools from Pyro (http://pyro.ai/examples/modules.html).  
The proposed neural architecture is rather simple, consisting of stacked fully-connected layers.

In [None]:
from pyro.nn.module import PyroModule, PyroSample
import pyro.distributions as dist

In [None]:
class BayesianNeuralNetwork(PyroModule):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        # Define neural layers.
        self.fc_in = PyroModule[torch.nn.Linear](in_features=input_size, out_features=hidden_size)
        self.fc_hid = PyroModule[torch.nn.Linear](in_features=hidden_size, out_features=hidden_size)
        self.fc_out = PyroModule[torch.nn.Linear](in_features=hidden_size, out_features=output_size)
        # Define prior distributions of neural parameters.
        self.fc_in.weight = PyroSample(prior=dist.Normal(0., 1.).expand([hidden_size, input_size]).to_event(2))
        self.fc_in.bias = PyroSample(prior=dist.Normal(0., 1.).expand([hidden_size]).to_event(1))
        self.fc_hid.weight = PyroSample(prior=dist.Normal(0., 1.).expand([hidden_size, hidden_size]).to_event(2))
        self.fc_hid.bias = PyroSample(prior=dist.Normal(0., 1.).expand([hidden_size]).to_event(1))
        self.fc_out.weight = PyroSample(prior=dist.Normal(0., 1.).expand([output_size, hidden_size]).to_event(2))
        self.fc_out.bias = PyroSample(prior=dist.Normal(0., 1.).expand([output_size]).to_event(1))
        # Define layer normalizations.
        self.layer_norm_in = PyroModule[torch.nn.LayerNorm](hidden_size)
        self.layer_norm_hid = PyroModule[torch.nn.LayerNorm](hidden_size)
        self.layer_norm_out = PyroModule[torch.nn.LayerNorm](output_size)
        # Define activation functions.
        self.activation = PyroModule[torch.nn.ReLU]()
        self.log_softmax = PyroModule[torch.nn.LogSoftmax](dim=1)        
    
    def forward(self, x):
        x = self.fc_in(x)
        x = self.layer_norm_in(x)
        x = self.activation(x)
        x = self.fc_hid(x)
        x = self.layer_norm_hid(x)
        x = self.activation(x)
        x = self.fc_out(x)
        x = self.layer_norm_out(x)
        x = self.log_softmax(x)
        return x

In [None]:
INPUT_SIZE = 88  # Number of audio features.
HIDDEN_SIZE = 64  # Size of hidden layers.
OUTPUT_SIZE = 17  # Number of output classes (i.e., different noise types).

___

## Define Pyro Model

Now, in order to perform inference with our bayesian neural network (i.e., take a sample), we need to wrap it in a probabilistic model with a defined likelihood.

In [None]:
class PyroModel(PyroModule):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.network = BayesianNeuralNetwork(input_size, hidden_size, output_size)

    def forward(self, input, output=None):
        logits = self.network(input)
        with pyro.plate("data"):  #, size=len(input)):
            return pyro.sample("obs", dist.Categorical(logits=logits), obs=output)

In [None]:
model = PyroModel(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)
guide = pyro.infer.autoguide.AutoNormal(model)

___

## Train the Model

We can now train the model.

NB: Implementation is kept rather straightforward for simplicity.  
For example, we could implement cross-validation or other advanced training methods.

In [None]:
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
NUM_SAMPLES = 10000
BATCH_SIZE = 10

In [None]:
# Load the dataset.
dataset_trn = DemandDataset(root=root_preprocessed, subset='trn')
dataloader_trn = torch.utils.data.DataLoader(dataset_trn, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Set up the optimizer and inference method.
pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': LEARNING_RATE})
svi = pyro.infer.SVI(model, guide, optim, loss=pyro.infer.Trace_ELBO())

In [None]:
if False:  # TODO: Change this to train the model.    
    best_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0
        for batch_waveform, labels in dataloader_trn:
            batch_features = compute_audio_features(batch_waveform).to(DEVICE).float()
            labels = labels.to(DEVICE)
            # Perform a training step.
            loss = svi.step(batch_features, labels) / labels.numel()
            epoch_loss += loss
        
        print(f"Epoch {epoch:02} - Loss: {epoch_loss}")

        # Save the model.
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), 'model.pt')

Here is the output of the training loop:
```
Epoch 00 - Loss: 1767609.256354446
Epoch 01 - Loss: 512494.62495819316
Epoch 02 - Loss: 60607.16488494873
Epoch 03 - Loss: 7810.326712163294
Epoch 04 - Loss: 4633.628274218243
Epoch 05 - Loss: 4408.207026608792
Epoch 06 - Loss: 4373.623391977942
Epoch 07 - Loss: 4371.596310106918
Epoch 08 - Loss: 4386.552696482344
Epoch 09 - Loss: 4391.185009002682
```

We can see that the loss tends to over time: the model seems to be learning properly.

___

## Evaluate the Accuracy of the Model

In [None]:
# Load the model.
model.load_state_dict(torch.load('model.pt'))

In [None]:
NUM_SAMPLES = 200
dataset_tst = DemandDataset(root=root_preprocessed, subset='tst', num_samples=NUM_SAMPLES)
dataloader_tst = torch.utils.data.DataLoader(dataset_tst, batch_size=NUM_SAMPLES, shuffle=True)

In [None]:
# Evaluate accuracy of the model.
for batch_waveform, labels in dataloader_tst:
    batch_features = compute_audio_features(batch_waveform).to(DEVICE).float()
    labels = labels.to(DEVICE)
    logits = model(batch_features)
    # Compute accuracy.
    accuracy = torch.mean((logits == labels).float())
    print(f"Accuracy: {accuracy*100:.2f} %")


Evaluating our model reveals that it does not perform any better than random guessing.  
Indeed, accuracy is around `1/17 ≃ 6%` on average.  

This is rather disappointing. However, the loss is undoubtebly decreasing; therefore, the model must be learning something!  
Unfortunately, despite our best efforts and after much research, we have not found the reason for this apparent incoherence.