In [None]:
import tonic
from tqdm.notebook import tqdm

import torch
import torchvision

sensor_size = tonic.datasets.NMNIST.sensor_size
transform = tonic.transforms.Compose(
    [
        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
        torch.from_numpy
    ]
)

train_dataset = tonic.datasets.NMNIST(save_to="/DATA/hwkang", train=True, transform=transform)
test_dataset = tonic.datasets.NMNIST(save_to="/DATA/hwkang", train=False, transform=transform)

cached_train_dataset = tonic.cached_dataset.MemoryCachedDataset(train_dataset)
cached_test_dataset = tonic.cached_dataset.MemoryCachedDataset(test_dataset)

from torch.utils.data import DataLoader
import multiprocessing

train_loader = DataLoader(cached_train_dataset, batch_size=128, num_workers=multiprocessing.cpu_count() // 2, shuffle=True, collate_fn=tonic.collation.PadTensors(batch_first=False))
test_loader = DataLoader(cached_test_dataset, batch_size=128, num_workers=multiprocessing.cpu_count() // 2, shuffle=False, collate_fn=tonic.collation.PadTensors(batch_first=False))

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [None]:
def forward_pass(net, data):
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)

  return torch.stack(spk_rec)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [None]:
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
            break