In [1]:
import time
import numpy as np
import pandas as pd
import torch
from mamba_ssm import Mamba

batch, length, dim = 1, 16, 1
torch.manual_seed(0)
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)

assert y.shape == x.shape
y = y.cpu().detach().numpy()
y = np.squeeze(y)
print(y)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


selective_scan_cuda
[ 0.06407421  0.06236958  0.01166452  0.02509372 -0.01790208 -0.01710338
  0.00957332  0.1328741  -0.00565795  0.10410124 -0.00940671 -0.01141407
 -0.0037374  -0.00544129 -0.01197526  0.00503347]


In [2]:
time_dict = {}
length_range = [2 ** power for power in range(7, 17)]
for length in length_range:
    time_dict[length] = 0
    repeats = 3
    for _ in range(repeats):
        x = torch.randn(batch, length, dim).to("cuda")

        end = time.time()
        y = model(x)
        time_dict[length] += time.time() - end
    time_dict[length] /= repeats
    print(f"Length {length} took {time_dict[length]} seconds")

df = pd.DataFrame(time_dict.items(), columns=["Length", "Time"])
df.to_csv("mamba_ssm.csv", index=False)

selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 128 took 0.001697222391764323 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 256 took 0.0005888144175211588 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 512 took 0.0006221135457356771 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 1024 took 0.0005869865417480469 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 2048 took 0.0008376439412434896 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 4096 took 0.0020596186319986978 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 8192 took 0.0005873839060465494 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 16384 took 0.0007665157318115234 seconds
selective_scan_cuda
selective_scan_cuda
selective_scan_cuda
Length 32768 took 0.0006613731384277344 seconds
selective_scan_cuda
selective_scan_cuda