[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="400">](https://github.com/jeshraghian/snntorch/)

# snnTorch - Surrogate Gradient Descent in a Convolutional Spiking Neural Network
## Tutorial 6
### By Jason K. Eshraghian (www.ncg.ucsc.edu)

<a href="https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_6_CNN.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width="28">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width="80">](https://github.com/jeshraghian/snntorch/)

The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>


# Introduction
In this tutorial, you will:
* Learn how to use surrogate gradient descent to overcome the dead neuron problem
* Construct and train a convolutional spiking neural network
* Use a sequential container, `nn.Sequential` to simplify model construction
* Use the `snn.backprop` module to reduce the time it takes to design a neural network

Part of this tutorial was inspired by Friedemann Zenke’s extensive
work on SNNs. Check out his repo on surrogate gradients
[here](https://github.com/fzenke/spytorch), and a favourite paper
of mine: E. O. Neftci, H. Mostafa, F. Zenke, [Surrogate Gradient
Learning in Spiking Neural Networks: Bringing the Power of
Gradient-based optimization to spiking neural
networks.](https://ieeexplore.ieee.org/document/8891809) IEEE
Signal Processing Magazine 36, 51–63.

At the end of the tutorial, we will train a convolutional spiking neural network (CSNN) using the MNIST dataset to perform image classification. The background theory follows on from [Tutorials 2, 4 and 5](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), so feel free to go back if you need to brush up.

If running in Google Colab:
* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`
* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`.

In [1]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.6.2-py2.py3-none-any.whl (104 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.7/104.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
Installing collected packages: snntorch
Successfully installed snntorch-0.6.2


In [2]:
!pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [None]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

import os
import shutil

### Define gradient

In [None]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)

To explore the other surrogate gradient functions available, [take a look at the documentation here.](https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html)

## 2.1 DataLoaders
Note that `utils.data_subset()` is called to reduce the size of the dataset by a factor of 10 to speed up training.

In [None]:
# dataloader arguments
batch_size = 128
data_path='~/justinData/mnist'
subset=10

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# reduce datasets by 10x to speed up training
utils.data_subset(mnist_train, subset)
utils.data_subset(mnist_test, subset)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

## 2.2 Define the Network

The convolutional network architecture to be used is: 12C5-MP2-64C5-MP2-1024FC10

- 12C5 is a 5$\times$5 convolutional kernel with 12 filters
- MP2 is a 2$\times$2 max-pooling function
- 1024FC10 is a fully-connected layer that maps 1,024 neurons to 10 outputs

In [None]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 1.0
num_steps = 15

In [None]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)
        
        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)
    
        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)
        return spk3, mem3

In the previous tutorial, the network was wrapped inside of a class, as shown above. 
With increasing network complexity, this adds a lot of boilerplate code that we might wish to avoid. Alternatively, the `nn.Sequential` method can be used instead.

> Note: the following code-block simulates over one single time-step, and requires a separate for-loop over time.

In [None]:
#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.AvgPool2d(2, stride=2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.AvgPool2d(2, stride=2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)


The `init_hidden` argument initializes the hidden states of the neuron (here, membrane potential). This takes place in the background as an instance variable. 
If `init_hidden` is activated, the membrane potential is not explicitly returned to the user, ensuring only the output spikes are sequentially passed through the layers wrapped in `nn.Sequential`. 

To train a model using the final layer's membrane potential, set the argument `output=True`. 
This enables the final layer to return both the spike and membrane potential response of the neuron.

## 2.3 Forward-Pass


Wrap that in a function, recording the membrane potential and spike response over time:

In [None]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)
  
  return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
spk_rec, mem_rec = forward_pass(net, num_steps, data)

In [None]:
net

# 3. Training Loop

## 3.1 Loss Using snn.Functional

In the previous tutorial, the Cross Entropy Loss between the membrane potential of the output neurons and the target was used to train the network. 
This time, the total number of spikes from each neuron will be used to calculate the Cross Entropy instead.

A variety of loss functions are included in the `snn.functional` module, which is analogous to `torch.nn.functional` in PyTorch. 
These implement a mix of cross entropy and mean square error losses, are applied to spikes and/or membrane potential, to train a rate or latency-coded network. 

