Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add truncate option in Tokenizer #213

Closed
wants to merge 6 commits into from
Closed

Conversation

gokulavasan
Copy link
Contributor

@gokulavasan gokulavasan commented Jan 18, 2024

Context

  • Some training samples, when encoded, could result in token list being really long beyond what the model can accept (max_sequence_length of a model). This will result in a training failure as the model won't be able to accept that input.
  • Add an option to be able to truncate the token id list returned by the tokenizer
  • Though alpaca dataset has input ids all fitting within 4096, slimorca (Adding SlimOrca Dataset to the datasets collection #116) has samples that are longer than 4096.

Reference implementation - HF BertTokenizer - https://colab.research.google.com/drive/1BBq5BPf1zjlPs0A5ky0mP-gNu_zFi0r5#scrollTo=iqmKgNj647FN. If there is a different tokenizer that is recommended, please do suggest. Note that lit-gpt performs truncation on the right and including the EOS while HF BertTokenizer doesn't drop the EOS.

Changelog

  • Add max_len to tokenizer constructor that is used during encode operation if truncation is set to True. So tokenizer's max_len would be set during it's initialization to be max_seq_length of the model. In this case, llama_tokenizer is initialized to max_seq_len of the llama2 model
  • Add truncate option to encode which when set will use the max_len param set for the tokenizer.

Test plan

  • Added unit tests that verify this

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 18, 2024
Copy link

netlify bot commented Jan 18, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit c8cb6ec
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65aa49ab4c57c20007357fb1
😎 Deploy Preview https://deploy-preview-213--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

torchtune/modules/tokenizer.py Outdated Show resolved Hide resolved
text,
add_bos=add_bos,
add_eos=add_eos,
out_type=int,
)
if truncate and self.max_len is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user sets truncate to True but forgets to set max length, perhaps we should raise a warning that the output will not be truncated

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, or an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA @rohan-varma Is there a way to log a warning/error only once (or say once every 10 seconds)? If the truncate is called and max len isn't set, it might flood the logs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the warning be logged from the constructor? If so, why would it flood the logs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case let's just go with an error if truncate is true but max length is none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kartikayk The max token len option is generally derived from the model's max seq length and so it is part of the tokenizer constructor. Whether the dataset wants to perform truncation or not is controlled at the encode method call. What I can probably do is, add a warning at tokenizer initialization to call out the max token len has not been set and thus no truncation will be performed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gokulavasan thanks for calling this out, I was wondering about this while reviewing this PR but decided against asking this question.

Whether the dataset wants to perform truncation or not is controlled at the encode method call

The fact that the dataset decides whether truncation needs to be performed or not makes a LOT of sense to me. If this is true, then shouldn't the truncation happen where encode is called instead of within the encode function itself? What's the value in having encode (and tokenizer) be aware of this param? It's not like it's saving you anything since you still to tokenize the entire input AND then truncate.

Copy link
Contributor Author

@gokulavasan gokulavasan Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kartikayk Is your suggestion that the truncate be before in the dataset code right after the tokenizer.encode is called? I would imagine both alpaca and slimorca datasets using truncation feature (based on the max sequence length of the model) - thus thought having this in tokenizer would allow that capability.

In HF tokenizer, it is available in the encode method - https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L2553-L2554.

In lit-gpt tokenizer, it is also available in the encode method - https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/tokenizer.py#L88.

If you think we don't have to move this to tokenizer just yet and do the truncation in slimorca dataset, I can do that instead and revisit this PR later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gokulavasan yeh exactly. Thinking through this conceptually, I'm trying to figure out what the value for adding this to the encode of the tokenizer is since this doesn't really impact the tokenization functionality. max_seq_len depends on the (dataset, model) tuple. We can just configure this as a param in the dataset and when getting the sample tokenize, truncate and return? Warning can be added to the dataset init. Tests can just be coupled with the dataset itself to make sure samples are being truncated appropriately. WDYT?

torchtune/modules/tokenizer.py Outdated Show resolved Hide resolved
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as @RdoubleA regarding the eos token. Please also take some time to build and render the documentation and add docs as needed which might as well be done if this file is being touched.

tests/torchtune/modules/test_tokenizer.py Outdated Show resolved Hide resolved
tests/torchtune/modules/test_tokenizer.py Outdated Show resolved Hide resolved
torchtune/modules/tokenizer.py Outdated Show resolved Hide resolved
torchtune/modules/tokenizer.py Outdated Show resolved Hide resolved
text,
add_bos=add_bos,
add_eos=add_eos,
out_type=int,
)
if truncate and self.max_len is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, or an error.

@gokulavasan
Copy link
Contributor Author

Updated the PR @rohan-varma @RdoubleA. Do you guys have any suggestion on limiting logger logs (every N seconds)? Or I can just logging with warning log level

@kartikayk
Copy link
Contributor

@gokulavasan thanks for adding this. Have we validated this with any reference code? If so, do you mind adding that information in the context of the PR?

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepting to unblock, just one comment and the truncate error thing

torchtune/modules/tokenizer.py Show resolved Hide resolved
@gokulavasan
Copy link
Contributor Author

@kartikayk Added reference implementation in the description (HF BertTokenizer). Note that lit-gpt has a different behavior where the EOS is truncated as well.

@kartikayk
Copy link
Contributor

Thanks @gokulavasan, I figured as much. So do we need to let the calling function figure out whether it wants to truncate EOS or not? And make this a flag in that function for us to turn on and off?

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this to "Request Changes" since we have a couple of open discussions. happy to revert back to "Approved" if those dont make sense to address here.

@gokulavasan
Copy link
Contributor Author

Closing this PR as I moved the truncation logic from this PR to SlimOrca Dataset PR #116

@joecummings joecummings deleted the tokenizer-truncate-option branch April 12, 2024 23:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants