In [1]:
! pip install snntorch -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# snnTorch
import snntorch as snn
from snntorch import spikeplot as splt

# pyTorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import itertools

# plotting
import matplotlib.pyplot as plt

In [4]:
# dataloader args
batch_size = 128
data_path = "/data/cifar10"

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cpu')

In [5]:
# define a transform
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])

In [6]:
# import datasets
cifar_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)

# create DataLoaders
train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=True, drop_last=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 37.2MB/s]


Extracting /data/cifar10/cifar-10-python.tar.gz to /data/cifar10
Files already downloaded and verified


In [9]:
# Network Architecture
num_inputs = 32 * 32
num_hidden = 2000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.9

In [10]:
# Define Network
class Net(nn.Module):

  def __init__(self):
    super().__init__()

    # Initialize layers
    self.fc1 = nn.Linear(num_inputs, num_hidden)
    self.lif1 = snn.Leaky(beta=beta)
    self.fc2 = nn.Linear(num_hidden, num_outputs)
    self.lif2 = snn.Leaky(beta=beta)

  def forward(self, x):

    # Initialize hidden states at t=0
    mem1 = self.lif1.init_leaky()
    mem2 = self.lif2.init_leaky()

    # record the final layer
    spk2_rec = []
    mem2_rec = []

    for step in range(num_steps):
      cur1 = self.fc1(x)
      spk1, mem1 = self.lif1(cur1, mem1)
      cur2 = self.fc2(spk1)
      spk2, mem2 = self.lif2(cur2, mem2)
      spk2_rec.append(spk2)
      mem2_rec.append(mem2)

    return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [11]:
# Load the network onto CUDA if available
net = Net().to(device)

In [12]:
# Training the SNN

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=4e-3, betas=(0.9, 0.999))

num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
  train_batch = iter(train_loader)

  # Minibatch training loop
  for data, targets in train_batch:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    net.train()
    spk_rec, _ = net(data.flatten(1))

    # initialize the loss & sum over time
    loss_val = torch.zeros((1), dtype=dtype, device=device)
    loss_val += loss(spk_rec.sum(0), targets)

    # Gradient calculation + weight update
    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()

    # Store loss history for future plotting
    loss_hist.append(loss_val.item())

    # Print train/test loss/acc
    if (counter % 10 == 0):
      print(f"Iteration : {counter} \t Train Loss : {loss_val.item()}")
    counter += 1

    if (counter == 100):
      break

Iteration : 0 	 Train Loss : 2.5993261337280273
Iteration : 10 	 Train Loss : 2.381847620010376
Iteration : 20 	 Train Loss : 2.3025853633880615
Iteration : 30 	 Train Loss : 2.3025853633880615
Iteration : 40 	 Train Loss : 2.3025853633880615
Iteration : 50 	 Train Loss : 2.3025853633880615
Iteration : 60 	 Train Loss : 2.3025853633880615
Iteration : 70 	 Train Loss : 2.3025853633880615
Iteration : 80 	 Train Loss : 2.3025853633880615
Iteration : 90 	 Train Loss : 2.3025853633880615


In [13]:
def measure_accuracy(model, dataloader):
  with torch.no_grad():
    model.eval()
    running_len = 0
    running_acc = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward pass
      spk_rec, _ = model(data.flatten(1))
      spike_count = spk_rec.sum(0)
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_len += len(targets)
      running_acc += num_correct

    acc = (running_acc / running_len)

    return acc.item()

In [14]:
acc = measure_accuracy(net, test_loader)
print(f"Test set accuracy : {acc}")

Test set accuracy : 0.09985977411270142
