## Init

In [10]:
from model_pytorch import Mamba, ModelArgs
from transformers import AutoTokenizer
import torch.nn.functional as F
import onnxruntime as ort
import numpy as np
import onnx
import torch

In [None]:
# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-130m'
dummy_input = "Harry Potter"

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
input_ids = tokenizer(dummy_input, return_tensors='pt').input_ids

## Exporting

In [7]:
#Export model
export_name = "mamba_model_130m_cumsum_no_einsum"

torch.onnx.export(
    model,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)
torch.save(model, f"{export_name}.pt")

In [None]:
# Export minimal model

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model = Mamba(args)
model.eval()
export_name = "mamba_minimal_1_layer_cumsum_no_einsum"

torch.save(model, f"{export_name}.pt")

torch.onnx.export(
    model,
    input_ids,  
    f"{export_name}.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'seq_length'},  
        'output': {0: 'batch_size', 1: 'seq_length'}
    }
)

## Checking the models

In [12]:
# Load the ONNX model
model_onnx = onnx.load("mamba_model_130m_cumsum_no_einsum.onnx")

# Check that the model is well formed
onnx.checker.check_model(model_onnx)

# Print a human readable representation of the graph
#print(onnx.helper.printable_graph(model.graph))

In [13]:
ort_session = ort.InferenceSession('mamba_model_130m_cumsum_no_einsum.onnx')

# Generate a model dummy input
dummy_prompt_1 = "Harry Potter test"  
tokens_1 = tokenizer(dummy_prompt_1, return_tensors="pt")
input_ids_1 = tokens_1.input_ids.to(device="cpu")
input_ids_np = np.array(input_ids_1)
print(input_ids_np.shape)

# Inference
inputs = {ort_session.get_inputs()[0].name: input_ids_np}
onnx_out = ort_session.run(None, inputs)

# Output
onnx_out


(1, 3)


[array([[[ -8.417368 , -22.110765 ,  -2.4200068, ..., -21.946516 ,
          -21.980217 , -21.913406 ],
         [  0.3511305, -26.0508   ,   1.545126 , ..., -25.823263 ,
          -26.040113 , -25.818874 ],
         [-40.83871  , -53.83088  , -38.9781   , ..., -53.959682 ,
          -53.70117  , -53.945526 ]]], dtype=float32)]

### Comparing PyTorch and ONNX inference

In [14]:
torch_out = model(input_ids_1)
print(torch_out)
# compare ONNX Runtime and PyTorch results
try:
    np.testing.assert_allclose(torch_out.detach().cpu().numpy(), onnx_out[0], rtol=1e-02, atol=1e-03)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")
except AssertionError as e:
    print("AssertionError:", str(e))

tensor([[[ -8.4174, -22.1108,  -2.4200,  ..., -21.9465, -21.9803, -21.9134],
         [  0.3512, -26.0508,   1.5452,  ..., -25.8232, -26.0401, -25.8188],
         [-40.8387, -53.8309, -38.9781,  ..., -53.9597, -53.7011, -53.9456]]],
       grad_fn=<UnsafeViewBackward0>)
Exported model has been tested with ONNXRuntime, and the result looks good!


## Text generation test

In [15]:
def get_next_token_logits(model, indices_to_input, is_onnx=False):
    if is_onnx:
        inputs = {model.get_inputs()[0].name: np.array(indices_to_input)}
        output_array = np.array(model.run(None, inputs))
        output_tensor = torch.from_numpy(output_array).squeeze(0)  # Remove numpy extra dim
        return output_tensor[:, -1]  # Select last generated token
    else:
        return model(indices_to_input)[:, -1]

def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40,
             is_onnx=False):
    
    if is_onnx:
        model = ort.InferenceSession(model)
    else:
        model.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
        
    for _ in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = get_next_token_logits(model, indices_to_input, is_onnx)

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [20]:
print(generate(model='mamba_model_130m_cumsum_no_einsum.onnx', is_onnx=True, tokenizer=tokenizer, prompt="Harry Potter is", n_tokens_to_gen=50))

Harry Potter is not your child. You have the ability to change it for the best of you, to make your son a better person in life. You may even be able to change him for you. But you can only change who you always make it. You


