In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.ops.ssd_combined import mamba_chunk_scan_combined


hidden_size, intermediate_size, ssm_state_size, chunk_size, num_heads, head_dim = 768, 1536, 128, 256, 1536 // 64, 64
conv1d_dim = intermediate_size + 2 * ssm_state_size

in_proj = nn.Linear(
    in_features=hidden_size,
    out_features=2 * (intermediate_size + ssm_state_size) + num_heads,
)
out_proj = nn.Linear(
    intermediate_size, hidden_size, 
    bias=True
)
conv1d = nn.Conv1d(
    in_channels=conv1d_dim,
    out_channels=conv1d_dim,
    bias=True,
    kernel_size=4,
    groups=conv1d_dim,
    padding=4 - 1,
)

dt_bias = nn.Parameter(torch.rand(size=(num_heads,)))
dt = torch.exp(
    torch.rand(num_heads)
    * (math.log(0.1) - math.log(0.001))
    + math.log(0.001)
).clamp(min=1e-4)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
    dt_bias.copy_(inv_dt)

A = torch.empty(num_heads, dtype=torch.float32).uniform_(*(1, 16))
A_log = nn.Parameter(torch.log(A))

D = nn.Parameter(torch.ones(num_heads))

In [2]:
def forward(hidden_states, initial_state=None, return_final_state=False, cache=None, use_cache=False):
    if initial_state is not None and cache is not None and use_cache is True:
        raise ValueError("Caching and passing initial states is not possible at the same time!")
    
    bsz, seq_len, _ = hidden_states.shape
    
    if cache is None and use_cache:
        cache = {
            "conv_state" : torch.zeros(
                bsz, conv1d.weight.shape[0], 4
            ),
            "ssm_state" : torch.zeros(
                bsz, num_heads, head_dim, ssm_state_size
            ),
            "seq_offset" : 0
        }
    cached_start = use_cache and cache["seq_offset"] == 0
    cached_forward = use_cache and cache["seq_offset"] > 0
    if cached_forward:
        hidden_states = hidden_states.squeeze(1)
    
    zxbcdt = in_proj(hidden_states)
    d_mlp = (zxbcdt.shape[-1] - 2 * intermediate_size - 2 * ssm_state_size - num_heads) // 2
    z0, x0, z, xBC, dt = torch.split(
        zxbcdt,
        [d_mlp, d_mlp, intermediate_size, intermediate_size + 2 * ssm_state_size, num_heads],
        dim=-1
    )

    if cached_start:
        xBC_t = rearrange(xBC, "b l d -> b d l")
        cache["conv_state"].copy_(F.pad(xBC_t, (4 - xBC_t.shape[-1], 0)))

    if cached_forward:
        cache["conv_state"].copy_(torch.roll(cache["conv_state"], shifts=-1, dims=-1))
        cache["conv_state"][:, :, -1] = xBC
        xBC = torch.sum(cache["conv_state"] * rearrange(conv1d.weight, "d 1 w -> d w"), dim=-1)
        if conv1d.bias is not None:
            xBC = xBC + conv1d.bias
        xBC = F.silu(xBC)
    else:
        xBC = F.silu(
            conv1d(xBC.transpose(1, 2))[..., :seq_len].transpose(1, 2)
        )

    A = -torch.exp(A_log)
    x, B, C = torch.split(
        xBC, [intermediate_size, ssm_state_size, ssm_state_size], dim=-1
    )
    
    init_state = initial_state if not cached_forward else cache["ssm_state"] 
    x_pattern = "b l (h p) -> b l h p" if not cached_forward else "b (h p) -> b 1 h p"
    BC_pattern = "b l n -> b l 1 n" if not cached_forward else "b n -> b 1 1 n"
    
    device = "cuda"
    y = mamba_chunk_scan_combined(
        x=rearrange(x, pattern=x_pattern, p=head_dim).to(device),
        dt=dt.to(device) if not cached_forward else dt.unsqueeze(1).to(device),
        A=A.to(device) if not cached_forward else A.to(device=device, dtype=torch.float32),
        B=rearrange(B, pattern=BC_pattern).to(device),
        C=rearrange(C, pattern=BC_pattern).to(device),
        chunk_size=chunk_size,
        D=D.to(device) if not cached_forward else D.to(device=device, dtype=torch.float32),
        z=None,
        initial_states=init_state.to(device) if init_state is not None else None,
        dt_bias=dt_bias.to(device),
        dt_softplus=True,
        seq_idx=None,
        dt_min=0.0,
        dt_max=float("inf"),
        return_final_states=return_final_state or use_cache
    )
    if return_final_state or use_cache:
        y, last_state = y
        
    y = rearrange(y, "b l h p -> b l (h p)")
    y = out_proj(y.to("cpu"))

    returned_last_state = None
    if return_final_state:
        returned_last_state = last_state
    
    out = (y, returned_last_state,)
    if use_cache:
        cache["ssm_state"].copy_(last_state)
        cache["seq_offset"] = y.shape[1] if cache["seq_offset"] == 0 else cache["seq_offset"] + 1
        out += (cache,)
        
    return out

In [3]:
hidden_states = torch.randn(size=(2, 256, 768), dtype=torch.float32)


# normal forward
res_1, _ = forward(hidden_states)

# cached forward
tmp, _, cache = forward(hidden_states[:, :-1, :], use_cache=True)
print(abs(res_1[:, :255, :] - tmp).sum())
print(torch.allclose(tmp, res_1[:, :255, :], atol=0.01, rtol=0.01))

res_2, _, _ = forward(hidden_states[:, -1, :].unsqueeze(1), use_cache=True, cache=cache)
res_2 = torch.cat((tmp, res_2), dim=1)
print(abs(res_1 - res_2).sum())
print(torch.allclose(res_1, res_2, atol=0.01, rtol=0.01))

tensor(0.0129, grad_fn=<SumBackward0>)
True
tensor(0.0157, grad_fn=<SumBackward0>)
True
