Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal reference implementation #10

Open
EelcoHoogendoorn opened this issue Dec 5, 2023 · 5 comments
Open

Minimal reference implementation #10

EelcoHoogendoorn opened this issue Dec 5, 2023 · 5 comments

Comments

@EelcoHoogendoorn
Copy link

Thanks so much for providing this code; looks very useful and reproducible.

As I understand, the custom scan kernel can be quite important to performance considerations, so it is great to see it here as well.

However, as a suggestion, I think itd be super neat to also see a minimal Mamba reference implementation, with minimal dependencies, simply for clarity of exposition; something that could be unit tested to behave the same at least on small datasets, as the custom kernel. Would that be a lot of work? Does it already exist somewhere? If a torch version exists id be happy to port it to a JAX version as well.

@tridao
Copy link
Collaborator

tridao commented Dec 5, 2023

There's a reference implementation of the selective scan in Pytorch here. That's the main primitive that requires CUDA.

@EelcoHoogendoorn
Copy link
Author

Right; I did see there is a reference implementation; I just wondered how close we should consider it to being 'minimal'.

How close could mamba get to this kind of minimalisms?

@tridao
Copy link
Collaborator

tridao commented Dec 6, 2023

The core is actually just a for loop, the code will simplify a lot if you only take the path where B/C are input-dependent.

@johnma2006
Copy link

Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!

@EelcoHoogendoorn
Copy link
Author

Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!

Thanks, that looks really clean, and should be trivial to port to JAX. From what i understand using JAX scan also isnt competitive for LLM-scale models but my intuition is itd be fine for some of the smaller stuff id want to try it on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants