# Set up data set (CIFAR) and device

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data_utils

import matplotlib.pyplot as plt
import numpy as np

train_set_CIFAR10 = torchvision.datasets.CIFAR10(
  root="./",
  download=True,
  train=True,
  transform=transforms.Compose([transforms.ToTensor()]),
)

test_set_CIFAR10 = torchvision.datasets.CIFAR10(
  root="./",
  download=True,
  train=False,
  transform=transforms.Compose([transforms.ToTensor()]),
)

train_set_CIFAR10 = data_utils.Subset(train_set_CIFAR10, torch.arange(10000))
test_set_CIFAR10 = data_utils.Subset(test_set_CIFAR10, torch.arange(10000))

def loading_data(batch_size, train_set, test_set):
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    return trainloader, testloader

batch_size = 512
trainloader, testloader = loading_data(batch_size, train_set_CIFAR10, test_set_CIFAR10)

In [2]:
dtype = torch.float
device = torch.device("mps")

# Train function

In [3]:
def TrainSNN(net, nepochs=1, print_epoch=True, plot_mem=False, break_after_convergence=True): 
    num_steps = net.num_steps
    num_epochs = nepochs
    
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999))

    prev_test_loss = np.inf   # For later use.
    min_test_loss = np.inf
    min_test_epoch = 0
    converged = False
    nepochs_to_converge = nepochs   # No. of epochs to converge.

    train_loss_rec = []
    test_loss_rec = []
    train_acc_rec = []
    test_acc_rec = []

    for epoch in range(num_epochs):
        trainbatch = iter(trainloader)

        train_loss = 0   # Initialize losses.
        test_loss = 0

        for idx, (data, targets) in enumerate(trainbatch):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass.
            spk_rec, mem_rec = net(data)

            # Initialize train loss and sum over time.
            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                loss_val += loss(mem_rec[step], targets)

            # Backward pass.
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            train_loss += loss_val.item()

            # Test set.
            with torch.no_grad():
                net.eval()
                test_data, test_targets = next(iter(testloader))
                test_data = test_data.to(device)
                test_targets = test_targets.to(device)

                # Test set forward pass.
                test_spk, test_mem = net(test_data)

                # Test set loss.
                loss_val = torch.zeros((1), dtype=dtype, device=device)
                for step in range(num_steps):
                    loss_val += loss(test_mem[step], test_targets)

                test_loss += loss_val.item()

            if plot_mem: # Plot output layer membrane potentials.
                if idx == 0:
                    for i in range(10):
                        plt.plot(mem_rec[:,0,i].detach().cpu().numpy())
                        plt.xlabel("Time Step")
                        plt.ylabel("Membrane Potential (V)")
                        plt.grid()
                        plt.show()      

        # Train set accuracy.
        _, preds = spk_rec.sum(dim=0).max(1)
        train_acc = np.mean((targets == preds).detach().cpu().numpy())                              

        # Test set accuracy.
        _, preds = test_spk.sum(dim=0).max(1)
        test_acc = np.mean((test_targets == preds).detach().cpu().numpy())

        if print_epoch:   # Print epoch, losses, and test accuracy.
            print(f'Epoch: {epoch+1} | Train Loss: {train_loss:.04} | Test Loss: {test_loss:.04} | Train Accuracy: {train_acc*100:.2f}% | Test Accuracy: {test_acc*100:.2f}%')

        train_loss_rec.append(train_loss)
        test_loss_rec.append(test_loss)
        train_acc_rec.append(train_acc*100)
        test_acc_rec.append(test_acc*100)

        if test_loss < min_test_loss:
            min_test_loss = test_loss
            min_test_epoch = epoch+1
        else:
            if epoch+1 - min_test_epoch == 5:
                if print_epoch:
                    print("-"*75, f"Converged after {min_test_epoch} epochs")
                if break_after_convergence:
                    nepochs_to_converge = min_test_epoch
                    break                

        prev_test_loss = test_loss
            
    return nepochs_to_converge, train_loss_rec, test_loss_rec, train_acc_rec, test_acc_rec

# Models

### Leaky

