In [2]:
import torch
from spikingjelly.activation_based import neuron, encoding, functional
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import numpy as np
train_data= np.load("trainset_normalized.npz")
train_data_tensor = torch.from_numpy(train_data['data'])
train_data_tensor = train_data_tensor.permute(0,2,1) 
print(train_data_tensor.shape)
train_label_tensor = torch.from_numpy(train_data['labels'])
train_dataset = TensorDataset(train_data_tensor, train_label_tensor)

print(train_dataset)
      
test_data= np.load("testset_normalized.npz")
test_data_tensor = torch.from_numpy(test_data['data'])
print(test_data_tensor.shape)
test_data_tensor = test_data_tensor.permute(0,2,1) 
test_label_tensor = torch.from_numpy(test_data['labels'])
test_dataset = TensorDataset(test_data_tensor, test_label_tensor)

print(test_dataset)
# data.shape = [samples, 190, 16], labels.shape = [samples]

encoder = encoding.PoissonEncoder()
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(16*180, 128)  # 输入层到隐藏层
        self.lif1 = neuron.LIFNode()  # LIF 神经元
        self.fc2 = nn.Linear(128, 5)  # 隐藏层到输出层，假设有5个类别

    def forward(self, x):
        x = x.reshape(x.size(0), -1).float()  # 将输入展平
        x = self.fc1(x)
        x = self.lif1(x)
        x = self.fc2(x)
        return x
import tqdm
# 初始化模型和优化器
model = SNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Create a DataLoader
batch_size = 5  # Set your desired batch size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for epoch in tqdm.trange(30):
    model.train()
    for i, (data, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        # 编码数据
        out_fr = 0.
        for t in range(30):
            encoded_img = encoder(data)
            out_fr += model(encoded_img)
        output = out_fr / 30
        loss = criterion(output, labels)
        loss.backward(retain_graph=True)
        optimizer.step()
        functional.reset_net(model)
    print(f'Epoch {epoch}, Loss: {loss.item()}')
    


torch.Size([975, 16, 180])
<torch.utils.data.dataset.TensorDataset object at 0x000002A150040310>
torch.Size([245, 180, 16])
<torch.utils.data.dataset.TensorDataset object at 0x000002A147060810>


  0%|          | 0/30 [00:00<?, ?it/s]

  3%|▎         | 1/30 [00:07<03:44,  7.74s/it]

Epoch 0, Loss: 1.5303056716918946


  7%|▋         | 2/30 [00:14<03:28,  7.44s/it]

Epoch 1, Loss: 1.6823700666427612


 10%|█         | 3/30 [00:23<03:29,  7.75s/it]

Epoch 2, Loss: 1.6322103977203368


 13%|█▎        | 4/30 [00:31<03:33,  8.21s/it]

Epoch 3, Loss: 1.5577417612075806


 17%|█▋        | 5/30 [00:41<03:32,  8.50s/it]

Epoch 4, Loss: 1.6582450866699219


 20%|██        | 6/30 [00:49<03:25,  8.57s/it]

Epoch 5, Loss: 1.625087833404541


 23%|██▎       | 7/30 [00:58<03:18,  8.62s/it]

Epoch 6, Loss: 1.5850774765014648


 27%|██▋       | 8/30 [01:07<03:11,  8.73s/it]

Epoch 7, Loss: 1.7211539268493652


 30%|███       | 9/30 [01:16<03:03,  8.74s/it]

Epoch 8, Loss: 1.548635721206665


 33%|███▎      | 10/30 [01:25<02:56,  8.82s/it]

Epoch 9, Loss: 1.5748395681381226


 37%|███▋      | 11/30 [01:34<02:49,  8.92s/it]

Epoch 10, Loss: 1.7877707719802856


 40%|████      | 12/30 [01:44<02:45,  9.19s/it]

Epoch 11, Loss: 1.6436540365219117


 43%|████▎     | 13/30 [01:53<02:37,  9.29s/it]

Epoch 12, Loss: 1.5280796766281128


 47%|████▋     | 14/30 [02:02<02:28,  9.29s/it]

Epoch 13, Loss: 1.7013749599456787


 50%|█████     | 15/30 [02:12<02:19,  9.30s/it]

Epoch 14, Loss: 1.5336036443710328


 53%|█████▎    | 16/30 [02:21<02:10,  9.29s/it]

Epoch 15, Loss: 1.5460233688354492


 57%|█████▋    | 17/30 [02:30<02:00,  9.30s/it]

Epoch 16, Loss: 1.7060786962509156


 60%|██████    | 18/30 [02:40<01:51,  9.33s/it]

Epoch 17, Loss: 1.6296915531158447


 63%|██████▎   | 19/30 [02:49<01:43,  9.44s/it]

Epoch 18, Loss: 1.6789665222167969


 67%|██████▋   | 20/30 [02:59<01:35,  9.50s/it]

Epoch 19, Loss: 1.6168400764465332


 70%|███████   | 21/30 [03:09<01:26,  9.58s/it]

Epoch 20, Loss: 1.6014222860336305


 73%|███████▎  | 22/30 [03:19<01:16,  9.62s/it]

Epoch 21, Loss: 1.6383400440216065


 77%|███████▋  | 23/30 [03:29<01:08,  9.75s/it]

Epoch 22, Loss: 1.5894567489624023


 80%|████████  | 24/30 [03:39<00:59,  9.87s/it]

Epoch 23, Loss: 1.6994383573532104


 83%|████████▎ | 25/30 [03:48<00:48,  9.76s/it]

Epoch 24, Loss: 1.6373874187469482


 87%|████████▋ | 26/30 [03:58<00:39,  9.82s/it]

Epoch 25, Loss: 1.566546368598938


 90%|█████████ | 27/30 [04:08<00:29,  9.67s/it]

Epoch 26, Loss: 1.6119741678237915


 93%|█████████▎| 28/30 [04:17<00:18,  9.50s/it]

Epoch 27, Loss: 1.5290060758590698


 97%|█████████▋| 29/30 [04:26<00:09,  9.42s/it]

Epoch 28, Loss: 1.6465307950973511


100%|██████████| 30/30 [04:35<00:00,  9.19s/it]

Epoch 29, Loss: 1.7118331432342528





Accuracy of the network on the test images: 0.0%


In [7]:
model.eval()
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_loader:
        out_fr = 0.
        for t in range(10):
            encoded_img = encoder(data)
            out_fr += model(encoded_img)
        output = out_fr / 10
        _, predicted = torch.max(output, 1)
        total += labels.size(0)
        _, labels = torch.max(labels, 1)
        correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the test images: {100 * correct / total}%')


Accuracy of the network on the test images: 21.224489795918366%
