In [None]:
"""
    Author: Shakil Mahmud Arafat, EEE AUST
    Date last updated: 28 Nov, 2024
    Description: 
"""


'\n    Author: Shakil Mahmud Arafat, EEE AUST\n    Date last updated: 28 Nov, 2024\n    Description: \n'

In [17]:
# Project setup and configs
FILETYPES = (".ogg", ".mp3", "wav")
AUDIO_INPUT_TEST = "./data/audio_dataset/test"
AUDIO_INPUT_TRAIN = "./data/audio_dataset/train"
STFT_PLOTS = "./data/stft_plots"

In [18]:
import os
import librosa
import numpy as np
import pandas as pd  # Import pandas for CSV reading

def audio_to_stft(dir):
    """
    Reads audio files from a directory, performs STFT on each file,
    and returns dictionaries containing STFT matrices, audio data with
    sampling rates, and audio type tags from corresponding CSV files.

    Args:
        dir (str): Directory containing audio files and CSV files.

    Returns:
        tuple: A tuple containing three dictionaries:
               - stft_data: keys are filenames and values are the corresponding STFT matrices.
               - audio_data: keys are filenames and values are tuples of (audio_data, sampling_rate).
               - audio_tags: keys are filenames and values are the audio type tags from the CSV.
    """

    stft_data = {}
    audio_data = {}
    audio_tags = {}

    for filename in os.listdir(dir):
        if filename.endswith(FILETYPES):
            file_path = os.path.join(dir, filename)
            audio, sr = librosa.load(file_path)
            stft = np.abs(librosa.stft(audio))
            stft_data[filename] = stft
            audio_data[filename] = (audio, sr)

            # Extract audio type tag from corresponding CSV file using pandas
            csv_filename = os.path.splitext(filename)[0] + ".csv"
            csv_filepath = os.path.join(dir, csv_filename)

            try:
                df = pd.read_csv(csv_filepath, names=['startTime', 'endTime', 'quantity', 'label'])  # Read CSV with pandas
                audio_tags[filename] = df.iloc[0,3]  # label on 4th column
            except FileNotFoundError:
                print(f"Warning: CSV file {csv_filename} not found.")
                audio_tags[filename] = None
            except IndexError:
                print(f"Warning: CSV file {csv_filename} label not found.")
                audio_tags[filename] = None

    return stft_data, audio_data, audio_tags

In [19]:
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np


def plot_audio_stft(audio, sr, stft, output_filename):
    """
    Plots the waveform and STFT of a single audio file and saves the plot.

    Args:
        audio (np.ndarray): The audio time series.
        sr (int): The sampling rate of the audio file.
        stft (np.ndarray): The STFT matrix of the audio file.
        output_filename (str): The name of the output file for the plot.
    """

    # Calculate time axis for both plots
    time_axis = np.arange(len(audio)) / sr  # Time axis for waveform
    stft_time_axis = np.arange(stft.shape[1]) * (
        stft.shape[1] / sr
    )  # Time axis for STFT

    # Create a figure with two rows and shared x-axis
    fig, axs = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

    # Plot the waveform in the first row
    axs[0].plot(time_axis, audio)
    axs[0].set_ylabel("Amplitude")
    axs[0].set_title("Waveform")

    # Plot the STFT in the second row
    librosa.display.specshow(
        librosa.amplitude_to_db(stft, ref=np.max),
        sr=sr,
        x_axis="time",
        y_axis="hz",
        cmap="viridis",
        ax=axs[1],
    )
    axs[1].set_xlabel("Time (s)")
    axs[1].set_ylabel("Frequency (Hz)")
    axs[1].set_title("STFT")

    # Layout so plots do not overlap
    fig.tight_layout()

    # Save the plot
    plt.savefig(os.path.join(STFT_PLOTS, f"{output_filename}.png"))
    plt.close()

In [20]:
import torch
from torch.utils.data import Dataset
import snntorch.spikegen as spikegen

