In [23]:
import librosa
import random
import numpy as np
import pandas as pd
import json
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import wave
import os 
from scipy.signal import find_peaks
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [24]:

def load_audio(file_path, sr=22050, duration=None):
    try:
        y, _ = torchaudio.load(file_path)
        # Resample if necessary
        if sr is not None:
            resampler = torchaudio.transforms.Resample(orig_freq=_, new_freq=sr)
            y = resampler(y)
        if duration is not None:
            y = y[:, :int(sr*duration)]
        return y.squeeze().numpy()
    except RuntimeError:
        print(f"Error: Failed to load audio from {file_path}")
        return None




def latency_coding(audio_signal, threshold=0.5, duration=50):
    # Convert audio_signal to a PyTorch tensor if it's not already
    audio_signal = torch.tensor(audio_signal) if not isinstance(audio_signal, torch.Tensor) else audio_signal
    
    # Normalize the audio signal to [0, 1]
    audio_signal = (audio_signal - audio_signal.min()) / (audio_signal.max() - audio_signal.min())
    
    # Calculate the spike time based on intensity
    spike_times = (1 - audio_signal) * duration
    spike_times = spike_times.long()
    
    # Generate spike trains
    spike_trains = torch.zeros(duration, len(audio_signal))
    for i in range(len(audio_signal)):
        if spike_times[i] < duration:
            spike_trains[spike_times[i], i] = 1.0
            
    return spike_trains



def extract_label_from_filename(filename):
    base_name = os.path.basename(filename)
    label = base_name.split('.')[0]
    return label






class AudioDataset(Dataset):
    def __init__(self, audio_files, sr=22050, threshold=0.5, duration=50):
        self.audio_files = audio_files
        self.sr = sr
        self.threshold = threshold
        self.duration = duration
        
        # Load audio files and filter out None values
        self.audio_data = [(f, load_audio(f, sr, duration)) for f in audio_files]
        self.audio_data = [(f, data) for f, data in self.audio_data if data is not None]
        
        self.labels = [extract_label_from_filename(f) for f, _ in self.audio_data]
        
        # Determine the length of the smallest audio file
        self.min_length = min([len(data) for _, data in self.audio_data])
        print(f"Length of the smallest audio file: {self.min_length}")


    def __len__(self):
        return len(self.audio_data)

    def __getitem__(self, idx):
        file_path, audio_signal = self.audio_data[idx]
        
        # Trim the audio signal to the length of the smallest file
        audio_signal = audio_signal[:self.min_length]
        
        spike_trains = latency_coding(audio_signal, self.threshold, self.duration)
        label = extract_label_from_filename(file_path)
        return spike_trains, label


In [25]:
# Get a list of all .wav files in the specified directory
folder_path = "snnTorch_audio"
audio_files = glob.glob(os.path.join(folder_path, "*.wav"))

dataset = AudioDataset(audio_files)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)



Error: Failed to load audio from snnTorch_audio/jazz.00054.wav
Length of the smallest audio file: 660000


In [None]:
for spikes, labels in dataloader:
    # Use spikes and labels in your training loop
    pass