In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.ops.ssd_combined import mamba_chunk_scan_combined


hidden_size, intermediate_size, ssm_state_size, chunk_size, num_heads, head_dim = 768, 1536, 128, 256, 1536 // 64, 64
conv1d_dim = intermediate_size + 2 * ssm_state_size

in_proj = nn.Linear(
    in_features=hidden_size,
    out_features=2 * (intermediate_size + ssm_state_size) + num_heads,
)
conv1d = nn.Conv1d(
    in_channels=conv1d_dim,
    out_channels=conv1d_dim,
    bias=True,
    kernel_size=4,
    groups=conv1d_dim,
    padding=4 - 1,
)

dt_bias = nn.Parameter(torch.rand(size=(num_heads,)))
dt = torch.exp(
    torch.rand(num_heads)
    * (math.log(0.1) - math.log(0.001))
    + math.log(0.001)
).clamp(min=1e-4)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
    dt_bias.copy_(inv_dt)

A = torch.empty(num_heads, dtype=torch.float32).uniform_(*(1, 16))
A_log = nn.Parameter(torch.log(A))

D = nn.Parameter(torch.ones(num_heads))

In [3]:
hidden_states = torch.randn(size=(2, 3, 768), dtype=torch.float32)
bsz, seq_len, _ = hidden_states.shape


zxbcdt = in_proj(hidden_states)
d_mlp = (zxbcdt.shape[-1] - 2 * intermediate_size - 2 * ssm_state_size - num_heads) // 2
z0, x0, z, xBC, dt = torch.split(
    zxbcdt,
    [d_mlp, d_mlp, intermediate_size, intermediate_size + 2 * ssm_state_size, num_heads],
    dim=-1
)
dt = F.softplus(dt + dt_bias).clamp(0, torch.inf)

xBC = F.silu(
    conv1d(xBC.transpose(1, 2))[..., :seq_len].transpose(1, 2)
)
x, B, C = torch.split(
    xBC, [intermediate_size, ssm_state_size, ssm_state_size], dim=-1
)

device = "cuda"

y, last_state = mamba_chunk_scan_combined(
    x=rearrange(x, "b l (h p) -> b l h p", p=head_dim).to(device),
    dt=dt.to(device),
    A=A.to(device),
    B=rearrange(B, "b l n -> b l 1 n").to(device),
    C=rearrange(C, "b l n -> b l 1 n").to(device),
    chunk_size=chunk_size,
    D=D.to(device),
    z=None,
    initial_states=None,
    dt_bias=dt_bias.to(device),
    dt_softplus=True,
    seq_idx=None,
    dt_min=0.0,
    dt_max=float("inf"),
    return_final_states=True
)
y = rearrange(y, "b l h p -> b l (h p)")
y.shape, z.shape, last_state.shape

(torch.Size([2, 3, 1536]),
 torch.Size([2, 3, 1536]),
 torch.Size([2, 24, 64, 128]))