In [None]:
from typing import Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray

# Mamba

In this example, we will implement the new Mamba model from Albert Gu and Tri Dao [[1]](https://arxiv.org/abs/2312.00752) by utilising the new `SelectiveStateSpaceModel` layer. 

In this example, you will learn the following:

    - how to implement Mamba
    - how to use a shared layer

Special thanks and cretits go to John (Zhiyao) Ma and his excellent Mamba implementation in PyTorch, which served as a great inspriration and foundation for this Equinox version. Go check it out [here](https://github.com/johnma2006/mamba-minimal).

The original implementation includes **a lot** of CUDA code [[2]](https://github.com/state-spaces/mamba) to optimise the so-called `selective_scan` algorithm, but this first iteration of the `SelectiveStateSpaceModel` implementation is not as heavily optimised. However, in future iterations, by using some clever Pallas code, we can get to the same performance. 

The following image shows the high level architecture of Mamba which we will implement.

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba1.drawio.svg" width="30%">
</div>

If we zoom into the `ResidualBlock`, we find the following:

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba2.drawio.svg" width="30%">
</div>

As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now. Let's keep on zooming until we get to the deepest component - at which point we can start to implement everything and work our way back up. Let's keep going.

We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. 

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba3.drawio.svg" width="60%">
</div>

Most of the parts we need are already present in Equinox's library. What's missing though is the new `SelectiveStateSpaceModel` (abbreviated as `SSM` above). Everything in green are trainable parameters. 

<div style="display: flex; justify-content: center; margin-left: auto; width: 100%">
    <img src="../imgs/Mamba4.drawio.svg" width="60%">
</div>

Alright! This is the deepest we can get. We've reached the point at which we have all needed components available to us (except for the `selective_scan` function, but that's not a problem). Let's start with the `SelectiveStateSpaceModel` and then work our way back up again.

In [1]:
def selective_scan(
    x: Float[Array, "seq_length d_inner"],
    delta: Float[Array, "seq_length d_inner"],
    A: Float[Array, "d_inner d_state"],
    B: Float[Array, "seq_length d_state"],
    C: Float[Array, "seq_length d_state"],
    D: Float[Array, " d_inner"],
) -> Float[Array, "seq_length d_inner"]:
    L, d_inner = x.shape
    _, d_state = A.shape
    delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A))
    delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, x)

    x_res = jnp.zeros(shape=(d_inner, d_state))

    def step(x, i):
        x = delta_A[i] * x + delta_B_u[i]

        y = jnp.einsum("d n,n -> d", x, C[i, :])
        return x, y

    _, ys = jax.lax.scan(step, x_res, jnp.arange(L))

    ys = ys + x * D
    return ys

NameError: name 'Float' is not defined

In [None]:
class SelectiveStateSpaceModel(eqx.Module, strict=True):
    input_proj: eqx.nn.Linear
    delta_proj: eqx.nn.Linear
    A_log: Float[Array, "d_inner d_state"]
    D: Float[Array, " d_inner"]

    d_inner: int = eqx.field(static=True)
    dt_rank: int = eqx.field(static=True)
    d_state: int = eqx.field(static=True)

    def __init__(
        self,
        d_inner: int,
        dt_rank: int,
        d_state: int,
        use_input_proj_bias: bool = False,
        use_delta_proj_bias: bool = False,
        *,
        key: PRNGKeyArray,
    ):
        self.d_inner = d_inner
        self.dt_rank = dt_rank
        self.d_state = d_state
        (
            key,
            input_proj_key,
            delta_proj_key,
        ) = jax.random.split(key, 3)
        self.input_proj = eqx.nn.Linear(
            d_inner,
            dt_rank + d_state * 2,
            use_bias=use_input_proj_bias,
            key=input_proj_key,
        )

        self.delta_proj = eqx.nn.Linear(
            dt_rank, d_inner, use_bias=use_delta_proj_bias, key=delta_proj_key
        )
        A = jnp.repeat(jnp.arange(1, d_state + 1), d_inner).reshape(d_inner, d_state)
        self.A_log = jnp.log(A)
        self.D = jnp.ones(d_inner)

    def __call__(self, x: Float[Array, "seq_length d_inner"]):
        A = -jnp.exp(self.A_log)
        D = self.D

        delta_b_c = jax.vmap(self.input_proj)(x)

        split_indices = [
            self.dt_rank,
            self.dt_rank + self.d_state,
        ]
        delta, B, C = jnp.split(delta_b_c, split_indices, axis=-1)
        delta = jax.nn.softplus(jax.vmap(self.delta_proj)(delta))

        y = selective_scan(x, delta, A, B, C, D)
        return y

## Detour: State Space Models
___TODO___: Explain SSMs in general!

Armed with the `SSM`, we can now implement the `MambaBlock` part. See the images above for where we are right now!

In [None]:
class MambaBlock(eqx.Module):
    in_proj: eqx.nn.Linear
    conv1d: eqx.nn.Conv1d
    ssm: SelectiveStateSpaceModel
    out_proj: eqx.nn.Linear

    def __init__(
        self,
        n_embd: int,
        d_inner: int,
        dt_rank: int,
        d_conv: int,
        use_in_projection_bias: bool = True,
        use_conv_bias: bool = True,
        use_out_proj_bias: bool = True,
        ssm_use_delta_proj_bias: bool = False,
        ssm_use_input_proj_bias: bool = False,
        *,
        key: PRNGKeyArray,
    ):
        (
            key,
            linear_key,
            conv1d_key,
            ssm_key,
            out_proj_key,
        ) = jax.random.split(key, 5)

        self.in_proj = eqx.nn.Linear(
            n_embd,
            d_inner * 2,
            use_bias=use_in_projection_bias,
            key=linear_key,
        )

        self.conv1d = eqx.nn.Conv1d(
            in_channels=d_inner,
            out_channels=d_inner,
            kernel_size=d_conv,
            use_bias=use_conv_bias,
            groups=d_inner,
            padding=d_conv - 1,
            key=conv1d_key,
        )
        self.ssm = SelectiveStateSpaceModel(
            d_inner=d_inner,
            dt_rank=dt_rank,
            d_state=d_inner,
            use_delta_proj_bias=ssm_use_delta_proj_bias,
            use_input_proj_bias=ssm_use_input_proj_bias,
            key=ssm_key,
        )
        self.out_proj = eqx.nn.Linear(
            d_inner,
            n_embd,
            use_bias=use_out_proj_bias,
            key=out_proj_key,
        )

    def __call__(self, x: Array):
        seq_len, d = x.shape
        x_and_res = jax.vmap(self.in_proj)(x)

        (x, res) = jnp.split(x_and_res, 2, axis=-1)
        x = jnp.transpose(x)
        x = self.conv1d(x)[:, :seq_len]
        x = jnp.transpose(x)
        x = jax.nn.silu(x)

        y = self.ssm(x)
        y = y * jax.nn.silu(res)

        output = jax.vmap(self.out_proj)(y)
        return output

Now, we can wrap the `MambaBlock` into the `ResidualBlock` -- as the name suggests, this has a residual connection (or in non-_sciency_ words: it adds the original input to the transformation).

In [None]:
class ResidualBlock(eqx.Module, strict=True):
    mamba_block: MambaBlock
    rns_norm: eqx.nn.RMSNorm

    def __init__(
        self,
        n_embd: int,
        d_inner: int,
        dt_rank: int,
        d_conv: int,
        use_in_projection_bias: bool = True,
        use_conv_bias: bool = True,
        use_out_proj_bias: bool = True,
        ssm_use_delta_proj_bias: bool = False,
        ssm_use_input_proj_bias: bool = False,
        *,
        key: PRNGKeyArray,
    ):
        self.mamba_block = MambaBlock(
            n_embd=n_embd,
            d_inner=d_inner,
            dt_rank=dt_rank,
            d_conv=d_conv,
            use_in_projection_bias=use_in_projection_bias,
            use_conv_bias=use_conv_bias,
            use_out_proj_bias=use_out_proj_bias,
            ssm_use_delta_proj_bias=ssm_use_delta_proj_bias,
            ssm_use_input_proj_bias=ssm_use_input_proj_bias,
            key=key,
        )
        self.rns_norm = eqx.nn.RMSNorm(n_embd)

    def __call__(
        self, x: Float[Array, "seq_len n_embd"], *, key: Optional[PRNGKeyArray] = None
    ) -> Array:
        return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x

