Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about does mamba support variable-length input or cu_seqlens like flash attention? #180

Open
zigzagcai opened this issue Feb 20, 2024 · 12 comments

Comments

@zigzagcai
Copy link
Contributor

zigzagcai commented Feb 20, 2024

We know that flash attention supports cu_seqlens, which can remove padding for variable-length input in a batch and only store regular tokens. This can be useful for optimizing the computational efficiency when packing multiple short sequences.

So, does Mamba also have this mechanism such as variable-length input or cu_seqlens like flash attention?

@tridao
Copy link
Collaborator

tridao commented Feb 20, 2024

Yes, there should be ways to deal with variable length. It's not implemented yet however.

@zigzagcai
Copy link
Contributor Author

Got it. Thank you Tri Dao!

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Feb 21, 2024

Yes, there should be ways to deal with variable length. It's not implemented yet however.

Sorry but I still have some confusions:

Is it theoretical possible for Mamba to provide variable-length API like Flash-Attention flash_attn_varlen_qkvpacked_func (Dao-AILab/flash-attention#432 (comment))?
Since in most cases, for the computing efficiency, we want to concatenating short samples to get one packed sample. We know that for Transformer-based models, we can use flash-attention API which provides cu_seqlens to process packed samples.

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Feb 21, 2024

From my understanding, since conv1d and parallel associative scan in the Mamba block are linear operations, hence in theory we can make Mamba block capable of processing packed sequence with the help of attention mask or cu_seqlens.
For example, we want Mamba block to processes (packed_sequence, hidden_size) instead of (batch_size, seq_length, hidden_size), as what flash attention does.

Not sure if my understanding is correct? Just curious whether it is possible to feed in one packed sequence as input (packed_sequence, hidden_size) into mamba block like what LSTM (here) or Transformer-block has been done.

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Feb 26, 2024

Just have another question, could Mamba be parallelized over seq_len dimension like what has been done in flash-attention?

@tridao
Copy link
Collaborator

tridao commented Feb 26, 2024

It's theoretically possible to process variable lengths / packed sequences, but the implementation will be a bit tricky.
Parallelizing over seq_len dimension reduces to how one would parallelize associative scan (e..g with Blelloch scan).

@albertfgu
Copy link
Contributor

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

@deroholic
Copy link
Contributor

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

It is often done that way, but it does cause sample cross contamination during training and that is usually not desirable.

@albertfgu
Copy link
Contributor

Yes. I'm just saying sometimes it's also fine :)

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 4, 2024

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

Hi @albertfgu @tridao , I just have another confusion about mamba. Does that mean selective SSM mechanism can learn the boundary patterns by delta, or we can reset the delta -> inf to manually specifying the sequence boundaries in a cumulative sequence input?
I see in the section 3.5.2 of Mamba paper and find below description:
image

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 5, 2024

I also see one blog on together.ai and on cartesia.ai, where the next steps shows that variable length training are on the future roadmap.
It would be fantastic if mamba could provide such feature like transformer in the future!
image

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 14, 2024

Update:
Mamba variable-length sequences has been supported in #244

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants