# CRI CIFAR Demonstration with snnTorch 

## Training SNN with snnTorch

In [1]:
!pip install snntorch

Collecting snntorch
  Using cached snntorch-0.5.3-py2.py3-none-any.whl (95 kB)
Collecting torch>=1.1.0
  Using cached torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl (887.4 MB)
Collecting matplotlib
  Downloading matplotlib-3.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.8 MB)
[K     |████████████████████████████████| 11.8 MB 11.1 MB/s eta 0:00:01
[?25hCollecting numpy>=1.17
  Downloading numpy-1.24.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[K     |████████████████████████████████| 17.3 MB 112.0 MB/s eta 0:00:01
[?25hCollecting pandas
  Downloading pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.2 MB)
[K     |████████████████████████████████| 12.2 MB 101.5 MB/s eta 0:00:01
[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99
  Using cached nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
Collecting nvidia-cublas-cu11==11.10.3.66
  Using cached nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1

In [1]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import time
from quant_layer import weight_quantize_fn

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


### Import CIFAR datasets

In [4]:
# dataloader arguments
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
# Define a transform
transform = transforms.Compose([
            transforms.ToTensor()])

batch_size = 4
subset = 10

CIFAR_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
CIFAR_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


train_loader = DataLoader(CIFAR_train,batch_size=batch_size,shuffle=True,num_workers=2)
test_loader = DataLoader(CIFAR_test,batch_size=batch_size,shuffle=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
classes = CIFAR_train.classes

In [7]:
classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

### Define the network

In [15]:
# Network Architecture
num_inputs = 28*28
num_hidden_0 = 2500
num_hidden_1 = 2000
num_hidden_2 = 1500
num_hidden_3 = 1000
num_hidden_4 = 500
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 1.0
spike_grad = surrogate.sigmoid(slope=25)

In [40]:
net = nn.Sequential(nn.Linear(num_inputs, num_hidden, bias=True), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden, num_outputs, bias=True),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True, output=True)).to(device)
            

In [55]:
net6 = nn.Sequential(nn.Linear(num_inputs, num_hidden_0), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_0, num_hidden_1), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_1, num_hidden_2), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_2, num_hidden_3), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_3, num_hidden_4), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_4, num_outputs),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True, output=True)).to(device)

In [17]:
num_hidden_0 = 6000
num_hidden_1 = 5500
num_hidden_2 = 5000
num_hidden_3 = 4500
num_hidden_4 = 4000
num_hidden_5 = 3500
num_hidden_6 = 3000
num_hidden_7 = 2500
num_hidden_8 = 2000
num_hidden_9 = 1500
num_hidden_10 = 500
num_hidden_11 = 100
num_hidden_12 = 600
num_hidden_13 = 400
num_hidden_14 = 200
num_hidden_15 = 100


In [20]:
net12 = nn.Sequential(nn.Linear(num_inputs, num_hidden_0), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_0, num_hidden_1), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_1, num_hidden_2), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_2, num_hidden_3), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_3, num_hidden_4), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_4, num_hidden_5), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_5, num_hidden_6), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_6, num_hidden_7), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_7, num_hidden_8), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_8, num_hidden_9), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_9, num_hidden_10), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
#                     nn.Linear(num_hidden_10, num_hidden_11), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
#                     nn.Linear(num_hidden_11, num_hidden_12), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
#                     nn.Linear(num_hidden_12, num_hidden_13), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
#                     nn.Linear(num_hidden_13, num_hidden_14), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
#                     nn.Linear(num_hidden_14, num_hidden_15), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    nn.Linear(num_hidden_10, num_outputs),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True, output=True)).to(device)

In [30]:
net12