The approach below applies the cross entropy loss to the output spike count in order train a rate-coded network:

In [None]:
# already imported snntorch.functional as SF 
loss_fn = SF.ce_rate_loss()

The recordings of the spike are passed as the first argument to `loss_fn`, and the target neuron index as the second argument to generate a loss. [The documentation provides further information and exmaples.](https://snntorch.readthedocs.io/en/latest/snntorch.functional.html#snntorch.functional.ce_rate_loss)

In [None]:
loss_val = loss_fn(spk_rec, targets)

print(f"The loss from an untrained network is {loss_val.item():.3f}")

## 3.2 Accuracy Using snn.Functional
The `SF.accuracy_rate()` function works similarly, in that the predicted output spikes and actual targets are supplied as arguments. `accuracy_rate` assumes a rate code is used to interpret the output by checking if the index of the neuron with the highest spike count matches the target index.

In [None]:
acc = SF.accuracy_rate(spk_rec, targets)

print(f"The accuracy of a single batch using an untrained network is {acc*100:.3f}%")





As the above function only returns the accuracy of a single batch of data, the following function returns the accuracy on the entire DataLoader object:

In [None]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()
    
    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

## 3.3 Training Automation Using snn.backprop

Training SNNs can become arduous even with simple networks, so the `snn.backprop` module is here to reduce some of this effort.

The `backprop.BPTT` function automatically performs a single epoch of training, where you need only provide the training parameters, dataloader, and several other arguments. 
The average loss across iterations is returned. 
The argument `time_var` indicates whether the input data is time-varying. 
As we are using the MNIST dataset, we explicitly specify `time_var=False`. 

The following code block may take a while to run. If you are not connected to GPU, then consider reducing `num_epochs`.

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
num_epochs = 20
test_acc_hist = []

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, optimizer=optimizer, criterion=loss_fn, 
                            num_steps=num_steps, time_var=False, device=device)
    
    print(f"Epoch {epoch}, Train Loss: {avg_loss.item():.2f}")

    # Test set accuracy
    test_acc = batch_accuracy(test_loader, net, num_steps)
    test_acc_hist.append(test_acc)

    print(f"Epoch {epoch}, Test Acc: {test_acc * 100:.2f}%\n")

# 4. Results
## 4.1 Plot Test Accuracy

In [None]:
# Plot Accuracy
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()

## 4.2 Spike Counter

Despite having selected some fairly generic values and architectures, the test set accuracy should be fairly competitive given the brief training run!

In [None]:
from IPython.display import HTML

idx = 0

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']
print(f"The target label is: {targets[idx]}")

# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels, 
                        animate=True, interpolate=4)


HTML(anim.to_html5_video())
# anim.save("spike_bar.mp4")

### Quantization 

In [None]:
def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)
        #print('uniform quant bit: ', b)
        return xhard

    class _pq(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, alpha):
            input.div_(alpha)                          # weights are first divided by alpha
            input_c = input.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)               # rescale to the original range
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()             # grad for weights will not be clipped
            input, input_q = ctx.saved_tensors
            i = (input.abs()>1.).float()     # >1 means clipped. # output matrix is a form of [True, False, True, ...]
            sign = input.sign()              # output matrix is a form of [+1, -1, -1, +1, ...]
            #grad_alpha = (grad_output*(sign*i + (input_q-input)*(1-i))).sum()
            grad_alpha = (grad_output*(sign*i + (0.0)*(1-i))).sum()
            # above line, if i = True,  and sign = +1, "grad_alpha = grad_output * 1"
            #             if i = False, "grad_alpha = grad_output * (input_q-input)"
            grad_input = grad_input*(1-i)
            return grad_input, grad_alpha

    return _pq().apply

class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit, wgt_alpha):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1
        self.wgt_alpha = wgt_alpha
        self.weight_q = weight_quantization(b=self.w_bit)
        #self.register_parameter('wgt_alpha', Parameter(torch.tensor(3.0)))
    def forward(self, weight):
        #mean = weight.data.mean()
        #std = weight.data.std()
        #weight = weight.add(-mean).div(std)      # weights normalization
        weight_q = self.weight_q(weight, self.wgt_alpha)

        return weight_q

