In [1]:
from model_pytorch import Mamba, ModelArgs
from transformers import AutoTokenizer
import torch
import onnxruntime as ort
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-130m'
dummy_input = "test"

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
input_ids = tokenizer(dummy_input, return_tensors='pt').input_ids

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [13]:
#Export model
export_name = "mamba_model"

torch.onnx.export(
    model,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)
torch.save(model, f"{export_name}.pt")

  outputs = orig_method(*args, **kwargs)
  outputs = orig_method(*args, **kwargs)


In [3]:
# Export minimal model

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model_1 = Mamba(args)
model_1.eval()
export_name = "mamba_minimal_1_layer"

torch.save(model_1, f"{export_name}.pt")

torch.onnx.export(
    model_1,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)

  outputs = orig_method(*args, **kwargs)
  outputs = orig_method(*args, **kwargs)


In [14]:
ort_session = ort.InferenceSession('mamba_model.onnx')

# Generate a model dummy input
dummy_prompt_1 = "Harry test ciao"  
tokens_1 = tokenizer(dummy_prompt_1, return_tensors="pt")
input_ids_1 = tokens_1.input_ids.to(device="cpu")
input_ids_np = np.array(input_ids_1)
print(input_ids_np.shape)

# Inference
inputs = {ort_session.get_inputs()[0].name: input_ids_np}
out = ort_session.run(None, inputs)

# Output
print(input_ids_np.shape)
out


(1, 4)
(1, 4)


[array([[[47.18734  , 37.300926 , 49.27418  , ..., 37.411118 ,
          37.17351  , 37.505386 ],
         [17.432875 ,  6.7534537, 17.546019 , ...,  6.6828403,
           6.663456 ,  6.74733  ],
         [-5.2570963, -8.190227 , -5.680652 , ..., -8.190908 ,
          -8.089128 , -8.0610895],
         [50.62617  , 37.7902   , 52.534973 , ..., 37.788456 ,
          37.55665  , 37.887684 ]]], dtype=float32)]

In [24]:
import torch
import torch.nn.functional as F

ort_session = ort.InferenceSession('mamba_model.onnx')

def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 10,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
        
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            inputs = {ort_session.get_inputs()[0].name: np.array(indices_to_input)}
            # Utilizzare np.squeeze per rimuovere le dimensioni singole iniziali o specifiche
            output_array = np.array(ort_session.run(None, inputs))
            output_tensor = torch.from_numpy(output_array).squeeze(0)  # Rimuove la dimensione extra in posizione 0
            next_token_logits = output_tensor[:, -1, :]  # Seleziona l'ultimo token generato
            # La dimensione ora sarà [1, 50280], corrispondente all'ultimo set di logits
            print(next_token_logits.shape)

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [25]:
print(generate(model=model, tokenizer=tokenizer, prompt="Harry Potter is "))

torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
torch.Size([1, 50280])
Harry Potter is  LINEAR AND DO.

    disappe disappe disappe disappe


In [None]:
# Export dynamic model (Beta)

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model_dyn = Mamba(args)
model_dyn.eval()
export_name = "mamba_minimal_1_layer_dyn"

torch.onnx.dynamo_export(
    model_dyn,
    input_ids,
    export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(f"{export_name}.onnx")