Sequential(
  (0): Linear(in_features=3072, out_features=6000, bias=True)
  (1): Leaky()
  (2): Linear(in_features=6000, out_features=5500, bias=True)
  (3): Leaky()
  (4): Linear(in_features=5500, out_features=5000, bias=True)
  (5): Leaky()
  (6): Linear(in_features=5000, out_features=4500, bias=True)
  (7): Leaky()
  (8): Linear(in_features=4500, out_features=4000, bias=True)
  (9): Leaky()
  (10): Linear(in_features=4000, out_features=3500, bias=True)
  (11): Leaky()
  (12): Linear(in_features=3500, out_features=3000, bias=True)
  (13): Leaky()
  (14): Linear(in_features=3000, out_features=2500, bias=True)
  (15): Leaky()
  (16): Linear(in_features=2500, out_features=2000, bias=True)
  (17): Leaky()
  (18): Linear(in_features=2000, out_features=1500, bias=True)
  (19): Leaky()
  (20): Linear(in_features=1500, out_features=500, bias=True)
  (21): Leaky()
  (22): Linear(in_features=500, out_features=10, bias=True)
  (23): Leaky()
)

In [21]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

In [22]:
def forward_pass(net, num_steps, data, batch_size):
    mem_rec = []
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(num_steps):
        spk_out, mem_out = net(data.view(batch_size, -1))
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)
  
    return torch.stack(spk_rec), torch.stack(mem_rec)

In [23]:
spk_rec, mem_rec = forward_pass(net6, num_steps, data, batch_size)

In [32]:
spk_rec.shape

torch.Size([25, 4, 10])

### Loss Functions

In [33]:
loss_fn = SF.ce_rate_loss()

In [36]:
loss_val = loss_fn(spk_rec, targets)

print(f"The loss from an untrained network is {loss_val.item():.3f}")

The loss from an untrained network is 2.299


### Accuracy 

In [37]:
acc = SF.accuracy_rate(spk_rec, targets)

print(f"The accuracy of a single batch using an untrained network is {acc*100:.3f}%")

The accuracy of a single batch using an untrained network is 25.000%


In [38]:
def batch_accuracy(train_loader, net, num_steps, batch_size):
    with torch.no_grad():
        total = 0
        acc = 0
        net.eval()

        train_loader = iter(train_loader)
        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)
            spk_rec, _ = forward_pass(net, num_steps, data, batch_size)

            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

    return acc/total

### Training

In [44]:
optimizer = torch.optim.SGD(net6.parameters(), lr=1e-3, momentum=0.9)

In [56]:
num_epochs = 2
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net6.train()
        spk_rec, mem_rec = forward_pass(net6, num_steps, data, batch_size)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val += loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net6.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = forward_pass(net6, num_steps, test_data,batch_size)

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            test_loss += loss_fn(test_spk, test_targets)

            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                    print(f"Epoch {epoch}, Iteration {iter_counter}")
                    print(f"Train Set Loss: {loss_hist[counter]:.3f}")
                    print(f"Test Set Loss: {test_loss_hist[counter]:.3f}")
                    train_acc = SF.accuracy_rate(spk_rec, targets)
                    test_acc = SF.accuracy_rate(test_spk, test_targets)
                    print(f"Train set accuracy for a single minibatch: {train_acc*100:.2f}%")
                    print(f"Test set accuracy for a single minibatch: {test_acc*100:.2f}%")
                    print("\n")
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 0.00%


Epoch 0, Iteration 50
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 0.00%


Epoch 0, Iteration 100
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 0.00%


Epoch 0, Iteration 150
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 0.00%


Epoch 0, Iteration 200
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 0.00%


Epoch 0, Iteration 250
Train Set Loss: 2.303
Test Set Loss: 2.303
Train set accuracy for a single minibatch: 0.00%
Test set accuracy for a single minibatch: 25.00%


Epoch 0, Ite

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f32f2682dc0>
Traceback (most recent call last):
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/Volumes/export/isn/keli/miniconda3/lib/python3.9/selectors.py", line 416, in select
 

RuntimeError: DataLoader worker (pid(s) 322224, 322236) exited unexpectedly

