# LMUFormer - psMNIST Experiments

#### This notebook includes Non-spiking LMUFormer and Spiking LMUFormer examples running on psMNIST.

In [None]:
# To reset the notebook, run from this point
%reset -f

## Setup

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
import random

import torch
from torch import nn
from torch import fft
from torch.nn import init
from torch.nn import functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

from spikingjelly.clock_driven.neuron import MultiStepLIFNode
from spikingjelly.clock_driven import functional
from scipy.signal import cont2discrete

In [None]:
# Connect to GPU
if torch.cuda.is_available():
    DEVICE = "cuda:3"
    # Clear cache if non-empty
    torch.cuda.empty_cache()
    # See which GPU has been allotted 
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    DEVICE = "cpu"

In [None]:
SEED = 0

def setSeed(seed):
    """ Set all seeds to ensure reproducibility """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

setSeed(SEED)

### Utils

In [None]:
def disp(img):
    """ Displays an image """
    if len(img.shape) == 3:
        img = img.squeeze(0)
    plt.imshow(img, cmap = "gray")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
def dispSeq(seq, rows = 8):
    """ Displays a sequence of pixels """
    seq = seq.reshape(rows, -1) # divide the 1D sequence into `rows` rows for easy visualization
    disp(seq)

In [None]:
def countParameters(model):
    """ Counts and prints the number of trainable and non-trainable parameters of a model """
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return f"The model has {trainable:,} trainable parameters and {frozen:,} frozen parameters"

### Dataset

In [None]:
class psMNIST(Dataset):
    """ Dataset that defines the psMNIST dataset, given the MNIST data and a fixed permutation """

    def __init__(self, mnist, perm):
        self.mnist = mnist # also a torch.data.Dataset object
        self.perm  = perm

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        img, label = self.mnist[idx]
        unrolled = img.reshape(-1)
        permuted = unrolled[self.perm]
        permuted = permuted.reshape(-1, 1)
        return permuted, label

### Model

In [None]:
def leCunUniform(tensor):
    """ 
        LeCun Uniform Initializer
        References: 
        [1] https://keras.rstudio.com/reference/initializer_lecun_uniform.html
        [2] Source code of _calculate_correct_fan can be found in https://pytorch.org/docs/stable/_modules/torch/nn/init.html
        [3] Yann A LeCun, Léon Bottou, Genevieve B Orr, and Klaus-Robert Müller. Efficient backprop. In Neural networks: Tricks of the trade, pages 9–48. Springer, 2012
    """
    fan_in = init._calculate_correct_fan(tensor, "fan_in")
    limit = np.sqrt(3. / fan_in)
    init.uniform_(tensor, -limit, limit) # fills the tensor with values sampled from U(-limit, limit)

