In [1]:
# clone our repo

%cd /content
!git clone https://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@github.com/rishi1001/mamba_needle.git
%cd /content/mamba_needle

/content
Cloning into 'mamba_needle'...
remote: Enumerating objects: 2840, done.[K
remote: Counting objects: 100% (623/623), done.[K
remote: Compressing objects: 100% (215/215), done.[K
remote: Total 2840 (delta 355), reused 577 (delta 335), pack-reused 2217 (from 1)[K
Receiving objects: 100% (2840/2840), 177.80 MiB | 16.29 MiB/s, done.
Resolving deltas: 100% (454/454), done.
/content/mamba_needle


In [2]:
!pip3 install --upgrade --no-deps git+https://github.com/dlsys10714/mugrade.git
!pip3 install pybind11

Collecting git+https://github.com/dlsys10714/mugrade.git
  Cloning https://github.com/dlsys10714/mugrade.git to /tmp/pip-req-build-ws3pj2p0
  Running command git clone --filter=blob:none --quiet https://github.com/dlsys10714/mugrade.git /tmp/pip-req-build-ws3pj2p0
  Resolved https://github.com/dlsys10714/mugrade.git to commit 656cdc2b7ad5a37e7a5347a7b0405df0acd72380
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mugrade
  Building wheel for mugrade (setup.py) ... [?25l[?25hdone
  Created wheel for mugrade: filename=mugrade-1.2-py3-none-any.whl size=3935 sha256=46efcc00ede63588feb1beabcd1cca0308675a69c3d3ef282cd4b39da08a6d5a
  Stored in directory: /tmp/pip-ephem-wheel-cache-kh3fjopv/wheels/8b/ba/3a/621da1207eab160c01968c5e0bd1266f505b9e3f8010376d61
Successfully built mugrade
Installing collected packages: mugrade
Successfully installed mugrade-1.2
Collecting pybind11
  Downloading pybind11-2.13.6-py3-none-any.whl.metadata (9.5 kB)
Download

In [3]:
!make

-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /usr/local/bin/python (found version "3.10.12") found components: Development Interpreter Development.Module Development.Embed
-- Performing Test HAS_FLTO
-- Performing Test HAS_FLTO - Success
-- Found pybind11: /usr/local/lib/python3.10/dist-packages/pybind11/include (found version "2.13.6")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- Found CUDA: /usr/local/cuda (found version "12.2")
-- Found cu

In [4]:
%set_env PYTHONPATH ./python
%set_env NEEDLE_BACKEND nd

env: PYTHONPATH=./python
env: NEEDLE_BACKEND=nd


In [5]:
import sys
sys.path.append('./python')
sys.path.append('./apps')
sys.path.append('.')

# **Mamba in Needle**


![Mamba Architecture](https://raw.githubusercontent.com/rishi1001/mamba_needle/refs/heads/main/images/mamba.png?token=GHSAT0AAAAAAC2WHEE7MJET7W6IAWQO2LIKZ2YXOMQ)


## **A State-Space Model Architecture**

### **Introduction**

In the homeworks for this course, we implemented three sequence models in the Needle library: recurrent neural networks (RNNs), long short-term memory (LSTM) networks and transformers. In this project, we add a fourth model to this list: Mamba.

Mamba is a state-space model architecture designed to combine the high accuracy of transformers with the efficiency of linear RNNs. This report explains the key principles behind Mamba, its efficiency, scalability, and how it achieves constant time inference.

### **Efficiency**

* **Linear Time Training**: Unlike transformers, which scale quadratically with sequence length, Mamba optimizes training to be linear.  
* **Constant Time Inference**: Enables real-time applications by reducing computational overhead.

### **Scalability**

* **Long-Range Tasks**: Handles long sequences efficiently.  
* **Fixed Computation per Timestep**: Ensures scalability while avoiding quadratic complexity.

The state-space model (SSM) forms the backbone of Mamba's architecture:

1. **State Equation**: \[ h'(t) \= A h(t) \+ B x(t) \]  
2. **Output Equation**: \[ y(t) \= C h(t) \+ D x(t) \]


![SSM Equations](https://raw.githubusercontent.com/rishi1001/mamba_needle/refs/heads/main/images/ssm_equations.jpg?token=GHSAT0AAAAAAC2WHEE6QPFRDU7VYS7DUN4GZ2YXVPA)


The state equation describes how the state changes as a function of time and as a function of its input x.

* **h’(t)**: the state update  
* **Ah(t)**: how the current state changes over time  
* **Bx(t)**: how the input influences the state

The output equation describes how the output changes as a function of the state and as a function of its input x.

* **y(t)**: the output  
* **Ch(t)**: how the current state affects the output  
* **Dx(t)**: how the input directly influences the output

We transform the continuous h(t) from above to a discrete form. A continuous h(t) is difficult to work with and also does not represent the input since inputs are usually discrete (i.e. sequence of text). This process can be compared to transforming the area under a curve to a Riemann sum. The discretized matrices A and B are calculated using the formula shown below.

So, now, we have the discretized version :

Note that SSMs can be calculated in the form of kernels. We can precaulcate the kernels and apply them parallely over the input, similar to CNNs.

### **Mamba Block**

The Mamba model is composed of a sequence of Mamba blocks. The architecture of the Mamba block is shown below.


![Mamba Block](https://raw.githubusercontent.com/rishi1001/mamba_needle/refs/heads/main/images/mamba_block.jpg?token=GHSAT0AAAAAAC2WHEE6HL6TYTNLBUGHHMHGZ2YXXQQ)


We implemented the following blocks:

* **RMSNorm**: Applied over the last (not batch size or sequence length) dimension. Divides each input by the L2 norm of the last dimension. A small value (epsilon) can be added to the denominator to ensure there will be no divide by 0 errors.  
* **Conv1D**: While Conv2D was applied over the last 2 dimensions, Conv1D is applied over the last dimension. It takes input of shape (batch\_size, in\_channels, sequence\_length) and slides a kernel of size (kernel\_size, in\_channels, out\_channels) over the length of the sequence. We added a groups argument, which determines how many groups to divide the inputs and outputs into. Each group is convolved separately, so the inputs only affect outputs in the same group.  
* **SiLU**: Input times logistic sigmoid.   
* **Selective SSM** (see next section)

### **Selective SSM**

SSMs are poor at selectively remembering specific inputs. For instance, a SSM would fail at copying specific parts of an input and outputting them in order. This is because the matrices A, B, and C are the same for each token the SSM generates, resulting in the inability to treat tokens differently.

A selective SSM solves this problem by making dt, B, and C dependent on the input. Recall that Bx(t) represents how the input influences the state and Ch(t) represents how the current state affects the output. The architecture for the Selective SSM is shown below.


![Selective SSM Block](https://github.com/rishi1001/mamba_needle/blob/main/images/selective_ssm.jpg?raw=1)


### **Sequential Selective Scan**

The selective scan implements the updates to the state equation and the output equation (see State Space Models section). The state equation is dependent on the previous state, which allows us to formulate the equation updates as a RNN as shown below.

Note that, since these matrices are now dynamic, they cannot be calculated using the convolution representation since it assumes a fixed kernel(like in SSMs). We can only use the recurrent representation and lose the parallelization the convolution provides. So, we use a prefix parallel scan algorithm during training as described ahead, and recurrent representation while inference.

# **Parallel prefix scans**

So far, we have seen an implementation of a Mamba that performs the forward pass sequentially; in linear time with respect to the length of the input sequence. According to the equations provided above, this is the natural implementation. And even this implementation provides time complexity improvements over the quadratic time transformer model.

However, the authors of the mamba paper argued that the computations can also equivalently be written as prefix sums. And given previous research on parallel prefix sum algorithms, they showed that they could further reduce the time complexity.

The parallel prefix sum algorithm we used was initially described by CMU Professor Guy Blelloch and further documented on NVIDIA’s website\[1\]. It’s divided up into two phases: the “up sweep” phase and the “down sweep” phase.

Each phase has on the order of log\_2(L) steps, where L is the length of the input sequence. And the actual number of computations done across all steps — referred to on the website as “work efficiency” — is also linear, meaning that asymptotically no extra computations are done in this algorithm versus the sequential implementation.

With all this being said, we needed to augment this base integer prefix sum algorithm in two ways. First: we needed to add support for multiple dimensions given that our input isn’t just a 1D array.

The solution for the first problem is simple: we just launch a separate thread block for every dimension that we compute like this:

The second problem is that we’re not just computing a basic prefix sum; each step includes a multiplication and an addition. Let’s take a second look at the to make this more concrete:

ℎ𝑡 \= 𝑨\_tℎ𝑡−1 \+ 𝑩\_t𝑥.

First notice that the 𝑩\_t𝑥 quantity is not dependent on previous hidden states; it can be calculated as a simple matrix element-wise multiplication for all t before running the prefix sum algorithm.

The next thing we need to do is shift our “prefix sum” approach to a more general “prefix scan” approach. The prefix scan is a more general version of the prefix sum, where an accumulator function abstraction are introduced. For prefix sums, this is the addition operator. The semantics here are equivalent to that of the itertools.accumulate function in Python.

So our goal now is to rewrite ℎ𝑡 \= 𝑨\_tℎ𝑡−1 \+ 𝑩\_t𝑥 in a form that can be computed using the up sweep/down sweep algorithm described above. This means that we need to be able to write h\_t as the partial accumulations of h\[0:i\] and h\[i:t\], essentially requiring us to define the accumulation functions.

The final accumulation of h\_t is actually the sum(h\[i:t\]) \+ prod(A\[i:t\]) \* sum(h\[0:i\]). To see why this is the case, we can try the simple case of computing h\_1 with i \= 1:

h\_1 \= 𝑨ℎ\_0 \+ 𝑩𝑥\_1 \= 𝑨\_1Bx\_0 \+ 𝑩𝑥\_1 \= prod() 𝑩𝑥\_1

with h\_0 being just Bx\_0

Based on this, we argue that calculating the prefix scan of h\_1 is actually two parallel prefix scans: one for keeping track of the partial products in the A array, and the other for actually keeping the partial accumulations of the hidden states.

Therefore, the accumulation function for the A prefix scan is just the multiplication operator. And for the hidden states, it’s actually just the sum(h\[i:t\]) \+ prod(A\[i:t\]) \* sum(h\[0:i\]).

One last implementation detail here is that we require an inclusive prefix scan rather than an exclusive prefix scan . We notice that the

In order to test the running time of this algorithm versus the sequential prefix scan implementation, we ran a simple experiment where we initialized random A and X matrices and ran each scan implementation 1,000 times with different sequence lengths, keeping the rest of the dimensions constant:

In [10]:
from time import perf_counter

import numpy as np

import needle as ndl

B, D, N = 10, 28, 10
device = ndl.cuda()

L = 64 # sequence length

A_n = np.random.rand(B, D, L, N)
X_n = np.random.rand(B, D, L, N)

A = ndl.Tensor(A_n, device=device)
X = ndl.Tensor(X_n, device=device)

start = perf_counter()
for _ in range(1_000):
  y = A.cached_data.pscan(X.cached_data).numpy()
print(f"CUDA pscan impelmentation (s):", perf_counter() - start)

start = perf_counter()
for _ in range(1_000):
  result = np.zeros((B, D, L, N))
  result[:, :, 0, :] = X_n[:, :, 0, :]
  for i in range(1, L):
    result[:, :, i, :] = (result[:, :, i - 1, :] * A_n[:, :, i, :]) + X_n[
      :, :, i, :
    ]

print(f"numpy sequential scan benchmark (s):", perf_counter() - start)

print("testing correctness")
np.testing.assert_allclose(y, result, atol=1e-5, rtol=1e-5)
print("correct")

CUDA pscan impelmentation (s): 0.6208266129999629
numpy sequential scan benchmark (s): 1.9076007060000393
testing correctness
correct


Even for a sequence length of 64, we saw that the parallel CUDA version was about 2-3 times faster than the sequential version. Of course, we expect the gap widens as the sequence length increases given the logarithmic time complexity the parallel algorithm provides.

Finally, we acknowledge that the prefix scan algorithm, as we implemented it in CUDA, has a few opportunities for improvement that we can work on in the future:

* currently only supports small sequence lengths (L \<= 128\) given that we try to compute the prefix sum for an entire channel within a single thread block  
* currently only supports sequence lengths that are powers of 2

# **Using Mamba in Needle**

add code snippet to instantiate mamba model, and explain hyperparameters

Add training code

Add generating code

Add plots

# **Future Work**

* Making the code hardware aware : A disadvantage of recent GPUs is their limited transfer (IO) speed between their small but highly efficient SRAM and their large but slightly less efficient DRAM. So, instead of preparing the scan input (𝑨, 𝑩) of size (B, L, D, N) in GPU HBM (high-bandwidth memory), load the SSM parameters (Δ, 𝑨, 𝑩, 𝑪) directly from slow HBM to fast SRAM, perform the discretization and recurrence in SRAM, and then write the final outputs of size (B, L, D) back to HBM.

# **References**