In [7]:
# Tonic packages
import tonic
from tonic import DiskCachedDataset
from tonic.dataset import Dataset
from tonic.download_utils import extract_archive
from tonic.io import make_structured_array

# PyTorch packages
import torch
import torchvision as tv
import torch.nn as nn
import tonic.transforms as transforms
from torch.utils.data import DataLoader #

# SNN PyTorch extension packages
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils

# Visualization packages
from IPython.display import HTML 
import matplotlib.pyplot as plt
import seaborn as sns

# Other packages
import numpy as np
import scipy.io as scio
import os, time
from typing import Any, Callable, Optional, Tuple

In [None]:
class ASLDVS(Dataset):
    """`ASL-DVS <https://github.com/PIX2NVS/NVS2Graph>`"""

    classes = [chr(letter) for letter in range(97, 123)]  # generate alphabet
    int_classes = dict(zip(classes, range(len(classes))))
    sensor_size = (240, 180, 2)
    dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
    ordering = dtype.names

    def __init__(self, 
                 save_to: str, 
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 transforms: Optional[Callable] = None):
        
        super().__init__(save_to, transform, target_transform, transforms)

        if not self._check_exists():
            self.download()
            # extract zips within zip
            for path, dirs, files in os.walk(self.location_on_system):
                dirs.sort()
                for file in files:
                    if file.startswith("Yin") and file.endswith("zip"):
                        extract_archive(os.path.join(self.location_on_system, file))

        for path, dirs, files in os.walk(self.location_on_system):
            dirs.sort()
            files.sort()
            for file in files:
                if file.endswith("mat"):
                    self.data.append(path + "/" + file)
                    self.targets.append(self.int_classes[path[-1]])


    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Returns:
            (events, target) where target is index of the target class.
        """
        events, target = scio.loadmat(self.data[index]), self.targets[index]
        events = make_structured_array(
            events["ts"],
            events["x"],
            self.sensor_size[1] - 1 - events["y"],
            events["pol"],
            dtype=self.dtype,
        )
        if self.transform is not None:
            events = self.transform(events)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transforms is not None:
            events, target = self.transforms(events, target)
        return events, target


    def __len__(self):
        return len(self.data)


    def _check_exists(self):
        return (
            self._is_file_present()
            and self._folder_contains_at_least_n_files_of_type(100800, ".mat")
        )

## CSNN Model Architecture

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network, 12 filters Conv 5X5-Max Pool 2X2-32 filters Conv 5X5-Max Pool 2X2-800 fully connected 10 o/p
net = nn.Sequential(
                    # 1st layer
                    nn.Conv2d(2, 12, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.MaxPool2d(2),
                    # 2nd layer
                    nn.Conv2d(12, 32, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.MaxPool2d(2),
                    # 3rd Layer
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2) # Adam optimizer
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2) # Mean Square Count Loss

In [None]:
def forward_pass(net, data):
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(data.size(0)):  # data.size(0) = number of time steps
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    return torch.stack(spk_rec)

In [None]:
num_epochs = 1
num_iters = 200

In [None]:
loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(trainloader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        if i == num_iters:
          break