In [1]:
# load audio
# perform stft
# plot stfts
# spatiotemporal gen


In [None]:
import os
from typing import Dict, Optional, Tuple, List
import librosa
import numpy as np
import pandas as pd  # Import pandas for CSV reading
import matplotlib.pyplot as plt
import cv2
from brian2 import *

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools


# Configuration Constants
FILETYPES = (".ogg", ".mp3", ".wav")
BASE_PATH = "./data/audio_dataset/"
AUDIO_INPUT_TEST = os.path.join(BASE_PATH, "test")
AUDIO_INPUT_TRAIN = os.path.join(BASE_PATH, "train")
STFT_PLOTS = "./data/stft_plots/"
SPATIOTEMPORAL_PLOTS = "./data/spatiotemporal_plots"

In [4]:
def audio_to_stft(audio: np.ndarray) -> np.ndarray:
    """
    Convert audio to Short-Time Fourier Transform (STFT).

    Args:
        audio (np.ndarray): Input audio time series

    Returns:
        np.ndarray: Magnitude of STFT
    """
    return np.abs(librosa.stft(audio))

def read_audio_label(dir: str, filename: str) -> Optional[str]:
    """
    Extract audio type tag from corresponding CSV file.

    Args:
        dir (str): Directory containing the CSV file
        filename (str): Name of the audio file

    Returns:
        Optional label from the CSV file
    """
    csv_filename = os.path.splitext(filename)[0] + ".csv"
    csv_filepath = os.path.join(dir, csv_filename)
    try:
        df = pd.read_csv(
            csv_filepath, names=["startTime", "endTime", "quantity", "label"]
        )
        return df.iloc[0, 3]  # label in 4th column
    except (FileNotFoundError, IndexError) as e:
        print(f"Warning: Unable to read label for {csv_filename}. Error: {e}")
    return None

In [5]:
import os
import librosa
import numpy as np
from typing import Dict, Optional, Tuple

def process_audio_dataset(
    data_path: str, 
    filetypes: Tuple[str, ...]
) -> Tuple[
    Dict[str, np.ndarray],  # audio dictionary 
    Dict[str, int],          # sampling rate dictionary
    Dict[str, np.ndarray],  # STFT data dictionary
    Dict[str, Optional[str]]  # audio labels dictionary
]:
    """
    Process audio files in a given directory.
    
    Args:
        data_path (str): Path to the directory containing audio files
        filetypes (Tuple[str, ...]): Allowed audio file extensions
    
    Returns:
        Tuple of dictionaries containing:
        - Audio time series
        - Sampling rates
        - STFT data
        - Audio labels
    """
    stft_data = {}
    audio_labels = {}
    sr = {}
    audio = {}
    
    for filename in os.listdir(data_path):
        file_path = os.path.join(data_path, filename)
        
        if file_path.endswith(filetypes):
            try:
                # Load audio
                audio[filename], sr[filename] = librosa.load(file_path, sr=None)
                
                # Compute STFT
                stft_data[filename] = np.abs(librosa.stft(audio[filename]))
                
                # Get label
                audio_labels[filename] = read_audio_label(data_path, filename)
            
            except Exception as e:
                print(f"Error processing {filename}: {e}")
    
    return audio, sr, stft_data, audio_labels

In [6]:
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np

def plot_audio_stft(audio, sr, stft, output_filename):
    """
    Plots the waveform and STFT of a single audio file and saves the plot.

    Args:
        stft (np.ndarray): The STFT matrix of the audio file.
        output_filename (str): The name of the output file for the plot.
    """


    # Create a figure with two rows and shared x-axis
    plt.plot(figsize=(12, 6))


    # Plot the STFT in the second row
    librosa.display.specshow(
        librosa.amplitude_to_db(stft, ref=np.max),
        sr=sr,
        cmap="viridis",
    )

    # Save the plot
    plt.savefig(os.path.join(STFT_PLOTS, f"{output_filename}.png"), bbox_inches='tight', pad_inches=0)
    plt.close()

