# Classifying Audio using Spectrograms and CNNs

## Note on Data locations

I recommend putting the audio (WAV) and image (PNG) files in `data\audio` and `data\iamges` directories, respectively.

The `data\` directory is already included in the `.gitignore` file, and so these large binary files won't be included in commits.

### Example:

<img src="data_structure_example.png" style="widht:400px; height:auto;">

## Imports

In [None]:
import os
import datetime
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
from timeit import default_timer as timer
from typing import Tuple

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import librosa
import librosa.display

## AudioFile class

- `file_path`: Path to the audio file
- `file_name`: Name of the audio file (extracted from the path)
- `label`: Label of the audio file (derived from the parent directory name)
- `audio`: Loaded audio data
- `sample_rate`: Sampling rate of the audio file
- `duration`: Duration of the audio file in seconds

### Methods

- `display_waveform()`: Display the waveform of the audio file
- `play()`: Play the audio file and return an audio player widget
- `trim(top_db=30)`: Trim silent parts of the audio using a decibel threshold
- `create_spectrogram()`: Generate a mel spectrogram of the audio file
- `show_spectrogram()`: Display the spectrogram of the audio file
- `save_spectrogram(output_dir=None, skip_existing=True)`: Save the spectrogram as a PNG file


In [None]:
class AudioFile:
    """
    A class to handle audio files and provide utilities for analysis and visualization.

    Attributes:
        file_path (str): Path to the audio file.
        file_name (str): Name of the audio file (extracted from the path).
        label (str): Label of the audio file (derived from the parent directory name).
        audio (np.ndarray): Loaded audio data.
        sample_rate (int): Sampling rate of the audio file.
        duration (float): Duration of the audio file in seconds.
    """

    def __init__(self, file_path):
        """
        Initialize the AudioFile instance by loading the audio file and extracting metadata.

        Args:
            file_path (str): Path to the audio file.
        """
        self.file_path = file_path
        self.file_name = os.path.basename(file_path)
        self.label = os.path.basename(os.path.dirname(self.file_path))
        self.audio, self.sample_rate = librosa.load(file_path)
        self.audio = librosa.util.normalize(self.audio)   # normalize audio
        self.duration = librosa.get_duration(y=self.audio, sr=self.sample_rate)

    def display_waveform(self):
        """
        Display the waveform of the audio file.
        """
        librosa.display.waveshow(self.audio, sr=self.sample_rate)
        plt.show()
        plt.close()

    def play(self):
        """
        Play the audio file.

        Returns:
            IPython.display.Audio: audio player widget.
        """
        return ipd.display(ipd.Audio(self.audio, rate=self.sample_rate))

    def trim(self, top_db=50):
        """
        Trim silent parts of the audio based on a decibel threshold.

        Args:
            top_db (int, optional): Decibel threshold below which audio is considered silent. Defaults to 30.
        """
        self.audio, _ = librosa.effects.trim(self.audio, top_db=top_db)

    def create_spectrogram(self):
        """
        Create a mel spectrogram of the audio file.

        Returns:
            np.ndarray: The mel spectrogram in decibel units.
        """
        mel_scale_sgram = librosa.feature.melspectrogram(
            y=self.audio,
            sr=self.sample_rate,
            power=1)
        mel_sgram = librosa.amplitude_to_db(mel_scale_sgram, ref=np.min)
        return mel_sgram

    def display_spectrogram(self):
        """
        Display the spectrogram of the audio file.
        """
        _spectrogram = self.create_spectrogram()

        fig, ax = plt.subplots()
        img = librosa.display.specshow(
            _spectrogram,
            sr=self.sample_rate,
            x_axis='time',
            y_axis='mel',
            ax=ax)
        plt.colorbar(img, format='%+2.0f dB')

        # remove whitespace around image
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

        plt.show()
        plt.close(fig)

    def save_spectrogram(self, output_dir=None, skip_existing=True):
        """
        Save the spectrogram as a PNG file.

        Args:
            output_dir (str, optional): Directory to save the spectrogram. Defaults to the directory of the audio file.
            skip_existing (bool, optional): Whether to skip saving if the file already exists. Defaults to True.
        """
        if not output_dir:
            output_dir = os.path.dirname(self.file_path)
        else:
            output_dir = os.path.join(output_dir, self.label)

        base, _ = os.path.splitext(self.file_name)
        output_file = os.path.join(output_dir, base + ".png")

        if skip_existing and os.path.exists(output_file):
            return

        spectrogram = self.create_spectrogram()
        librosa.display.specshow(spectrogram, sr=self.sample_rate)

        os.makedirs(output_dir, exist_ok=True)
        # save, removing whitespace
        plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
        plt.close()


### Example of using AudioFile class

In [None]:
_audio_file = os.path.join("data", "audio", "Speech Commands", "backward", "0a2b400e_nohash_0.wav")
test_audio = AudioFile(_audio_file)

test_audio.display_waveform()
test_audio.display_spectrogram()
test_audio.trim(top_db=50)
test_audio.play()
test_audio.display_waveform()
test_audio.display_spectrogram()
test_audio.play()

## Convert Audio Files to Spectrograms

 - set input_dir and output_dir accordingly
 - call process_directory()
 - if skip_existing is True, existing spectrogram PNG files will be skipped (recommended)


### NOTE:

- Only run this cell if you need to save out all the spectrograms. It takes awhile, and is prone to crashing (hence the use of skip_existing, so it can continue where it left off).
- Commented out the "process_directory(...)" line at the bottom to avoid accidental runs

In [None]:
def process_directory(input_dir, skip_existing=True, include_trimmed=False):
    output_dir = os.path.join("data", "images", os.path.basename(input_dir))
    output_dir_trimmed = os.path.join(output_dir + " (trimmed)")
    for root, dirs, files in os.walk(input_dir):
        # sort directories alphabetically
        dirs.sort()
        directory = os.path.basename(root)
        print(f"Processing directory: {directory}")
        for file in files:
            if file.endswith('.wav'):
                audio = None
                # trim off .wav from file
                base, _ = os.path.splitext(file)
                output_file = os.path.join(output_dir, directory, base + ".png")
                
                if not (skip_existing and os.path.exists(output_file)):
                    # load file
                    audio = AudioFile(os.path.join(root, file))
                    # save spectrogram
                    audio.save_spectrogram(output_dir, skip_existing=skip_existing)

                if include_trimmed:
                    output_file_trimmed = os.path.join(output_dir_trimmed, directory, base + ".png")
                    if not (skip_existing and os.path.exists(output_file_trimmed)):
                        # have we loaded the file already?
                        if not audio:
                            audio = AudioFile(os.path.join(root, file))
                        # trim and save
                        audio.trim()
                        audio.save_spectrogram(output_file_trimmed, skip_existing=skip_existing)
                
process_directory('data/audio/Speech Commands_noise', skip_existing=True, include_trimmed=False)

### Corrupt image check

In [None]:
from PIL import Image

# Define your directories
output_dir = os.path.join("data", "images", "Speech Commands")
output_dir_trimmed = os.path.join("data", "images", "Speech Commands (trimmed)")

def check_png_corruption(directories, output_file="corrupt_pngs.txt"):
    corrupt_files = []
    for directory in directories:
        for root, dirs, files in os.walk(directory):
            for file in files:
                if file.lower().endswith('.png'):
                    file_path = os.path.join(root, file)
                    try:
                        with Image.open(file_path) as img:
                            img.verify()  # This will raise an exception if the file is corrupted.
                    except Exception as e:
                        print(f"Corrupted PNG found: {file_path} (Error: {e})")
                        corrupt_files.append(file_path)

    # Write the results to a text file
    if corrupt_files:
        with open(output_file, "w") as f:
            for filename in corrupt_files:
                f.write(f"{filename}\n")
        print(f"Found {len(corrupt_files)} corrupt PNG(s). Details saved to '{output_file}'.")
    else:
        print("No corrupt PNG files found.")

# List of directories to check
directories_to_check = [output_dir, output_dir_trimmed]

# Run the check
check_png_corruption(directories_to_check)

## Convolutional Neural Net

*Pre-process TODO*

- spectrogram: look into librosa specshow options. remove black bar at top of many. check axes.

#### Setup and Parameters

In [None]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"  # Apple Metal Performance Shaders
else:
    device = "cpu"

print(f"Using {device} device")

# setup tensorboard
writer_path = f"./logs/run_{datetime.datetime.now().strftime('%Y%m%d-%H%M')}"
writer = SummaryWriter(writer_path)

model_path = f"./models/cnn_ryan/{datetime.datetime.now().strftime('%Y%m%d-%H%M')}"
            

## PARAMETERS ##
image_size = (256, 190) # from 496x369. Closely maintains aspect ratio
num_channels = 3 # RGB images
# for DataLoader
batch_size = 256
num_workers = 4
# for training
accuracy_threshold = 96


#### Helper Functions

In [None]:
def display_spectrogram(img):
    img = img / 2 + 0.5     # de-normalize
    plt.axis('off')
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))

def display_spectrogram_batches(batches, writer_path=None):
    _iter = iter(batches)
    images, _ = next(_iter)
    img_grid = torchvision.utils.make_grid(images)
    display_spectrogram(img_grid)

    if writer_path:
        # write to tensorboard
        writer.add_images(writer_path, img_grid.unsqueeze(0))

def calc_convolution_output(image_size: Tuple[int, int], kernel_size, stride=1, padding=0):
    _w = int((image_size[0] - kernel_size[0] + 2 * padding) / stride) + 1
    _h = int((image_size[1] - kernel_size[1] + 2 * padding) / stride) + 1
    return (_w, _h)

def calc_pooling_output(image_size: Tuple[int, int], kernel_size, stride=None):
    if stride is None:
        stride = kernel_size
    _w = int((image_size[0] - kernel_size[0]) / stride[0]) + 1
    _h = int((image_size[1] - kernel_size[1]) / stride[1]) + 1
    return (_w, _h)

 #### Transform and Load data 

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((image_size[0], image_size[1])), # convert to square that can easily divide evenly
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# ImageFolder will load data from subdirectories and assign integer labels
data_dir = "data/images/Speech Commands"
dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)

# get class labels
class_labels = [label for label in dataset.class_to_idx]
num_classes = len(class_labels)
print(f"Classes: {num_classes}")

# split into training and validation sets
val_split = 0.2
val_size = int(val_split * len(dataset))
train_size = len(dataset) - val_size

X_train, X_val = torch.utils.data.random_split(dataset, [train_size, val_size])
print(f"Training: {train_size}\nValidation: {val_size}")

# load batches
train_batches = torch.utils.data.DataLoader(
    X_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True)

val_batches = torch.utils.data.DataLoader(
    X_val,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True)

# show a spectrogram
# display_spectrogram(X_train[0][0])

# show some spectrograms
# display_spectrogram_batches(val_batches)

#### __Neural Net Class__


In [None]:
class CNN_1(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(10, 10, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.classifier = nn.Linear(in_features=480, out_features=num_classes)

    def forward(self, x):
        x = self.features(x)
        try:
            logits = self.classifier(x)
        except Exception as e:
            print(f"ERROR: Linear block in_features needs to be: {x.shape[1]}")

        return logits

    
model_old = CNN_1()
print(summary(model_old, input_size=(num_channels, image_size[0], image_size[1])))

In [None]:
class SpectrogramCNN(nn.Module):
    def __init__(self, num_classes):
        super(SpectrogramCNN, self).__init__()
        
        # Kernel calculations
        # W: input dimension; F: filter size; P: padding; S: stride
        # [(W - f + 2p) / s] + 1
         
        # input: (3, 256, 190)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1)
        _map_size = calc_convolution_output(image_size, kernel_size=self.conv1.kernel_size, stride=self.conv1.stride[0], padding=self.conv1.padding[0])
        self.instance_norm = nn.InstanceNorm2d(64)    # speaker normalization
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool1.kernel_size)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        _map_size = calc_convolution_output(_map_size, kernel_size=self.conv2.kernel_size, stride=self.conv2.stride[0], padding=self.conv2.padding[0])
        self.batch_norm_2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool2.kernel_size)
        
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        _map_size = calc_convolution_output(_map_size, kernel_size=self.conv3.kernel_size, stride=self.conv3.stride[0], padding=self.conv3.padding[0])
        self.batch_norm_3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
        _map_size = calc_pooling_output(_map_size, kernel_size=self.pool3.kernel_size)
        
        # final feature dimensions - channels x width x height
        feature_dims = 256 * _map_size[0] * _map_size[1]
        print(f'feature dimensions: {feature_dims}')
        
        # Fully Connected with larger intermediate layers
        self.fc1 = nn.Linear(in_features=feature_dims, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=256)
        self.fc3 = nn.Linear(in_features=256, out_features=num_classes)
        
        # dropout for regularization
        self.dropout2d = nn.Dropout2d(0.2)  # spatial dropout for conv layers
        self.dropout = nn.Dropout(0.4)      # regular dropout for FC layers
    
    def forward(self, x):
        # block 1
        x = F.relu(self.instance_norm(self.conv1(x)))  # using speaker normalization here
        x = self.pool1(x)
        x = self.dropout2d(x)
        
        # block 2
        x = F.relu(self.batch_norm_2(self.conv2(x)))
        x = self.pool2(x)
        x = self.dropout2d(x)
        
        # block 3
        x = F.relu(self.batch_norm_3(self.conv3(x)))
        x = self.pool3(x)
        x = self.dropout2d(x)
        
        # flatten
        x = x.view(x.size(0), -1)
        
        # FC layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x
    

model = SpectrogramCNN(num_classes)
print(summary(model, input_size=(num_channels, image_size[0], image_size[1])))
model = model.to(device)

#### Training

In [None]:
# display training time
def display_training_time(start, end):
    total_time = end - start
    print(f"Training time : {total_time:.3f} seconds")
    return total_time

# display training info for each epoch
def display_training_info(epoch, val_loss, train_loss, accuracy):
    val_loss = round(val_loss.item(), 2)
    train_loss = round(train_loss.item(), 2)
    accuracy = round(accuracy, 2)
    print(f"\nEpoch: {epoch} | Training loss: {train_loss} | Validation loss: {val_loss} | Accuracy: {accuracy}%")
    
# calculate accuracy
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

# training
def train_neural_net(epochs, model, loss_func, optimizer, train_batches, val_batches, patience=5):
    final_accuracy = 0
    best_val_accuracy = 0
    epochs_without_improvement = 0

    # path for saved models
    model_path = f"./models/cnn_ryan/{datetime.datetime.now().strftime('%Y%m%d-%H%M')}"
            
    for epoch in tqdm(range(epochs), desc="Epochs"):
        # training mode
        model.train()
        with torch.enable_grad():
            train_loss = 0
            for images, labels in tqdm(train_batches, desc="Training Batches", leave=False):
                images = images.to(device)
                labels = labels.to(device)
                predictions = model(images)
                loss = loss_func(predictions, labels)
                train_loss += loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            train_loss /= len(train_batches)

        # evaluation mode
        val_loss, val_accuracy = 0, 0
        model.eval()
        with torch.inference_mode():
            for images, labels in tqdm(val_batches, desc="Validation Batches", leave=False):
                images = images.to(device)
                labels = labels.to(device)
                predictions = model(images)
                val_loss += loss_func(predictions, labels)
                val_accuracy += accuracy_fn(y_true=labels, y_pred=predictions.argmax(dim=1))
            val_loss /= len(val_batches)
            val_accuracy /= len(val_batches)
            final_accuracy = val_accuracy
        display_training_info(epoch+1, val_loss, train_loss, val_accuracy)

        # write to tensorboard
        writer.add_scalars("Loss", {
            "Training": train_loss,
            "Validation": val_loss
        }, epoch)
        writer.add_scalar("Accuracy", val_accuracy, epoch)

        # save models if validation accuracy improves
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            os.makedirs(model_path, exist_ok=True)
            torch.save(model.state_dict(), f"{model_path}/model_params.pt")
            torch.save(model, f"{model_path}/model_full.pt")
            print(f"\nCheckpoint saved at epoch {epoch+1} with val accuracy {val_accuracy:.2f}%")
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")

        if epochs_without_improvement >= patience:
            print("Early stopping: no improvement in validation accuracy for ", patience, "epochs.")
            break
        
        if val_accuracy >= accuracy_threshold:
            print("Early stopping: accuracy threshold reached.")
            break
        
    return final_accuracy

In [None]:
max_epochs = 50
learning_rate = 0.001
gradient_momentum = 0.9 # only for SGD optimizer

loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=gradient_momentum)

train_time_start_on_gpu = timer()
model_accuracy = train_neural_net(max_epochs, model, loss_func, optimizer, train_batches, val_batches, patience=5)
print(f"\nTraining complete : {model_accuracy} %")
display_training_time(start=train_time_start_on_gpu, end=timer())

writer.flush()
writer.close()