In [19]:
print(generate(model=model, tokenizer=tokenizer, prompt="Harry Potter is", n_tokens_to_gen=50))

Harry Potter is over, then I'm really glad you're not going to my parents! That is, if you were to have a mother and a father, that would be the best. The man who comes on like an ice-fish has got enough heart disease


## Test dynamic export (Beta)

In [None]:
# Export dynamic model (Beta)

args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model_dyn = Mamba(args)
model_dyn.eval()
export_name = "mamba_minimal_1_layer_dyn"

torch.onnx.dynamo_export(
    model_dyn,
    input_ids,
    export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(f"{export_name}.onnx")

## Example of how to implements einsum alternative operations

In [21]:
# Define the dimensions of each axis
b, l, d, n = 2, 3, 4, 5

# Create example tensors with appropriate dimensions
dt = torch.randn(b, l, d)  # Tensor with shape (b, l, d)
A = torch.randn(d, n)      # Tensor with shape (d, n)

print(dt.shape)
print(A.shape)
# Use torch.einsum to multiply dt and A according to the specified rule
result = torch.einsum('bld,dn->bldn', dt, A)

# Display the shape of the result tensor to confirm it is (b, l, d, n)
print(result.shape)  # Should print: torch.Size([b, l, d, n])

torch.Size([2, 3, 4])
torch.Size([4, 5])
torch.Size([2, 3, 4, 5])


In [3]:
# EXPANSION OF deltaA = torch.einsum('bld,dn->bldn', dt, A)

# Example dimensions
b, l, d, n = 2, 3, 4, 5

# Example tensors
dt = torch.randn(b, l, d)
A = torch.randn(d, n)

# Operation without einsum
# Step 1: Resize A for broadcasting
A_expanded = A.view(1, 1, d, n).expand(b, l, d, n)

# Step 2: Multiply dt by A_expanded
# Since we want to keep dt unchanged and only "apply" A to each element,
# we need to first add a dimension to dt for broadcasting.
dt_expanded = dt.unsqueeze(-1)  # Adds a dimension at the end for broadcasting

# Element-wise multiplication
result = dt_expanded * A_expanded

# Verify the shape of the result
print(result.shape)  # Should be torch.Size([b, l, d, n])

torch.Size([2, 3, 4, 5])


In [4]:
# EXPANSION OF deltaB_u = torch.einsum('bld,bld,bln->bldn', dt, u, B)

# Example dimensions
b, l, d, n = 2, 3, 4, 5

# Example tensors
dt = torch.randn(b, l, d)
u = torch.randn(b, l, d)
B = torch.randn(b, l, n)

# Step 1: Element-wise multiplication of dt and u
dt_u_product = dt * u  # The result has shape (b, l, d)

# Step 2: Expand the result by adding a new dimension for broadcasting
dt_u_expanded = dt_u_product.unsqueeze(-1)  # Adds a dummy dimension at the end, shape (b, l, d, 1)

# Step 3: Expand B for broadcasting
B_expanded = B.unsqueeze(2)  # Adds a dummy dimension in the third position, shape (b, l, 1, n)

# Step 4: Element-wise multiplication with broadcasting
deltaB_u = dt_u_expanded * B_expanded  # The result has shape (b, l, d, n)

# Check the shape of the result
print(deltaB_u.shape)  # Should be torch.Size([b, l, d, n])

torch.Size([2, 3, 4, 5])


In [5]:
# EXPANSION OF y = torch.einsum('bldn,bln->bld', x, C)

# Example dimensions
b, l, d, n = 2, 3, 4, 5

# Example tensors
x = torch.randn(b, l, d, n)
C = torch.randn(b, l, n)

# Expand C for broadcasting
C_expanded = C.unsqueeze(2)  # Shape becomes (b, l, 1, n)

# Multiply x and C with broadcasting
product = torch.mul(x, C_expanded)

# Sum along the n axis to get the final shape (b, l, d)
y = torch.sum(product, dim=-1)

# Check the shape of the result
print(y.shape)  # Should be torch.Size([b, l, d])

torch.Size([2, 3, 4])
