In [1]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [5]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from snntorch import utils

In [6]:
# Training Parameters
batch_size=128
data_path='/tmp/data/fashion-mnist'
num_classes = 10  # fashion MNIST has 10 output classes

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# print(device)

In [7]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.2MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 210kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.89MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 9.89MB/s]


In [48]:
# Loader serves minibatches of batch_size

train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [84]:
# Time to combine to build our model
class LeakySurrogate(nn.Module):
  def __init__(self, beta, threshold=1) -> None:
    super(LeakySurrogate, self).__init__()

    # Initialize decay rate constant and threshold
    self.beta = beta
    self.threshold = threshold # R=1 to simplify
    self.spike_grad = self.Arctan.apply
    # self._init_mem()

  def forward(self, input_, mem):
    # Define forward pass
    spike = self.spike_grad(mem-self.threshold) # spike = 0.0 or 1.0
    reset = (self.beta * spike * self.threshold).detach() # reset membrane potential
    mem = self.beta * mem + input_ - reset # remove reset before backward pass
    return spike, mem

  def _init_mem(self):
        mem = torch.zeros(0)
        self.register_buffer("mem", mem, False)

  def init_leaky(self, input_):
    # Reset membrane
    return torch.zeros_like(input_)

  # Autograd func
  @staticmethod
  class Arctan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mem):
      # Returns 1.0 if over threshold, 0.0 otherwise
      ctx.save_for_backward(mem)
      return (mem > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
      # Custom to avoid dead neuron
      (mem, ) = ctx.saved_tensors  # Retrieve
      grad_input = 1 / (1+(np.pi*mem).pow_(2)) * grad_output # Smoothing, modifying tensor in place for space efficiency
      return grad_input

In [50]:
# Network arch
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal dynamics
num_steps = 25
beta = 0.95

In [88]:
# Define network
class Network(nn.Module):
  def __init__(self) -> None:
    super().__init__()

    # Initialize layers
    self.fc1 = nn.Linear(num_inputs, num_hidden)  # transform input pixels
    self.lif1 = LeakySurrogate(beta)              # weighted input over time
    self.fc2 = nn.Linear(num_hidden, num_outputs) # transform output spikes
    self.lif2 = LeakySurrogate(beta)              # weighted spikes over time

  def forward(self, x):

    # At t=0
    mem1 = None
    mem2 = None

    # Save final layer
    spk2_rec = []
    mem2_rec = []

    for step in range(num_steps):
      # Apply layers

      cur1 = self.fc1(x)

      if mem1 is None:
        mem1 = self.lif1.init_leaky(cur1)
      spk1, mem1 = self.lif1(cur1, mem1)
      # print("spk1 shape:", spk1.shape)

      cur2 = self.fc2(spk1)

      if mem2 is None:
        mem2 = self.lif2.init_leaky(cur2)
      spk2, mem2 = self.lif2(cur2, mem2)

      # Save for recording
      spk2_rec.append(spk2)
      mem2_rec.append(mem2)
    return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)


In [89]:
# Load network onto cuda
net = Network().to(device)

In [67]:
# Accuracy metrics

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(
    data, targets, epoch,
    counter, iter_counter,
        loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [55]:
# Loss func
loss = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [56]:
# Training params
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

In [91]:
net.train()

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer(
                    data, targets, epoch,
                    counter, iter_counter,
                    loss_hist, test_loss_hist,
                    test_data, test_targets)
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 9
Train Set Loss: 61.69
Test Set Loss: 62.13
Train set accuracy for a single minibatch: 13.28%
Test set accuracy for a single minibatch: 17.19%


Epoch 0, Iteration 59
Train Set Loss: 64.37
Test Set Loss: 61.93
Train set accuracy for a single minibatch: 7.81%
Test set accuracy for a single minibatch: 12.50%


Epoch 0, Iteration 109
Train Set Loss: 62.54
Test Set Loss: 60.52
Train set accuracy for a single minibatch: 10.16%
Test set accuracy for a single minibatch: 10.16%


Epoch 0, Iteration 159
Train Set Loss: 65.42
Test Set Loss: 61.18
Train set accuracy for a single minibatch: 11.72%
Test set accuracy for a single minibatch: 14.06%


Epoch 0, Iteration 209
Train Set Loss: 62.23
Test Set Loss: 63.95
Train set accuracy for a single minibatch: 7.81%
Test set accuracy for a single minibatch: 13.28%


Epoch 0, Iteration 259
Train Set Loss: 61.91
Test Set Loss: 61.84
Train set accuracy for a single minibatch: 7.81%
Test set accuracy for a single minibatch: 8.59%


Epoch