In [12]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
class ColorizedMNISTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.labels = []

        # 遍历文件夹加载图像和标签
        for label in range(10):  # 文件夹名为 0-9
            folder_path = os.path.join(root_dir, str(label))
            for file_name in os.listdir(folder_path):
                if file_name.endswith('.png'):
                    file_path = os.path.join(folder_path, file_name)
                    self.data.append(file_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        img = Image.open(img_path).convert('RGB')  # 转为 RGB 彩色
        if self.transform:
            img = self.transform(img)
        return img, label

# 数据变换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转为 Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载数据集
train_dataset = ColorizedMNISTDataset("colorized-MNIST-master/training", transform)
test_dataset = ColorizedMNISTDataset("colorized-MNIST-master/testing", transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [13]:
import spikingjelly.activation_based.neuron as sj_neuron
import spikingjelly.activation_based.functional as sj_functional
import torch.nn as nn

class TransformerSNN(nn.Module):
    def __init__(self, num_classes=10):
        super(TransformerSNN, self).__init__()
        # Transformer Encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024),
            num_layers=6
        )
        self.embedding = nn.Linear(28 * 28 * 3, 512)  # 将图像展平并嵌入到 Transformer 的输入维度
        
        # 使用 LIF 神经元
        self.lif_neuron = sj_neuron.LIFNode()  
        self.fc = nn.Linear(512, num_classes)  # 输出分类

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平图像
        x = self.embedding(x)
        x = self.transformer(x.unsqueeze(1))  # 增加时间步维度
        
        # 应用 SNN 的 LIFNode
        x = self.lif_neuron(x)
        
        x = self.fc(x[:, -1, :])  # 使用最后时间步的输出
        return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerSNN(num_classes=10).to(device)

In [14]:
from spikingjelly.activation_based.functional import reset_net
import torch.optim as optim

# 假设 model 已定义好
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.93)
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        BCE_loss = nn.CrossEntropyLoss()(inputs, targets)
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return focal_loss

criterion = FocalLoss()


# 修正后的训练函数
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        
        # 重置网络状态
        reset_net(model)

        # 前向传播
        outputs = model(imgs)
        
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Train Loss: {total_loss:.4f}, Accuracy: {correct / total:.4f}")

# 修正后的测试函数
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():  # 测试时不需要计算梯度
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            
            # 重置网络状态
            reset_net(model)

            # 前向传播
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    print(f"Test Accuracy: {correct / total:.4f}")
# 训练模型
for epoch in range(40):  # 假设训练 10 个 Epoch
    print(f"Epoch {epoch + 1}")
    train_model(model, train_loader, optimizer, criterion, device)
    test_model(model, test_loader, device)


Epoch 1
Train Loss: 63.3149, Accuracy: 0.2108
Test Accuracy: 0.1324
Epoch 2
Train Loss: 57.0664, Accuracy: 0.4758
Test Accuracy: 0.2580
Epoch 3
Train Loss: 47.6427, Accuracy: 0.5739
Test Accuracy: 0.4118
Epoch 4
Train Loss: 37.6107, Accuracy: 0.6388
Test Accuracy: 0.5145
Epoch 5
Train Loss: 28.7830, Accuracy: 0.7007
Test Accuracy: 0.6320
Epoch 6
Train Loss: 22.0325, Accuracy: 0.7494
Test Accuracy: 0.6885
Epoch 7
Train Loss: 17.1527, Accuracy: 0.7916
Test Accuracy: 0.7269
Epoch 8
Train Loss: 13.6151, Accuracy: 0.8212
Test Accuracy: 0.7401
Epoch 9
Train Loss: 11.0201, Accuracy: 0.8360
Test Accuracy: 0.7900
Epoch 10
Train Loss: 9.1877, Accuracy: 0.8518
Test Accuracy: 0.7736
Epoch 11
Train Loss: 7.7434, Accuracy: 0.8609
Test Accuracy: 0.7792
Epoch 12
Train Loss: 6.4889, Accuracy: 0.8735
Test Accuracy: 0.7955
Epoch 13
Train Loss: 5.7566, Accuracy: 0.8746
Test Accuracy: 0.8144
Epoch 14
Train Loss: 5.1808, Accuracy: 0.8809
Test Accuracy: 0.8127
Epoch 15
Train Loss: 4.6539, Accuracy: 0.8872
Te

In [15]:
#模型保存
torch.save(model.state_dict(), "transformer_snn_model3.pth")

In [16]:
# 加载模型
#model = TransformerSNN(num_classes=10)  # 与保存时模型结构相同
#model.load_state_dict(torch.load("model_epoch_10.pth"))
#model.to(device)  # 将模型加载到对应的设备