In [4]:
class Leaky(nn.Module):
    def __init__(self, beta=0.8, threshold=1.0):
        super(Leaky, self).__init__()
        self.beta = beta
        self.threshold = threshold
        self.spike_gradient = self.FastSig.apply

    def forward(self, input_, mem, timestep):
        spk = self.spike_gradient((mem-self.threshold))
        reset = (spk * self.threshold).detach()
        mem = self.beta * mem + input_ - reset
        return spk, mem

    @staticmethod
    class FastSig(torch.autograd.Function):
        @staticmethod
        def forward(ctx, mem):
            spk = (mem > 0).float()
            spk = spk.to(device)
            ctx.save_for_backward(mem)
            return spk

        @staticmethod
        def backward(ctx, grad_output):
            (mem,) = ctx.saved_tensors
            grad = 1 / (1 + np.pi * torch.abs(mem))**2 * grad_output
            return grad


class LeakyNet(nn.Module):
    def __init__(self, num_steps=15):
        super().__init__()
        self.num_steps = num_steps

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Linear(4096, 64)     
        self.fc2 = nn.Linear(64, 10)                
        self.flatten = nn.Flatten()
        self.lif = Leaky()

    def forward(self, input):
        cur_batch_size = input.shape[0]

        # Initialize membrane potentials: 
        mem_conv1 = torch.zeros(cur_batch_size, 32, 32, 32, device=device)
        mem_conv2 = torch.zeros(cur_batch_size, 64, 16, 16, device=device)
        mem_conv3 = torch.zeros(cur_batch_size, 64, 8, 8, device=device)
        mem_fc1 = torch.zeros(cur_batch_size, 64, device=device)
        mem_fc2 = torch.zeros(cur_batch_size, 10, device=device) 

        spk_fc2_rec = []   # Final spike record.
        mem_fc2_rec = []   # Final membrane record.

        for step in range(self.num_steps):
            cur = self.conv1(input)
            spk, mem_conv1 = self.lif(cur, mem_conv1, step)
            input_ = self.pool((spk))    

            cur = self.conv2(input_)
            spk, mem_conv2 = self.lif(cur, mem_conv2, step)
            input_ = self.pool(spk)

            cur = self.conv3(input_)
            spk, mem_conv3 = self.lif(cur, mem_conv3, step)
            input_ = spk

            input_ = torch.flatten(input_, 1)
            input_ = self.flatten(input_)

            cur = self.fc1(input_)
            spk, mem_fc1 = self.lif(cur, mem_fc1, step)
            cur = self.fc2(spk)          

            spk_fc2, mem_fc2 = self.lif(cur, mem_fc2, step)
            spk_fc2_rec.append(spk_fc2)
            mem_fc2_rec.append(mem_fc2)
            
        return torch.stack(spk_fc2_rec, dim=0), torch.stack(mem_fc2_rec, dim=0)

### LTC

In [5]:
class LTC(nn.Module):
    def __init__(self, beta=0.8, A=1, K=1, threshold=1.0):
        super(LTC, self).__init__()
        self.beta = beta
        self.A = A
        self.K = K
        self.threshold = threshold
        self.spike_gradient = self.FastSig.apply


    def forward(self, input_, mem, timestep):
        spk = self.spike_gradient((mem-self.threshold))
        reset = (spk * self.threshold).detach()
        mem = (self.beta-self.K*torch.sigmoid(input_)) * mem + self.A * self.K*torch.sigmoid(input_) - reset
        return spk, mem

    @staticmethod
    class FastSig(torch.autograd.Function):
        @staticmethod
        def forward(ctx, mem):
            spk = (mem > 0).float()
            spk = spk.to(device)
            ctx.save_for_backward(mem)
            return spk

        @staticmethod
        def backward(ctx, grad_output):
            (mem,) = ctx.saved_tensors
            grad = 1 / (1 + np.pi * torch.abs(mem))**2 * grad_output
            return grad

