Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FlashAttentionKwargs and seq_idx to flat collator #36456

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

garrett361
Copy link

@garrett361 garrett361 commented Feb 27, 2025

What does this PR do?

Adds additional, optional return values in DataCollatorWithFlattening as needed for padding-free training with particular models.

Relates to #35861 and #35941.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 marked this pull request as draft February 27, 2025 16:23
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@garrett361 garrett361 force-pushed the flash-attn-kwargs-in-flat-collator branch from 1d732bf to f0113af Compare February 27, 2025 16:23
@garrett361 garrett361 mentioned this pull request Feb 27, 2025
5 tasks
@vasqu
Copy link
Contributor

vasqu commented Feb 27, 2025

I'd like to take a look as well when you think you're ready, so gladly ping then :)

@vasqu
Copy link
Contributor

vasqu commented Feb 27, 2025

cc @ArthurZucker

@garrett361 garrett361 force-pushed the flash-attn-kwargs-in-flat-collator branch 2 times, most recently from 3daac1b to a3fc94c Compare March 4, 2025 14:39
@garrett361 garrett361 marked this pull request as ready for review March 4, 2025 14:57
@garrett361
Copy link
Author

@vasqu could you please take a look when you have time? Thank you.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smaller issues/nits overall, I'd be for a warning in case settings might cause issues (RoPE with fa kwarg only)

Otherwise, not on you but I think it would be nice to have equivalent collator tests on torch at least. Usually, all corresponding paths are tested (pt, tf, np).

Comment on lines +1799 to +1897
return_position_ids=True,
return_flash_attn_kwargs=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iiuc #35941 points to a problem that fa kwargs will cause issues on the rope paths subsequently (under fa true, positions false).

I'd be for a warning in case of the bad combo of fa kwarg true and position ids false on init (maybe someone has a different use case which shouldn't directly cause errors)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe someone has a different use case which shouldn't directly cause errors

Yeah, I thought about warnings in cases like that, but I was hesitant because of different requirements for different models.

Like, if a transformer model uses FA but not RoPE, then FA True, pos_ids False make sense. And for a mamba-only model (like mamba2) FA False, pos_ids False, seq_idx True is what you'd use.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IMO it should be up to the model to ensure it's getting the right inputs it needs and to raise a ValueError or similar if an improper combination is passed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, then FA utils should be written in a way that ensures this (which should be a follow up PR to this).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can be address at the level of the FA utils, since different models can logically use different valid combinations here. The FA utils just need to be able to handle the different combos.

It does look like non-trivial FlashAttentionKwargs and position_ids=attention_mask=None is currently not supported, though. IIUC you'd end up in this block with all of your cu_seq_lens_{q,k} etc ignored.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None? It might be an error (unintentional path) or warning (no padding path); unsure here. (Or even implementing that path ourselves which seems unwanted)

Imo it's a bit confusing when the original flash attn can handle those args while we can't. As a user it would silently fall through when I'm familiar with fa but not transformers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None?

Yeah, not sure why this isn't done in the current FA utils. Also found this confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, either way I think some sort of handling is warranted here which should help ease up the bamba checks (hopefully).

@garrett361
Copy link
Author

I think it would be nice to have equivalent collator tests on torch at least

I found it a little surprising that it's all numpy, as far as I saw.

@vasqu
Copy link
Contributor

vasqu commented Mar 4, 2025

Would you be willing to add the other paths or at least pt? Seems like an oversight on the initial PR 👀

@garrett361
Copy link
Author

Would you be willing to add the other paths or at least pt? Seems like an oversight on the initial PR

Yep, and I just discovered that the "pt" path is broken because int entries like those in cu_seq_lens_{q,k} turn into torch.int64's whereas FA2 needs them to be torch.int32s 😭

So, this is all going to take some reworking. I'll ping again later.

@garrett361 garrett361 force-pushed the flash-attn-kwargs-in-flat-collator branch from a3fc94c to 601e649 Compare March 14, 2025 15:48
@garrett361
Copy link
Author

garrett361 commented Mar 14, 2025

Alright, made a few changes:

  • No batch dimension is expected on any of the FlashAttentionKwargs, and all these variables and seq_idx must be int32 rather than the default int64. I removed the reliance on the default data collators to achieve this.
  • I added a ModelTesterMixin::test_flash_attention_2_padding_matches_padding_free_with_position_ids_from_flat_collator which both has an awfully long name and is probably an unnecessarily expensive addition to the test suite.
  • Added tf and pt tests for the flat collator.

Do you have any advice on the test @vasqu ? I wanted a non-trivial test that the collator outputs are in the right format (which is easy to get wrong), but the above seems like overkill.

EDIT: I only verified that the LlamaModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids_from_flat_collator version of the test passes, I should probably say.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's overall solid. Some implementations could be simplified imo (e.g. the wrapper at the end of the collator) and left some other smaller comments.

No batch dimension is expected on any of the FlashAttentionKwargs, and all these variables and seq_idx must be int32 rather than the default int64. I removed the reliance on the default data collators to achieve this.

Personally, I don't think that the comment about the batch dim provides any real value - would drop it, kinda confused me at first. int32 adjustments are good, only worried about max_seq_len_q/k (should be a simple py in32, but see comments).

I added a ModelTesterMixin::test_flash_attention_2_padding_matches_padding_free_with_position_ids_from_flat_collator which both has an awfully long name and is probably an unnecessarily expensive addition to the test suite.

Imo, could be added to the previous padding-free test. If you want to keep it as is, I'd suggest a renaming (from flattening collator is rather non-telling).

Added tf and pt tests for the flat collator.

🥳

Comment on lines 1872 to 1873
ret["input_ids"] = data_cls(ret["input_ids"], dtype=dtype_64)[None]
ret["labels"] = data_cls(ret["labels"], dtype=dtype_64)[None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that this (and the following calls) could be simplified, e.g. something like

for key, dtype in zip(["input_ids", ...], [dtype_64, ...]):
    if ret.get(key, None):
        ret[key] = data_cls(ret[key], dtype=dtype)

rough draft ^

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! cc @ArthurZucker for core maintainer review

@vasqu
Copy link
Contributor

vasqu commented Mar 17, 2025

Looks like you'll need to run make style for the ci to be happy. I think the PR is overall ready tho.

@garrett361 garrett361 force-pushed the flash-attn-kwargs-in-flat-collator branch from 5590060 to b582c82 Compare March 21, 2025 14:08
@garrett361
Copy link
Author

@ArthurZucker please let me know if I can answer any further questions

@garrett361
Copy link
Author

@vasqu any advice here?

@vasqu
Copy link
Contributor

vasqu commented Mar 29, 2025

@garrett361 sorry don't have anything to add. Arthur is quite busy so notifications can easily get lost at times, I'm gonna cc again.

@vasqu
Copy link
Contributor

vasqu commented Mar 29, 2025

cc core maintainer review @ArthurZucker @Cyrilvallez

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants