In [5]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from mamba_ssm import Mamba


In [6]:
batch, length, dim = 2, 64, 5
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=11,  # SSM state expansion factor
    d_conv=3,    # Local convolution width
    expand=7,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

In [7]:
model
# Mamba(
#   (in_proj): Linear(in_features=5, out_features=70, bias=False)
#   (conv1d): Conv1d(35, 35, kernel_size=(3,), stride=(1,), padding=(2,), groups=35)
#   (act): SiLU()
#   (x_proj): Linear(in_features=35, out_features=23, bias=False)
#   (dt_proj): Linear(in_features=1, out_features=35, bias=True)
#   (out_proj): Linear(in_features=35, out_features=5, bias=False)
# )

Mamba(
  (in_proj): Linear(in_features=5, out_features=70, bias=False)
  (conv1d): Conv1d(35, 35, kernel_size=(3,), stride=(1,), padding=(2,), groups=35)
  (act): SiLU()
  (x_proj): Linear(in_features=35, out_features=23, bias=False)
  (dt_proj): Linear(in_features=1, out_features=35, bias=True)
  (out_proj): Linear(in_features=35, out_features=5, bias=False)
)

In [8]:
# I want to get the parameter shape for different layers inside model
for name, param in model.named_parameters():
    print(name, param.shape)

# A_log torch.Size([35, 11])
# D torch.Size([35])
# in_proj.weight torch.Size([70, 5])
# conv1d.weight torch.Size([35, 1, 3])
# conv1d.bias torch.Size([35])
# x_proj.weight torch.Size([23, 35])
# dt_proj.weight torch.Size([35, 1])
# dt_proj.bias torch.Size([35])
# out_proj.weight torch.Size([5, 35])

A_log torch.Size([35, 11])
D torch.Size([35])
in_proj.weight torch.Size([70, 5])
conv1d.weight torch.Size([35, 1, 3])
conv1d.bias torch.Size([35])
x_proj.weight torch.Size([23, 35])
dt_proj.weight torch.Size([35, 1])
dt_proj.bias torch.Size([35])
out_proj.weight torch.Size([5, 35])
