In [25]:
import torch
from mamba_ssm.modules.mamba_simple import Mamba

In [26]:
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

In [27]:
y

tensor([[[-1.3734e-02, -2.1213e-03,  3.9866e-03,  ..., -2.4578e-02,
          -8.9588e-03, -1.9662e-02],
         [ 4.9733e-02,  2.0208e-02, -7.2006e-04,  ...,  7.7961e-05,
           2.0560e-02, -3.7630e-02],
         [-9.8380e-02,  6.8168e-02,  6.5655e-03,  ..., -6.3021e-02,
          -2.4585e-02, -2.5683e-02],
         ...,
         [ 8.0266e-02,  5.1947e-02,  7.4426e-02,  ...,  7.4200e-02,
           5.6962e-02, -4.3922e-02],
         [ 2.9499e-02,  5.9678e-04,  2.3381e-02,  ...,  4.3481e-03,
          -5.3622e-02,  2.4644e-02],
         [ 5.8308e-02,  2.7924e-02,  5.9399e-02,  ..., -2.2119e-02,
           2.2731e-02,  2.8583e-03]],

        [[ 1.5607e-02, -3.3724e-04, -9.1325e-03,  ..., -1.6032e-02,
          -1.0729e-02, -5.7896e-05],
         [-1.2276e-02, -1.6701e-02,  7.6755e-03,  ..., -1.3373e-02,
           4.0862e-02,  3.6515e-03],
         [ 9.5950e-03,  2.3054e-03,  2.2636e-02,  ...,  2.2525e-02,
          -8.1435e-03, -5.7061e-03],
         ...,
         [ 1.3179e-02,  4

In [28]:
model

Mamba(
  (in_proj): Linear(in_features=16, out_features=64, bias=False)
  (conv1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32)
  (act): SiLU()
  (x_proj): Linear(in_features=32, out_features=33, bias=False)
  (dt_proj): Linear(in_features=1, out_features=32, bias=True)
  (out_proj): Linear(in_features=32, out_features=16, bias=False)
)

In [29]:
# load pretrained mamba model
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
mamba_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Once upon a time?", return_tensors="pt")["input_ids"]

out = mamba_model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))


['Once upon a time?\n\nThe first time I saw the world,']


In [30]:
def get_output_mamba(prompt):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
    out = mamba_model.generate(input_ids, max_new_tokens=10)
    return tokenizer.batch_decode(out)[0]

get_output_mamba("Hey how are you doing?")

"Hey how are you doing?\n\nI'm so glad you're here."

In [31]:
get_output_mamba("1+1=")

'1+1=2$ and $1+1=3$'

In [32]:
from transformers import AutoModelForCausalLM, AutoTokenizer

pythia_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
pythia_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

def get_output_pythia(prompt):
    input_ids = pythia_tokenizer(prompt, return_tensors="pt")["input_ids"]
    out = pythia_model.generate(input_ids, max_new_tokens=10)
    return tokenizer.batch_decode(out)[0]

get_output_pythia("1+1=")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


'1+1=0$ and $1+1=1$.'