class LTCNet(nn.Module):
    def __init__(self, num_steps=15, A=1, K=1):
        super().__init__()
        self.num_steps = num_steps
        self.A = A
        self.K = K        

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Linear(4096, 64)     
        self.fc2 = nn.Linear(64, 10)                
        self.flatten = nn.Flatten()
        self.ltc = LTC(A=self.A, K=self.K)


    def forward(self, input):
        cur_batch_size = input.shape[0]

        # Initialize membrane potentials: 
        mem_conv1 = torch.zeros(cur_batch_size, 32, 32, 32, device=device)
        mem_conv2 = torch.zeros(cur_batch_size, 64, 16, 16, device=device)
        mem_conv3 = torch.zeros(cur_batch_size, 64, 8, 8, device=device)
        mem_fc1 = torch.zeros(cur_batch_size, 64, device=device)
        mem_fc2 = torch.zeros(cur_batch_size, 10, device=device) 

        spk_fc2_rec = []   # Final spike record.
        mem_fc2_rec = []   # Final membrane record.

        for step in range(self.num_steps):
            cur = self.conv1(input)
            spk, mem_conv1 = self.ltc(cur, mem_conv1, step)
            input_ = self.pool((spk))    

            cur = self.conv2(input_)
            spk, mem_conv2 = self.ltc(cur, mem_conv2, step)
            input_ = self.pool(spk)

            cur = self.conv3(input_)
            spk, mem_conv3 = self.ltc(cur, mem_conv3, step)
            input_ = spk

            input_ = torch.flatten(input_, 1)
            input_ = self.flatten(input_)

            cur = self.fc1(input_)
            spk, mem_fc1 = self.ltc(cur, mem_fc1, step)
            cur = self.fc2(spk)          

            spk_fc2, mem_fc2 = self.ltc(cur, mem_fc2, step)
            spk_fc2_rec.append(spk_fc2)
            mem_fc2_rec.append(mem_fc2)
            
        return torch.stack(spk_fc2_rec, dim=0), torch.stack(mem_fc2_rec, dim=0)

### LIF2D

In [6]:
class LIF2D(nn.Module):
    def __init__(self, beta=0.8, B0=1, B1=1, threshold=1.0):
        super(LIF2D, self).__init__()
        self.beta = beta
        self.B0 = B0
        self.B1 = B1
        self.threshold = threshold
        self.spike_gradient = self.FastSig.apply

    def forward(self, input_, mem, gat, timestep):
        spk = self.spike_gradient((mem-self.threshold))
        reset = (spk * self.threshold).detach()
        mem = self.beta * mem + input_ - gat - reset
        gat = self.B0 * gat + self.B1 * mem
        return spk, mem, gat

    @staticmethod
    class FastSig(torch.autograd.Function):
        @staticmethod
        def forward(ctx, mem):
            spk = (mem > 0).float()
            spk = spk.to(device)
            ctx.save_for_backward(mem)
            return spk

        @staticmethod
        def backward(ctx, grad_output):
            (mem,) = ctx.saved_tensors
            grad = 1 / (1 + np.pi * torch.abs(mem))**2 * grad_output
            return grad

class LIF2DNet(nn.Module):
    def __init__(self, num_steps=15, B0=1, B1=1):
        super().__init__()
        self.num_steps = num_steps
        self.B0 = B0
        self.B1 = B1

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Linear(4096, 64)     
        self.fc2 = nn.Linear(64, 10)                
        self.flatten = nn.Flatten()
        self.lif2d = LIF2D(B0=self.B0, B1=self.B1)

    def forward(self, input):
        cur_batch_size = input.shape[0]

        # Initialize membrane potentials: 
        mem_conv1 = torch.zeros(cur_batch_size, 32, 32, 32, device=device)
        mem_conv2 = torch.zeros(cur_batch_size, 64, 16, 16, device=device)
        mem_conv3 = torch.zeros(cur_batch_size, 64, 8, 8, device=device)
        mem_fc1 = torch.zeros(cur_batch_size, 64, device=device)
        mem_fc2 = torch.zeros(cur_batch_size, 10, device=device) 

        gat_conv1 = torch.ones(cur_batch_size, 32, 32, 32, device=device)
        gat_conv2 = torch.ones(cur_batch_size, 64, 16, 16, device=device)
        gat_conv3 = torch.ones(cur_batch_size, 64, 8, 8, device=device)
        gat_fc1 = torch.ones(cur_batch_size, 64, device=device)
        gat_fc2 = torch.ones(cur_batch_size, 10, device=device)        

        spk_fc2_rec = []   # Final spike record.
        mem_fc2_rec = []   # Final membrane record.
    
        for step in range(self.num_steps):
            cur = self.conv1(input)
            spk, mem_conv1, gat_conv1 = self.lif2d(cur, mem_conv1, gat_conv1, step)
            input_ = self.pool((spk))    

            cur = self.conv2(input_)
            spk, mem_conv2, gat_conv2 = self.lif2d(cur, mem_conv2, gat_conv2, step)
            input_ = self.pool(spk)

            cur = self.conv3(input_)
            spk, mem_conv3, gat_conv3 = self.lif2d(cur, mem_conv3, gat_conv3, step)
            input_ = spk

            input_ = torch.flatten(input_, 1)
            input_ = self.flatten(input_)

            cur = self.fc1(input_)
            spk, mem_fc1, gat_fc1 = self.lif2d(cur, mem_fc1, gat_fc1, step)
            cur = self.fc2(spk)          

            spk_fc2, mem_fc2, gat_fc2 = self.lif2d(cur, mem_fc2, gat_fc2, step)
            spk_fc2_rec.append(spk_fc2)
            mem_fc2_rec.append(mem_fc2)
            
        return torch.stack(spk_fc2_rec, dim=0), torch.stack(mem_fc2_rec, dim=0)    

