In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import spikegen
import snntorch.functional as SF
from snntorch import spikegen

  from snntorch import backprop


In [3]:
# Define transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load CIFAR-10 dataset
batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Define the SNN model
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.num_inputs = 32 * 32 * 3
        self.num_hidden = 256
        self.num_outputs = 10
        self.beta = 0.9

        # Initialize layers
        self.fc1 = nn.Linear(self.num_inputs, self.num_hidden)
        self.lif1 = snn.Leaky(beta=self.beta, spike_grad=surrogate.fast_sigmoid())

        self.fc2 = nn.Linear(self.num_hidden, self.num_hidden)
        self.lif2 = snn.Leaky(beta=self.beta, spike_grad=surrogate.fast_sigmoid())

        self.fc3 = nn.Linear(self.num_hidden, self.num_outputs)
        self.lif3 = snn.Leaky(beta=self.beta, spike_grad=surrogate.fast_sigmoid(), output=True)

    def forward(self, x):
        # Flatten the image
        x = x.view(-1, self.num_inputs)

        x = spikegen.rate(x, num_steps=100)

        # Layer 1
        x = self.fc1(x)
        spk1, mem1 = self.lif1(x)

        # Layer 2
        x = self.fc2(spk1)
        spk2, mem2 = self.lif2(x)

        # Layer 3
        x = self.fc3(spk2)
        spk3, mem3 = self.lif3(x)

        return spk3, mem3

# Instantiate the model
net = SNN()


In [11]:
num_epochs = 5
batch_size = 64
num_steps = 100  # Number of time steps for rate encoding
learning_rate = 1e-3

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Reset the gradients
        optimizer.zero_grad()

        # Forward pass
        spk_out, mem_out = net(inputs)

        # Sum spikes across the time dimension (dim=0), preserving batch_size and class dimensions
        spk_out_sum = spk_out.sum(dim=0)  # Summing over time steps, resulting in [batch_size, num_classes]

        # Compute loss using the summed output spikes as logits
        loss_val = loss_fn(spk_out_sum, labels)

        # Backpropagation
        loss_val.backward()
        optimizer.step()

        running_loss += loss_val.item()
        if i % 100 == 99:    # Print every 100 mini-batches
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        spk_out, _ = net(images)
        _, predicted = torch.max(spk_out.sum(dim=0), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the 10000 test images: {100 * correct / total:.2f}%')