In [25]:
import librosa
import random
import numpy as np
import pandas as pd
import json
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import wave
import os 
from scipy.signal import find_peaks
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [26]:
# layer parameters
batch_size = 1
n_mels = 80

num_hidden = 20
num_outputs = 2
beta = 0.9
num_inputs = n_mels

num_steps = 10


dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)


cpu


In [27]:


def load_mel_spectrogram(file_path, sr=22050/16, n_mels=n_mels, duration=None):
    try:
        y, _ = torchaudio.load(file_path)
        if sr is not None:
            resampler = T.Resample(orig_freq=_, new_freq=sr)
            y = resampler(y)
        if duration is not None:
            y = y[:, :int(sr*duration)]
        
        mel_transform = T.MelSpectrogram(sample_rate=sr, n_mels=n_mels)
        mel_spectrogram = mel_transform(y)
        
        return mel_spectrogram.squeeze().numpy()
    except RuntimeError:
        print(f"Error: Failed to load audio from {file_path}")
        return None




def latency_coding(audio_signal, threshold=0.5, duration=50):
    # Convert audio_signal to a PyTorch tensor if it's not already
    audio_signal = torch.tensor(audio_signal) if not isinstance(audio_signal, torch.Tensor) else audio_signal
    
    # Normalize the audio signal to [0, 1]
    audio_signal = (audio_signal - audio_signal.min()) / (audio_signal.max() - audio_signal.min())
    
    # Calculate the spike time based on intensity
    spike_times = (1 - audio_signal) * duration
    spike_times = spike_times.long()
    
    # Generate spike trains
    spike_trains = torch.zeros(duration, len(audio_signal))
    for i in range(len(audio_signal)):
        if spike_times[i] < duration:
            spike_trains[spike_times[i], i] = 1.0
            
    return spike_trains



def extract_label_from_filename(filename):
    base_name = os.path.basename(filename)
    label = base_name.split('.')[0]
    return label






class AudioDataset(Dataset):
    def __init__(self, audio_files, sr=22050, threshold=0.5, duration=50, n_mels=40):
    

        self.audio_files = audio_files
        self.sr = sr
        self.threshold = threshold
        self.duration = duration
        self.n_mels = n_mels
        
        # Load audio files and filter out None values
        self.audio_data = [(f, load_mel_spectrogram(f, sr, n_mels, duration)) for f in audio_files]
        self.audio_data = [(f, data) for f, data in self.audio_data if data is not None]
        
        self.labels = [extract_label_from_filename(f) for f, _ in self.audio_data]
        
        # Determine the length of the smallest audio file
        self.min_length = min([data.shape[1] for _, data in self.audio_data])
        print(f"Length of the smallest audio file (in Mel bins): {self.min_length}")

        unique_labels = list(set(self.labels))
        self.label_to_int = {label: i for i, label in enumerate(unique_labels)}
        self.int_to_label = {i: label for label, i in self.label_to_int.items()}
        self.encoded_labels = [self.label_to_int[label] for label in self.labels]

    def __len__(self):
        return len(self.audio_data)

    def __getitem__(self, idx):
        file_path, mel_spectrogram = self.audio_data[idx]
        
        # Trim the Mel-spectrogram to the length of the smallest file
        mel_spectrogram = mel_spectrogram[:, :self.min_length]
        
        spike_trains = latency_coding(mel_spectrogram.flatten(), self.threshold)
        #label = extract_label_from_filename(file_path)
        encoded_label = self.encoded_labels[idx]
        return spike_trains, encoded_label



In [28]:
# Get a list of all .wav files in the specified directory
folder_path = "snnTorch_audio"
audio_files = glob.glob(os.path.join(folder_path, "*.wav"))


dataset = AudioDataset(audio_files)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)






Length of the smallest audio file (in Mel bins): 3309


In [29]:

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        
    def forward(self, x):
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        
        # Record the final layer
        spk1_rec = []
        mem1_rec = []
        
        # Use expected_num_steps here
        for step in range(x.size(1)):  # This will automatically adjust to the size of the input data
            cur1 = self.fc1(x[:, step, :])  # Index the data tensor here
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
        
        return torch.stack(spk1_rec, dim=0), torch.stack(mem1_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)



In [30]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [31]:
from torch.nn.utils import clip_grad_norm_
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
data, targets = next(iter(dataloader))
data = data.to(device)
targets = targets.to(device)

# Reshape the data tensor
data, targets = next(iter(dataloader))
flattened_size = data.numel()
print(flattened_size)
expected_num_steps = flattened_size // (batch_size * num_inputs)
print(expected_num_steps)
data = data.view(batch_size, expected_num_steps, num_inputs)

data = data.to(device)
targets = targets.to(device)

# Reshape the data tensor
#data = data.view(batch_size, num_steps, -1)
#data = data.view(batch_size, expected_num_steps, num_inputs)

spk_rec, mem_rec = net(data)
print(mem_rec.size())

6618000
82725
torch.Size([82725, 1, 20])


In [32]:
# Training parameters
num_epochs = 10

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_idx, (data, targets) in enumerate(dataloader):
        # Move data to device
        data, targets = data.to(device), targets.to(device)
        
        # Reshape the data tensor
        flattened_size = data.numel()
        expected_num_steps = flattened_size // (batch_size * num_inputs)
        data = data.view(batch_size, expected_num_steps, num_inputs)

        # Forward pass
        outputs, _ = net(data)
        # Sum the spikes over time and get the neuron with the highest number of spikes
        outputs_sum = outputs.sum(dim=0)

        # Print shapes of outputs_sum and targets for debugging
        #print("outputs_sum shape:", outputs_sum.shape)
        #print("targets shape:", targets.shape)

        # Print some example target values for debugging
        #print("Targets:", targets)

        # Check the shapes of the individual outputs for a single batch element
        #for i, output in enumerate(outputs):
            #print(f"Output {i} shape:", output.shape)

        # Calculate the loss
        loss = loss_fn(outputs_sum, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        # Clip the gradient norm
        clip_grad_norm_(net.parameters(), max_norm=1)

        optimizer.step()

        epoch_loss += loss.item()

    # Print average loss for the epoch
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Epoch [1/10], Loss: 571.4331
Epoch [2/10], Loss: 76.5000
Epoch [3/10], Loss: 46.5000
Epoch [4/10], Loss: 15.5001
Epoch [5/10], Loss: 23.3018
Epoch [6/10], Loss: 34.5000
Epoch [7/10], Loss: 33.8313
Epoch [8/10], Loss: 16.0321
Epoch [9/10], Loss: 22.2000
Epoch [10/10], Loss: 41.2000
