# Train an SNN on MNIST to Prove it Works

In [1]:
import torch
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F

import snntorch.surrogate as surrogate

import numpy as np
from scipy.signal import convolve2d


from model import SpikingNetwork

In [2]:
dataset = MNIST(root = '../data/', download = True)
mnist_dataset = MNIST(root = '../data/', train = True, transform = transforms.ToTensor())
pretrain_dataset = MNIST(root = '../data/', train = False, transform = transforms.ToTensor())

image_tensor, label = mnist_dataset[0]

train_data, validation_data = random_split(mnist_dataset, [50000, 10000])

print("length of Train Datasets: ", len(train_data))
print("length of Validation Datasets: ", len(validation_data))

length of Train Datasets:  50000
length of Validation Datasets:  10000


In [3]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(validation_data, batch_size=128)

In [8]:
def train_step(model, train_loader, optimizer, epoch, device='cpu'):
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data = torch.flatten(data, 2).squeeze()
        data = torch.stack([torch.greater(data, torch.rand_like(data)) for _ in range(16)]).transpose(0, 1).float()
        
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        spikes, voltages = model(data)
        output = spikes.sum(1)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test(model, test_loader, device='cpu'):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            data = torch.flatten(data, 2).squeeze()
            data = torch.stack([torch.greater(data, torch.rand_like(data)) for _ in range(16)]).transpose(0, 1).float()

            spikes, voltages = model(data)
            output = spikes.sum(1)

            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [10]:
network = SpikingNetwork(
    in_dims = 28 * 28,
    fc_dims = [128, 10],
    neuron_models = 'lif', 
    neuron_options = {
        'beta': 0.9,
        'threshold': 1.0,
        'spike_fn': surrogate.atan(alpha=2),
    },
    linear_options = {
        'bias': True, 
    },
)

optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

for epoch in range(5):
    train_step(network, train_loader, optimizer, epoch=epoch)



In [11]:
test(network, test_loader)


Test set: Average loss: 0.0010, Accuracy: 9632/10000 (96%)

