In [None]:
from model_pytorch import Mamba, ModelArgs
from transformers import AutoTokenizer
import torch
import onnxruntime as ort
import numpy as np

In [None]:
# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-130m'
dummy_input = "test"

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
input_ids = tokenizer(dummy_input, return_tensors='pt').input_ids

In [None]:
#Export model
export_name = "mamba_model"

torch.onnx.export(
    model,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)
torch.save(model, f"{export_name}.pt")

In [None]:
# Export minimal model

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model_1 = Mamba(args)
model_1.eval()
export_name = "mamba_minimal_1_layer"

torch.save(model_1, f"{export_name}.pt")

torch.onnx.export(
    model_1,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)

In [None]:
ort_session = ort.InferenceSession('mamba_minimal_1_layer.onnx')

# Generate a model dummy input
dummy_prompt_1 = "Harry test ciao bla bla"  
tokens_1 = tokenizer(dummy_prompt_1, return_tensors="pt")
input_ids_1 = tokens_1.input_ids.to(device="cpu")
input_ids_np = np.array(input_ids_1)
print(input_ids_np.shape)

# Inference
inputs = {ort_session.get_inputs()[0].name: input_ids_np}
out = ort_session.run(None, inputs)

# Output
print(input_ids_np.shape)
out


In [None]:
# Export dynamic model (Beta)

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model_dyn = Mamba(args)
model_dyn.eval()
export_name = "mamba_minimal_1_layer_dyn"

torch.onnx.dynamo_export(
    model_dyn,
    input_ids,
    export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(f"{export_name}.onnx")