In [28]:
def spatiotemporal_gen(stft_plot_path, audio_data, sr, output_filename):
    """
    Generates spatiotemporal plots using the STFT plot and audio data.

    Args:
        stft_plot_path (str): Path to the STFT plot image.
        audio_data (np.ndarray): The audio data.
        sr (int): The sample rate of the audio data.
        output_filename (str): The name of the output file for the spatiotemporal plot.
    """
    # Verify that the file exists
    if not os.path.exists(stft_plot_path):
        raise FileNotFoundError(f"The file {stft_plot_path} does not exist.")

    # Load the chirplet transform spectrogram PNG image using OpenCV
    duration = librosa.get_duration(y=audio_data, sr=sr)
    image = cv2.imread(stft_plot_path, cv2.IMREAD_GRAYSCALE)

    # Check if the image was loaded successfully
    if image is None:
        raise ValueError(f"Failed to load image at {stft_plot_path}. Please check the file.")

    image_data = image.astype(float) / 255.0
    height, width = image_data.shape

    time_duration = duration * second  # Use the actual audio duration
    num_neurons = height
    global_threshold = 0.09
    block_size = 19
    C = -5
    adaptive_thresh = cv2.adaptiveThreshold(
        (image_data * 255).astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, C
    )
    adaptive_thresh = adaptive_thresh / 255.0

    times = []
    neurons = []
    for ex in range(width):  # ex and why used to not conflict with x and y
        for why in range(height):
            if image_data[why, ex] > global_threshold and adaptive_thresh[why, ex] > 0:
                time_point = ex / width * duration
                neuron_point = why
                times.append(time_point)
                neurons.append(neuron_point)

    times = np.array(times) * second
    neurons = np.array(neurons, dtype=int)

    G = NeuronGroup(num_neurons, 'v : 1', threshold='v>1', reset='v=0', method='exact')
    G.v = 0
    input_group = SpikeGeneratorGroup(num_neurons, indices=neurons, times=times)
    S = Synapses(input_group, G, on_pre='v_post += 1')
    S.connect(j='i')
    run(time_duration)

    # Create the spatiotemporal plot
    plt.figure(figsize=(12, 6))
    plt.scatter(times / second, neurons, s=1, c='blue', alpha=0.6)
    plt.title("Spatiotemporal Activity")
    plt.xlabel("Time (s)")
    plt.ylabel("Neuron Index")
    plt.tight_layout()
    plt.savefig(os.path.join(SPATIOTEMPORAL_PLOTS, f"{output_filename}.png"), bbox_inches='tight', pad_inches=0)
    plt.close()


In [19]:
"""
Main function to process train and test audio datasets
"""

directories = [AUDIO_INPUT_TRAIN, AUDIO_INPUT_TEST]

filenames = []
for directory in directories:
    filenames.extend([
        filename for filename in os.listdir(directory) 
        if filename.lower().endswith(FILETYPES)
])
    
# Process train dataset
train_audio, train_sr, train_stft_data, train_audio_labels = process_audio_dataset(
    AUDIO_INPUT_TRAIN, 
    FILETYPES
)

# Process test dataset
test_audio, test_sr, test_stft_data, test_audio_labels = process_audio_dataset(
    AUDIO_INPUT_TEST, 
    FILETYPES
)

# test
print(f"Processed {len(train_stft_data)} train audio files")
print(f"Processed {len(test_stft_data)} test audio files")

for file in filenames:
    if file in train_stft_data:
        plot_audio_stft(train_audio[file], train_sr[file], train_stft_data[file], os.path.splitext(file)[0])
    else:
        plot_audio_stft(test_audio[file], test_sr[file], test_stft_data[file], os.path.splitext(file)[0])


KeyboardInterrupt: 

In [29]:
for file in filenames:
    
    if file in train_stft_data:
        filename_without_ext = os.path.splitext(file)[0]
        filepath = os.path.join(STFT_PLOTS, filename_without_ext + '.png')
        spatiotemporal_gen(filepath, train_audio[file], train_sr[file], os.path.splitext(file)[0])
    else:
        spatiotemporal_gen(filepath, test_audio[file], test_sr[file], os.path.splitext(file)[0])

In [35]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import snntorch as snn
from snntorch import surrogate
from PIL import Image
import numpy as np