In [None]:
class LMUFFT(nn.Module):

    def __init__(self, input_size, hidden_size, memory_size, seq_len, theta):

        super(LMUFFT, self).__init__()

        self.hidden_size = hidden_size
        self.memory_size = memory_size
        self.seq_len = seq_len
        self.theta = theta

        self.W_u = nn.Linear(in_features = input_size, out_features = 1)
        self.f_u = nn.ReLU()
        self.W_h = nn.Linear(in_features = memory_size + input_size, out_features = hidden_size)
        self.f_h = nn.ReLU()

        A, B = self.stateSpaceMatrices()
        self.register_buffer("A", A) # [memory_size, memory_size]
        self.register_buffer("B", B) # [memory_size, 1]

        H, fft_H = self.impulse()
        self.register_buffer("H", H) # [memory_size, seq_len]
        self.register_buffer("fft_H", fft_H) # [memory_size, seq_len + 1]


    def stateSpaceMatrices(self):
        """ Returns the discretized state space matrices A and B """

        Q = np.arange(self.memory_size, dtype = np.float64).reshape(-1, 1)
        R = (2*Q + 1) / self.theta
        i, j = np.meshgrid(Q, Q, indexing = "ij")

        # Continuous
        A = R * np.where(i < j, -1, (-1.0)**(i - j + 1))
        B = R * ((-1.0)**Q)
        C = np.ones((1, self.memory_size))
        D = np.zeros((1,))

        # Convert to discrete
        A, B, C, D, dt = cont2discrete(
            system = (A, B, C, D), 
            dt = 1.0, 
            method = "zoh"
        )

        # To torch.tensor
        A = torch.from_numpy(A).float() # [memory_size, memory_size]
        B = torch.from_numpy(B).float() # [memory_size, 1]
        
        return A, B


    def impulse(self):
        """ Returns the matrices H and the 1D Fourier transform of H (Equations 23, 26 of the paper) """

        H = []
        A_i = torch.eye(self.memory_size)
        for t in range(self.seq_len):
            H.append(A_i @ self.B)
            A_i = self.A @ A_i

        H = torch.cat(H, dim = -1) # [memory_size, seq_len]
        fft_H = fft.rfft(H, n = 2*self.seq_len, dim = -1) # [memory_size, seq_len + 1]

        return H, fft_H


    def forward(self, x):
        """
        Parameters:
            x (torch.tensor): 
                Input of size [batch_size, seq_len, input_size]
        """

        batch_size, seq_len, input_size = x.shape

        # Equation 18 of the paper
        u = self.f_u(self.W_u(x)) # [batch_size, seq_len, 1]

        # Equation 26 of the paper
        fft_input = u.permute(0, 2, 1) # [batch_size, 1, seq_len]
        fft_u = fft.rfft(fft_input, n = 2*seq_len, dim = -1) # [batch_size, seq_len, seq_len+1]

        # Element-wise multiplication (uses broadcasting)
        # [batch_size, 1, seq_len+1] * [1, memory_size, seq_len+1]
        temp = fft_u * self.fft_H.unsqueeze(0) # [batch_size, memory_size, seq_len+1]

        m = fft.irfft(temp, n = 2*seq_len, dim = -1) # [batch_size, memory_size, seq_len+1]
        m = m[:, :, :seq_len] # [batch_size, memory_size, seq_len]
        m = m.permute(0, 2, 1) # [batch_size, seq_len, memory_size]

        # Equation 20 of the paper (W_m@m + W_x@x  W@[m;x])
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size]
        h = self.f_h(self.W_h(input_h)) # [batch_size, seq_len, hidden_size]

        h_n = h[:, -1, :] # [batch_size, hidden_size]

        return h, h_n

In [None]:
class SpikingLMUFFT(LMUFFT):

    def __init__(self, input_size, hidden_size, memory_size, seq_len, theta):

        super(SpikingLMUFFT, self).__init__(input_size, hidden_size, memory_size, seq_len, theta)

        self.f_u = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        self.f_h = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        
    def forward(self, x):
        """
        Parameters:
            x (torch.tensor): 
                Input of size [batch_size, seq_len, input_size]
        """

        batch_size, seq_len, input_size = x.shape

        # Equation 18 of the paper
        u_spike = self.f_u(self.W_u(x).transpose(-1,-2).permute(2,0,1).contiguous()) # [B,N,C]->[B,C,N]->[N,B,C]
        u = u_spike.permute(1,0,2).contiguous() # [N,B,C]->[B,N,C] [batch_size, seq_len, 1]

        # Equation 26 of the paper
        fft_input = u.permute(0, 2, 1) # [batch_size, 1, seq_len]
        fft_u = fft.rfft(fft_input, n = 2*seq_len, dim = -1) # [batch_size, seq_len, seq_len+1]

        # Element-wise multiplication (uses broadcasting)
        # [batch_size, 1, seq_len+1] * [1, memory_size, seq_len+1]
        temp = fft_u * self.fft_H.unsqueeze(0) # [batch_size, memory_size, seq_len+1]

        m = fft.irfft(temp, n = 2*seq_len, dim = -1) # [batch_size, memory_size, seq_len+1]
        m = m[:, :, :seq_len] # [batch_size, memory_size, seq_len]
        m = m.permute(0, 2, 1) # [batch_size, seq_len, memory_size]

        # Equation 20 of the paper (W_m@m + W_x@x  W@[m;x])
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size] # [4*100, 784, 469]

        h = self.f_h(self.W_h(input_h).transpose(-1,-2).permute(2,0,1).contiguous()) # [B,N,C]->[B,C,N]->[N,B,C]
        h = h.permute(1,0,2).contiguous() # [N,B,C]->[B,N,C]
        h_n = h[:, -1, :] # [batch_size, hidden_size]

        return h, h_n

