### Export to ONNX and save models notebook
#### Requirements

- install pytorch: https://pytorch.org/get-started/locally/

In [None]:
#Required once
%pip install --user -qqr requirements.txt

In [4]:
#Required

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from transformers import AutoTokenizer

from mamba_ssm.onnx.model_wrapper import ModelWrapper, BlockModelWrapper, MambaModelWrapper

import onnxruntime as ort
import numpy as np

# Config
model_name = "state-spaces/mamba-130m"
device = "cpu"
dtype = torch.float32

In [None]:
# Generate a model dummy input
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
dummy_prompt = "Harry Potter"  
tokens = tokenizer(dummy_prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)

### Model
Run below to export MambaLMHeadModel in ONNX:

In [11]:
# Init model pretrained
model = ModelWrapper(model_name=model_name, use_generation=False, device=device, dtype=dtype)

Number of layers: 24
Size of d: 768
Number of parameters: 129135360


In [7]:
# Init model custom config not pretrained
config = MambaConfig()
config.d_model = 200
config.n_layer = 1
model = ModelWrapper(model_name=None, use_generation=False, config=config, device=device, dtype=dtype)

Number of layers: 1
Size of d: 200
Number of parameters: 10328800


In [None]:
onnx_model_path = "model.onnx"

# Export to ONNX
torch.onnx.export(model, 
                input_ids,  
                onnx_model_path,
                verbose=False,
                input_names=['input_ids'],
                output_names=['output'],
                dynamic_axes={
                    'input_ids': {0: 'batch_size', 1: 'seq_length'},  
                    'output': {0: 'batch_size', 1: 'seq_length'}
                }
)

torch.save(model, "model_wrapper.pt")

print(f"Model exported in {onnx_model_path}")


In [4]:
torch.save(model, "model_wrapper.pt")

Run to save pretrained pytorch model:

In [5]:
model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=dtype)
torch.save(model, "model_original_pretrained.pt")

Run to save custom not pretrained model:

In [2]:
config = MambaConfig()
config.d_model = 200
config.n_layer = 2
model = MambaLMHeadModel(config=config, device=device, dtype=dtype)
torch.save(model, "model_custom.pt")

### Block layer
Run to export block layer:

In [None]:
config = MambaConfig()
config.d_model = 5
config.n_layer = 1
block_model_wrapper = BlockModelWrapper(config=config, device=device, dtype=dtype)

hidden_states = torch.randn(1, 10, config.d_model, device='cpu')  # Batch size = 1, Seq length = 10

residual = torch.zeros_like(hidden_states, device=hidden_states.device, dtype=hidden_states.dtype)

torch.onnx.export(
    block_model_wrapper,
    (input_ids, hidden_states, residual),  
    'block_model.onnx',
    input_names=['input_ids', 'hidden_states', 'residual'],
    output_names=['output']
)
torch.save(block_model_wrapper, "block_model.pt")

### Mamba layer
Run to export Mamba layer:

In [None]:
config = MambaConfig()
config.d_model = 5
config.n_layer = 1
block_model_wrapper = MambaModelWrapper(config=config, device=device, dtype=dtype)

torch.onnx.export(
    block_model_wrapper,
    input_ids,  
    'mamba_model.onnx',
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)

torch.save(block_model_wrapper, "mamba_model.pt")

In [None]:
config = MambaConfig()
config.d_model = 5
config.n_layer = 1
block_model_wrapper = MambaModelWrapper(config=config, device=device, dtype=dtype)
block_model_wrapper.eval()

# Generate a model dummy input
dummy_prompt_1 = "Harry test ciao"  
tokens_1 = tokenizer(dummy_prompt_1, return_tensors="pt")
input_ids_1 = tokens_1.input_ids.to(device=device)

out = block_model_wrapper(input_ids_1)

print("output", out.shape)
out

In [None]:
session_options = ort.SessionOptions()
session_options.log_severity_level = 0  # 0 = VERBOSE, higher numbers indicate less verbosity
session_options.log_verbosity_level = 5  # Adjust this for more detailed logs, higher means more verbose

ort_session = ort.InferenceSession('mamba_model.onnx', sess_options=session_options)

# Generate a model dummy input
dummy_prompt_1 = "Harry"  
tokens_1 = tokenizer(dummy_prompt_1, return_tensors="pt")
input_ids_1 = tokens_1.input_ids.to(device=device)
input_ids_np = np.array(input_ids_1)
print(input_ids_np.shape)

# Inference
inputs = {ort_session.get_inputs()[0].name: input_ids_np}

out = ort_session.run(None, inputs)

# Output
print(input_ids_np.shape)
print(np.array(out).shape)
out



In [None]:
config = MambaConfig()
config.d_model = 5
config.n_layer = 1
block_model_wrapper = MambaModelWrapper(config=config, device=device, dtype=dtype)

torch.onnx.dynamo_export(
    block_model_wrapper,
    input_ids,
    export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save("mamba_dyn.onnx")