In [1]:
import numpy as np
import tonic
import tonic.transforms as transforms
import pandas as pd
from torch.utils.data import Dataset, DataLoader
#from tonic.dataset import Dataset
from typing import Callable, Optional
import torch
import matplotlib.pyplot as plt
from torchvision.ops import masks_to_boxes
from torchvision.utils import draw_bounding_boxes
import matplotlib.patches as patches
from tqdm.notebook import tqdm
from torchvision.transforms import Lambda

In [2]:
import sys
sys.path.append("./mnist_sg_cnn")

In [3]:
import utils
import snn_utils
import base_model
import lenet_decolle_model

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [5]:
num_labels = 10
num_bins_per_frame = 100
epochs = 100
batch_size = 128

In [6]:
sensor_size = tonic.datasets.NMNIST.sensor_size
sensor_size

frame_transform = transforms.ToFrame(sensor_size=sensor_size, n_time_bins=num_bins_per_frame)
denoise_transform = tonic.transforms.Denoise(filter_time=10000)
transform = transforms.Compose([denoise_transform, frame_transform])


target_transform = Lambda(lambda y: torch.zeros(num_labels, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) # One hot encode

In [7]:
trainset = tonic.datasets.NMNIST(save_to="./data/tonic/NMNIST", train=True, transform=transform, target_transform=target_transform)
testset = tonic.datasets.NMNIST(save_to='./data/tonic/NMNIST', transform=frame_transform, train=False, target_transform=target_transform)

In [8]:
from torch.utils.data import DataLoader
#from tonic import DiskCachedDataset


#cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache/nmnist/train')
#cached_testset = DiskCachedDataset(testset, cache_path='./data/cache/nmnist/test')


trainloader = DataLoader(trainset, batch_size=batch_size) #collate_fn=tonic.collation.PadTensors())
testloader = DataLoader(testset, batch_size=batch_size) #collate_fn=tonic.collation.PadTensors(batch_first=False))

In [9]:
data, target = next(iter(trainloader))

In [10]:
data.shape

torch.Size([128, 100, 2, 34, 34])

In [11]:
target.shape

torch.Size([128, 10])

In [12]:
#target

In [13]:

def decolle_loss(r, s, tgt):
    loss_tv = 0
    for i in range(len(r)):
        #print(r[i].shape)
        #print(tgt.shape)
        loss_tv += loss(r[i],tgt) 
    return loss_tv

loss = torch.nn.SmoothL1Loss()

convnet_sg = lenet_decolle_model.LenetDECOLLE( out_channels=10,
                    Nhid=[16,32], #Number of convolution channels
                    Mhid=[64],
                    kernel_size=[7],
                    pool_size=[2,2],
                    input_shape=data.shape[2:],
                    alpha=[.95],
                    alpharp=[.65],
                    beta=[.92],
                    num_conv_layers=2,
                    num_mlp_layers=1,
                    lc_ampl=.5).to(device)

convnet_sg

data_d = data.to(device)
target_d = target.to(device)
convnet_sg.init_parameters(data_d)

LAYER SIZE: 4624
STDV: 0.007352941176470588
LAYER SIZE: 2048
STDV: 0.011048543456039804
LAYER SIZE: 64
STDV: 0.0625


In [14]:
# %timeit -o -r 10 data_d.transpose(0, 1)

1.38 µs ± 8.88 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)


<TimeitResult : 1.38 µs ± 8.88 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)>

In [16]:
from tqdm.notebook import tqdm

opt_conv = torch.optim.Adamax(convnet_sg.get_trainable_parameters(), lr=1e-9, betas=[0., .95])
for e in range(epochs):        
    error = []
    accuracy=[]
    for data, label in tqdm(iter(trainloader)):
        convnet_sg.train()
        loss_hist = 0
        data_d = data.to(device)
        label_d = label.to(device)
        convnet_sg.init(data_d, burnin=10)
        readout = 0

        #print(label_d)
        #print(label_d.shape)
        #break

        data_d = data_d.transpose(0, 1)
        
        for n in range(num_bins_per_frame):
           # print(f"Data shape: {data_d[n].shape}, overall: {data_d.shape}")
            st, rt, ut = convnet_sg.forward(data_d[n])
            #print("Readout")
            #print(len(rt), rt[0].shape, rt[1].shape, rt[2].shape)
            #print(f"Label: {label_d[n].shape} overall {label_d.shape}")
            #print(label_d[n].shape)
            #print(label_d[n])
            loss_tv = decolle_loss(rt, st, label_d)
            loss_tv.backward()
            opt_conv.step()
            opt_conv.zero_grad()
            loss_hist += loss_tv
            readout += rt[-1]
        error += (readout.argmax(axis=1)!=label_d.argmax(axis=1)).float()
        accuracy+=(readout.argmax(axis=1)==label_d.argmax(axis=1)).float()
    print('Training Error', torch.mean(torch.Tensor(error)).data)
    print('Training accuracy', torch.mean(torch.Tensor(accuracy)).data)     
    print('Epoch', e, 'Loss', loss_hist.data)
    PATH = './mnist_network_sg_conv.pth'
    torch.save(convnet_sg.state_dict(), PATH)

  0%|          | 0/469 [00:00<?, ?it/s]

Training Error tensor(0.2207)
Training accuracy tensor(0.7793)
Epoch 0 Loss tensor(7.8938, device='cuda:0')


  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn

In [None]:
# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [None]:
# this time, we won't return membrane as we don't need it

def forward_pass(net, data):
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)

  return torch.stack(spk_rec)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [None]:
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in tqdm(enumerate(iter(dataloader)), desc=f"Epoch {e}"):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
          break