In [43]:
# imports
import snntorch as snn
from snntorch import surrogate
#from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [44]:
class Net(nn.Module):
    def __init__(self, num_steps, beta, spike_grad, dropout_rate=0.5):
        super().__init__()

        # Initialize Attributes
        self.num_steps = num_steps
        
        # Initialize layers
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.bn1 = nn.BatchNorm2d(12)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.bn2 = nn.BatchNorm2d(64)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)

        self.dropout = nn.Dropout(dropout_rate)
        
        self.fc1 = nn.Linear(64*5*5, 10)
        self.bn3 = nn.BatchNorm1d(10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def recursive_forward(self, cur, mem, net, num_steps):
        spk_rec = []
        mem_rec = []
        utils.reset(net)
        for step in range(num_steps):
            spk, mem = net(cur)
            spk_rec.append(spk)
            mem_rec.append(mem)
        return torch.stack(spk_rec), torch.stack(mem_rec)
    
    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.bn1(self.conv1(x)), 2)
        #spk1, mem1 = self.lif1(cur1, mem1)
        spk1, mem1 = self.recursive_forward(cur1, mem1, self.lif1, self.num_steps)
        spk1 = torch.mean(spk1, dim=0)

        cur2 = F.max_pool2d(self.bn2(self.conv2(spk1)), 2)
        #spk2, mem2 = self.lif2(cur2, mem2)
        spk2, mem2 = self.recursive_forward(cur2, mem2, self.lif2, self.num_steps)
        spk2 = torch.mean(spk2, dim=0)

        cur3 = self.dropout(spk2.view(spk2.shape[0], -1))
        cur3 = self.fc1(cur3)

        cur3 = self.bn3(cur3)
        #spk3, mem3 = self.lif3(cur3, mem3)
        spk3, mem3 = self.recursive_forward(cur3, mem3, self.lif3, self.num_steps)

        return spk3, mem3

In [45]:
class MyNormalize:
    def __init__(self):
        pass
    def __call__(self, data):
        vmax, vmin = data.max(), data.min()
        return (data-vmin)/(vmax-vmin)

In [46]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

In [47]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      #spk_rec, _ = forward_pass(net, num_steps, data)
      spk_rec, _ = net(data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [48]:
transform = transforms.Compose([transforms.ToTensor(), MyNormalize()])
train_dataset = datasets.CIFAR10(root='/home/hwkang/jupyter/root/dataset', train=True, transform=transform, download=False)
test_dataset = datasets.CIFAR10(root='/home/hwkang/jupyter/root/dataset', train=False, transform=transform, download=False)

In [49]:
batch_size=64

In [50]:
### 훈련 데이터셋과 검증 데이터셋으로 분할
dataset_size = len(train_dataset)
train_size = int(0.9 * dataset_size)
valid_size = dataset_size - train_size
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
##*

### 데이터로더 준비
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [51]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [52]:
num_steps=50

# Model
net = Net(num_steps=num_steps, beta=0.9, spike_grad=surrogate.atan())
net = net.to(device)

# Epoch
num_epochs = 50

# Loss func.
loss_fn_rate = SF.ce_rate_loss()

# Optim.
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))

# Util.
loss_hist = []
test_acc_hist = []
counter = 0
min_acc = 0.0
best_acc_epoch = 0

In [53]:
# Training loop
for epoch in range(num_epochs):
    for data, targets in iter(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        #spk_rec, _ = forward_pass(net, num_steps, data)
        spk_rec, _ = net(data)

        # initialize the loss & sum over time
        loss = loss_fn_rate(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if counter % 100 == 0:
            test_acc = batch_accuracy(valid_loader, net, num_steps)
            print(f'Epoch: {epoch+1}/{num_epochs} | Iter.: {counter} | Acc.: {test_acc*100:.2f}%')
        counter += 1

        if( test_acc > min_acc ):
            best_acc_epoch = epoch
            best_model_state = net.state_dict()

Epoch: 1/50 | Iter.: 0 | Acc.: 9.74%
Epoch: 1/50 | Iter.: 100 | Acc.: 31.08%
Epoch: 1/50 | Iter.: 200 | Acc.: 35.54%
Epoch: 1/50 | Iter.: 300 | Acc.: 35.12%
Epoch: 1/50 | Iter.: 400 | Acc.: 35.54%
Epoch: 1/50 | Iter.: 500 | Acc.: 40.08%
Epoch: 1/50 | Iter.: 600 | Acc.: 37.18%
Epoch: 1/50 | Iter.: 700 | Acc.: 35.62%
Epoch: 2/50 | Iter.: 800 | Acc.: 36.26%
Epoch: 2/50 | Iter.: 900 | Acc.: 43.06%
Epoch: 2/50 | Iter.: 1000 | Acc.: 34.22%
Epoch: 2/50 | Iter.: 1100 | Acc.: 42.14%
Epoch: 2/50 | Iter.: 1200 | Acc.: 32.42%
Epoch: 2/50 | Iter.: 1300 | Acc.: 40.02%
Epoch: 2/50 | Iter.: 1400 | Acc.: 40.58%
Epoch: 3/50 | Iter.: 1500 | Acc.: 38.36%
Epoch: 3/50 | Iter.: 1600 | Acc.: 41.26%
Epoch: 3/50 | Iter.: 1700 | Acc.: 40.10%
Epoch: 3/50 | Iter.: 1800 | Acc.: 40.10%
Epoch: 3/50 | Iter.: 1900 | Acc.: 47.00%
Epoch: 3/50 | Iter.: 2000 | Acc.: 47.84%
Epoch: 3/50 | Iter.: 2100 | Acc.: 43.12%
Epoch: 4/50 | Iter.: 2200 | Acc.: 47.14%
Epoch: 4/50 | Iter.: 2300 | Acc.: 46.20%
Epoch: 4/50 | Iter.: 2400 | A

In [54]:
print(f'Load model from epoch [{best_acc_epoch}]')
net.load_state_dict(best_model_state)

# Test set forward pass
test_acc = batch_accuracy(test_loader, net, num_steps)
print(f"Test Acc: {test_acc * 100:.2f}%\n")

Load model from epoch [49]
Test Acc: 63.58%



In [55]:
!pwd

/home/hwkang/jupyter/root
