# L15a: What comes after Transformer and LLMs?
In this lecture we'll speculate (wildly) about what might come after Transformers and Large Language Models (LLMs), see the conversation [between Yann LeCun and Bill Dally | NVIDIA GTC 2025](https://www.youtube.com/watch?v=eyrDM3A_YFc).

Transformers have been a huge success, but they are not the end of the line. There are many other architectures and techniques that could be used to build even more powerful models. Let's explore a few of these possibilities, [see the review paper Schneider, J. (2024). What comes after transformers? - A selective survey connecting ideas in deep learning. ArXiv, abs/2408.00386.)](https://arxiv.org/abs/2305.13936) for a more in-depth discussion.

The Schneider paper reviews a number of different approaches that have been proposed as alternatives to Transformers, and far-out ideas for future architectures. In this lecture we'll explore a few of these ideas:

* __State Space Models (SSMs)__ are an emerging alternative to transformers for sequence modeling, using a fixed-size latent state that enables efficient processing of extremely long inputs, such as entire books or audio streams, without the quadratic computational cost of attention mechanisms. While SSMs like Mamba can match or even outperform transformers at small to medium scale, they are generally less effective than transformers at tasks requiring selective attention or copying from specific parts of the input, due to their reliance on compressing information into a fixed-size state rather than dynamically attending to all previous tokens.
* __Spiking Neural Networks (SNNs)__ are brain-inspired models that process information using discrete spikes over time, offering energy-efficient and biologically plausible alternatives to transformers, especially for tasks with strong temporal dynamics148. By leveraging event-driven computation and sparse communication, SNNs can achieve high computational efficiency and are particularly well-suited for deployment on [neuromorphic hardware](https://en.wikipedia.org/wiki/Neuromorphic_computing), addressing some of the limitations of transformer architectures in terms of power consumption and real-time processing6810.

The sources used to prepare this lecture are:
* [Schneider, J. (2024). What comes after transformers? - A selective survey connecting ideas in deep learning. ArXiv, abs/2408.00386](https://arxiv.org/abs/2305.13936)
* [Smith, J., Warrington, A., & Linderman, S.W. (2022). Simplified State Space Layers for Sequence Modeling. ArXiv, abs/2208.04933.](https://arxiv.org/abs/2208.04933)
* [Limbacher, T., Özdenizci, O., & Legenstein, R.A. (2022). Memory-enriched computation and learning in spiking neural networks through Hebbian plasticity. ArXiv, abs/2205.11276.](https://arxiv.org/abs/2205.11276)

___

In [2]:
include("Include.jl"); # we'll need a couple of packages later

## Review: Linear Time Invariant State Space Models
Linear time invariant (LTI) state space models are a class of _continuous-time_ models that can represent a system's dynamics over time. The following equations characterize them:
$$
\begin{align*}
\dot{\mathbf{x}} &= \mathbf{A} \mathbf{x} + \mathbf{B} \mathbf{u} \\
\mathbf{y} &= \mathbf{C} \mathbf{x} + \mathbf{D} \mathbf{u}
\end{align*}
$$
where $\mathbf{x}\in\mathbb{R}^{h}$ is an $h$-dimensional state vector, $\mathbf{u}\in\mathbb{R}^{d_{in}}$ is the $d_{in}$ dimensional input vector, $\mathbf{y}\in\mathbb{R}^{d_{out}}$ is the $d_{out}$ dimensional output vector. The LTI system is defined by the system matrices (and the initial state and input):
* The $\mathbf{A}\in\mathbb{R}^{h\times{h}}$ matrix is the state transition matrix, which describes how the state depends upon itself over time.
* The $\mathbf{B}\in\mathbb{R}^{h\times{d_{in}}}$ matrix is the input matrix, which describes how the input vector affects the state.
* The $\mathbf{C}\in\mathbb{R}^{d_{out}\times{h}}$ matrix is the output matrix, which describes how the state affects the output vector.
* The $\mathbf{D}\in\mathbb{R}^{d_{out}\times{d_{in}}}$ matrix is the feedforward matrix, which describes how the input vector affects the output vector.

Linear time-invariant state space models have been widely used in control theory, signal processing, and other fields. You may be familiar with these models from your automatic control class, where they are used to model system dynamics. In this lecture, we will focus on the discrete-time version of these models, which are often used in machine learning and signal processing applications.
* __Single Input Single Output (SISO)__: The simplest case of a linear time invariant state space model is the single input single output (SISO) case, where there is one input $d_{in} = 1$ and one output $d_{out} = 1$ _per time step_. In this case, the system can be represented by a single transfer function, which describes the relationship between the input and output.
* __Multiple Input Multiple Output (MIMO)__: In the multiple input multiple output (MIMO) case, there are multiple inputs and multiple outputs. In this case, the system can be represented by a matrix of transfer functions, which describes the relationship between the inputs and outputs.

### Discretization
Whether it is SISO or MIMO, to speed up the calculation, we discretize the continuous-time state space model and use the discrete variables of the hidden state in all calculations.
The discrete-time state space model is given by:
$$
\begin{align*}
\mathbf{x}_{t} &= \mathbf{\bar{A}} \mathbf{x}_{t-1} + \mathbf{\bar{B}} \mathbf{u}_{t} \\
\mathbf{y}_{t} &= \mathbf{\bar{C}} \mathbf{x}_{t} + \mathbf{\bar{D}} \mathbf{u}_{t}
\end{align*}
$$
where $\mathbf{x}_{t}$ is the hidden state at time $t$, $\mathbf{u}_{t}$ is the input at time $t$, and $\mathbf{y}_{t}$ is the output at time $t$. The discretized matrices $\mathbf{\bar{A}}$, $\mathbf{\bar{B}}$, and $\mathbf{\bar{C}}$ can be obtained from a variety of methods, such as the bilinear method:
$$
\begin{align*}
\mathbf{\bar{A}} &= \left(\mathbf{I}-\left(\Delta/2\right)\cdot\mathbf{A}\right)^{-1}\left(\mathbf{I}+\left(\Delta/2\right)\cdot\mathbf{A}\right) \\
\mathbf{\bar{B}} &= \left(\mathbf{I}-\left(\Delta/2\right)\cdot\mathbf{A}\right)^{-1}\left(\Delta\cdot\mathbf{B}\right) \\
\mathbf{\bar{C}} &= \mathbf{C}
\end{align*}
$$
where $\Delta$ is the time step size (sampling frequency), and $\mathbf{I}$ is the identity matrix. The bilinear method is a standard method for discretizing continuous-time state-space models, and it is used in many applications.
* _Simplification_: In most applications, we set $\mathbf{D} = 0$, which means that the output is only dependent on the hidden state and not on the input, thus $\mathbf{\bar{D}} = 0$. If this were not the case, we set $\mathbf{\bar{D}} = \mathbf{D}$, which means that the output is dependent on both the hidden state and the input.

### SISO Leg-S HiPPO matrices
The Leg-S HiPPO matrices are a specific type of structured state space model designed to capture long-range dependencies in sequential data efficiently. The Leg-S approach is based on [Legendre polynomials](https://en.wikipedia.org/wiki/Legendre_polynomials), which are a set of _orthogonal polynomials_ that can be used to represent functions over the finite interval $[-1,1]$. The Leg-S HiPPO state transition matrix $a_{ik}\in\mathbf{A}$ for a `SISO` problem is constructed as:
$$
\begin{align*}
a_{ik} &= -\begin{cases}
    \left(2i+1\right)^{1/2}\left(2k+1\right)^{1/2} & \text{if } i>k \\
    \left(i+1\right) & \text{if } i=k \\
    0 & \text{if } i<k \\
\end{cases}
\end{align*}
$$
and the $b_{n}\in\mathbf{B}$ input matrix is constructed as:
$$
\begin{align*}
b_{i} &= \left(2i+1\right)^{1/2} \\
\end{align*}
$$

The Leg-S HiPPO matrices are designed to capture long-range dependencies in sequential data efficiently. However, the matrix $\mathbf{A}$ is not invertible, which leads to some complication and computational overhead. We'll discuss this in more detail in the next section.

## Two views of the same state space model
Linear state space models, either SISO or MIMO can be operated in two ways: (i) they can process one input token at a time, or (ii) they can process all input tokens at once. The first approach is called _sequential operation_, and the second approach is called _convolutional operation_. 

### Sequential operation
Imagive that we have [a queue of input tokens $x\in\mathcal{Q}$](https://en.wikipedia.org/wiki/Queue_(abstract_data_type)) that we want to process one at a time. We can use a linear state space model to process each input token in the queue sequentially, and then place the output token in a corresponding [output queue $\mathcal{O}$](https://en.wikipedia.org/wiki/Queue_(abstract_data_type)). The time steps of the input and output queues are aligned, so that the output token at time $t$ corresponds to the input token at time $t$. Let's look at a simple algorithm for sequential processing of input tokens:

__Initialization__: The user provides $\mathbf{\bar{A}}$, $\mathbf{\bar{B}}$, $\mathbf{\bar{C}}$, and $\mathbf{\bar{D}}$ matrices. The model initializes the hidden state $\mathbf{x}_{0} = \mathbf{0}$. Set $t = 1$. Initialize the input queue $\mathcal{Q}$, output queue $\mathcal{O}$ and the hidden state storage $\mathcal{x}_{0}\rightarrow\mathcal{H}$.

While $\mathcal{Q}$ is not empty:
1. Read the input token $x\gets\mathcal{Q}$ and set $\mathbf{u}_{t} = x$.
2. Get the _previous_ hidden state $\mathbf{x}_{t-1}\gets\mathcal{H}$, and compute the _next_ hidden state $\mathbf{x}_{t} = \mathbf{\bar{A}} \mathbf{x}_{t-1} + \mathbf{\bar{B}} \mathbf{u}_{t}$.
3. Compute the _next_ output token $\mathbf{y}_{t} = \mathbf{\bar{C}} \mathbf{x}_{t} + \mathbf{\bar{D}} \mathbf{u}_{t}$.
4. Write the _next_ output token $\mathbf{y}_{t}$ to the output queue $\mathbf{y}_{t}\rightarrow\mathcal{O}$.
5. Store the new hidden state $\mathbf{x}_{t}\rightarrow\mathcal{H}$, and increment the time index $t \gets t + 1$.
6. If the input queue $\mathcal{Q}$ is __not empty__ continue, otherwise __stop__.

### Convolutional operation
In the convolutional operation, we process all $t$ input tokens at once to produce the output tokens $y_{1},y_{2},\dots,y_{t}$. Let's look at each time step of the sequential operation, where $\mathbf{x}_{0} = \mathbf{0}$, and $\mathbf{\bar{D}} = 0$. For $i = 0, 1, 2, \dots, t$ we have the output tokens:
$$
\begin{align*}
\mathbf{y}_{0} & = \mathbf{\bar{C}}\mathbf{\bar{B}}\mathbf{u}_{0}\quad |~\textit{substitute}\quad\mathbf{x}_{0} = \mathbf{\bar{B}}\mathbf{u}_{0} \\
\mathbf{y}_{1} & = \mathbf{\bar{C}}\mathbf{\bar{A}}\mathbf{\bar{B}}\mathbf{u}_{0} + \mathbf{\bar{C}}\mathbf{\bar{B}}\mathbf{u}_{1}\quad |~\textit{substitute}\quad\mathbf{x}_{1} = \mathbf{\bar{A}}\mathbf{\bar{B}}\mathbf{u}_{0}+\mathbf{\bar{B}}\mathbf{u}_{1}\\
\mathbf{y}_{2} & = \mathbf{\bar{C}}\mathbf{\bar{A}}^{2}\mathbf{u}_{0} + \mathbf{\bar{C}}\mathbf{\bar{A}}\mathbf{\bar{B}}\mathbf{u}_{1} + \mathbf{\bar{C}}\mathbf{\bar{B}}\mathbf{u}_{2}\quad |~\textit{substitute}\quad\mathbf{x}_{2} = \mathbf{\bar{A}}^{2}\mathbf{\bar{B}}\mathbf{u}_{0} +
\mathbf{\bar{A}}\mathbf{\bar{B}}\mathbf{u}_{1} + \mathbf{\bar{B}}\mathbf{u}_{2} \\
\vdots & \\
\mathbf{y}_{t} & = \mathbf{\bar{C}}\mathbf{\bar{A}}^{t}\mathbf{u}_{0} + \mathbf{\bar{C}}\mathbf{\bar{A}}^{t-1}\mathbf{\bar{B}}\mathbf{u}_{1} + \mathbf{\bar{C}}\mathbf{\bar{A}}^{t-2}\mathbf{\bar{B}}\mathbf{u}_{2} + \cdots + \mathbf{\bar{C}}\mathbf{\bar{B}}\mathbf{u}_{t}\quad\blacksquare
\end{align*}
$$
which we can rewrite as the convoluation operation:
$$
\begin{align*}
\mathbf{y}_{t} & = \sum_{i=0}^{t}\mathbf{\bar{C}}\mathbf{\bar{A}}^{t-i}\mathbf{\bar{B}}\mathbf{u}_{i} \\
& = \mathbf{\bar{C}}\sum_{i=0}^{t}\mathbf{\bar{A}}^{t-i}\mathbf{\bar{B}}\mathbf{u}_{i} \\
& = \mathbf{\bar{C}}\sum_{i=0}^{t}\mathbf{\bar{A}}^{i}\mathbf{\bar{B}}\mathbf{u}_{t-i}\qquad\blacksquare\\
\end{align*}
$$

## S5: Extention to MIMO Systems
The Simplified Structured State Space Sequence (S5) model is a generalization of the Leg-S HiPPO model to multiple input multiple output (MIMO) systems. See: [Smith, J., Warrington, A., & Linderman, S.W. (2022). Simplified State Space Layers for Sequence Modeling. ArXiv, abs/2208.04933.](https://arxiv.org/abs/2208.04933)

The S5 system is similar to the Leg-S HiPPO S4 system, but it uses different matricies for the state transition and input matrices. In paricular, in the S5 system, we want the state transition matrix $\mathbf{A}$ to a _diagonal matrix_. Let's diagonalize the Leg-S HiPPO matrix $\mathbf{A}$.

### Background: Diagonalization of a square matrix
A square matrix $\mathbf{A}\in\mathbb{R}^{n\times{n}}$ is said to be diagonalizable if there exists an invertible matrix $\mathbf{V}$ and a diagonal matrix $\mathbf{D}$ such that:
$$
\mathbf{A} = \mathbf{V}\mathbf{D}\mathbf{V}^{-1}
$$
where $\mathbf{D}$ is a diagonal matrix with the eigenvalues of $\mathbf{A}$ along the diagonal, and $\mathbf{V}$ is the matrix of eigenvectors of $\mathbf{A}$ (eigenvectors on the columns). Let's try this with the Leg-S HiPPO matrix $\mathbf{A}$:


In [8]:
A = let

    # initialize -
    h = 100; # internal hidden state memory size
    A = Array{Float64,2}(undef, h, h); # internal hidden state memory
    
    # build the A-matrix
    for i ∈ 1:h
        for k = 1:h
            
            if (i > k)
                A[i,k] = -sqrt((2*i+1))*sqrt((2*k+1));
            elseif (i == k)
                A[i,k] = -(i+1);
            else
                A[i,k] = 0.0;
            end
        end
    end

    A; # return -
end

100×100 Matrix{Float64}:
  -2.0        0.0        0.0        0.0      …     0.0       0.0       0.0
  -3.87298   -3.0        0.0        0.0            0.0       0.0       0.0
  -4.58258   -5.91608   -4.0        0.0            0.0       0.0       0.0
  -5.19615   -6.7082    -7.93725   -5.0            0.0       0.0       0.0
  -5.74456   -7.4162    -8.77496   -9.94987        0.0       0.0       0.0
  -6.245     -8.06226   -9.53939  -10.8167   …     0.0       0.0       0.0
  -6.7082    -8.66025  -10.247    -11.619          0.0       0.0       0.0
  -7.14143   -9.21954  -10.9087   -12.3693         0.0       0.0       0.0
  -7.54983   -9.74679  -11.5326   -13.0767         0.0       0.0       0.0
  -7.93725  -10.247    -12.1244   -13.7477         0.0       0.0       0.0
   ⋮                                         ⋱                      
 -23.5584   -30.4138   -35.9861   -40.8044         0.0       0.0       0.0
 -23.6854   -30.5778   -36.1801   -41.0244         0.0       0.0       0.0
 -23.8

__Blast from the past!__ In week 2 we learned eigendecomposition of a matrix. Let's compute the eigenvalues/eigenvectors computed using [the `eigen(...)` function](https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.eigen) which takes a square array `A` as an argument and returns the eigendecomposition.

In [None]:
(Λ,V) = let

    # initialize -
    (n,m) = size(A); # what is the dimension of A?
    Λ = Matrix{Float64}(1.0*I, n, n); # builds the I matrix, we'll update with λ -
    
    # Decompose using the built-in function
    F = eigen(A);   # eigenvalues and vectors in F of type Eigen
    λ = F.values;   # vector of eigenvalues
    V = F.vectors;  # n x n matrix of eigenvectors, each col is an eigenvector

    # package the eigenvalues into Λ -
    for i ∈ 1:n
        Λ[i,i] = λ[i];
    end

    Λ,V
end;

Great! Now that we have the eigenvalues and eigenvectors of the Leg-S HiPPO matrix $\mathbf{A}$, we can compute the diagonal matrix $\mathbf{D}$ and the invertible matrix $\mathbf{V}$:
$$
\begin{align*}
\mathbf{A} &= \mathbf{V}\mathbf{D}\mathbf{V}^{-1} \\
\mathbf{V}^{-1}\mathbf{A} &= \mathbf{D}\mathbf{V}^{-1} \\
\mathbf{V}^{-1}\mathbf{A}\mathbf{V} &= \mathbf{D} \\
\mathbf{V}^{-1}\mathbf{A}\mathbf{V} &= \text{diag}(\lambda_{1},\lambda_{2},\dots,\lambda_{n}) \\
\end{align*}
$$
where $\lambda_{i}$ are the eigenvalues of the matrix $\mathbf{A}$. The diagonal matrix $\mathbf{D}$ is a diagonal matrix with the eigenvalues of the matrix $\mathbf{A}$ along the diagonal. The invertible matrix $\mathbf{V}$ is the matrix of eigenvectors of the matrix $\mathbf{A}$.

In [12]:
D = inv(V)*A*V; # diagonalize A using the eigenvectors

__Check__: If this worked, then $\mathbf{\Lambda} = \mathbf{D}$, let's check this using [the `isapprox(...)` function](https://docs.julialang.org/en/v1/base/math/#Base.isapprox) and [the @assert macro](https://docs.julialang.org/en/v1/base/base/#Base.@assert) to check if the two matrices are approximately equal:

In [None]:
@assert D ≈ Λ; # check if D is diagonal equal to Λ

AssertionError: AssertionError: D ≈ Λ

### Rethikning the $\mathbf{A}$ matrix
__Hmmm__. As it turns out, the _original_ Leg-S HiPPO matrix $\mathbf{A}$ is _not_ diagonalizable. However, we can fix this issue with a little magic, i.e., reformulate $\mathbf{A}$ matrix. Let's rewrite the Leg-S HiPPO matrix $\mathbf{A}$ as the sum of [a _normal_ matrix](https://en.wikipedia.org/wiki/Normal_matrix) and a low-rank matrix $\mathbf{P}\in\mathbb{R}^{n}$:
$$
\begin{align*}
\mathbf{A} &= \mathbf{A}^{\text{normal}} - \mathbf{P}\mathbf{P}^{\top} \\
\end{align*}
$$
where $\mathbf{A}^{\text{normal}}$ is a normal matrix, and $\mathbf{P}$ is a low-rank matrix. The normal and low-rank matrices are defined as follows:
$$
\begin{align*}
a^{\text{normal}}_{ik} &= -\begin{cases}
    \left(i+\frac{1}{2}\right)^{1/2}\left(k+\frac{1}{2}\right)^{1/2} & \text{if } i>k \\
    \frac{1}{2} & \text{if } i=k \\
    \left(i+\frac{1}{2}\right)^{1/2}\left(k+\frac{1}{2}\right)^{1/2} & \text{if } i<k \\
\end{cases}
\end{align*}
$$
and
$$
\begin{align*}
p_{i} & = \left(i+\frac{1}{2}\right)^{1/2}
\end{align*}
$$
So let's try this again. Compute the revised $\mathbf{A}$ matrix using the normal and low-rank matrices (save this in the `Â::Array{Float64,2}` variable).

In [19]:
Â = let

    # initialize -
    h = 10; # internal hidden state memory size
    AN = Array{Float64,2}(undef, h, h); # internal hidden state memory
    P = Array{Float64,2}(undef, h, 1); # internal hidden state memory

    # build the A-matrix
    for i ∈ 1:h
        for k = 1:h
            
            if (i > k)
                AN[i,k] = -sqrt((i+1/2))*sqrt((k+1/2));

            elseif (i == k)
                AN[i,k] = -1/2;
            else
                AN[i,k] = -sqrt((i+1/2))*sqrt((k+1/2));
            end
        end
    end

    # build the B-matrix
    for i ∈ 1:h
        P[i,1] = sqrt((i+1/2));
    end

   
    # compute -
    A = AN - P*P'; # A = A + P*P'
   
    A; # return -
end

10×10 Matrix{Float64}:
 -2.0       -3.87298   -4.58258  …   -7.14143   -7.54983   -7.93725
 -3.87298   -3.0       -5.91608      -9.21954   -9.74679  -10.247
 -4.58258   -5.91608   -4.0         -10.9087   -11.5326   -12.1244
 -5.19615   -6.7082    -7.93725     -12.3693   -13.0767   -13.7477
 -5.74456   -7.4162    -8.77496     -13.6748   -14.4568   -15.1987
 -6.245     -8.06226   -9.53939  …  -14.8661   -15.7162   -16.5227
 -6.7082    -8.66025  -10.247       -15.9687   -16.8819   -17.7482
 -7.14143   -9.21954  -10.9087       -9.0      -17.9722   -18.8944
 -7.54983   -9.74679  -11.5326      -17.9722   -10.0      -19.975
 -7.93725  -10.247    -12.1244      -18.8944   -19.975    -11.0

Then compute the eigendecomposition of the matrix $\mathbf{\hat{A}}$ using [the `eigen(...)` function](https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.eigen). Save the eigenvalues in the `λ::Array{Float64,1}` variable and the eigenvectors in the `V::Array{Float64,2}` variable. Then compute the diagonal matrix $\mathbf{D}$ and the invertible matrix $\mathbf{V}$:

In [32]:
(Λ̂, V̂) = let

    # initialize -
    (n,m) = size(Â); # what is the dimension of A?
    Λ = Matrix{Float64}(1.0*I, n, n); # builds the I matrix, we'll update with λ -
    
    # Decompose using the built-in function
    F = eigen(Â);   # eigenvalues and vectors in F of type Eigen
    λ = F.values;   # vector of eigenvalues
    V = F.vectors;  # n x n matrix of eigenvectors, each col is an eigenvector

    # package the eigenvalues into Λ -
    for i ∈ 1:n
        Λ[i,i] = λ[i];
    end

    # @assert Â ≈ V*Λ*V'; # check if Â is diagonal equal to Λ

    Λ, V;
end;

In [33]:
inv(V̂)*Â*V̂ ≈ Λ̂ # diagonalize A using the eigenvectors

true

## Lab: Spiking Neural Networks and S5
In lab `L15b`, we will will implement an embedding layer composed of a Spiking Neural Network (SNN) and then play around with the S5 model for a few text classification tasks. Should be fun!

# Today?
That's a wrap! What are some of the interesting things we discussed today?