In [None]:
w_alpha=1
w_bits=16
weight_quant = weight_quantize_fn(w_bit= w_bits)  ## define quant function
weight_quant.wgt_alpha = w_alpha
conv1_quant      = weight_quant(net[0].weight)
w_delta          = w_alpha/(2**(w_bits-1)-1)
conv1_int        = conv1_quant/w_delta
print("Conv1 Weights: \n",conv1_int)

In [None]:
for layer in net:
    print(layer)

In [None]:
net[1]

In [None]:
for layer in net:
        if isinstance(layer, torch.nn.Linear):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, torch.nn.Conv2d):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, snn.Leaky):
                layer.threshold = layer.threshold/w_delta

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
save_checkpoint({'state_dict': net.state_dict(),}, 0, fdir)

### Save Model

In [None]:
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_cnn_best.pth.tar'))
    else:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_cnn_quan.pth.tar'))

In [None]:

if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'
if not os.path.exists(fdir):
    os.makedirs(fdir)

In [None]:
save_checkpoint({'state_dict': net.state_dict(),}, 1, fdir)

### Load Saved Model

In [None]:
best_model_path = '/Volumes/export/isn/keli/Desktop/CRI/result/model_cnn_quan.pth.tar'
checkpoint = torch.load(best_model_path)
net.load_state_dict(checkpoint['state_dict'])
net.eval()

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps)
print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

# Conclusion
You should now have a grasp of the basic features of snnTorch and be able to start running your own experiments. [In the next tutorial](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), we will train a network using a neuromorphic dataset.

A special thanks to [Gianfresco Angelini](https://github.com/gianfa) for providing valuable feedback on the tutorial.

If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.

# Additional Resources
* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)

# Mapping into CRI 

In [1]:
layers = [net.state_dict()['0.weight'].detach().cpu().numpy(),net.state_dict()['3.weight'].detach().cpu().numpy(),net.state_dict()['7.weight'].detach().cpu().numpy()]
biases = [net.state_dict()['0.bias'].detach().cpu().numpy(),net.state_dict()['3.bias'].detach().cpu().numpy(),net.state_dict()['7.bias'].detach().cpu().numpy()]

NameError: name 'net' is not defined

In [None]:
print(np.min(layers[1]))
print(np.max(layers[1]))

In [None]:
def conv2dOutputSize(layer,inputSize):
    H_out = (inputSize[0] + layer.padding[0]-layer.dilation[0]*(layer.kernel_size[0]-1)-1)/layer.stride[0] +1
    W_out = (inputSize[1] + layer.padding[0]-layer.dilation[1]*(layer.kernel_size[1]-1)-1)/layer.stride[1] +1
    return [layer.out_channels,int(H_out),int(W_out)]


In [None]:
def maxPoolOutputSize(layer,inputSize):
    H_out = (inputSize[1] + layer.padding - layer.dilation*(layer.kernel_size-1)-1)/layer.stride +1
    W_out = (inputSize[2] + layer.padding - layer.dilation*(layer.kernel_size-1)-1)/layer.stride +1
    return [inputSize[0],int(H_out),int(W_out)]


In [None]:
def AvgPoolOutputSize(layer,inputSize):
    H_out = (inputSize[1] + layer.padding*2 - (layer.kernel_size-1))/layer.stride +1
    W_out = (inputSize[2] + layer.padding*2 - (layer.kernel_size-1))/layer.stride +1
    return [inputSize[0],int(H_out),int(W_out)]

In [None]:
def conv2dToCRI(inputs,output,layer,layerIdx,axonsDict=None,neuronsDict=None):
    Hk, Wk = layer.kernel_size
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    filters = layer.weight.detach().cpu().numpy()
    if layerIdx==0:
        Hi, Wi = inputs.shape
        for row in range(pad_top,Hi-pad_top):
            for col in range(pad_left,Wi-pad_left):
                patch = inputs[row-pad_top:row+pad_top+1,col-pad_left:col+pad_left+1]
                for filIdx, fil in enumerate(filters):
                    postSynapticID = str(output[filIdx,row-pad_top,col-pad_left])
                    for i,axons in enumerate(patch):
                        for j,axon in enumerate(axons):
                            axonsDict[axon].append((postSynapticID,int(fil[0,i,j])))
    else:
        Hi, Wi = inputs.shape[1],inputs.shape[2]
        for channel in range(inputs.shape[0]):
            for row in range(pad_top,Hi-pad_top):
                for col in range(pad_left,Wi-pad_left):
                    patch = inputs[channel,row-pad_top:row+pad_top+1,col-pad_left:col+pad_left+1]
                    for filIdx, fil in enumerate(filters):
                        postSynapticID = str(output[filIdx,row-pad_top,col-pad_left])
                        for i,neurons in enumerate(patch):
                            for j,neuron in enumerate(neurons):
                                neuronsDict[str(neuron)].append((postSynapticID,int(fil[channel,i,j])))


