# Set up data set (MNIST) and device

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data_utils

import matplotlib.pyplot as plt
import numpy as np

train_set_MNIST = torchvision.datasets.MNIST(
  root="./",
  download=True,
  train=True,
  transform=transforms.Compose([transforms.ToTensor()]),
)

test_set_MNIST = torchvision.datasets.MNIST(
  root="./",
  download=True,
  train=False,
  transform=transforms.Compose([transforms.ToTensor()]),
)

train_set_MNIST = data_utils.Subset(train_set_MNIST, torch.arange(10000))
test_set_MNIST = data_utils.Subset(test_set_MNIST, torch.arange(10000))

def loading_data(batch_size, train_set, test_set):
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    return trainloader, testloader

batch_size = 512
trainloader, testloader = loading_data(batch_size, train_set_MNIST, test_set_MNIST)

In [2]:
dtype = torch.float
device = torch.device("mps")

# Train function

In [3]:
def TrainSNN(net, nepochs=1, print_epoch=True, plot_mem=False, break_after_convergence=True):
    num_steps = net.num_steps
    num_epochs = nepochs
    
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999))

    prev_train_loss = np.inf   # For later use.
    converged = False
    nepochs_to_converge = nepochs   # No. of epochs to converge.

    train_loss_rec = []
    test_loss_rec = []
    acc_rec = []

    for epoch in range(num_epochs):
        trainbatch = iter(trainloader)

        train_loss = 0   # Initialize losses.
        test_loss = 0

        for idx, (data, targets) in enumerate(trainbatch):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass.
            spk_rec, mem_rec = net(data)

            # Initialize train loss and sum over time.
            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                loss_val += loss(mem_rec[step], targets)

            # Backward pass.
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            train_loss += loss_val.item()

            # Test set.
            with torch.no_grad():
                net.eval()
                test_data, test_targets = next(iter(testloader))
                test_data = test_data.to(device)
                test_targets = test_targets.to(device)

                # Test set forward pass.
                test_spk, test_mem = net(test_data)

                # Test set loss.
                loss_val = torch.zeros((1), dtype=dtype, device=device)
                for step in range(num_steps):
                    loss_val += loss(test_mem[step], test_targets)

                test_loss += loss_val.item()

            if plot_mem: # Plot output layer membrane potentials.
                if idx == 0:
                    for i in range(10):
                        plt.plot(mem_rec[:,0,i].detach().cpu().numpy())
                        plt.xlabel("Time Step")
                        plt.ylabel("Membrane Potential (V)")
                        plt.grid()
                        plt.show()            

        # Test set accuracy
        _, preds = test_spk.sum(dim=0).max(1)
        acc = np.mean((test_targets == preds).detach().cpu().numpy())

        if print_epoch:   # Print epoch, losses, and test accuracy.
            print(f'Epoch: {epoch+1} | Train Loss: {train_loss:.04} | Test Loss: {test_loss:.04} | Test Accuracy: {acc*100:.2f}%')

        train_loss_rec.append(train_loss)
        test_loss_rec.append(test_loss)
        acc_rec.append(acc*100)

        if np.abs(prev_train_loss - train_loss) < 0.10 and not converged:   # Convergence.  
            converged = True
            nepochs_to_converge = epoch+1
            if print_epoch:
                print("-"*75, f"Converged after {nepochs_to_converge} epochs")
            if break_after_convergence:
                break


        prev_train_loss = train_loss
            
    return nepochs_to_converge, train_loss_rec, test_loss_rec, acc_rec

# LTC

In [4]:
class LTC(nn.Module):
    def __init__(self, beta=0.8, A=1, K=1, threshold=1.0):
        super(LTC, self).__init__()
        self.beta = beta
        self.A = A
        self.K = K
        self.threshold = threshold
        self.spike_gradient = self.FastSig.apply

    def forward(self, input_, mem, timestep):
        spk = self.spike_gradient((mem-self.threshold))
        reset = (spk * self.threshold).detach()
        mem = (self.beta-self.K*torch.sigmoid(input_)) * mem + self.A * self.K*torch.sigmoid(input_) - reset
        return spk, mem

    @staticmethod
    class FastSig(torch.autograd.Function):
        @staticmethod
        def forward(ctx, mem):
            spk = (mem > 0).float()
            spk = spk.to(device)
            ctx.save_for_backward(mem)
            return spk

        @staticmethod
        def backward(ctx, grad_output):
            (mem,) = ctx.saved_tensors
            grad = 1 / (1 + np.pi * torch.abs(mem))**2 * grad_output
            return grad

class LTCNet(nn.Module):
    def __init__(self, dim=784, nclass=10, width=1000, depth=1, num_steps=15, A=1, K=1):
        super().__init__()
        self.dim = dim
        self.width = width
        self.depth = depth
        self.nclass = nclass
        self.num_steps = num_steps
        self.A = A
        self.K = K

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(dim, width)
        self.fc2 = nn.ModuleList([nn.Linear(width, width) for _ in range(self.depth-1)])    
        self.fc3 = nn.Linear(width, nclass)
        self.ltc = LTC(A=self.A, K=self.K)

    def forward(self, input):
        cur_batch_size = input.shape[0]

        # Initialize membrane potentials: 
        # mem1: hidden layers excl. last; mem2: last hidden layer; mem3: output layer
        mem1 = [torch.zeros(cur_batch_size, self.width, device=device) for _ in range(self.depth-1)]
        mem2 = torch.zeros(cur_batch_size, self.width, device=device) 
        mem3 = torch.zeros(cur_batch_size, self.nclass, device=device)

        spk3_rec = []   # Final spike record.
        mem3_rec = []   # Final membrane record.

        for step in range(self.num_steps):
            input_ = self.flatten(input)
            cur = self.fc1(input_)

            for j in range(self.depth-1): 
                spk, mem1[j] = self.ltc(cur, mem1[j], step)
                cur = self.fc2[j](spk)

            spk, mem2 = self.ltc(cur, mem2, step)
            cur = self.fc3(spk)            

            spk3, mem3 = self.ltc(cur, mem3, step)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)