### Training

In [None]:
def train(model, loader, optimizer, criterion):
    """ A single training epoch on the psMNIST data """

    epoch_loss = 0
    y_pred = []
    y_true = []
    
    model.train()
    for batch, labels in loader:

        torch.cuda.empty_cache()

        batch = batch.to(DEVICE)
        labels = labels.long().to(DEVICE)

        optimizer.zero_grad()

        output = model(batch)
        loss = criterion(output, labels)
        
        loss.backward()
        optimizer.step()

        preds  = output.argmax(dim = 1)
        y_pred += preds.tolist()
        y_true += labels.tolist()
        epoch_loss += loss.item()
        functional.reset_net(model)

    # Loss
    avg_epoch_loss = epoch_loss / len(loader)

    # Accuracy
    epoch_acc = accuracy_score(y_true, y_pred)

    return avg_epoch_loss, epoch_acc

In [None]:
def validate(model, loader, criterion):
    """ A single validation epoch on the psMNIST data """

    epoch_loss = 0
    y_pred = []
    y_true = []
    
    model.eval()
    with torch.no_grad():
        for batch, labels in loader:

            torch.cuda.empty_cache()

            batch = batch.to(DEVICE)
            labels = labels.long().to(DEVICE)

            output = model(batch)
            loss = criterion(output, labels)
            
            preds  = output.argmax(dim = 1)
            y_pred += preds.tolist()
            y_true += labels.tolist()
            epoch_loss += loss.item()
            
    # Loss
    avg_epoch_loss = epoch_loss / len(loader)

    # Accuracy
    epoch_acc = accuracy_score(y_true, y_pred)

    return avg_epoch_loss, epoch_acc

### Non-Spiking LMUFormer

In [None]:
N_x = 1 # dimension of the input, a single pixel
N_t = 784
N_h = 346 # dimension of the hidden state
N_m = 468 # dimension of the memory
N_c = 10 # number of classes 
THETA = 784
N_b = 100 # batch size
N_epochs = 50

In [None]:
transform = transforms.ToTensor()
mnist_train = datasets.MNIST("data", train = True, download = True, transform = transform)
mnist_val   = datasets.MNIST("data", train = False, download = True, transform = transform)

perm = torch.randperm(N_t)
torch.save(perm, "permutation.pt")

perm = torch.load("permutation.pt").long() # created using torch.randperm(784)
ds_train = psMNIST(mnist_train, perm)
ds_val   = psMNIST(mnist_val, perm)

dl_train = DataLoader(ds_train, batch_size = N_b, shuffle = True, num_workers = 2, pin_memory = True)
dl_val   = DataLoader(ds_val, batch_size = N_b, shuffle = True, num_workers = 2, pin_memory = True)

In [None]:
# Example of the data
eg_img, eg_label = ds_train[0]
print("Label:", eg_label)
dispSeq(eg_img)

In [None]:
class Model(nn.Module):
    """ A simple model for the psMNIST dataset consisting of a single LMUFFT layer and a single dense classifier """

    def __init__(self, input_size, output_size, hidden_size, memory_size, seq_len, theta):
        super(Model, self).__init__()
        input_size = 3
        self.conv1 = nn.Conv1d(1, input_size, kernel_size = 1, stride = 1)
        self.bn1 = nn.BatchNorm1d(input_size)
        self.act = nn.ReLU()
        self.lmu_fft = LMUFFT(input_size, hidden_size, memory_size, seq_len, theta)
        self.dropout1 = nn.Dropout(p = 0.4)
        self.classifier = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.conv1(x.permute(0, 2, 1).contiguous())
        x = self.bn1(x).permute(0, 2, 1).contiguous()
        h, h_n = self.lmu_fft(x) # [batch_size, hidden_size]
        x = h.mean(dim=1)
        x = self.dropout1(x)
        output = self.classifier(x)
        return output # [batch_size, output_size]