In [None]:
test_acc = batch_accuracy(test_loader, net6, num_steps, batch_size)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.title("Loss Curves")
plt.legend(["Train Loss", "Test Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

# Plot Accuracy
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()

### Save Models

In [10]:
def save_checkpoint(state, is_quan, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_quan:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_quan.pth.tar'))
    else:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))

In [11]:
if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'
if not os.path.exists(fdir):
    os.makedirs(fdir)

In [None]:
save_checkpoint({'state_dict': net16.state_dict(),}, 0, fdir)

### Quantization

In [None]:
def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)
        #print('uniform quant bit: ', b)
        return xhard

    class _pq(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, alpha):
            input.div_(alpha)                          # weights are first divided by alpha
            input_c = input.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)               # rescale to the original range
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()             # grad for weights will not be clipped
            input, input_q = ctx.saved_tensors
            i = (input.abs()>1.).float()     # >1 means clipped. # output matrix is a form of [True, False, True, ...]
            sign = input.sign()              # output matrix is a form of [+1, -1, -1, +1, ...]
            #grad_alpha = (grad_output*(sign*i + (input_q-input)*(1-i))).sum()
            grad_alpha = (grad_output*(sign*i + (0.0)*(1-i))).sum()
            # above line, if i = True,  and sign = +1, "grad_alpha = grad_output * 1"
            #             if i = False, "grad_alpha = grad_output * (input_q-input)"
            grad_input = grad_input*(1-i)
            return grad_input, grad_alpha

    return _pq().apply

class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1
        #self.wgt_alpha = wgt_alpha
        self.weight_q = weight_quantization(b=self.w_bit)
        #self.register_parameter('wgt_alpha', Parameter(torch.tensor(3.0)))
    def forward(self, weight):
        #mean = weight.data.mean()
        #std = weight.data.std()
        #weight = weight.add(-mean).div(std)      # weights normalization
        weight_q = self.weight_q(weight, self.wgt_alpha)

        return weight_q

In [None]:
w_alpha=1
w_bits=16
weight_quant = weight_quantize_fn(w_bit= w_bits)  ## define quant function
weight_quant.wgt_alpha = w_alpha
fc1_quant      = weight_quant(net6[0].weight)
w_delta        = w_alpha/(2**(w_bits-1)-1)
fc1_int        = fc1_quant/w_delta
print("FC1 Weights: \n",fc1_int)

In [None]:

for layer in net6:
        if isinstance(layer, torch.nn.Linear):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, torch.nn.Conv2d):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, snn.Leaky):
                layer.threshold = layer.threshold/w_delta

In [None]:
test_acc = batch_accuracy(test_loader, net6, num_steps, batch_size)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
save_checkpoint({'state_dict': net6.state_dict(),}, 1, fdir)

### Load Saved Model

In [41]:
best_model_path = '/Volumes/export/isn/keli/Desktop/CRI/result/model_quan.pth.tar'
checkpoint = torch.load(best_model_path)
net.load_state_dict(checkpoint['state_dict'])
net.eval()

Sequential(
  (0): Linear(in_features=784, out_features=1000, bias=True)
  (1): Leaky()
  (2): Linear(in_features=1000, out_features=10, bias=True)
  (3): Leaky()
)

In [42]:
test_acc = batch_accuracy(test_loader, net, num_steps, batch_size)
print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

NameError: name 'batch_accuracy' is not defined

## Mapping into CRI

In [19]:
for i, layer in enumerate(net):
    if i % 2 == 0:
        print(layer.weight.shape)

torch.Size([1000, 784])
torch.Size([10, 1000])


In [21]:
# extract weights and bias for torchsnn
layers, biases = [], []
for i, layer in enumerate(net):
    if i % 2 == 0:
        layers.append(layer.weight.detach().cpu().numpy())
        biases.append(layer.bias.detach().cpu().numpy())

print(np.min(layers[1]))
print(np.max(layers[1]))

-7424.0
5286.0


In [22]:
for layerNum, layer in enumerate(layers):
    print(layer.shape)
    print(biases[layerNum].shape)

(1000, 784)
(1000,)
(10, 1000)
(10,)


In [43]:
def conv2dOutputSize(layer,inputSize):
    H_out = (inputSize[0] + layer.padding[0]-layer.dilation[0]*(layer.kernel_size[0]-1)-1)/layer.stride[0] +1
    W_out = (inputSize[1] + layer.padding[0]-layer.dilation[1]*(layer.kernel_size[1]-1)-1)/layer.stride[1] +1
    return [layer.out_channels,int(H_out),int(W_out)]

