In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import snntorch as snn
from snntorch import utils
from snntorch import spikegen
from snntorch import functional as SF

import snntorch.spikeplot as splt
import matplotlib.pyplot as plt

import sys
sys.path.append('utility')

from preprocessing import *

In [2]:
torch.manual_seed(42)
# 데이터셋에 대한 변환 정의 (이미지를 텐서로 변환하고, 784차원 벡터로 펼침)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # 28x28 이미지를 784 차원 벡터로 변환
])

# MNIST 데이터셋 다운로드 및 로드
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transform, download=True)

# DataLoader를 통해 배치로 데이터를 로드
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

In [3]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 1000)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(1000,10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [4]:
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        # FC - Leaky - FC - Leaky
        self.fc1 = nn.Linear(784, 1000)
        self.lif1 = snn.Leaky(beta=0.5)
        self.fc2 = nn.Linear(1000,10)
        self.lif2 = snn.Leaky(beta=0.5)

    def forward(self, x):
        mem1 = self.lif1.init_leaky().cuda()
        mem2 = self.lif2.init_leaky().cuda()

        num_steps = x.size(0) # << [num_steps, 784]
        spk1_rec = []
        mem1_rec = []
        spk2_rec = []
        mem2_rec = []
        for step in range(num_steps):
            
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
        
        return torch.stack(spk1_rec), torch.stack(mem1_rec), torch.stack(spk2_rec), torch.stack(mem2_rec) # num_steps * [4] >> [num_steps, 4]

In [5]:
class RetrogradeSNN(nn.Module):
    def __init__(self):
        super().__init__()
        # FC - Leaky - FC - Leaky
        self.fc1 = nn.Linear(784,1000)
        self.lif1 = snn.Leaky(beta=0.5)
        self.fc2 = nn.Linear(1000,10)
        self.lif2 = snn.Leaky(beta=0.5)
        self.retrograde_fc = nn.Linear(10, 1000)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        #retrograde_spk1 = torch.zeros_like(x)
        retrograde_spk2 = torch.zeros(1000) # [1000]

        num_steps = x.size(0) # << [num_steps, 4]
        spk1_rec = []
        mem1_rec = []
        spk2_rec = []
        mem2_rec = []
        for step in range(num_steps):
            cur1 = self.fc1(x[step]) # [784] >> [1000]
            spk1, mem1 = self.lif1(cur1, mem1) # [1000] >> [1000]
            
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
            
            cur2 = self.fc2(spk1) # [1000] >> [10]
            spk2, mem2 = self.lif2(cur2, mem2) # [10] >> [10]



            retrograde_spk2 = torch.clone(spk2) 


            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
        
        return torch.stack(spk1_rec), torch.stack(mem1_rec), torch.stack(spk2_rec), torch.stack(mem2_rec) # num_steps * [4] >> [num_steps, 4]

In [6]:
def train_ann(model, train_loader, criterion, optimizer, epochs=3):
    model.train()  # 모델을 학습 모드로 설정
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            # 옵티마이저 초기화
            optimizer.zero_grad()
            
            # 순전파
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 역전파 및 옵티마이저 스텝
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # 에포크마다 평균 손실 출력
        print(f'Epoch [{epoch+1}/{epochs}] | Loss: {running_loss/len(train_loader):.4f}')

def evaluate_ann(model, test_loader):
    model.eval()  # 평가 모드로 전환
    correct = 0
    total = 0
    with torch.no_grad():  # 평가 시에는 그래디언트를 계산하지 않음
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)  # 가장 높은 확률의 클래스 예측
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')

In [7]:
ann = ANN().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

In [8]:
train_ann(ann, train_loader=train_loader, criterion=criterion, optimizer=optimizer, epochs=1)

Epoch [1/1] | Loss: 0.2550


In [9]:
evaluate_ann(ann, test_loader)

Accuracy on test set: 96.66%


In [10]:
def train_snn(model, train_loader, criterion, optimizer, epochs=3):
    model.train()  # 모델을 학습 모드로 설정
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            spk_in = spikegen.rate(images, 8) # >> [num_steps, batch_size, 784]

            # 옵티마이저 초기화
            optimizer.zero_grad()
            
            # 순전파
            _, _, outputs, _ = model(spk_in)
            loss = criterion(outputs, labels)
            
            # 역전파 및 옵티마이저 스텝
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # 에포크마다 평균 손실 출력
        print(f'Epoch [{epoch+1}/{epochs}] | Loss: {running_loss/len(train_loader):.4f}')

def evaluate_snn(model, test_loader):
    model.eval()  # 평가 모드로 전환
    all_predictions = []
    all_labels = []
    with torch.no_grad():  # 평가 시에는 그래디언트를 계산하지 않음
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()

            spk_in = spikegen.rate(images, 8)

            _, _, outputs, _ = model(spk_in) # >> [num_steps, batch_size, num_neurons]
            
            all_predictions.append(outputs) # >> num_batches * [num_steps, batch_size, num_neurons]
            all_labels.append(labels) # >> num_batches * [batch_size]

    all_predictions = torch.concat(all_predictions, dim=1)
    all_labels = torch.concat(all_labels, dim=0)
    
    accuracy = SF.accuracy_rate(all_predictions, all_labels)
    print(f'Accuracy on test set: {accuracy*100:.2f}%')

In [11]:
def count_spikes(spk_in):
    for idx, neuron in enumerate(torch.split(spk_in, 1, dim=1)):
        print(f'#{idx}: {len(neuron.nonzero())}')
    print()

In [12]:
simple_snn = SNN().cuda()
criterion = SF.ce_rate_loss()
optimizer = torch.optim.Adam(simple_snn.parameters())

In [13]:
train_snn(simple_snn, train_loader, criterion, optimizer, epochs=1)

Epoch [1/1] | Loss: 1.5643


In [14]:
evaluate_snn(simple_snn, test_loader)

Accuracy on test set: 94.08%


In [None]:
retrograde_snn = RetrogradeSNN().cuda()
criterion = SF.ce_rate_loss()
optimizer = torch.optim.Adam(retrograde_snn.parameters())

In [None]:
train_snn(retrograde_snn, train_loader, criterion, optimizer, epochs=1)

In [None]:
evaluate_snn(retrograde_snn, test_loader)