In [15]:
# %env PATH=/usr/local/cuda-12.1/bin:$PATH

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

env: PATH=/usr/local/cuda-12.1/bin:$PATH


In [16]:
import jax
import jax.numpy as np
from einops import rearrange, repeat
from flax import linen as nn
from flax.linen.initializers import normal as flax_normal
from jax.nn.initializers import lecun_normal, normal
from jax.scipy.linalg import block_diag

from __future__ import annotations
from typing import Union
import math
import json
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

In [17]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)



In [18]:

class MambaBlock(nn.Module):
    args: ModelArgs  # Assuming ModelArgs is a structure containing your model parameters

    def setup(self):
        self.in_proj = nn.Dense(self.args.d_model, self.args.d_inner * 2, 
                                kernel_init=normal(), bias=self.args.bias)
        
        # Adjusted for Flax. Flax does not have nn.Conv1d, so you might need to reshape or use a different approach
        self.conv1d = nn.Conv(
            features=self.args.d_inner, 
            kernel_size=(self.args.d_conv,), 
            padding=((self.args.d_conv - 1, 0),), 
            feature_group_count=self.args.d_inner, 
            bias=self.args.conv_bias
            )

        self.x_proj = nn.Dense(self.args.d_inner, self.args.dt_rank + self.args.d_state * 2, bias=False)
        self.dt_proj = nn.Dense(self.args.dt_rank, self.args.d_inner, bias=True)

        A = np.tile(np.arange(1, self.args.d_state + 1), (self.args.d_inner, 1))
        self.A_log = self.param('A_log', lambda rng, shape: np.log(A), (self.args.d_inner, self.args.d_state))
        self.D = self.param('D', nn.initializers.ones, (self.args.d_inner,))

        self.out_proj = nn.Dense(self.args.d_model, kernel_init=normal(), bias=self.args.bias)

    def __call__(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        pass

    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        pass

    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]

        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = np.exp(np.einsum('b l d, d n->b l d n', delta, A)) # d is the d_in
        deltaB_u = np.einsum('b l d, b l n, b l d-> b l d n', delta, B, u)

        # Define a scan function for selectively scanning over the time steps
        def scan_fn(x, i):
            x_next = np.einsum('bin,blin->bin', x, deltaA[:, i]) + deltaB_u[:, i]
            y = np.einsum('bin,bn->bi', x_next, C[:, i])
            return x_next, y

        # Initialize x and perform the scan
        x_init = np.zeros((b, d_in, n))
        _, ys = jax.lax.scan(scan_fn, x_init, np.arange(l))

        # Reshape and add D*u component
        y = ys + u * D

        return y
