Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefix scan op for Mamba SSM prefill #9553

Closed
esmalTT opened this issue Jun 19, 2024 · 0 comments
Closed

Add prefix scan op for Mamba SSM prefill #9553

esmalTT opened this issue Jun 19, 2024 · 0 comments
Assignees
Labels
mamba prefill LLM models have prefill mode and it's optimization is usually separated from decode mode.

Comments

@esmalTT
Copy link
Contributor

esmalTT commented Jun 19, 2024

We require a prefix scan operator to implement Mamba SSM prefill. The operation should implement the following torch code:

def sequential_prefix_scan(a, bx):
    (_, _, L, EN) = bx.shape
    hidden_states = torch.zeros((1, 1, L, EN), device=a.device)
    for i in range(L):
        hidden_states[:, :, i] = a[:, :, i] * hidden_states[:, :, i - 1] + bx[:, :, i]
    return hidden_states

Where L is the sequence length and EN = 5120 * 32.

@esmalTT esmalTT added mamba prefill LLM models have prefill mode and it's optimization is usually separated from decode mode. labels Jun 19, 2024
@esmalTT esmalTT self-assigned this Jun 19, 2024
@esmalTT esmalTT closed this as completed Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mamba prefill LLM models have prefill mode and it's optimization is usually separated from decode mode.
Projects
None yet
Development

No branches or pull requests

1 participant