In [None]:
def maxPoolToCRI(inputs,output,layer,neuronsDict):
    Hk, Wk = layer.kernel_size, layer.kernel_size
    Hi, Wi = inputs.shape[1],inputs.shape[2]
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    scaler = 1e6
    for row in range(0,Hi,2):
        for col in range(0,Wi,2):
            for channel in range(inputs.shape[0]):
                patch = inputs[channel,row:row+pad_top+1,col:col+pad_left+1]
                postSynapticID = str(output[channel,row//2,col//2])
                for i,preSynNeurons in enumerate(patch):
                    for j,preSynNeuron in enumerate(preSynNeurons):
                        neuronsDict[str(preSynNeuron)].append((postSynapticID,scaler))

In [None]:
def avgPoolToCRI(inputs,output,layer,neuronsDict):
    Hk, Wk = layer.kernel_size, layer.kernel_size
    Hi, Wi = inputs.shape[1],inputs.shape[2]
    Ho, Wo = output.shape[1],output.shape[2]
    pad_top,pad_left = Hk//2,Wk//2
    scaler = 1e6
    for row in range(0,Hi,2):
        for col in range(0,Wi,2):
            for channel in range(inputs.shape[0]):
                patch = inputs[channel,row:row+pad_top+1,col:col+pad_left+1]
                postSynapticID = str(output[channel,row//2,col//2])
                for i,preSynNeurons in enumerate(patch):
                    for j,preSynNeuron in enumerate(preSynNeurons):
                        neuronsDict[str(preSynNeuron)].append((postSynapticID,scaler))

In [None]:
def linearToCRI(inputs,output,layer,layerIdx,neuronsDict,outputNeurons=None):
    inputs = inputs.flatten()
    weight = layer.weight.detach().cpu().numpy()
    currLayerNeuronIdxOffset,nextLayerNeuronIdxOffset = inputs[0],inputs[-1]+1
    for baseNeuronIdx, neuron in enumerate(weight.T):
        neuronID = str(baseNeuronIdx+currLayerNeuronIdxOffset)
        neuronEntry = [(str(basePostSynapticID+nextLayerNeuronIdxOffset), int(synapseWeight)) for basePostSynapticID, synapseWeight in enumerate(neuron) if synapseWeight != 0]
        neuronsDict[neuronID] = neuronEntry
    print('instantiate output neurons')
    for baseNeuronIdx in range(layer.out_features):
        neuronID = str(baseNeuronIdx+nextLayerNeuronIdxOffset)
        neuronsDict[neuronID] = []
        outputNeurons.append(neuronID)

In [None]:
def convBiasAxons(layer,axonsDict,axonOffset,outputs):
    biases = layer.bias.detach().cpu().numpy()
    for biasIdx, bias in enumerate(biases):
        biasID = 'a'+str(biasIdx+axonOffset)
        axonsDict[biasID] = [(str(neuronIdx),int(bias)) for neuronIdx in outputs[biasIdx].flatten()]

In [None]:
def linearBiasAXons(layer,axonsDict,axonOffset,outputs):
    biases = layer.bias.detach().cpu().numpy()
    for biasIdx, bias in enumerate(biases):
        biasID = 'a'+str(biasIdx+axonOffset)
        axonsDict[biasID] = [(str(outputs[biasIdx]),int(bias))]

In [None]:

from collections import defaultdict
axonsDict = defaultdict(list)
neuronsDict = defaultdict(list)
outputNeurons = []
H_in, W_in = 28, 28
inputSize = (H_in, W_in)
axonOffset = 0
neuronOffset = 0
currInput = None

for layerIdx, layer in enumerate(net):
    if layerIdx == 0: #input layer
        if isinstance(layer,torch.nn.Conv2d):
            print('constructing Axons')
            outputSize = conv2dOutputSize(layer,inputSize)
            print("Input layer shape(infeature, outfeature): ", inputSize,',',outputSize)
            input = np.arange(0,inputSize[0]*inputSize[1],dtype=int).reshape(inputSize)
            inputAxons = np.array([['a'+str(i) for i in row] for row in input])
            output = np.arange(0,outputSize[0]*outputSize[1]*outputSize[2],dtype=int).reshape(outputSize)
            conv2dToCRI(inputAxons,output,layer,layerIdx,axonsDict)
            axonOffset += len(axonsDict)
            print('constructing bias axons for input layer:',layer.bias.shape[0],'axons')
            convBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
            currInput = output
    elif layerIdx == len(net)-2: #output layer
        if isinstance(layer,torch.nn.Linear):
            print('constructing output layer')
            outputSize = layer.out_features
            print("output layer shape(infeature, outfeature): ", currInput.flatten().shape[0],',',outputSize)
            neuronOffset += currInput.shape[0]*currInput.shape[1]*currInput.shape[2]
            output = np.arange(neuronOffset,neuronOffset+outputSize,dtype=int)
            linearToCRI(currInput,output,layer,layerIdx,neuronsDict=neuronsDict,outputNeurons=outputNeurons)
            print('constructing bias axons for output linearlayer:',layer.bias.shape[0],'axons')
            print('Numer of neurons:',len(neuronsDict))
            linearBiasAXons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
    else: #hidden layer
        if isinstance(layer,torch.nn.AvgPool2d):
            print('constructing hidden avgpool layer')
            outputSize = AvgPoolOutputSize(layer,currInput.shape)
            print("Hidden layer shape(infeature, outfeature): ", currInput.shape,',',outputSize)
            neuronOffset += currInput.shape[0]*currInput.shape[1]*currInput.shape[2]
            output = np.arange(neuronOffset,neuronOffset+outputSize[0]*outputSize[1]*outputSize[2],dtype=int).reshape(outputSize)
            avgPoolToCRI(currInput,output,layer,neuronsDict)
            currInput = output
            print('Numer of neurons:',len(neuronsDict))
        if isinstance(layer,torch.nn.Conv2d):
            print('constructing hidden conv2d layer')
            outputSize = conv2dOutputSize(layer,currInput.shape)
            print("Hidden layer shape(infeature, outfeature): ", currInput.shape,',',outputSize)
            neuronOffset += currInput.shape[0]*currInput.shape[1]*currInput.shape[2]
            output = np.arange(neuronOffset,neuronOffset+outputSize[0]*outputSize[1]*outputSize[2],dtype=int).reshape(outputSize)
            conv2dToCRI(currInput,output,layer,layerIdx,neuronsDict=neuronsDict)
            print('constructing bias axons for hidden conv2d layer:',layer.bias.shape[0],'axons')
            convBiasAxons(layer,axonsDict,axonOffset,output)
            axonOffset += layer.bias.shape[0]
            currInput = output
            print('Numer of neurons:',len(neuronsDict))


In [None]:
print("Number of axons: ",len(axonsDict))
totalAxonSyn = 0
maxFan = 0
for key in axonsDict.keys():
    totalAxonSyn += len(axonsDict[key])
    if len(axonsDict[key]) > maxFan:
        maxFan = len(axonsDict[key])
print("Total number of connections between axon and neuron: ", totalAxonSyn)
print("Max fan out of axon: ", maxFan)
print('---')
print("Number of neurons: ", len(neuronsDict))
totalSyn = 0
maxFan = 0
for key in neuronsDict.keys():
    totalSyn += len(neuronsDict[key])
    if len(neuronsDict[key]) > maxFan:
        maxFan = len(neuronsDict[key])
print("Total number of connections between hidden and output layers: ", totalSyn)
print("Max fan out of neuron: ", maxFan)
print(len(axonsDict))
print(len(neuronsDict))

In [None]:
axonsDict, neuronsDict = dict(axonsDict), dict(neuronsDict)

In [None]:
import time

In [None]:
from l2s.api import CRI_network
import cri_simulations

In [None]:
config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = 9*10**4
#softwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='simpleSim', outputs = outputNeurons)
hardwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='CRI', outputs = outputNeurons,simDump = False)

In [None]:
def input_to_CRI(currentInput):
    num_steps = 10
    currentInput = data.view(data.size(0), -1)
    batch = []
    n = 0
    for element in currentInput:
        timesteps = []
        rateEnc = spikegen.rate(element,num_steps)
        rateEnc = rateEnc.detach().cpu().numpy()
        for element in rateEnc:
            currInput = ['a'+str(idx) for idx,axon in enumerate(element) if axon != 0]
            biasInput = ['a'+str(idx) for idx in range(784,len(axonsDict))]
#             timesteps.append(currInput)
#             timesteps.append(biasInput)
            timesteps.append(currInput+biasInput)
        batch.append(timesteps)
    return batch

In [None]:
def run_CRI(inputList):
    firstOutput = 13760
    predictions = []
    total_time_cri = 0
    #each image
    for currInput in inputList:
        #reset the membrane potential to zero
        softwareNetwork.simpleSim.initialize_sim_vars(len(neuronsDict))
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            swSpike = softwareNetwork.step(slice, membranePotential=False)
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            for spike in swSpike:
                spikeIdx = int(spike) - firstOutput 
                try: 
                    if spikeIdx >= 0: 
                        spikeRate[spikeIdx] += 1 
                except:
                    print("SpikeIdx: ", spikeIdx,"\n SpikeRate:",spikeRate )
        predictions.append(spikeRate.index(max(spikeRate)))
    print(f"Total simulation execution time: {total_time_cri:.5f} s")
    cri_sw_runtime += total_time_cri
    return(predictions)

In [None]:
def run_CRI_hw(inputList):
    firstOutput = 13760
    predictions = []
    #each image
    total_time_cri = 0
    for currInput in inputList:
        #initiate the softwareNetwork for each image
        cri_simulations.FPGA_Execution.fpga_controller.clear(len(neuronsDict), False, 0)  ##Num_neurons, simDump, coreOverride
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            hwSpike = hardwareNetwork.step(slice)
            for spike in hwSpike:
                spikeIdx = int(spike[0]) - firstOutput 
                if spikeIdx >= 0: 
                    spikeRate[spikeIdx] += 1 
        predictions.append(spikeRate.index(max(spikeRate))) 
    # print(f"Total execution time CRIFPGA: {total_time_cri:.5f} s")
    cri_hw_runtime += total_time_cri
    return(predictions)

In [None]:
total = 0
correct = 0
cri_correct = 0
cri_correct_hw = 0
# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=128, shuffle=True, drop_last=False)
global snnTorch_runtime 
global cri_hw_runtime 
global cri_sw_runtime 
with torch.no_grad():
    net.eval()
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        input = input_to_CRI(data)
#         criPred = torch.tensor(run_CRI(input)).to(device)
        criPred_hw = torch.tensor(run_CRI_hw(input)).to(device)
        print("CRI Predicted: ",criPred)
#         print("CRI Predicted HW: ",criPred_hw)
        print("Target: ",targets)
        snn_bTime = time.time()
        test_spk, _ = forward_pass(net, num_steps, data)
        snn_eTime = time.time()
        snnTorch_runtime += snn_bTime-snn_eTime
        # calculate total accuracy
        _, predicted = test_spk.sum(dim=0).max(1)
        print("Torchsnn Predicted: ",predicted)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
#         cri_correct += (criPred == targets).sum().item()
        cri_correct_hw += (criPred_hw == targets).sum().item()
        break #run for one batch

In [None]:
# print(f"Totoal execution time: {end_time-start_time:.2f} s")
print(f"Total correctly classified test set images for TorchSNN: {correct}/{total}")
print(f"Total correctly classified test set images for CRI: {cri_correct}/{total}")
print(f"Test Set Accuracy for TorchSNN: {100 * correct / total:.2f}%")
print(f"Test Set Accuracy for CRI: {100 * cri_correct / total:.2f}%")