def maxPoolOutputSize(layer,inputSize):
    H_out = (inputSize[1] + layer.padding - layer.dilation*(layer.kernel_size-1)-1)/layer.stride +1
    W_out = (inputSize[2] + layer.padding - layer.dilation*(layer.kernel_size-1)-1)/layer.stride +1
    return [inputSize[0],int(H_out),int(W_out)]

def AvgPoolOutputSize(layer,inputSize):
    H_out = (inputSize[1] + layer.padding*2 - (layer.kernel_size-1))/layer.stride +1
    W_out = (inputSize[2] + layer.padding*2 - (layer.kernel_size-1))/layer.stride +1
    return [inputSize[0],int(H_out),int(W_out)]

def conv2dToCRI(inputs,output,layer,axonsDict=None,neuronsDict=None):
    Hk, Wk = layer.kernel_size
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    filters = layer.weight.detach().cpu().numpy()
    if axonsDict is not None:
        Hi, Wi = inputs.shape
        for row in range(pad_top,Hi-pad_top):
            for col in range(pad_left,Wi-pad_left):
                patch = inputs[row-pad_top:row+pad_top+1,col-pad_left:col+pad_left+1]
                for filIdx, fil in enumerate(filters):
                    postSynapticID = str(output[filIdx,row-pad_top,col-pad_left])
                    for i,axons in enumerate(patch):
                        for j,axon in enumerate(axons):
                            axonsDict[axon].append((postSynapticID,int(fil[0,i,j])))
    else:
        Hi, Wi = inputs.shape[1],inputs.shape[2]
        for channel in range(inputs.shape[0]):
            for row in range(pad_top,Hi-pad_top):
                for col in range(pad_left,Wi-pad_left):
                    patch = inputs[channel,row-pad_top:row+pad_top+1,col-pad_left:col+pad_left+1]
                    for filIdx, fil in enumerate(filters):
                        postSynapticID = str(output[filIdx,row-pad_top,col-pad_left])
                        for i,neurons in enumerate(patch):
                            for j,neuron in enumerate(neurons):
                                neuronsDict[str(neuron)].append((postSynapticID,int(fil[channel,i,j])))