# Tests

### Convergence

In [7]:
trials = 5

In [None]:
# Leaky
sum_nepochs = 0
sum_trainloss = 0
sum_testloss = 0
sum_trainacc = 0
sum_testacc = 0
for _ in range(trials):
    net = LeakyNet().to(device)    
    nepochs_to_converge, train_loss_rec, test_loss_rec, train_acc_rec, test_acc_rec = TrainSNN(net, nepochs = 100, print_epoch=False)
    sum_nepochs += nepochs_to_converge
    sum_trainloss += train_loss_rec[nepochs_to_converge-1]
    sum_testloss += test_loss_rec[nepochs_to_converge-1]
    sum_trainacc += train_acc_rec[nepochs_to_converge-1]
    sum_testacc += test_acc_rec[nepochs_to_converge-1]

print(sum_nepochs/trials)
print(sum_trainloss/trials)
print(sum_testloss/trials)
print(sum_trainacc/trials)
print(sum_testacc/trials)

In [None]:
# LTC
sum_nepochs = 0
sum_trainloss = 0
sum_testloss = 0
sum_trainacc = 0
sum_testacc = 0
for _ in range(trials):
    net = LTCNet(A=8, K=0.2).to(device)    
    nepochs_to_converge, train_loss_rec, test_loss_rec, train_acc_rec, test_acc_rec = TrainSNN(net, nepochs = 100, print_epoch=False)
    sum_nepochs += nepochs_to_converge
    sum_trainloss += train_loss_rec[nepochs_to_converge-1]
    sum_testloss += test_loss_rec[nepochs_to_converge-1]
    sum_trainacc += train_acc_rec[nepochs_to_converge-1]
    sum_testacc += test_acc_rec[nepochs_to_converge-1]

print(sum_nepochs/trials)
print(sum_trainloss/trials)
print(sum_testloss/trials)
print(sum_trainacc/trials)
print(sum_testacc/trials)

In [None]:
# LIF2D
sum_nepochs = 0
sum_trainloss = 0
sum_testloss = 0
sum_trainacc = 0
sum_testacc = 0
for _ in range(trials):
    net = LIF2DNet(B0=0.5, B1=0.5).to(device)    
    nepochs_to_converge, train_loss_rec, test_loss_rec, train_acc_rec, test_acc_rec = TrainSNN(net, nepochs = 100, print_epoch=False)
    sum_nepochs += nepochs_to_converge
    sum_trainloss += train_loss_rec[nepochs_to_converge-1]
    sum_testloss += test_loss_rec[nepochs_to_converge-1]
    sum_trainacc += train_acc_rec[nepochs_to_converge-1]
    sum_testacc += test_acc_rec[nepochs_to_converge-1]

print(sum_nepochs/trials)
print(sum_trainloss/trials)
print(sum_testloss/trials)
print(sum_trainacc/trials)
print(sum_testacc/trials)