In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.utils.data import DataLoader
from norse.torch.module.lif import LIFCell

# Define the fully Spiking Neural Network
class FullySpikingNN(nn.Module):
    def __init__(self, num_classes=10):
        super(FullySpikingNN, self).__init__()
        
        # Spiking layers
        self.lif1 = LIFCell(input_size=32 * 32 * 3, hidden_size=1024)
        self.lif2 = LIFCell(input_size=1024, hidden_size=512)
        self.lif3 = LIFCell(input_size=512, hidden_size=256)
        # self.lif4 = LIFCell(input_size=256, hidden_size=num_classes)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        # Flatten the input image
        x = x.view(-1, 32 * 32 * 3)
        
        # Initialize the state for each LIF layer
        batch_size = x.size(0)
        lif1_state = self.lif1.initial_state(batch_size=batch_size, device=x.device)
        lif2_state = self.lif2.initial_state(batch_size=batch_size, device=x.device)
        lif3_state = self.lif3.initial_state(batch_size=batch_size, device=x.device)
        # lif4_state = self.lif4.initial_state(batch_size=batch_size, device=x.device)
        
        # Fully spiking layers
        z, lif1_state = self.lif1(x, lif1_state)
        z, lif2_state = self.lif2(z, lif2_state)
        z, lif3_state = self.lif3(z, lif3_state)
        # z, lif4_state = self.lif4(z, lif4_state)
        x = self.fc(z)

        return x

# Training and Evaluation
def train_snn(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:    # Print every 100 mini-batches
            print(f'[{i + 1}, {len(train_loader)}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

def test_snn(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy: {100 * correct / total:.2f}%')


# Load CIFAR-10 data
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=100, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(testset, batch_size=100, shuffle=False)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model, criterion, and optimizer
model = FullySpikingNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
for epoch in range(10):  # Loop over the dataset multiple times
    print(f'Epoch {epoch + 1}')
    train_snn(model, train_loader, criterion, optimizer, device)
    test_snn(model, test_loader, device)

print('Finished Training')


Files already downloaded and verified
Files already downloaded and verified
Epoch 1
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 2
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 3
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 4
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 5
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 6
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 500] loss: 2.303
[500, 500] loss: 2.303
Accuracy: 10.00%
Epoch 7
[100, 500] loss: 2.303
[200, 500] loss: 2.303
[300, 500] loss: 2.303
[400, 5