We've arrived at the highest point again. We can put everything into the `Mamba` class now. Note that the weights of the embedding layer and the final linear layer are shared! This is not a problem though, because we can use `eqx.nn.Shared` to implement this.

In [None]:
class Mamba(eqx.Module, strict=True):
    layers: eqx.nn.Sequential
    normalization: eqx.nn.RMSNorm
    shared_emb_lm_head: eqx.nn.Shared

    def __init__(
        self,
        n_layers: int,
        n_dims: int,
        n_embd: int,
        d_inner: int,
        dt_rank: int,
        d_conv: int,
        use_in_projection_bias: bool = True,
        use_conv_bias: bool = True,
        use_out_proj_bias: bool = True,
        ssm_use_delta_proj_bias: bool = False,
        ssm_use_input_proj_bias: bool = False,
        *,
        key: PRNGKeyArray,
    ):
        key, *subkeys = jax.random.split(key, 1 + n_layers)
        self.layers = eqx.nn.Sequential(
            [
                ResidualBlock(
                    n_embd=n_embd,
                    d_inner=d_inner,
                    dt_rank=dt_rank,
                    d_conv=d_conv,
                    use_in_projection_bias=use_in_projection_bias,
                    use_conv_bias=use_conv_bias,
                    use_out_proj_bias=use_out_proj_bias,
                    ssm_use_delta_proj_bias=ssm_use_delta_proj_bias,
                    ssm_use_input_proj_bias=ssm_use_input_proj_bias,
                    key=subkeys[i + 1],
                )
                for i in range(n_layers)
            ],
        )
        self.normalization = eqx.nn.RMSNorm(n_embd)

        embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])
        lm_head = eqx.nn.Linear(
            n_embd,
            n_dims,
            use_bias=False,
            key=subkeys[-1],
        )
        where = lambda embed_and_lin: embed_and_lin[1].weight
        get = lambda embed_and_lin: embed_and_lin[0].weight
        self.shared_emb_lm_head = eqx.nn.Shared(
            (embedding, lm_head), where=where, get=get
        )

    def __call__(
        self,
        x: Int[Array, "seq_len"],  # noqa
        *,
        key: Optional[PRNGKeyArray] = None,
    ) -> Float[Array, "seq_len n_dims"]:  # noqa
        embedding, linear = self.shared_emb_lm_head()
        x = jax.vmap(embedding)(x)

        x = self.layers(x)
        x = jax.vmap(self.normalization)(x)
        logits = jax.vmap(linear)(x)
        return logits

Note the usage of `eqx.nn.Shared`:

```python
    # Embedding layer
    embedding = eqx.nn.Embedding(
        n_dims, n_embd, key=subkeys[0]
    )
    # Linear layer
    lm_head = eqx.nn.Linear(
        n_embd,
        n_dims,
        use_bias=False,
        key=subkeys[-1],
    )
    # refers to the linear weights
    where = lambda embed_and_lin: embed_and_lin[1].weight 

    # refers to the embedding weights
    get = lambda embed_and_lin: embed_and_lin[0].weight

    # Create a shared layer
    self.shared_emb_lm_head = eqx.nn.Shared(
        (embedding, lm_head), where=where, get=get
    )
```

And to use the shared layers, we have to get them first out of the shared layer:

```python
    embedding, linear = self.shared_emb_lm_head()
    # embedding and linear are eqx.nn.Embedding and eqx.nn.Linear respectively
    # proceed usage as usual
```

Excellent! We have successfully implemented the Mamba model!