-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal reference implementation #10
Comments
There's a reference implementation of the selective scan in Pytorch here. That's the main primitive that requires CUDA. |
Right; I did see there is a reference implementation; I just wondered how close we should consider it to being 'minimal'. How close could mamba get to this kind of minimalisms? |
The core is actually just a for loop, the code will simplify a lot if you only take the path where B/C are input-dependent. |
Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps! |
Thanks, that looks really clean, and should be trivial to port to JAX. From what i understand using JAX scan also isnt competitive for LLM-scale models but my intuition is itd be fine for some of the smaller stuff id want to try it on. |
Thanks so much for providing this code; looks very useful and reproducible.
As I understand, the custom scan kernel can be quite important to performance considerations, so it is great to see it here as well.
However, as a suggestion, I think itd be super neat to also see a minimal Mamba reference implementation, with minimal dependencies, simply for clarity of exposition; something that could be unit tested to behave the same at least on small datasets, as the custom kernel. Would that be a lot of work? Does it already exist somewhere? If a torch version exists id be happy to port it to a JAX version as well.
The text was updated successfully, but these errors were encountered: