-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add FlashAttentionKwargs and seq_idx to flat collator #36456
base: main
Are you sure you want to change the base?
add FlashAttentionKwargs and seq_idx to flat collator #36456
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
1d732bf
to
f0113af
Compare
I'd like to take a look as well when you think you're ready, so gladly ping then :) |
3daac1b
to
a3fc94c
Compare
@vasqu could you please take a look when you have time? Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Smaller issues/nits overall, I'd be for a warning in case settings might cause issues (RoPE with fa kwarg only)
Otherwise, not on you but I think it would be nice to have equivalent collator tests on torch at least. Usually, all corresponding paths are tested (pt, tf, np).
return_position_ids=True, | ||
return_flash_attn_kwargs=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iiuc #35941 points to a problem that fa kwargs will cause issues on the rope paths subsequently (under fa true, positions false).
I'd be for a warning in case of the bad combo of fa kwarg true and position ids false on init (maybe someone has a different use case which shouldn't directly cause errors)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe someone has a different use case which shouldn't directly cause errors
Yeah, I thought about warnings in cases like that, but I was hesitant because of different requirements for different models.
Like, if a transformer model uses FA but not RoPE, then FA True, pos_ids False make sense. And for a mamba-only model (like mamba2) FA False, pos_ids False, seq_idx True is what you'd use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IMO it should be up to the model to ensure it's getting the right inputs it needs and to raise a ValueError
or similar if an improper combination is passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, then FA utils should be written in a way that ensures this (which should be a follow up PR to this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can be address at the level of the FA utils, since different models can logically use different valid combinations here. The FA utils just need to be able to handle the different combos.
It does look like non-trivial FlashAttentionKwargs
and position_ids=attention_mask=None
is currently not supported, though. IIUC you'd end up in this block with all of your cu_seq_lens_{q,k}
etc ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None
? It might be an error (unintentional path) or warning (no padding path); unsure here. (Or even implementing that path ourselves which seems unwanted)
Imo it's a bit confusing when the original flash attn can handle those args while we can't. As a user it would silently fall through when I'm familiar with fa but not transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None?
Yeah, not sure why this isn't done in the current FA utils. Also found this confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, either way I think some sort of handling is warranted here which should help ease up the bamba checks (hopefully).
I found it a little surprising that it's all |
Would you be willing to add the other paths or at least pt? Seems like an oversight on the initial PR 👀 |
Yep, and I just discovered that the So, this is all going to take some reworking. I'll ping again later. |
a3fc94c
to
601e649
Compare
Alright, made a few changes:
Do you have any advice on the test @vasqu ? I wanted a non-trivial test that the collator outputs are in the right format (which is easy to get wrong), but the above seems like overkill. EDIT: I only verified that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's overall solid. Some implementations could be simplified imo (e.g. the wrapper at the end of the collator) and left some other smaller comments.
No batch dimension is expected on any of the FlashAttentionKwargs, and all these variables and seq_idx must be int32 rather than the default int64. I removed the reliance on the default data collators to achieve this.
Personally, I don't think that the comment about the batch dim provides any real value - would drop it, kinda confused me at first. int32 adjustments are good, only worried about max_seq_len_q/k (should be a simple py in32, but see comments).
I added a ModelTesterMixin::test_flash_attention_2_padding_matches_padding_free_with_position_ids_from_flat_collator which both has an awfully long name and is probably an unnecessarily expensive addition to the test suite.
Imo, could be added to the previous padding-free test. If you want to keep it as is, I'd suggest a renaming (from flattening collator is rather non-telling).
Added tf and pt tests for the flat collator.
🥳
ret["input_ids"] = data_cls(ret["input_ids"], dtype=dtype_64)[None] | ||
ret["labels"] = data_cls(ret["labels"], dtype=dtype_64)[None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a feeling that this (and the following calls) could be simplified, e.g. something like
for key, dtype in zip(["input_ids", ...], [dtype_64, ...]):
if ret.get(key, None):
ret[key] = data_cls(ret[key], dtype=dtype)
rough draft ^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! cc @ArthurZucker for core maintainer review
Looks like you'll need to run |
5590060
to
b582c82
Compare
@ArthurZucker please let me know if I can answer any further questions |
@vasqu any advice here? |
@garrett361 sorry don't have anything to add. Arthur is quite busy so notifications can easily get lost at times, I'm gonna cc again. |
cc core maintainer review @ArthurZucker @Cyrilvallez |
What does this PR do?
Adds additional, optional return values in
DataCollatorWithFlattening
as needed for padding-free training with particular models.Relates to #35861 and #35941.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.