-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sample packing for map datasets with correct RoPE encoding and no cross-contamination #875
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/875
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 97e69f4 with merge base f3611e5 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super happy to see this project :) I wanted to find a codebase that's simple, usable, with reasonable performance.
I think torchtune
pretty much satisfies most of my needs except sentence packing (which you already did in this PR!) & integrating more models (e.g., Mixtral
). Would be excited to see this PR go through!
torchtune/datasets/_packed.py
Outdated
self.ds, desc="Packing dataset", dynamic_ncols=True | ||
): | ||
buffer["input_ids"].extend(input_ids) | ||
buffer["labels"].extend(labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we want to optionally allow the user to add a separator (e.g., <eod>
) for packing (e.g., via a argument?)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The EOS tokens should serve as inherent separators, no? And your other comment about creating a sentence mask is a preferable approach imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sentence_mask
is the ultimate solution i agree! with that we probably don't need this separator at all
torchtune/datasets/_packed.py
Outdated
# If buffer has reached max_seq_len, append packed sample | ||
while len(buffer["input_ids"]) > self.max_seq_len: | ||
self.samples.append( | ||
{k: v[: self.max_seq_len] for k, v in buffer.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it will be good to support only packing examples together only when all of them can fit into ONE window.
If adding an example will exceed the context window, maybe we should NOT add that example to the current pack, set the rest of the current pack with <PAD>
, and move that example to the next pack? This will be very helpful for people doing SFT (fewer examples, but hope to get the benefit of packing without the need to truncate any sentences).
I did something like I described above here - feel free to re-use some part of the logic if interested: https://github.com/xingyaoww/Megatron-LLM/blob/main/tools/preprocess_instruct_data.py#L148-L194
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent point. I agree, I think we should do this by default for finetune. Although, for pretraining datasets / unstructured text data maybe we don't need to perform padding and it's ok to split samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! For pre-training, we can just go with the current implementation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a flag then to control this!
torchtune/datasets/_packed.py
Outdated
buffer["labels"].extend(labels) | ||
|
||
# If buffer has reached max_seq_len, append packed sample | ||
while len(buffer["input_ids"]) > self.max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A relatively unimportant point, but I believe this >
could be >=
, which may end up creating an extra sample.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm let me double check this
@RdoubleA this looks great. Your implementation of the packing logic is much cleaner than mine :-) One thing you might want to add is a test that validates how individual tokens are packed. This was easy for me to do because my code had complete control over initialization. It might be more complex now because This kind of test might be overkill now, but further down the line when someone else is working on this class (or other dataset classes that integrate with it) it might be a helpful check. Here is how I did it:
|
@calmitchell617 great suggestion, I was actually planning to borrow some of your testing logic so I'll try to include this. Also I'm glad someone else caught my hidden SpongeBob reference 👌 Overall, does this unblock what you wanted to achieve with your earlier PR? Or you need streaming first? |
I do not need streaming, this unblocks my use case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Left a handful of comments but no huge concerns from my side
if split_samples: | ||
# If we split samples, we'll know how many samples by taking the | ||
# full length and dividing by sample size | ||
last_index, remainder = divmod(max_rows * max_seq_len, sample_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did not even know divmod was a thing. Nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually could also use math.ceil(max_rows * max_seq_len / sample_size) here I think? Maybe that's clearer tbh
torchtune/datasets/_packed.py
Outdated
raise ValueError( | ||
f"Dataset sample is too long ({len(input_ids)} > {self.max_seq_len}). " | ||
"Please set `split_samples=True` or increase `max_seq_len`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about whether this is the right thing to do here. Without sample packing we wouldn't error out here, right? We would just truncate. I wonder if it makes sense to do the same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm yeah that makes sense to me
torchtune/datasets/_packed.py
Outdated
current_pack["input_ids"].extend(input_ids) | ||
current_pack["labels"].extend(labels) | ||
|
||
if len(current_pack["input_ids"]) > self.max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So where did we land on the > vs >= thing here? I still don't understand why this isn't >= personally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is just me, but I find the logic a bit round about. Why not just see if the current length of the pack + length of incoming sample is > max length or not? If it is, you write the current pack out, if not just add it. Maybe I'm missing some complexity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nuance is when you split the sample, if you check current length + length of incoming then you either write the current pack out or split the incoming sample upto max seq len and write the pack out. IMO that logic and the logic here are almost identical
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, let's figure out masking correctly before landing this PR.
torchtune/datasets/_packed.py
Outdated
inputs and labels. | ||
max_seq_len (int): Maximum number of tokens to pack | ||
max_rows (Optional[int]): maximum number of samples to pack. Default is None, which will pack as many samples as possible. | ||
split_samples (bool): if the last sample in a pack does not fit in ``max_seq_len``, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like split_across_instances
or split_across_boundary
would be more intuitive. If I just found this flag in a config, I'd interpret it as splitting all samples across instances i.e. no packing. Ignore this comment if this is an accepted convention in the community. If not, I'd suggest thinking a bit about this flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change it to split_across_pack
torchtune/datasets/_packed.py
Outdated
current_pack["input_ids"].extend(input_ids) | ||
current_pack["labels"].extend(labels) | ||
|
||
if len(current_pack["input_ids"]) > self.max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is just me, but I find the logic a bit round about. Why not just see if the current length of the pack + length of incoming sample is > max length or not? If it is, you write the current pack out, if not just add it. Maybe I'm missing some complexity?
torchtune/datasets/_packed.py
Outdated
) | ||
|
||
previous_sample_boundary = len(current_pack["input_ids"]) | ||
if self.max_rows is not None and len(self.samples) >= self.max_rows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So max_rows is only respected if we have < max_seq_len tokens right? Then why have that option at all? User can just reduce the seq len to get fewer samples? Or why would they want fewer samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not exactly - max rows lets users limit how many total packs in the dataset is returned. I imagine this will be a lot more relevant for streamed / iterable datasets. max_seq_len controls the size of each individual pack. reducing max_seq_len but keeping dataset size the same will actually result in more samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with @RdoubleA. Max rows is how many packed rows you end with. Most users will choose None
for datasets that fit in RAM, but will choose some lesser value for OOM datasets. A very nice helper arg might be to calculate max_rows
for you. Not needed for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh hmm, then I'm not sure this is clear in the doc string
max_rows (Optional[int]): maximum number of samples to pack. Default is None, which will pack as many samples as possible.
This leads me to believe this param controls the number of samples read in not the number f packs returned? Or am I misunderstanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote that comment.
The intent is to signal that this arg controls the total amount of packed samples you will end up with at the end of the process. However, if you don't understand it, then others wont, either. So, the phrasing should probably be revised for clarity.
These are great suggestions for further optimizations that I will leave as follow ups :)
This is a good point, but where would you suggest creating this mask? This would require keeping track of sequence lengths if we create it outside |
There is a tradeoff, pre-computing them is faster but if we serialize it would be very memory expensive. We can compute them with the position ids, maybe storing them as ids like
Also, maybe renaming
|
I debated this exact thing when trying to figure out where to create the mask and position ids, I was mainly focused on simplicity and speed. If we have to recreate the mask in TransformerDecoderLayer or the collator we need to loop through the entire batch and the seqlens for each pack which is quite inefficient, though it might be worth benchmarking what the time difference is. But you're right that this would be much cheaper memory wise.
This was the existing name, so I'd rather not add a refactor on top of all the changes here :)
My opinion is that this is incorrect and should not be supported. Samples in an SFT dataset are typically highly correlated and should not cross-attend. But others are welcome to add their thoughts on this. |
This is more regarding the naming of the padding functions
When I read this, I got the idea that the
|
Hm I see what you mean, technically this is not a collator, |
Another option is building it in the model forward rather than each decoder layer. This way it only needs to be generated once. |
torchtune/utils/collate.py
Outdated
mask = torch.block_diag(mask, mask_pad) | ||
# For position ids, continue to increment for pad tokens | ||
input_pos_pad = torch.arange( | ||
input_pos[-1] + 1, max_seq_len - len(input_pos) + input_pos[-1] + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: a bit confusing, maybe just define next_pos=input_pos[-1] + 1
in a separate line to make the arange clearer
torchtune/utils/collate.py
Outdated
input_pos[-1] + 1, max_seq_len - len(input_pos) + input_pos[-1] + 1 | ||
) | ||
# Do not go beyond max_seq_len - 1 | ||
input_pos_pad = input_pos_pad.clamp(max=max_seq_len - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully follow this. Does this only happen when split_across_pack=True
or something? Also in that case where are we truncating the tokens and labels? Is it in the base dataset? If so it's a bit confusing to have it spread out across three places (though admittedly I don't really see a better way since base dataset obviously shouldn't care about input_pos)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are not truncating - the creation of the packs are "truncating" to a set max_seq_len. If we split across packs there is an edge case for the last sample if it bleeds over into the next pack where the pos ids continue to the end of that sample (say sample had length 5 and max_seq_len is 8, and [..., 0, 1, 2] ended up in the previous pack, and the sample got split into the last pack [3, 4]. We still have 6 slots left, how do we pad position ids here without going beyond max_seq_len?). So I just enforced this, these position ids should be ignored anyway by the loss. But the better solution is a padding mask
torchtune/modules/transformer.py
Outdated
# shape: [b, 1, s, s] | ||
if mask is not None: | ||
mask = mask[:, None, :, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on doing this in attention.py? Cause (a) that's where we actually need it to be this shape, and (b) then we do not have separate contracts on mask shape in TransformerDecoder
and TransformerDecoderLayer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think that makes sense
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #875 +/- ##
===========================================
- Coverage 66.39% 26.85% -39.54%
===========================================
Files 155 176 +21
Lines 6484 7532 +1048
===========================================
- Hits 4305 2023 -2282
- Misses 2179 5509 +3330 ☔ View full report in Codecov by Sentry. |
from torchtune.modules.tokenizers import Tokenizer | ||
|
||
|
||
def alpaca_dataset( | ||
tokenizer: Tokenizer, | ||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: be consistent here. In other things you put the kw delineator as the first arg, and in others you put it after tokenizer. Just pick one, either works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm for instruct and chat dataset builders there are many more required positional arguments which is why I put the asterisk in the beginning. I can just put it in the beginning for all of them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One nit regarding consistency where we place the kw-only delineator in the dataset builders, but otherwise this looks amazing!
@RdoubleA is the hero we don't deserve 🫡
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work, this is probably the best PR I've seen in torchtune!
|
||
# If the current pack is long enough, add it to self.packs and retain | ||
# any truncated samples for next pack, if splitting samples | ||
if len(current_pack["tokens"]) > self.max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why not >=
?
ds (Dataset): dataset to sample pack. This should return a dictionary with field | ||
"tokens" and "labels" containing the tokenized and label samples. | ||
max_seq_len (int): Maximum number of tokens to pack | ||
max_packs (Optional[int]): maximum number of packs. Default is None, which will create as many |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have the appropriate error checking here and do we really need to expose this? What if I specify max_packs as something silly like 1 or 2, but the dataset packed into 2 samples isn't possible because it exceeds max_seq_len?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it won't pack the entire dataset into 2 packs, it will just create 2 packs with whatever samples possible and then drop the rest.
max_packs
may become more relevant for iterable datasets, but maybe it's not as useful right now. Will revisit this. This was originally added by @calmitchell617 so it may have been needed for his use case?
|
||
# Add the last pack with remaining samples that did not fit in previous | ||
if len(current_pack["tokens"]) > 0 and ( | ||
self.max_packs is None or len(self.packs) < self.max_packs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't fully understand why we need max_packs
and if we don't need it, this code might become simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General concern about data parallelism -
- When running sample pack, will each rank read different data, to ensure the generated packs are still correctly sharded across ranks and ranks don't see duplicated data (and no data is dropped)?
- Assuming each rank is indeed packing different data as expected. There seems that there could be a potential for
self.packs
to become different size across ranks in some situations? For example, if rank 0 encounters 2 max_seq_len samples, it might pack this into 2, while rank 1 encounters a bunch of smaller samples and packs this into 1. In general I don't see any guarantees that the # of packs are the same across ranks.
if rank == 0: | ||
pbar = tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) | ||
|
||
for batch in self.ds: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this iteration rank / distributed aware? i.e. will each data parallel rank read different data to pack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed offline, you're right that it may be worth using DistributedSampler
here so that packing is partitioned across ranks (cc @tcapelle who also suggested this)
This is a great point and something we'll need to consider anyway when we move to iterable datasets. I need to think about this more, @gokulavasan had some suggestions on this so let me connect with him offline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing this through!
Thank you very much for your contribution. I have a question: So the current implementation still uses PyTorch's Mem-eff SDPA instead of flash_attn_varlen_func, right? |
Yes that is correct, unfortunately PyTorch SDPA does not support arbitrary masks for flash attention and we need to use a block causal mask for sample packing. We need to do a follow up investigation to compare the performance gains using flash_attn_varlen_func to see if it's worthwhile to take a dependency on this (especially for a core op like attention). I will note that in all other cases when sample packing is not used then SDPA should default to flash attention. |
…ss-contamination (pytorch#875)
The Problem
Packing multiple samples within a single context window means the model may accidentally attend to other samples it should not attend to. If there are sequences that are completely orthogonal in topic/content, cross-contamination is adding a lot of noise for that sample. Additionally, if position ids are not correctly adjusted for each individual sequence with a packed sample, then tokens that are later in the pack get an unwarranted later position bias. So, we need to add two things:
This is a highly requested feature by the community, as seen in many discussions in TRL, Hugging Face, and llama recipes here and here and here and here
However, PyTorch Core's SDPA does not support flash attention with a non-causal mask (you can see several discussions on this here, here, and here). This means to do sample packing properly, we need to either turn to other implementations or use PyTorch's memory-efficient SDPA.
The Approach
Mem-eff SDPA is not as fast as flash v2 (2x slower), so we could turn to Dao's original implementation of flash attention that supports varied sample lengths. Since torchtune does not use fused QKV to support MQA and GQA, we should use the
flash_attn_varlen_func
as an alternative to SDPA. This allows our between sample masking via cumulative sequence lengths in the batch.However, adding a third party dependency especially in such a core module should not be taken lightly and invites risk for breakages, lack of support, lack of control, etc. We need to evaluate whether the improved performance significantly outweighs this tradeoff.
Changelog
In this update, to unblock sample packing I decided to stick with mem eff SDPA + proper packing mask. More benchmarking of Dao's flash attention compared to mem eff attention needs to be done to understand the tradeoffs, which I will do as a follow-up. Here is a summary of the updates:
mask
andinput_pos
kwargs in the transformer layers. This was enough to cover both without requiring any additional conversion or mask calculations, while still maintaining minimal changes to model forward signature.PackedDataset
to create the lower triangular block mask and positions ids for each subsample within a pack, since it's much more efficient to create it while packing alongside the tokens and labels. Having to infer the mask and position ids batchwise is expensive and requires iterating through the whole batch again.PackedDataset
now returns tokens, labels, mask, and input_pos. To generalize this for all datasets, all dataset classes now return a dictionary. All of them will still just have "tokens" and "labels" keys, butPackedDataset
will additionally return "mask" and "input_pos". This also gives flexibility for new datasets to also return mask and input_pos if neededTransformerDecoder
forward now can take in an optional mask. This cannot be used on inference when causal mask is used. The mask is propagated down to the attention module. Now packing mask can be used in SDPA.input_pos
as a 2D tensor that has position ids for each sample in each pack in a batch, with a very tiny change._padded_collate_packed
, so we do not use a collator in the dataloader if dataset is packedTest plan
tune run full_finetune_single_device --config llama3/8B_full_single_device epochs=1 dataset.packed=True dataset.max_seq_len=4096
Sample packing improves performance proportional to max sequence length - at 4096, tokens/sec were at about 1800, which is nearly 6x the unpacked version at a higher max sequence length of 8192. This effect may be exaggerated the more long-tailed the sequence length distribution of the dataset is. Longer sequence length (4096) also brought down run time to about 1 hour per epoch, compared to 1 hr 45 min at 512 when packed and nearly 5-6 hours at 512 when not packed. Packing for 52k samples for Alpaca only takes 40 seconds.
Loss curves across all packed and unpacked runs align.
Evals are nearly identical for packed runs vs non-packed runs.
<style type="text/css"></style>
Huge shoutout to @calmitchell617 for initiating this in #827, I've incorporated some of your logic here for a basic sample packing feature.
What is sample packing?
Sample packing maximizes the sequence length / context window by jamming as many samples as can fit so we don't waste compute on padding tokens. This leads to much faster training because the model can process more data with fewer forward passes. However, there's usually a slower start up if you perform packing on your dataset prior to training.
There are ways to do this offline, with a bin packing algorithm, or on-the-fly, but for now I do the naive approach of greedily packing samples as a preprocessing step when initializing the dataset. This will need to evolve further once we support IterableDatasets and streaming.
How is shuffling handled?
The DistributedSampler will still handle shuffling in the recipe layer, so for a packed dataset only the packs are shuffled and not within a pack. We may need to add within-pack shuffling later, although it will be simpler with IterableDatasets.
How can I configure packing?
Simply set
packed=True
in the config. This gets routed to the builder, so all dataset builders are updated with this flag. This is preferable to having a recipe parameter, because that would require updating all configs.