In [33]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import utils
from snntorch import functional as SF
from snntorch import surrogate

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import itertools
from tqdm import tqdm

from custom_data import LoadDataset
import custom_data
from model import model, compute_loss

import matplotlib.pyplot as plt
from IPython.display import HTML

from collections import defaultdict


In [34]:

def forward_pass(net, data):
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(data.size(0)):  # data.size(0) = number of time steps
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    return torch.stack(spk_rec)


In [35]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Temporal Dynamics
num_steps = 10
beta = 0.95
dataset_path = "dataset/"
batch_size = 1

train_dataset = LoadDataset(dir = dataset_path, train=True)
test_dataset = LoadDataset(dir = dataset_path,  train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=custom_data.custom_collate, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=custom_data.custom_collate, shuffle=False,)


spike_grad = surrogate.atan()
net = model.cnn(beta=beta, spike_grad=spike_grad).to(device)
model_path = 'models/model1.pth'
net.load_state_dict(torch.load(model_path))



optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
# loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 100
num_iters = 50
pixel = 64
correct_rate = 0.1
loss_hist = []



In [58]:


def save_img(spk_rec, label, image_path):
    # label = label.reshape((pixel, pixel)).to('cpu')
    num_steps = len(spk_rec)
    nrows = 2
    ncols = 5
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(13,8), tight_layout = True)
    th_rate = 0.1
    for i in range(nrows):
        for j in range(ncols):
            if i == 0 and j == 0:
                _label = label.reshape((pixel, pixel)).to('cpu')
                axes[i, j].imshow(_label)
                axes[i, j].set_title('label')
                continue
            pred = compute_loss.show_pred(spk_rec, th_rate)
            iou = compute_loss.culc_iou(spk_rec, label, th_rate)
            pred = pred.reshape((pixel, pixel)).to('cpu')
            axes[i,j].imshow(pred)
            th_count = int(th_rate*num_steps)
            axes[i, j].set_title(f'pred(th={th_count})\n IoU:{round(iou,3)}')
            th_rate += 0.1
    plt.tight_layout()
    plt.savefig(image_path)
    plt.close()


result_dir = 'result_img'
if os.path.exists(result_dir):
        shutil.rmtree(result_dir)
os.makedirs(result_dir)
hist = defaultdict(list)
with torch.no_grad():
    net.eval()
    for i, (data, label) in enumerate(iter(test_loader)):
        data = data.to(device)
        label = label.to(device)
        batch = len(data[0])
        data = data.reshape(num_steps, batch, 1, pixel, pixel)
        spk_rec = forward_pass(net, data)
        loss_val = compute_loss.spike_mse_loss(spk_rec, label)
        iou = compute_loss.culc_iou(spk_rec, label, correct_rate)
        pred = compute_loss.show_pred(spk_rec, correct_rate)
        img_path = os.path.join(result_dir, f'{str(i).zfill(4)}.png')
        save_img(spk_rec, label, img_path)
