In [1]:
# imports
import os
import h5py

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import utils

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
import numpy as np
import itertools
import tables

In [2]:
# visualize the data
# fileh = tables.open_file("hdspikes/shd_train.h5", mode='r')
# units = fileh.root.spikes.units
# times = fileh.root.spikes.times
# labels = fileh.root.labels
# root
# |-spikes
#    |-times[]
#    |-units[]
# |-labels[]
# |-extra
#    |-speaker[]
#    |-keys[]
#    |-meta_info
#       |-gender[]
#       |-age[]
#       |-body_height[]
# Training Parameters
batch_size = 32
# Torch Variables
dtype = torch.float
# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")     
else:
    device = torch.device("cpu")

In [3]:
class CustomHDF5Dataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        self.fileh = h5py.File(file_path, 'r')
        self.units = self.fileh['spikes']['units']
        self.times = self.fileh['spikes']['times']
        self.labels = self.fileh['labels']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        unit_data = torch.tensor(self.units[index], dtype=torch.float32)
        time_data = torch.tensor(self.times[index], dtype=torch.float32)
        label = torch.tensor(self.labels[index], dtype=torch.long)

        # ensures size is always 700 (to prevent different sized tensors error
        # change size as needed
        # it matches size in num_inputs
        target_length = 700
        if unit_data.size(0) < target_length:
            pad_length_unit = target_length - unit_data.size(0)
            unit_data = torch.nn.functional.pad(unit_data, (0, pad_length_unit), mode='constant', value=0)
        elif unit_data.size(0) > target_length:
            unit_data = unit_data[:target_length]
        
        if time_data.size(0) < target_length:
            pad_length_time = target_length - time_data.size(0)
            time_data = torch.nn.functional.pad(time_data, (0, pad_length_time), mode='constant', value=0)
        elif time_data.size(0) > target_length:
            time_data = time_data[:target_length]
        
        return unit_data, time_data, label

In [9]:
# Make sure you have correct file path
train_dataset = CustomHDF5Dataset("hdspikes/shd_train.h5")
test_dataset = CustomHDF5Dataset("hdspikes/shd_test.h5")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# from https://github.com/fzenke/spytorch/blob/main/notebooks/SpyTorchTutorial4.ipynb
# Network Architecture Parameters
num_inputs = 700
num_hidden = 100
num_outputs = 20

# Temporal Dynamics
num_steps = 25
time_step = 1e-3
tau_mem = 10e-3
tau_syn = 5e-3
alpha = float(np.exp(-time_step/tau_syn))
beta = float(np.exp(-time_step/tau_mem))
lif1 = snn.Leaky(beta=0.9)

In [25]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        combined_input_size = 2 * num_inputs  # since we're concatenating unit_data and time_data

        self.fc1 = nn.Linear(combined_input_size, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, unit_data, time_data):
        # Concatenate unit_data and time_data
        x = torch.cat((unit_data, time_data), dim=1)

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)


net = Net().to(device)


In [51]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [67]:
def print_accuracy(mem_rec, targets, num_steps, batch_size):
    # Initialize correct predictions counter
    correct_predictions = 0

    # Iterate over each time step's predictions
    for step in range(num_steps):
        # Obtain predicted labels from the membrane potentials at each time step
        _, predicted_labels = torch.max(mem_rec[step], dim=1)
        # Check how many predictions match the ground truth labels
        correct_predictions += (predicted_labels == targets).sum().item()

    # Calculate accuracy
    accuracy = correct_predictions / (num_steps * batch_size)

    # Print accuracy
    print(f"Accuracy on this batch: {accuracy:.2%}")

In [68]:
# First get first batch of data
data_iter = iter(train_loader)
unit_data, time_data, targets = next(data_iter)

In [69]:
# Reshape and move data to the correct device
unit_data = unit_data.view(unit_data.size(0), -1).to(device)
time_data = time_data.view(time_data.size(0), -1).to(device)
targets = targets.to(device)

spk_rec, mem_rec = net(unit_data, time_data)
print(mem_rec.size())

torch.Size([25, 32, 20])


In [74]:
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)
# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")

# Initialize correct predictions counter
correct_predictions = 0

# Iterate over each time step's predictions
for step in range(num_steps):
    # Obtain predicted labels from the membrane potentials at each time step
    _, predicted_labels = torch.max(mem_rec[step], dim=1)
    # Check how many predictions match the ground truth labels
    correct_predictions += (predicted_labels == targets).sum().item()

# Calculate accuracy
accuracy = correct_predictions / (num_steps * batch_size)
print_accuracy(mem_rec, targets, num_steps, batch_size)

Training loss: 83.497
Accuracy on this batch: 2.88%


In [75]:
# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

In [76]:
# now rerun after a single iteration
spk_rec, mem_rec = net(unit_data.view(batch_size, -1), time_data.view(batch_size, -1))

# Initialize the total loss value again for the new outputs
loss_val = torch.zeros((1), dtype=dtype, device=device)

# Sum loss at every step for the new outputs
for step in range(num_steps):
    loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")
print_accuracy(mem_rec, targets, num_steps, batch_size)



Training loss: 81.158
Accuracy on this batch: 2.75%


In [115]:
# everytime this cell runs accuracy goes up
# example: i ran it 10 times accuracy went from 9% to 11%

# Clear previously stored gradients
optimizer.zero_grad()

# Calculate the gradients
loss_val.backward()

# Weight update
optimizer.step()

spk_rec, mem_rec = net(unit_data.view(batch_size, -1), time_data.view(batch_size, -1))

# Initialize the total loss value again for the new outputs
loss_val = torch.zeros((1), dtype=dtype, device=device)

# Sum loss at every step for the new outputs
for step in range(num_steps):
    loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")
print_accuracy(mem_rec, targets, num_steps, batch_size)

Training loss: 65.492
Accuracy on this batch: 14.00%