class CustomLabeledImageDataset(Dataset):
    def __init__(self, folder_path, labels, transform=None):
        """
        Custom dataset for loading images with provided labels
        
        Args:
            folder_path (string): Path to the folder with images
            labels (list): List of labels corresponding to images
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.image_paths = []
        self.labels = []
        
        # Supported image extensions
        img_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
        
        # Scan directory for image files
        filenames = os.listdir(folder_path)
        
        # Match images with labels
        for filename in filenames:
            if any(filename.lower().endswith(ext) for ext in img_extensions):
                # Find corresponding label
                matching_labels = [label for label in labels if filename.split('.')[0] in str(label)]
                
                if matching_labels:
                    self.image_paths.append(os.path.join(folder_path, filename))
                    self.labels.append(matching_labels[0])
        
        # Convert labels to categorical
        unique_labels = sorted(set(self.labels))
        self.label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
        self.labels = [self.label_to_index[label] for label in self.labels]
        
        self.transform = transform
        
        print(f"Loaded {len(self.image_paths)} images")
        print(f"Unique labels: {unique_labels}")
        print(f"Label mapping: {self.label_to_index}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Open image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        
        # Apply transforms if any
        if self.transform:
            image = self.transform(image)
        
        return image, self.labels[idx]

# Hyperparameters
batch_size = 2  # Adjusted for small dataset
data_path = './data/spatiotemporal_plots'  # Update this to your actual path
learning_rate = 1e-4
num_epochs = 50

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data transformations with augmentation
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # Resize to consistent size
    transforms.RandomRotation(10),  # Light data augmentation
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalization
])

# Create datasets
# TODO: Replace these with your actual labels from preprocessing
train_dataset = CustomLabeledImageDataset(
    os.path.join(data_path, 'train'), 
    train_audio_labels,  # Use your preprocessed labels 
    transform=transform
)

test_dataset = CustomLabeledImageDataset(
    os.path.join(data_path, 'test'), 
    test_audio_labels,  # Use your preprocessed labels
    transform=transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=0
)

# Number of classes (inferred from dataset)
num_classes = len(train_dataset.label_to_index)

# Simplified Spiking Neural Network Model
class SmallDatasetSpikingNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Leaky neuron parameters
        beta = 0.5
        spike_grad = surrogate.sigmoid(slope=10)

        # Simplified network architecture
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.pool = nn.MaxPool2d(2)
        
        # Reduced fully connected layers
        self.fc1 = nn.Linear(8 * 14 * 14, 32)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        # Tracking membrane potential
        spk_out_list = []
        
        # Convolution and spiking layer
        x = self.pool(self.conv1(x))
        x, mem1 = self.lif1(x)
        spk_out_list.append(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.fc1(x)
        x, mem2 = self.lif2(x)
        spk_out_list.append(x)
        
        # Final layer
        x = self.fc2(x)
        
        return x, spk_out_list

# Initialize the model
net = SmallDatasetSpikingNet(num_classes=num_classes).to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-5)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=5, 
    verbose=True
)

# Training Loop
def train(net, train_loader, num_epochs):
    net.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            output, _ = net(data)
            loss = criterion(output, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Compute accuracy
            _, predicted = output.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            total_loss += loss.item()
        
        # Step the scheduler
        scheduler.step(total_loss)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Loss: {total_loss/len(train_loader):.4f}, '
              f'Accuracy: {100.*correct/total:.2f}%')

# Evaluation Function
def evaluate(net, test_loader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            
            output, _ = net(data)
            _, predicted = output.max(1)
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    print(f'Test Accuracy: {100.*correct/total:.2f}%')

# Run training and evaluation
train(net, train_loader, num_epochs)
evaluate(net, test_loader)

# Optional: Save the model
# torch.save(net.state_dict(), 'snn_small_dataset_model.pth')

# Print dataset information
print(f"Total training images: {len(train_dataset)}")
print(f"Total test images: {len(test_dataset)}")
print(f"Number of classes: {num_classes}")
print("Label mapping:", train_dataset.label_to_index)



Loaded 6 images
Unique labels: ['52441.wav', '7068.wav', '76089.wav', '7913.wav', '97331.wav', '99500.wav']
Label mapping: {'52441.wav': 0, '7068.wav': 1, '76089.wav': 2, '7913.wav': 3, '97331.wav': 4, '99500.wav': 5}
Loaded 3 images
Unique labels: ['20571.wav', '7067.wav', '97317.wav']
Label mapping: {'20571.wav': 0, '7067.wav': 1, '97317.wav': 2}


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.