In [1]:
import torch
from spikingjelly.activation_based import neuron, encoding, layer, functional
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import numpy as np
train_data= np.load("trainset_normalized.npz")
train_data_tensor = torch.from_numpy(train_data['data'])
print(train_data_tensor.shape)

train_label_tensor = torch.from_numpy(train_data['labels'])
print(train_label_tensor.shape)
train_dataset = TensorDataset(train_data_tensor, train_label_tensor)

print(train_dataset)
      
test_data= np.load("testset_normalized.npz")
test_data_tensor = torch.from_numpy(test_data['data'])
print(test_data_tensor.shape)
test_label_tensor = torch.from_numpy(test_data['labels'])
test_dataset = TensorDataset(test_data_tensor, test_label_tensor)

print(test_dataset)
# data.shape = [samples, 190, 16], labels.shape = [samples]

encoder = encoding.PoissonEncoder()

torch.Size([975, 180, 16])
torch.Size([975, 5])
<torch.utils.data.dataset.TensorDataset object at 0x000002206E2CEC90>
torch.Size([245, 180, 16])
<torch.utils.data.dataset.TensorDataset object at 0x000002206DEDB690>


In [2]:
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.SNN_net=nn.Sequential(
            layer.Linear(16, 200),
            neuron.LIFNode(),
            layer.Linear(200, 5)
        )

    def forward(self, x):
        
        output = self.SNN_net(x).mean(0)
        return output


In [5]:
model = SNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
batch_size = 50  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for epoch in range(100):
    model.train()
    for i, (data, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.permute(1, 0, 2)
        output = model(data.float())
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        functional.reset_net(model)
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')


Epoch 0, Loss: 1.6094723176956176
Epoch 10, Loss: 1.550252046585083
Epoch 20, Loss: 1.1957360053062438
Epoch 30, Loss: 0.9791638445854187
Epoch 40, Loss: 0.9815982019901276
Epoch 50, Loss: 0.8030106419324875
Epoch 60, Loss: 0.8280121329426765
Epoch 70, Loss: 0.5409609532356262
Epoch 80, Loss: 0.531818352304399
Epoch 90, Loss: 0.49698328018188476


In [7]:
   
model.eval()
test_loader = DataLoader(test_dataset, batch_size=10)
correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_loader:
        functional.reset_net(model)
        data = data.permute(1, 0, 2)
        print(data.shape)
        output = model(data.float())
        print(output.shape)
        predicted = torch.max(output, 1)[1]
        labels = torch.max(labels, 1)[1]
        print(predicted)
        print(labels)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        functional.reset_net(model)
print(f'Accuracy of the network on the test images: {100 * correct / total}%')


torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([0, 0, 0, 0, 0, 0, 3, 4, 3, 3])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 1])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([0, 0, 2, 0, 0, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([4, 1, 1, 4, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([4, 1, 1, 1, 1, 1, 1, 1, 1, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torch.Size([180, 10, 16])
torch.Size([10, 5])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torch.Si