### Export to ONNX and save models notebook
#### Requirements

- install pytorch: https://pytorch.org/get-started/locally/

In [None]:
#Required once
%pip install --user -qqr requirements.txt

In [None]:
#Required

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from transformers import AutoTokenizer

from mamba_ssm.onnx.model_wrapper import ModelWrapper, BlockModelWrapper

# Config
model_name = "state-spaces/mamba-130m"
device = "cpu"
dtype = torch.float32

Run below to export MambaLMHeadModel in ONNX:

In [None]:
# Init model pretrained
model = ModelWrapper(model_name=model_name, use_generation=False, device=device, dtype=dtype)

In [None]:
# Init model custom config not pretrained
config = MambaConfig()
config.d_model = 200
config.n_layer = 1
model = ModelWrapper(model_name=None, use_generation=False, config=config, device=device, dtype=dtype)

In [None]:
# Generate a model input
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
dummy_prompt = "Hello, world!"  
tokens = tokenizer(dummy_prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)

onnx_model_path = "model.onnx"

# Export to ONNX
torch.onnx.export(model, 
                  (input_ids),  
                  onnx_model_path,
                  verbose=False,
                  input_names=['input_ids'],
                  output_names=['output'],
                  dynamic_axes={'input_ids': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

torch.save(model, "model_wrapper.pt")

print(f"Model exported in {onnx_model_path}")


In [4]:
torch.save(model, "model_wrapper.pt")

Run to save pretrained pytorch model:

In [5]:
model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=dtype)
torch.save(model, "model_original_pretrained.pt")

Run to save custom not pretrained model:

In [2]:
config = MambaConfig()
config.d_model = 200
config.n_layer = 2
model = MambaLMHeadModel(config=config, device=device, dtype=dtype)
torch.save(model, "model_custom.pt")

In [5]:
config = MambaConfig()
config.d_model = 5
config.n_layer = 1
block_model_wrapper = BlockModelWrapper(config=config, device=device, dtype=dtype)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
dummy_prompt = "Hello, world!"  
tokens = tokenizer(dummy_prompt, return_tensors="pt")

input_ids = tokens.input_ids.to(device=device)
hidden_states = torch.randn(1, 10, config.d_model, device='cpu')  # Batch size = 1, Seq length = 10

residual = torch.zeros_like(hidden_states, device=hidden_states.device, dtype=hidden_states.dtype)

torch.onnx.export(
    block_model_wrapper,
    (input_ids, hidden_states, residual),  
    'block_model.onnx',
    input_names=['input_ids', 'hidden_states', 'residual'],
    output_names=['output']
)
torch.save(block_model_wrapper, "block_model.pt")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Size of d: 5
Number of parameters: 725
