# Mamba

In this example, we will implement the new Mamba model from Albert Gu and Tri Dao [[1]](https://arxiv.org/abs/2312.00752) by utilising the new `SelectiveStateSpaceModel` layer. 

In this example, you will learn the following:

    - how to implement Mamba
    - how to use a shared layer

Special thanks and cretits go to John (Zhiyao) Ma and his excellent Mamba implementation in PyTorch, which served as a great inspriration and foundation for this Equinox version. Go check it out [here](https://github.com/johnma2006/mamba-minimal).

The original implementation includes **a lot** of CUDA code [[2]](https://github.com/state-spaces/mamba) to optimise the so-called `selective_scan` algorithm, but this first iteration of the `SelectiveStateSpaceModel` implementation is not as heavily optimised. However, in future iterations, by using some clever Pallas code, we can get to the same performance. 

The following image shows the high level architecture of Mamba which we will implement.

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba1.drawio.svg" width="30%">
</div>

Before dive into the `ResidualBlock` part, which contains the main `SelectiveStateSpaceModel` code, let's quickly build everything around it first. Also note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this.

In [4]:
from typing import Optional

import equinox as eqx
import jax
from jaxtyping import Array, Float, Int, PRNGKeyArray

In [5]:
class Mamba(eqx.Module, strict=True):
    layers: eqx.nn.Sequential
    normalization: eqx.nn.RMSNorm
    shared_emb_lm_head: eqx.nn.Shared

    def __init__(self, n_layers: int, n_dims: int, n_embd: int, *, key: PRNGKeyArray):
        key, *subkeys = jax.random.split(key, 1 + n_layers)
        self.layers = eqx.nn.Sequential(
            [ResidualBlock(key=subkeys[i + 1]) for i in range(n_layers)],
        )
        self.normalization = eqx.nn.RMSNorm(n_embd)

        embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])
        lm_head = eqx.nn.Linear(
            n_embd,
            n_dims,
            use_bias=False,
            key=subkeys[-1],
        )
        where = lambda embed_and_lin: embed_and_lin[1].weight
        get = lambda embed_and_lin: embed_and_lin[0].weight
        self.shared_emb_lm_head = eqx.nn.Shared(
            (embedding, lm_head), where=where, get=get
        )

    def __call__(
        self,
        x: Int[Array, "seq_len"],  # noqa
        *,
        key: Optional[PRNGKeyArray] = None,
    ) -> Float[Array, "seq_len n_dims"]:  # noqa
        embedding, linear = self.shared_emb_lm_head()
        x = jax.vmap(embedding)(x)

        x = self.layers(x)
        x = jax.vmap(self.normalization)(x)
        logits = jax.vmap(linear)(x)
        return logits

We haven't implementated `ResidualBlock` yet, but we will get there soon. Note the usage of `eqx.nn.Shared`:

```python
    # Embedding layer
    embedding = eqx.nn.Embedding(
        n_dims, n_embd, key=subkeys[0]
    )
    # Linear layer
    lm_head = eqx.nn.Linear(
        n_embd,
        n_dims,
        use_bias=False,
        key=subkeys[-1],
    )
    # refers to the linear weights
    where = lambda embed_and_lin: embed_and_lin[1].weight 

    # refers to the embedding weights
    get = lambda embed_and_lin: embed_and_lin[0].weight

    # Create a shared layer
    self.shared_emb_lm_head = eqx.nn.Shared(
        (embedding, lm_head), where=where, get=get
    )
```

And to use the shared layers, we have to get them first out of the shared layer:

```python
    embedding, linear = self.shared_emb_lm_head()
    # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively
    # proceed usage as usual
```

Let's continue with the `ResidualBlock`.

Here's an overview of what the components of the `ResidualBlock` will look like.

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba2.drawio.svg" width="30%">
</div>

As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now.

In [11]:
class ResidualBlock(eqx.Module, strict=True):
    mamba_block: MambaBlock
    rns_norm: eqx.nn.RMSNorm

    def __init__(self, n_embd: int, *, key: PRNGKeyArray):
        self.mamba_block = MambaBlock(
            key=key,
        )
        self.rns_norm = eqx.nn.RMSNorm(n_embd)

    def __call__(
        self, x: Float[Array, "seq_len n_embd"], *, key: Optional[PRNGKeyArray] = None
    ) -> Array:
        return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x

We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. 

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba3.drawio.svg" width="60%">
</div>

In [10]:
class MambaBlock(eqx.Module):
    in_proj: eqx.nn.Linear
    conv1d: eqx.nn.Conv1d
    ssm: SSM
    out_proj: eqx.nn.Linear

    def __init__(
        self,
        n_embd: int,
        d_inner: int,
        dt_rank: int,
        d_conv: int,
        use_in_projection_bias: bool=True, 
        use_conv_bias: bool=True,
        use_out_proj_bias: bool = True,
        *,
        key: PRNGKeyArray,
    ):
        (
            key,
            linear_key,
            conv1d_key,
            ssm_key,
            out_proj_key,
        ) = jax.random.split(key, 5)

        self.in_proj = eqx.nn.Linear(
            n_embd,
            d_inner * 2,
            use_bias=use_in_projection_bias,
            key=linear_key,
        )

        self.conv1d = eqx.nn.Conv1d(
            in_channels=d_inner,
            out_channels=d_inner,
            kernel_size=d_conv,
            use_bias=use_conv_bias,
            groups=d_inner,
            padding=d_conv - 1,
            key=conv1d_key,
        )
        self.ssm = SSM(key=ssm_key)
        self.out_proj = eqx.nn.Linear(
            d_inner,
            n_embd,
            use_bias=use_out_proj_bias,
            key=out_proj_key,
        )

    def __call__(self, x: Array):
        seq_len, d = x.shape
        x_and_res = jax.vmap(self.in_proj)(x)

        (x, res) = jnp.split(x_and_res, 2, axis=-1)
        x = jnp.transpose(x)
        x = self.conv1d(x)[:, :seq_len]
        x = jnp.transpose(x)
        x = jax.nn.silu(x)

        y = self.ssm(x)
        y = y * jax.nn.silu(res)

        output = jax.vmap(self.out_proj)(y)
        return output

In [9]:
class SSM(eqx.Module):
    pass