<a href="https://colab.research.google.com/github/zypchn/Spiking-Neural-Networks/blob/main/SNN_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install snntorch -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from snntorch import spikegen
import numpy as np
import itertools
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

In [None]:
data_path = "/content/data/mnist"
batch_size = 128
dtype = torch.float

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cpu')

In [None]:
# Rate Coding (batch)

num_steps = 5
data = iter(train_loader)
data_it, target_it = next(data)

# Spiking Data
spike_data = spikegen.rate(data_it, num_steps=num_steps)
spike_data = spike_data.permute(1, 2, 3, 4, 0)
spike_data.shape

torch.Size([128, 1, 28, 28, 5])

In [None]:
class LIFNeuron:

  def __init__(self, threshold=1.0, decay=0.1):
    self.membrane_potential = None
    self.threshold = threshold
    self.decay = decay

  def forward(self, input_spikes):
    if (self.membrane_potential is None):
      self.membrane_potential = torch.zeros_like(input_spikes[:, :, :, :, 0])

    spike_outs = []
    for t in range(input_spikes.shape[-1]):
      self.membrane_potential = (1 - self.decay) * self.membrane_potential + input_spikes[:, :, :, :, t]

      spikes = (self.membrane_potential >= self.threshold).float()
      self.membrane_potential[spikes == 1] = 0.0

      spike_outs.append(spikes)

    return torch.stack(spike_outs, dim=-1)

In [None]:
neuron = LIFNeuron()

In [None]:
out = neuron.forward(spike_data)

In [None]:
spike_data.shape, out.shape

(torch.Size([128, 1, 28, 28, 5]), torch.Size([128, 1, 28, 28, 5]))

In [None]:
spike_data.sum(), out.sum()

(tensor(51570.), tensor(51570.))

In [None]:
sample_pixel = (0, 0, 10, 10)

print("Input spike train for one pixel:")
print(spike_data[sample_pixel].cpu().numpy())

print("Output spike train for the same pixel:")
print(out[sample_pixel].cpu().numpy())

Input spike train for one pixel:
[0. 0. 0. 0. 0.]
Output spike train for the same pixel:
[0. 0. 0. 0. 0.]
