# SNN Test


In [48]:
# import numpy as np
# import itertools
# #import cv2
# #from google.colab import files
# import timeit
# from os import listdir
# import os
# from zipfile import ZipFile
# import gdown
# import shutil

# import snntorch as snn
# from snntorch import spikeplot as splt
# from snntorch import spikegen

# import torchvision.transforms.functional as TF
# from torch.utils.data import DataLoader, random_split
# from torchvision import datasets, transforms
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torchvision
# from torchview import draw_graph
# from torchvision.transforms import ToTensor
# from torchshape import tensorshape

# import matplotlib.pyplot as plt
# from matplotlib.animation import FuncAnimation
# from IPython.display import HTML
# from sklearn.metrics import classification_report

#-----------------------------------------------
import sys
import snntorch as snn
from snntorch import spikeplot as split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import itertools


print('System Version:', sys.version)
print('PyTorch version', torch.__version__)
# print('Torchvision version', torchvision.__version__)
print('Numpy version', np.__version__)
# print('Pandas version', pd.__version__)
print("env path: ", sys.executable) #[+]

#[+] check to see if gpu is available, else use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Using device: ',device)
for i in range(torch.cuda.device_count()): print('       -->',i,':', torch.cuda.get_device_name(i))

System Version: 3.9.18 (main, Sep 11 2023, 13:30:38) [MSC v.1916 64 bit (AMD64)]
PyTorch version 2.3.1+cu118
Numpy version 1.26.4
env path:  c:\Users\richa\OneDrive\Dugree\Project\cuda\Scripts\python.exe
Using device:  cuda
       --> 0 : NVIDIA GeForce RTX 2070


## DataLoading MNIST

In [49]:
batch_size= 128
data_path= r'./data'
dtype= torch.float

# Create the transoform for MNIST dataset to make sure its 28x28, grayscale, a tensor, and vals normalized to fall between 0 and 1
transform= transforms.Compose([
    transforms.Resize((28,28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,)),])

# Automatically downloads and splits the MNIST dataset
mnist_train = datasets.MNIST(data_path, train= True , download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train= False, download=True, transform=transform)

#create DataLoaders
train_loader= DataLoader(mnist_train , batch_size= batch_size, shuffle=True, drop_last=True)
test_loader= DataLoader(mnist_test , batch_size= batch_size, shuffle=True, drop_last=True)

# Construct a Fully Connected SNN Architecture

In [50]:
# Number of inputs should match number of pixels in the MNIST img
num_inputs= 28*28  #= 784

# Hidden layer is however big you want as long as it fits in your GPU
num_hidden= 1000

# One output neuron for each of the 10 MNIST digits
num_outputs= 10

# 25 time steps is a quick simulation
num_steps= 25

# Rate of decay
beta= 0.95

In [51]:


# Define Network
class Snn(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers before defining the forward function
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)

        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)
        

    def forward(self, x):

        # init hidden states at t=0, mem is membrane potential
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # store in list
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # The network returns a tensor of spike recordings over time, and a tensor of membrane potential recordnigs over time
        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
        
# Load the network onto CUDA if available
snn = Snn().to(device)

## Training

In [52]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=5e-4, betas=(0.9, 0.999))

# # 60000 data samples / 128 samples per batch = approx 468 iterations
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    train_batch = iter(train_loader)

    # Minibatch training loop: each batch will have data which is the 128 samples, and each sample will have target labels (digits 0-9), we load all into cuda
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        #forward pass: we set our network to train mode, and pass the data into it
        snn.train()
        spk_rec, mem_rec = snn(data.flatten(1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val += loss(spk_rec.sum(0), targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Print train/test loss/accuracy
        if counter % 10 == 0:
            print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
          torch.save(snn.state_dict(), 'snn_mdl.pth')
          print(' --> model saved')
          break



Iteration: 0 	 Train Loss: 2.8155517578125
Iteration: 10 	 Train Loss: 1.1944315433502197
Iteration: 20 	 Train Loss: 0.7074097990989685
Iteration: 30 	 Train Loss: 0.6654691100120544
Iteration: 40 	 Train Loss: 0.6876184940338135
Iteration: 50 	 Train Loss: 0.7761414647102356
Iteration: 60 	 Train Loss: 0.7439694404602051
Iteration: 70 	 Train Loss: 0.9466572403907776
Iteration: 80 	 Train Loss: 0.7870963215827942
Iteration: 90 	 Train Loss: 0.6423141360282898
 --> model saved


## Test Accuracy

In [53]:
def measure_accuracy(model, dataloader):
  with torch.no_grad():
    model.eval()
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward-pass
      # spk_rec, mem_rec = model(data)
      spk_rec, mem_rec = model(data.flatten(1))
      
      spike_count = spk_rec.sum(0)
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct
    
    accuracy = (running_accuracy / running_length)

    return accuracy.item()



In [54]:
print( 'Accuracy: ', measure_accuracy(snn, test_loader))

Accuracy:  0.8388421535491943
