In [1]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [None]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import learning, layer, neuron, functional
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define Hyperparameters
T = 8  # Time steps
N = 64  # Batch size
C = 1  # Number of input channels (grayscale for MNIST)
H = 28  # Height of the input image
W = 28  # Width of the input image
lr_stdp = 0.01  # Learning rate for STDP
tau_pre = 2.  # Time constant for pre-synaptic traces
tau_post = 100.  # Time constant for post-synaptic traces
step_mode = 'm'  # Multi-step mode for SNN processing
epochs = 5  # Number of training epochs

# Load the MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
    transforms.Lambda(lambda x: (x > 0.5).float())  # Binarize the data to spike or no spike
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=N, shuffle=True)

# Define the SNN Network
def f_weight(x):
    return torch.clamp(x, -1, 1.)

# Define a simple SNN with two convolutional layers
net = nn.Sequential(
    layer.Conv2d(C, 16, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Flatten(),
    layer.Linear(32 * 7 * 7, 10, bias=False),  # Final layer for 10 classes
    neuron.IFNode(),
)

functional.set_step_mode(net, step_mode)

# Define the STDP Learners
# STDP is applied to Conv2d and Linear layers
instances_stdp = (layer.Conv2d, layer.Linear)

# Create an STDP learner for each layer in the SNN
stdp_learners = []

for i in range(len(net)):
    if isinstance(net[i], instances_stdp):
        stdp_learners.append(
            learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i + 1], tau_pre=tau_pre, tau_post=tau_post,
                                f_pre=f_weight, f_post=f_weight)
        )

# Create Optimizer for STDP Updates
# Gather parameters to be updated by STDP
params_stdp = []
for m in net.modules():
    if isinstance(m, instances_stdp):
        params_stdp += list(m.parameters())

# Define the optimizer to apply STDP updates
optimizer_stdp = torch.optim.SGD(params_stdp, lr=lr_stdp, momentum=0.)

# Training Loop
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Convert data to [T, N, C, H, W] format
        x_seq = data.unsqueeze(0).repeat(T, 1, 1, 1, 1)  # Shape: [T, N, C, H, W]

        # Zero gradients
        optimizer_stdp.zero_grad()

        # Forward pass through the network
        out_spike_seq = []
        for t in range(T):
            out_spike_seq.append(net(x_seq[t:t+1]))  # Keep `T` dimension as `[1, N, C, H, W]`
            # Use STDP learners to update weights based on spikes
            for stdp_learner in stdp_learners:
                stdp_learner.step(on_grad=True)  # Adds `- delta_w * scale` to grad

        # Apply the weight update
        optimizer_stdp.step()

        # Reset the network and learners to clear internal states
        functional.reset_net(net)
        for stdp_learner in stdp_learners:
            stdp_learner.reset()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch + 1}, Batch {batch_idx + 1} - STDP Updates Applied')


Epoch 1, Batch 1 - STDP Updates Applied
Epoch 1, Batch 101 - STDP Updates Applied
Epoch 1, Batch 201 - STDP Updates Applied
Epoch 1, Batch 301 - STDP Updates Applied
Epoch 1, Batch 401 - STDP Updates Applied
Epoch 1, Batch 501 - STDP Updates Applied
Epoch 1, Batch 601 - STDP Updates Applied
Epoch 1, Batch 701 - STDP Updates Applied
Epoch 1, Batch 801 - STDP Updates Applied
Epoch 1, Batch 901 - STDP Updates Applied
Epoch 2, Batch 1 - STDP Updates Applied
Epoch 2, Batch 101 - STDP Updates Applied
Epoch 2, Batch 201 - STDP Updates Applied
Epoch 2, Batch 301 - STDP Updates Applied
Epoch 2, Batch 401 - STDP Updates Applied
Epoch 2, Batch 501 - STDP Updates Applied
Epoch 2, Batch 601 - STDP Updates Applied
Epoch 2, Batch 701 - STDP Updates Applied
Epoch 2, Batch 801 - STDP Updates Applied
Epoch 2, Batch 901 - STDP Updates Applied
Epoch 3, Batch 1 - STDP Updates Applied
Epoch 3, Batch 101 - STDP Updates Applied
Epoch 3, Batch 201 - STDP Updates Applied
Epoch 3, Batch 301 - STDP Updates Applie