In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import utils
from snntorch import functional as SF
from snntorch import surrogate

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import itertools
from tqdm import tqdm
# from collections import defaultdict

from custom_data import LoadDataset
import custom_data
from module import network, compute_loss

import matplotlib.pyplot as plt
from IPython.display import HTML

from collections import defaultdict


In [2]:
def forward_pass(net, data):
    soft = nn.Softmax2d()
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(data.size(0)):  # data.size(0) = number of time steps
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)

    spk_rec = torch.stack(spk_rec)
    spk_cnt = compute_loss.spike_count(spk_rec, channel=True)# batch channel(n_class) pixel pixel 
    # pred_pro = torch.sigmoid(spk_cnt)# batch channel(n_class) pixel pixel
    # pred_pro = torch.tanh(spk_cnt/5)# batch channel(n_class) pixel pixel
    # print(spk_cnt)
    # pred_pro = soft(spk_cnt)
    pred_pro = F.softmax(spk_cnt, dim=1)
    print(spk_cnt.shape, pred_pro.shape)
    
    return pred_pro

# いろいろ定義

In [3]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Temporal Dynamics
num_steps = 10
beta = 0.95
dataset_path = "dataset/"
batch_size = 1

train_dataset = LoadDataset(dir = dataset_path, train=True)
test_dataset = LoadDataset(dir = dataset_path,  train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=custom_data.custom_collate, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=custom_data.custom_collate, shuffle=False,)


spike_grad = surrogate.atan()
net = network.fcn2(beta=beta, spike_grad=spike_grad).to(device)
model_path = 'models/model1.pth'
net.load_state_dict(torch.load(model_path))



optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
# loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 100
num_iters = 50
pixel = 64
correct_rate = 0.5
loss_hist = []



# 出力結果の解析
危険とみなすスパイク数の閾値を変化させたときの、出力画像やiouを算出する


In [4]:


def save_img(pred_pro, label, image_path):
    # label = label.reshape((pixel, pixel)).to('cpu')
    num_steps = len(pred_pro)
    nrows = 2
    ncols = 5
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(13,8), tight_layout = True)
    th_rate = 0.1
    for i in range(nrows):
        for j in range(ncols):
            if i == 0 and j == 0:
                _label = label.reshape((pixel, pixel)).to('cpu')
                axes[i, j].imshow(_label)
                axes[i, j].set_title('label')
                continue
            # pred_pro = compute_loss.show_pred(pred_pro, th_rate)
            iou = compute_loss.culc_iou(pred_pro, label, th_rate)
            pred_pro_ = pred_pro[:, 1, :, :]
        
            pred_pro_ = pred_pro_.reshape((pixel, pixel)).to('cpu')
            pred_pro_ = torch.where(pred_pro_>= th_rate, 1, 0)
            axes[i,j].imshow(pred_pro_)
            axes[i, j].set_title(f'pred_pro(th={th_rate})\n IoU:{round(iou,3)}')
            hist[round(th_rate, 1)].append(iou)
            th_rate += 0.1
    plt.tight_layout()
    plt.savefig(image_path)
    plt.close()

hist = defaultdict(list)
result_dir = 'result_img'
if os.path.exists(result_dir):
        shutil.rmtree(result_dir)
os.makedirs(result_dir)

with torch.no_grad():
    net.eval()
    for i, (data, label) in enumerate(iter(test_loader)):
        data = data.to(device)
        label = label.to(device)
        batch = len(data[0])
        data = data.reshape(num_steps, batch, 1, pixel, pixel)
        pred_pro = forward_pass(net, data)
        
        iou = compute_loss.culc_iou(pred_pro, label, correct_rate)

        # pred_pro = compute_loss.show_pred(pred_pro, correct_rate)
        img_path = os.path.join(result_dir, f'{str(i).zfill(4)}.png')
      
        save_img(pred_pro, label, img_path)


torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) torch.Size([1, 2, 64, 64])
torch.Size([1, 2, 64, 64]) t

## 閾値毎のIOUの平均算出

In [9]:
for key, value in hist.items():
    print(f'IoU_{key}:{np.mean(value)}')

IoU_0.1:0.19448467977694237
IoU_0.2:0.1929588017729111
IoU_0.3:0.19132550317794084
IoU_0.4:0.1894285321631469
IoU_0.5:0.1864378122438211
IoU_0.6:0.18147931992542
IoU_0.7:0.169953452370828
IoU_0.8:0.13539415887673387
IoU_0.9:0.026044897423125803
