Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence parallelism in the mixer (Context Parallelism) #482

Open
TranSirius opened this issue Jul 21, 2024 · 1 comment
Open

Sequence parallelism in the mixer (Context Parallelism) #482

TranSirius opened this issue Jul 21, 2024 · 1 comment

Comments

@TranSirius
Copy link

The general question is, does mamba-ssm currently support sequence parallelism in the mixer?

I noticed that Section 8.2 in the paper of Mamba2 proposes a potential way to split activation among multiple devices during mixing information among tokens. Does current version of mamba-ssm support such context-parallelism scheme?

By the way, if it is possible to confirm that, the suggested implementation should be incorporated into the fast scan algorithm. As a parallel tree traversing algorithm, each node should be calculated on a single device. In the leaf-to-root pass, the communication will be invoked when two brother nodes are calculated on different devices to transmit the hidden information; in the root-to-leaf pass, the communication is similarly triggered. I show a simple illustration on how to implement CP. As a result, the CP_SIZE is also determined by the number of children when implementing the fast scan algorithm.
(Just to confirm whether I am understanding correctly, thx)

image
@josiahbjorgaard
Copy link

josiahbjorgaard commented Sep 24, 2024

I don't believe it does, as I'm working on an implementation of it.

I think there is another detail to it than you've sketched out here. There is a weighted cumulative sum that occurs over all states from previous chunks in the sequence. This will need to be updated for each group of chunks as they've been scattered to multiple GPUs. It's in figure 7 of the Mamba 2 paper - that is, the yellow arrows. It is here in the code, but not modified for context parallel.

Distributing it either requires gathering final states from all GPUs operating on previous chunks of the sequence, calculation of weight updates (the products of A elements) and then a weighted sum reductions per GPU of previous 'final' states per GPU, or alternatively sequential point-to-point GPU by GPU in order to weight 'final' states sequentially.

It looks to me like everything else can be computed per chunk (i.e. per GPU), except the convolution on the sequence which runs prior to the mixer, which may also need to be modified to prevent a bottleneck when running sequence/context parallel.

Would be great to get some feedback on this if anyone else is working on it or understands the context parallel strategy for the SSD model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants