# Mamba

In this example, we will implement the new Mamba model from Albert Gu and Tri Dao [[1]](https://arxiv.org/abs/2312.00752) by utilising the new `SelectiveStateSpaceModel` layer. 

In this example, you will learn the following:

    - how to implement Mamba
    - how to use a shared layer

Special thanks and cretits go to John (Zhiyao) Ma and his excellent Mamba implementation in PyTorch, which served as a great inspriration and foundation for this Equinox version. Go check it out [here](https://github.com/johnma2006/mamba-minimal).

The original implementation includes **a lot** of CUDA code [[2]](https://github.com/state-spaces/mamba) to optimise the so-called `selective_scan` algorithm, but this first iteration of the `SelectiveStateSpaceModel` implementation is not as heavily optimised. However, in future iterations, by using some clever Pallas code, we can get to the same performance. 

The following image shows the high level architecture of Mamba which we will implement.

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba1.drawio.svg" width="30%">
</div>

Before dive into the `ResidualBlock` part, which contains the main `SelectiveStateSpaceModel` code, let's quickly build everything around it first. Also note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this.

In [2]:
import equinox as eqx
import jax
from jaxtyping import Array, Float, Int, PRNGKeyArray

In [4]:
class Mamba(eqx.Module, strict=True):
    layers: eqx.nn.Sequential
    normalization: eqx.nn.RMSNorm
    shared_emb_lm_head: eqx.nn.Shared

    def __init__(self, n_layers: int, n_dims: int, n_embd: int, *, key: PRNGKeyArray):
        key, *subkeys = jax.random.split(key, 1 + n_layers)
        self.layers = eqx.nn.Sequential(
            [ResidualBlock(key=subkeys[i + 1]) for i in range(n_layers)],
        )
        self.normalization = eqx.nn.RMSNorm(n_embd)

        embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])
        lm_head = eqx.nn.Linear(
            n_embd,
            n_dims,
            use_bias=False,
            key=subkeys[-1],
        )
        where = lambda embed_and_lin: embed_and_lin[1].weight
        get = lambda embed_and_lin: embed_and_lin[0].weight
        self.shared_emb_lm_head = eqx.nn.Shared(
            (embedding, lm_head), where=where, get=get
        )

    def __call__(
        self,
        x: Int[Array, "seq_len"],  # noqa
        *,
        key: PRNGKeyArray = None,
    ) -> Float[Array, "seq_len n_dims"]:  # noqa
        embedding, linear = self.shared_emb_lm_head()
        x = jax.vmap(embedding)(x)

        x = self.layers(x)
        x = jax.vmap(self.normalization)(x)
        logits = jax.vmap(linear)(x)
        return logits

AttributeError: module 'equinox.nn' has no attribute 'RMSNorm'

We haven't implementated `ResidualBlock` yet, but we will get there soon. Note the usage of `eqx.nn.Shared`:

```python
    # Embedding layer
    embedding = eqx.nn.Embedding(
        n_dims, n_embd, key=subkeys[0]
    )
    # Linear layer
    lm_head = eqx.nn.Linear(
        n_embd,
        n_dims,
        use_bias=False,
        key=subkeys[-1],
    )
    # refers to the linear weights
    where = lambda embed_and_lin: embed_and_lin[1].weight 

    # refers to the embedding weights
    get = lambda embed_and_lin: embed_and_lin[0].weight

    # Create a shared layer
    self.shared_emb_lm_head = eqx.nn.Shared(
        (embedding, lm_head), where=where, get=get
    )
```

And to use the shared layers, we have to get them first out of the shared layer:

```python
    embedding, linear = self.shared_emb_lm_head()
    # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively
    # proceed usage as usual
```

Let's continue with the `ResidualBlock`.

In [None]:
class ResidualBlock(eqx.Module):
    pass