In [None]:
model = Model(
    input_size = N_x, 
    output_size = N_c, 
    hidden_size = N_h, 
    memory_size = N_m, 
    seq_len = N_t, 
    theta = THETA
)
model = model.to(DEVICE)

In [None]:
print(model)
countParameters(model)

In [None]:
optimizer = optim.Adam(params = model.parameters())

In [None]:
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(DEVICE)

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(N_epochs):
   
    print(f"Epoch: {epoch+1:02}/{N_epochs:02}")

    train_loss, train_acc = train(model, dl_train, optimizer, criterion)
    val_loss, val_acc = validate(model, dl_val, criterion)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%")
    print(f"Val. Loss: {val_loss:.3f} |  Val. Acc: {val_acc*100:.2f}%")
    print()

In [None]:
# Learning curves

plt.plot(range(N_epochs), train_losses)
plt.plot(range(N_epochs), val_losses)
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend(["Train", "Val."])
plt.show()

plt.plot(range(N_epochs), train_accs)
plt.plot(range(N_epochs), val_accs)
plt.ylabel("Accuracy")
plt.xlabel("Epochs")
plt.legend(["Train", "Val."])
plt.show()

### Spiking LMUFormer

In [None]:
N_x = 1 # dimension of the input, a single pixel
N_t = 784
N_h = 346 # dimension of the hidden state
N_m = 468 # dimension of the memory
N_c = 10 # number of classes 
THETA = 784
N_b = 100 # batch size
N_epochs = 50

In [None]:
transform = transforms.ToTensor()
mnist_train = datasets.MNIST("data", train = True, download = True, transform = transform)
mnist_val   = datasets.MNIST("data", train = False, download = True, transform = transform)

perm = torch.randperm(N_t)
torch.save(perm, "permutation.pt")

perm = torch.load("permutation.pt").long()
ds_train = psMNIST(mnist_train, perm)
ds_val   = psMNIST(mnist_val, perm)

dl_train = DataLoader(ds_train, batch_size = N_b, shuffle = True, num_workers = 2, pin_memory = True)
dl_val   = DataLoader(ds_val, batch_size = N_b, shuffle = True, num_workers = 2, pin_memory = True)

In [None]:
# Example of the data
eg_img, eg_label = ds_train[0]
print("Label:", eg_label)
dispSeq(eg_img)

In [None]:
class Model(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, memory_size, seq_len, theta):
        super(Model, self).__init__()
        self.lmu_fft = SpikingLMUFFT(input_size, hidden_size, memory_size, seq_len, theta)
        self.dropout = nn.Dropout(p = 0.5)
        self.classifier = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        _, h_n = self.lmu_fft(x) # [batch_size, hidden_size]
        x = self.dropout(h_n)
        output = self.classifier(x)
        return output # [batch_size, output_size]
    

In [None]:
model = Model(
    input_size = N_x, 
    output_size = N_c, 
    hidden_size = N_h, 
    memory_size = N_m, 
    seq_len = N_t, 
    theta = THETA
)
model = model.to(DEVICE)

In [None]:
print(model)
countParameters(model)

In [None]:
optimizer = optim.Adam(params = model.parameters())

In [None]:
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(DEVICE)

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(N_epochs):
   
    print(f"Epoch: {epoch+1:02}/{N_epochs:02}")

    train_loss, train_acc = train(model, dl_train, optimizer, criterion)
    val_loss, val_acc = validate(model, dl_val, criterion)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%")
    print(f"Val. Loss: {val_loss:.3f} |  Val. Acc: {val_acc*100:.2f}%")
    print()

In [None]:
# Learning curves

plt.plot(range(N_epochs), train_losses)
plt.plot(range(N_epochs), val_losses)
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend(["Train", "Val."])
plt.show()

plt.plot(range(N_epochs), train_accs)
plt.plot(range(N_epochs), val_accs)
plt.ylabel("Accuracy")
plt.xlabel("Epochs")
plt.legend(["Train", "Val."])
plt.show()