class PreProcessAudioDataset(Dataset):
    def __init__(self, stft_data, audio_labels, encoding_type='rate', num_steps=50): 
        self.stft_data = stft_data
        self.audio_labels = audio_labels
        self.filenames = list(stft_data.keys())
        self.label_map = {"gun_shot": 0, "dog_bark": 1, "children_playing": 2}
        self.encoding_type = encoding_type
        self.num_steps = num_steps  # Number of time steps for encoding

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        stft = self.stft_data[filename]
            
        # Pad/truncate in the frequency dimension
        num_freq_bins = stft.shape[1]
        max_freq_bins = 100
        if num_freq_bins < max_freq_bins:
            pad_amount = max_freq_bins - num_freq_bins
            stft = np.pad(stft, ((0, 0), (0, pad_amount)), 'constant')
        elif num_freq_bins > max_freq_bins:
            stft = stft[:, :max_freq_bins] 
            
        # Min-Max Normalization (before tensor conversion)
        min_val = stft.min()
        max_val = stft.max()
        stft = (stft - min_val) / (max_val - min_val) 

        # Convert to tensor
        stft = torch.from_numpy(stft).float()  

        # Spike Encoding
        if self.encoding_type == 'rate':
            spike_train = spikegen.rate(stft, num_steps=self.num_steps)
        elif self.encoding_type == 'latency':
            spike_train = spikegen.latency(stft, num_steps=self.num_steps)
        else:
            raise ValueError("Invalid encoding_type. Choose 'rate' or 'latency'")

        label = self.label_map[self.audio_labels[filename]] 
        return stft, label

In [None]:
"""
    Loading and spectogram
"""
import os
from torch.utils.data import DataLoader


# Loads the data, performs stft and returns the stft, audio and labels
train_stft_data, train_audio_data, train_audio_label = audio_to_stft(AUDIO_INPUT_TRAIN)
test_stft_data, test_audio_data, test_audio_label = audio_to_stft(AUDIO_INPUT_TEST)


# Plots the audio and stft spectogram
for filename in os.listdir(AUDIO_INPUT_TRAIN):
    # print(filename)
    if filename in train_stft_data:
        (audio, sr) = train_audio_data[filename]
        print(filename, train_audio_label[filename], train_stft_data[filename].shape)
        # plot_audio_stft(audio, sr, train_stft_data[filename], os.path.splitext(filename)[0])

7068.wav gun_shot (1025, 137)
76089.wav gun_shot (1025, 44)
52441.wav dog_bark (1025, 3208)
7913.wav dog_bark (1025, 358)
97331.wav children_playing (1025, 9740)
97331.mp3 children_playing (1025, 9739)
99500.wav children_playing (1025, 4878)


In [26]:
# preprocess
train_dataset = PreProcessAudioDataset(
    train_stft_data, 
    train_audio_label, 
    encoding_type='rate',  # or 'rate'
    num_steps=100  # Adjust the number of time steps as needed
)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

test_dataset = PreProcessAudioDataset(
    test_stft_data, 
    test_audio_label, 
    encoding_type='rate',
    num_steps=100 
)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


In [36]:
# dataset shape
for train, label in train_dataset:
    print(train.shape)
    # break
# dataloader and input shape
for spike_train, label in train_dataloader:
    print(spike_train.shape)  # shape before flattening
    x = spike_train.view(spike_train.size(0), -1)  # flattened
    print(x.shape)  # shape after flattening 
    break
    

torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([1025, 100])
torch.Size([2, 1025, 100])
torch.Size([2, 102500])


In [None]:
import torch.nn as nn
import snntorch as snn

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# print(device)
# Network Architecture
num_inputs = 1025*100
num_hidden = 100
num_outputs = 3

# Temporal Dynamics
num_steps = 2
beta = 0.95
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)
        elf.fc3 = nn.Linear(num_hidden, num_outputs)
        self.lif3 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            x = x.view(x.size(0), -1)
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

In [9]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(2, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(
    data, targets, epoch,
    counter, iter_counter,
        loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [42]:
num_epochs = 2
loss_hist = []
test_loss_hist = []
counter = 0
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
batch_size = 2
dtype = torch.float

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_dataloader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_dataloader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer(
                    data, targets, epoch,
                    counter, iter_counter,
                    loss_hist, test_loss_hist,
                    test_data, test_targets)
            counter += 1
            iter_counter +=1

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x51250 and 102500x100)