In [17]:
# ignore all warnings
import warnings
warnings.filterwarnings("ignore")

In [18]:
import torch
from mamba_ssm.modules.mamba_simple import Mamba

In [16]:
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

In [7]:
y

tensor([[[ 0.0061, -0.0075, -0.0036,  ..., -0.0224, -0.0146,  0.0187],
         [-0.0254,  0.0448, -0.0635,  ..., -0.0174,  0.0193,  0.0190],
         [ 0.0478, -0.0245, -0.0183,  ...,  0.0371, -0.0228, -0.0039],
         ...,
         [-0.0108, -0.0210,  0.0096,  ...,  0.0085,  0.0064, -0.0039],
         [ 0.0010,  0.0085,  0.0560,  ..., -0.0292, -0.0099, -0.0102],
         [ 0.0068,  0.0129,  0.0102,  ...,  0.0255,  0.0638, -0.0435]],

        [[-0.0515, -0.0012, -0.0054,  ..., -0.0310,  0.0122, -0.0046],
         [ 0.0053,  0.0221,  0.0181,  ..., -0.0531, -0.0148,  0.0050],
         [-0.0265,  0.0225, -0.0192,  ...,  0.0397,  0.0001, -0.0244],
         ...,
         [ 0.0538,  0.0222,  0.0367,  ..., -0.0277,  0.0373, -0.0108],
         [ 0.0104, -0.0272,  0.0313,  ...,  0.0040,  0.0191, -0.0087],
         [ 0.0245, -0.0115,  0.0113,  ..., -0.0025, -0.0172,  0.0341]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [8]:
model

Mamba(
  (in_proj): Linear(in_features=16, out_features=64, bias=False)
  (conv1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32)
  (act): SiLU()
  (x_proj): Linear(in_features=32, out_features=33, bias=False)
  (dt_proj): Linear(in_features=1, out_features=32, bias=True)
  (out_proj): Linear(in_features=32, out_features=16, bias=False)
)

In [9]:
# load pretrained mamba model
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Once upon a time?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))


The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


['Once upon a time?\n\nThe first time I saw the world,']


In [10]:
def get_output(prompt):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
    out = model.generate(input_ids, max_new_tokens=10)
    return tokenizer.batch_decode(out)[0]

get_output("Hey how are you doing?")

"Hey how are you doing?\n\nI'm so glad you're here."

In [11]:
get_output("1+1=")

'1+1=2$ and $1+1=3$'

In [13]:
from transformers import AutoModelForCausalLM, AutoTokenizer

pythia_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
pythia_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

def get_output_pythia(prompt):
    input_ids = pythia_tokenizer(prompt, return_tensors="pt")["input_ids"]
    out = pythia_model.generate(input_ids, max_new_tokens=10)
    return tokenizer.batch_decode(out)[0]

get_output_pythia("1+1=")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


'1+1=0$ and $1+1=1$.'

In [7]:
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

def generate_text(model, tokenizer, prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors="pt")
    start_time = time.time()
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=max_length)
    
    end_time = time.time()
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generation_time = end_time - start_time
    
    return generated_text, generation_time

def main():
    prompt = "Once upon a time, in a land far away,"
    max_length = 50

    # Mamba 130M model
    mamba_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
    mamba_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")

    # Pythia-160M model
    pythia_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
    pythia_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

    print("Generating text with Mamba 130M...")
    mamba_text, mamba_time = generate_text(mamba_model, mamba_tokenizer, prompt, max_length)
    
    print("Generating text with Pythia-160M...")
    pythia_text, pythia_time = generate_text(pythia_model, pythia_tokenizer, prompt, max_length)

    print("\nResults:")
    print(f"Mamba 130M generation time: {mamba_time:.4f} seconds")
    print(f"Pythia-160M generation time: {pythia_time:.4f} seconds")
    print(f"\nMamba 130M generated text:\n{mamba_text}, {len(mamba_text)}")
    print(f"\nPythia-160M generated text:\n{pythia_text}, {len(pythia_text)}")

if __name__ == "__main__":
    main()

Generating text with Mamba 130M...


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Generating text with Pythia-160M...

Results:
Mamba 130M generation time: 18.1766 seconds
Pythia-160M generation time: 4.0884 seconds

Mamba 130M generated text:
Once upon a time, in a land far away, there lived a man who was a great hunter. He had a great hunting lodge, and he hunted for his prey. He hunted for his prey, and he killed many of them. He was, 196

Pythia-160M generated text:
Once upon a time, in a land far away, the

government of the United States of America was established.

The United States was a nation of the United States, and the

United States was a nation of the United States, 213