def maxPoolToCRI(inputs,output,layer,neuronsDict):
    Hk, Wk = layer.kernel_size, layer.kernel_size
    Hi, Wi = inputs.shape[1],inputs.shape[2]
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    scaler = 1e6
    for row in range(0,Hi,2):
        for col in range(0,Wi,2):
            for channel in range(inputs.shape[0]):
                patch = inputs[channel,row:row+pad_top+1,col:col+pad_left+1]
                postSynapticID = str(output[channel,row//2,col//2])
                for i,preSynNeurons in enumerate(patch):
                    for j,preSynNeuron in enumerate(preSynNeurons):
                        neuronsDict[str(preSynNeuron)].append((postSynapticID,scaler))
                        
                        
def avgPoolToCRI(inputs,output,layer,neuronsDict):
    Hk, Wk = layer.kernel_size, layer.kernel_size
    Hi, Wi = inputs.shape[1],inputs.shape[2]
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    scaler = 1e6
    for row in range(0,Hi,2):
        for col in range(0,Wi,2):
            for channel in range(inputs.shape[0]):
                patch = inputs[channel,row:row+pad_top+1,col:col+pad_left+1]
                postSynapticID = str(output[channel,row//2,col//2])
                for i,preSynNeurons in enumerate(patch):
                    for j,preSynNeuron in enumerate(preSynNeurons):
                        neuronsDict[str(preSynNeuron)].append((postSynapticID,scaler))
  

In [61]:
def linearToCRI(inputs,output,layer,axonsDict=None,neuronsDict=None,outputNeurons=None):
    inputs = inputs.flatten()
    weight = layer.weight.detach().cpu().numpy()
    if axonsDict is not None:
        for baseNeuronIdx, neuron in enumerate(weight.T):
            axonID = inputs[baseNeuronIdx]
            axonsDict[axonID] = [(str(basePostSynapticID), int(synapseWeight)) for basePostSynapticID, synapseWeight in enumerate(neuron) if synapseWeight != 0]
    else:
        currLayerNeuronIdxOffset,nextLayerNeuronIdxOffset = inputs[0],inputs[-1]+1
        for baseNeuronIdx, neuron in enumerate(weight.T):
            neuronID = str(baseNeuronIdx+currLayerNeuronIdxOffset)
            neuronEntry = [(str(basePostSynapticID+nextLayerNeuronIdxOffset), int(synapseWeight)) for basePostSynapticID, synapseWeight in enumerate(neuron) if synapseWeight != 0]
            neuronsDict[neuronID] = neuronEntry
    if outputNeurons is not None:
        print('instantiate output neurons')
        for baseNeuronIdx in range(layer.out_features):
            neuronID = str(baseNeuronIdx+nextLayerNeuronIdxOffset)
            neuronsDict[neuronID] = []
            outputNeurons.append(neuronID)

In [62]:
def linearBiasAxons(layer,axonsDict,axonOffset,outputs):
    biases = layer.bias.detach().cpu().numpy()
    for biasIdx, bias in enumerate(biases):
        biasID = 'a'+str(biasIdx+axonOffset)
        axonsDict[biasID] = [(str(outputs[biasIdx]),int(bias))]

In [63]:
from collections import defaultdict
axonsDict = defaultdict(list)
neuronsDict = defaultdict(list)
outputNeurons = []
H_in, W_in = 28, 28
inputSize = np.array([H_in, W_in])
axonOffset = 0
neuronOffset = 0
currInput = None

for layerIdx, layer in enumerate(net):
    if layerIdx == 0: #input layer
        if isinstance(layer,torch.nn.Conv2d):
            print('constructing Axons')
            outputSize = conv2dOutputSize(layer,inputSize)
            print("Input layer shape(infeature, outfeature): ", inputSize,',',outputSize)
            input = np.arange(0,np.prod(inputSize),dtype=int).reshape(inputSize)
            inputAxons = np.array([['a'+str(i) for i in row] for row in input])
            output = np.arange(0,np.prod(outputSize),dtype=int).reshape(outputSize)
            conv2dToCRI(inputAxons,output,layer,axonsDict)
            axonOffset += len(axonsDict)
            print('constructing bias axons for input layer:',layer.bias.shape[0],'axons')
            convBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
            currInput = output
        if isinstance(layer,torch.nn.Linear):
            print('constructing Axons')
            outputSize = layer.out_features
            print("output layer shape(infeature, outfeature): ", inputSize,',',outputSize)
            input = np.arange(0,np.prod(inputSize),dtype=int).reshape(inputSize)
            inputAxons = np.array([['a'+str(i) for i in row] for row in input])
            output = np.arange(0,outputSize,dtype=int)
            linearToCRI(inputAxons,output,layer,axonsDict=axonsDict)
            axonOffset += len(axonsDict)
            print('constructing bias axons for input layer:',layer.bias.shape[0],'axons')
            linearBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
            currInput = output
    elif layerIdx == len(net)-2: #output layer
        if isinstance(layer,torch.nn.Linear):
            print('constructing output layer')
            outputSize = layer.out_features
            print("output layer shape(infeature, outfeature): ", currInput.flatten().shape[0],',',outputSize)
            neuronOffset += np.prod(currInput.shape)
            output = np.arange(neuronOffset,neuronOffset+outputSize,dtype=int)
            linearToCRI(currInput,output,layer,neuronsDict = neuronsDict,outputNeurons=outputNeurons)
            print('constructing bias axons for output linearlayer:',layer.bias.shape[0],'axons')
            print('Numer of neurons:',len(neuronsDict))
            linearBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
    else: #hidden layer
        if isinstance(layer,torch.nn.AvgPool2d):
            print('constructing hidden avgpool layer')
            outputSize = AvgPoolOutputSize(layer,currInput.shape)
            print("Hidden layer shape(infeature, outfeature): ", currInput.shape,',',outputSize)
            neuronOffset += np.prod(currInput.shape)
            output = np.arange(neuronOffset,neuronOffset+np.prod(outputSize.shape),dtype=int).reshape(outputSize)
            avgPoolToCRI(currInput,output,layer,neuronsDict)
            currInput = output
            print('Numer of neurons:',len(neuronsDict))
        if isinstance(layer,torch.nn.Conv2d):
            print('constructing hidden conv2d layer')
            outputSize = conv2dOutputSize(layer,currInput.shape)
            print("Hidden layer shape(infeature, outfeature): ", currInput.shape,',',outputSize)
            neuronOffset += np.prod(currInput.shape)
            output = np.arange(neuronOffset,neuronOffset+np.prod(outputSize.shape),dtype=int).reshape(outputSize)
            conv2dToCRI(currInput,output,layer,neuronsDict=neuronsDict)
            print('constructing bias axons for hidden conv2d layer:',layer.bias.shape[0],'axons')
            convBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
            currInput = output
            print('Numer of neurons:',len(neuronsDict))


constructing Axons
output layer shape(infeature, outfeature):  [28 28] , 1000
constructing bias axons for input layer: 1000 axons
constructing output layer
output layer shape(infeature, outfeature):  1000 , 10
instantiate output neurons
constructing bias axons for output linearlayer: 10 axons
Numer of neurons: 1010


In [65]:
output[0]

1000

In [64]:
print("Number of axons: ",len(axonsDict))
totalAxonSyn = 0
maxFan = 0
for key in axonsDict.keys():
    totalAxonSyn += len(axonsDict[key])
    if len(axonsDict[key]) > maxFan:
        maxFan = len(axonsDict[key])
print("Total number of connections between axon and neuron: ", totalAxonSyn)
print("Max fan out of axon: ", maxFan)
print('---')
print("Number of neurons: ", len(neuronsDict))
totalSyn = 0
maxFan = 0
for key in neuronsDict.keys():
    totalSyn += len(neuronsDict[key])
    if len(neuronsDict[key]) > maxFan:
        maxFan = len(neuronsDict[key])
print("Total number of connections between hidden and output layers: ", totalSyn)
print("Max fan out of neuron: ", maxFan)

Number of axons:  1794
Total number of connections between axon and neuron:  784716
Max fan out of axon:  1000
---
Number of neurons:  1010
Total number of connections between hidden and output layers:  9997
Max fan out of neuron:  10


In [23]:
from l2s.api import CRI_network
import cri_simulations



In [None]:
config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = 9*10**4
#softwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='simpleSim', outputs = outputs)
hardwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='CRI', outputs = outputs,simDump = False)

added axons to connectome
added neurons to connectome
added axon synpases
added neuron synapses
generated Connectome
begin: 1
end: 1
begin: 2
end: 2
begin: 3
end: 3
begin: 4
end: 4
begin: 5
end: 5
begin: 6
end: 6
begin: 7
end: 7
begin: 8
end: 8
begin: 9
end: 9
begin: 10
end: 10
begin: 11
end: 11
begin: 12
end: 12
begin: 13
end: 13
begin: 14
end: 14
begin: 15
end: 15
begin: 16
end: 16
begin: 17
end: 17
begin: 18
end: 18
begin: 19
end: 19
begin: 20
end: 20
begin: 21
end: 21
begin: 22
end: 22
begin: 23
end: 23
begin: 24
end: 24
begin: 25
end: 25
begin: 26
end: 26
begin: 27
end: 27
begin: 28
end: 28
begin: 29
end: 29
begin: 30
end: 30
begin: 31
end: 31
begin: 32
end: 32
begin: 33
end: 33
begin: 34
end: 34
begin: 35
end: 35
begin: 36
end: 36
begin: 37
end: 37
begin: 38
end: 38
begin: 39
end: 39
begin: 40
end: 40
begin: 41
end: 41
begin: 42
end: 42
begin: 43
end: 43
begin: 44
end: 44
begin: 45
end: 45
begin: 46
end: 46
begin: 47
end: 47
begin: 48
end: 48
begin: 49
end: 49
begin: 50
end: 50
b

In [None]:
def input_to_CRI(currentInput):
    num_steps = 10
    currentInput = currentInput.view(currentInput.size(0), -1)
    batch = []
    for element in currentInput:
        timesteps = []
        rateEnc = spikegen.rate(element,num_steps)
        rateEnc = rateEnc.detach().cpu().numpy()
        for element in rateEnc:
            currInput = ['a'+str(idx) for idx,axon in enumerate(element) if axon != 0]
            biasInput = ['a'+str(idx) for idx in range(784,len(axonsDict))]
#             timesteps.append(currInput)
#             timesteps.append(biasInput)
            timesteps.append(currInput+biasInput)
        batch.append(timesteps)
    return batch

In [None]:
def run_CRI(inputList,output_offset):
    predictions = []
    total_time_cri = 0
    #each image
    for currInput in inputList:
        #reset the membrane potential to zero
        softwareNetwork.simpleSim.initialize_sim_vars(len(neuronsDict))
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            swSpike = softwareNetwork.step(slice, membranePotential=False)
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            for spike in swSpike:
                spikeIdx = int(spike) - output_offset 
                try: 
                    if spikeIdx >= 0: 
                        spikeRate[spikeIdx] += 1 
                except:
                    print("SpikeIdx: ", spikeIdx,"\n SpikeRate:",spikeRate )
        predictions.append(spikeRate.index(max(spikeRate)))
    print(f"Total simulation execution time: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
def run_CRI_hw(inputList,output_offset):
    predictions = []
    #each image
    total_time_cri = 0
    for currInput in inputList:
        #initiate the softwareNetwork for each image
        cri_simulations.FPGA_Execution.fpga_controller.clear(len(neuronsDict), False, 0)  ##Num_neurons, simDump, coreOverride
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            hwSpike = hardwareNetwork.step(slice, membranePotential=False)
#             print("Mem:",mem)
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            print(hwSpike)
            for spike in hwSpike:
                print(int(spike))
                spikeIdx = int(spike) - output_offset 
                if spikeIdx >= 0: 
                    spikeRate[spikeIdx] += 1 
        predictions.append(spikeRate.index(max(spikeRate))) 
    print(f"Total execution time CRIFPGA: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
total = 0
correct = 0
cri_correct = 0
cri_correct_hw = 0
batch_size = 128
# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
output_offset = int(outputs[0])
with torch.no_grad():
    net6.eval()
    
    train_loader = iter(train_loader)
    count = 0
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        input = input_to_CRI(data)
#         criPred = torch.tensor(run_CRI(input,output_offset)).to(device)
        criPred_hw = torch.tensor(run_CRI_hw(input,output_offset)).to(device)

        # calculate total accuracy
        spk_rec, _ = forward_pass(net6, num_steps, data, batch_size)

        correct += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
        total += spk_rec.size(1)
#         cri_correct += (criPred == targets).sum().item()
        cri_correct_hw += (criPred_hw == targets).sum().item()
        count += 1
#         if count == 12:
        break

ERROR:root:non-spike packet encountered during spike flush: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 205 171 205 171]
ERROR:root:non-spike packet encountered during spike flush: [227 242 130  22   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 202 186 202]


[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [  3  55   8  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]
ERROR:root:non-spike packet encountered during spike flush: [181  95 131  22   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 202 186 202]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [221 161   8  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 205 171 205 171]
ERROR:root:non-spike packet encountered during spike flush: [172  18   9  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]
ERROR:root:non-spike packet encountered during spike flush: [197  54 132  22   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 202 186 202]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [233 130   9  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [145 242   9  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [ 99 106  10  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 205 171 205 171]
ERROR:root:non-spike packet encountered during spike flush: [113 231  10  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]
ERROR:root:non-spike packet encountered during spike flush: [145   2 134  22   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 202 186 202]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [191 102  11  23   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 186 186 186 186]


[]
[0 0 0 ... 0 0 0]


ERROR:root:non-spike packet encountered during spike flush: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0 205 171 205 171]


[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]
[]
[0 0 0 ... 0 0 0]


In [None]:
# print(f"Totoal execution time: {end_time-start_time:.2f} s")
print(f"Total correctly classified test set images for TorchSNN: {correct}/{total}")
print(f"Total correctly classified test set images for CRI: {cri_correct}/{total}")
print(f"Test Set Accuracy for TorchSNN: {100 * correct / total:.2f}%")
print(f"Test Set Accuracy for CRI: {100 * cri_correct / total:.2f}%")