Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor datasets and tokenizer #624

Merged
merged 11 commits into from
Apr 2, 2024
Merged

Refactor datasets and tokenizer #624

merged 11 commits into from
Apr 2, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Mar 31, 2024

Note: a lot of the changes to templates here are primarily done by @RdoubleA

Context

  • Currently the way we handle tokenization and masking of prompts and responses across instruct and chat datasets is fragmented and hard to follow. Additionally, as pointed out in Potential future tokenization issue in Alpaca #366, there is potential for tokenization errors by tokenizing across message boundaries.
  • The safest thing to do is to call the tokenizer's encode method at the level of individual messages, then stitch them together in our own tokenizer class with whatever custom logic we need. This PR allows us to do that.

Changelog

  • Add tokenize_messages API to our tokenizer. This takes in a list of messages, tokenizes each one individually, then stitches the outputs together with any requisite special tokens, truncation, etc. For SentencePiece tokenizer (which is all we currently support), this is just BOS and EOS, but for tokenizers with more complicated sets of special tokens, this allows us to customize at the tokenizer level without the dataset class having to worry about it.
  • Align on a unified API for both instruct and chat datasets calling tokenize_messages. This is more general than our current usage of tokenize_prompt_and_response.
  • Split templates into chat formats (for chat datasets) and instruct templates (for instruct datasets). Chat formats are List[Message] -> List[Message] to better align with the natural format of chat conversations. Instruct templates still operate as string formatters to maintain the ability for simple prompt formatting on instruct datasets.
  • Other smaller stuff: simplify truncate API (shouldn't require tokens and labels and tokenizer), add handling of leading whitespaces when splitting across messages (sentencepiece defaults to prepending any tokenized string with a whitespace, which we don't always want).

Test plan

Added new tokenizer tests for (a) tokenize_messages API, and (b) encoding without leading whitespace
Refactored existing template tests into instruct template tests and chat format tests (again thanks @RdoubleA).

pytest tests
...
===== 199 passed, 17 skipped, 5 warnings in 90.91s (0:01:30) ===
pytest tests -m integration_test
...
====== 15 passed, 201 deselected, 3 warnings in 149.88s (0:02:29) =====

LoRA finetune on slimorca dataset (manually deleting dataset.use_clean from the config cause I can't override from CLI properly)

tune run --nproc_per_node 4 lora_finetune_distributed --config llama2/7B_lora checkpointer.checkpoint_dir=/data/users/ebs/checkpoints checkpointer.checkpoint_files=['llama2-7b-torchtune.pt'] checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer tokenizer.path=/data/users/ebs/checkpoints/lora-debug/tokenizer.model dataset=torchtune.datasets.slimorca_dataset metric_logger=torchtune.utils.metric_lo
gging.WandBLogger metric_logger.project=lora-debug
Screenshot 2024-04-01 at 9 49 21 PM

Copy link

pytorch-bot bot commented Mar 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/624

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5603599 with merge base 73647e2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 31, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overall quite intuitive and lets Instruct and ChatDataset still play specialized roles while utilizing common APIs. I think we've struck a nice balance here.

@@ -50,6 +52,45 @@ def eos_id(self):
def bos_id(self):
return 0

def tokenize_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think of an alternative solution to this, because we'll have to update this every time the real tokenize messages is updated. Can we do a more stripped down approach for testing purposes or is this the most barebones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. The main thing I wanted here was to guarantee I could replicate the performance in the existing unit test using the same logic. We can definitely use a simpler method but will have to change the expected values. (Really I should just add a test for tokenize_messages on the tokenizer, then we can use something simple here and still be confident it's working as expected.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you added the tokenize_messages test - should we simplify here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do like explicitly testing fetching a sample from ChatDataset though. I have another idea to simplify the code here 😃

formatted_dialogue = []
for message in messages:
content = ""
if message.role == "system":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could put system, user, assistant in a dictionary and just key on message.role


from torchtune.datasets import slimorca_dataset
from torchtune.modules.tokenizer import Tokenizer

LLAMA_TEMPLATE = Llama2ChatTemplate()
LLAMA_TEMPLATE = Llama2ChatFormat()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be initialized here, just the class name

def format(
cls,
sample: List[Message],
) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update return to List[Message]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be deleted?

)
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice indeed, but please add comment for mere mortals like me who need to take more than a min to understand this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this needlessly complicated so that's my fault

convert_to_dialogue=convert_to_dialogue,
template=_get_template(template),
convert_to_messages=convert_to_messages,
chat_format=chat_format,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a mapping from str to the actual class pointer, a simpler version of get_template

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just gonna pass the ChatFormat directly for now, lmk if that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, then this builder will require a nested component to use from the config. It won't work from the config then. I can update it in a follow up if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will leave here and duck - is the one layer of component instantiation coming in the way of getting stuff done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's do in a follow-up. This will only affect custom chat datasets, right? Which we don't have any of yet anyways. Re nested instantiation, imo this is not a sufficient reason to add it.. I think we can find another way here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the string mapping is very simple and uses our existing tools without requiring nested instantiation or impacting UX. Will do a follow-up and we can discuss there.

convert_to_dialogue=sharegpt_to_llama2_dialogue,
template=Llama2ChatTemplate(),
convert_to_messages=sharegpt_to_llama2_messages,
chat_format=Llama2ChatFormat(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No () here because format is a class method

mask.extend([message.masked] * len(tokens))

# Break out early if we reach max_seq_len
if max_seq_len and len(tokens) >= max_seq_len:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(tokenized_messages)?

torchtune/modules/tokenizer.py Show resolved Hide resolved
self._tokenizer, prompt_tokens, label_tokens, self.max_seq_len
messages = self._convert_to_messages(sample)
messages = self.chat_format.format(messages)
tokens, mask = self._tokenizer.tokenize_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is train_on_input used for ChatDataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Gonna pass it to _convert_to_messages, I think it makes most sense there tbh. Basically we want it wherever we are building the messages

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm _convert_to_messages is provided by the user right? so then we have to enforce that their Callables take in this parameter and the burden is on them to mask appropriately. maybe pass it to chat_format or tokenize_messages and let the user know that if they set it it will overwrite whatever custom masking they set up in _convert_to_messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. Similar to my comment above, can we handle this in a follow-up? I think it's a relatively standalone change not too related to all this tokenizer business

@pytest.mark.parametrize(
"config", ["full_single_device_low_memory", "full_single_device"]
)
# @pytest.mark.parametrize(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for some testing, will remove

) -> List[int]:
"""Encode text into token IDs.

Args:
text (str): The input text to be encoded, unbatched.
add_bos (bool): Whether to prepend BOS to the input, defaults to True.
add_eos (bool): Whether to append EOS to the input, defaults to True.

trim_leading_whitespace (bool): Whether to trim leading whitespace from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanna link the discussion you found here so people know why this is a thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I think it might be a bit misleading to just link that with no context. I am gonna add my own comment that imo is more relevant.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llama2ChatFormat and MistralChatFormat have a lot of whitespaces between messages - have you considered how that interacts with the trimming? should we remove those whitespaces pre-emptively in the ChatFormat classes?

CHAT_SAMPLE = [
Message(
role="system",
content="You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe can just chunk the string like the expected_dialogues below so we can remove the noqa?

tests/torchtune/datasets/test_chat_dataset.py Show resolved Hide resolved
tests/torchtune/modules/test_tokenizer.py Show resolved Hide resolved
@@ -9,7 +9,9 @@
from torchtune.data._types import Message


def sharegpt_to_llama2_messages(sample: Mapping[str, Any]) -> List[Message]:
def sharegpt_to_llama2_messages(
sample: Mapping[str, Any], train_on_input: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't know if we should make these transforms require train_on_input, then we need to enforce a certain API if a user passes in their own transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. I'd like to tackle this as a follow-up tbh

self.whitespace_encodings = {
c: self.spm_model.encode(c) for c in WHITESPACE_CHARS
}
self.encodes_whitespace = any(self.whitespace_encodings.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if spm_model fails to encode a character, what is returned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty list. That's why I'm doing this -- to check if whitespaces are encoded as part of the model. In fact prob don't even need to save the encodings dict at all

Returns:
List[int]: The encoded token IDs.
"""
if trim_leading_whitespace:
# Can define our own custom prefix depending on vocab if needed
if not hasattr(self, "prefix"):
self.prefix = prefix or "pre"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "\n" not a reasonable default?

@@ -63,6 +69,7 @@ def encode(
add_bos: bool = True,
add_eos: bool = True,
trim_leading_whitespace: bool = False,
prefix: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to assert that trim_leading_whitespace is True if prefix is set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk about assert, but maybe warn? Cause we can set prefix and have it just be a no-op. In general I think it is quite rare to explicitly set prefix though, it should not be needed for the canonical Llama2 tokenizer vocab

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK yeah I am gonna leave out the assertion cause tbh prefix is not really something that I expect people to be experimenting with. Lmk if that makes sense

# Can define our own custom prefix depending on vocab if needed
if not hasattr(self, "prefix"):
self.prefix = prefix or "pre"
self.encoded_prefix = self.spm_model.encode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be outside the if statement if prefix is already set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, we do this one time and cache the result.

Returns:
List[int]: The encoded token IDs.
"""
if trim_leading_whitespace:
# Can define our own custom prefix depending on vocab if needed
if not hasattr(self, "prefix"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is self.prefix set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically just caching it the first time we use it, hence why it's set in the following line. If it's already been encoded, no need to keep re-encoding it since it will remain fixed for the life of the program.

trim_leading_whitespace = (
(not start_of_turn)
and self.encodes_whitespace
and not prev_ends_with_space
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is prev_ends_with_space defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L153 as of this version. This is so that we can work with e.g. the grammar format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should throw an error because I don't see it defined on the very first message, it is only defined after. you should initialize it as False (or whichever is the default behavior) above the for loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh god I think it's breaking out of the and early in the test cases so we never hit this. Good catch

@ebsmothers ebsmothers changed the title [RFC] Refactor datasets and tokenizer Refactor datasets and tokenizer Apr 1, 2024
@@ -50,6 +52,45 @@ def eos_id(self):
def bos_id(self):
return 0

def tokenize_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you added the tokenize_messages test - should we simplify here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

convert_to_dialogue=convert_to_dialogue,
template=_get_template(template),
convert_to_messages=convert_to_messages,
chat_format=chat_format,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, then this builder will require a nested component to use from the config. It won't work from the config then. I can update it in a follow up if needed

trim_leading_whitespace = (
(not start_of_turn)
and self.encodes_whitespace
and not prev_ends_with_space
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should throw an error because I don't see it defined on the very first message, it is only defined after. you should initialize it as False (or whichever is the default behavior) above the for loop

Comment on lines +24 to +25
@classmethod
@abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL!

Comment on lines +146 to +148
raise ValueError(
"System prompts are not supported in MistralChatFormat"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the very noob question, but why is this not a pass through as we do for assistant? Basically, are we saying that this format doesnt support "system" as a role or that it doesn't support the system tags? Or are they the same thing? And if "system" isn't supported then should we be setting system above to None instead of empty string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA may have the best answer here tbh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistral does not support the system role, so if a user passes in a message with a system role, we need to either error out or raise a warning that it will be ignored. See for context: vllm-project/vllm#2080 (comment)

Agreed on setting system to None instead of an empty string.


class ChatMLFormat(ChatFormat):
"""
OpenAI's Chat Markup Language used by their chat models:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, OpenAI models support only tiktoken. Are we adding support for that in this PR or in a follow up? Or not yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it for a follow-up if you're good with that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they use this format with TikToken: https://community.openai.com/t/how-does-chatml-do-the-exact-formatting/80751

The main motivation to add this template is that it is default in HF if the model has no custom template: https://github.com/huggingface/transformers/blob/096f304695f7e7b169b031f7814352e900ad71c4/src/transformers/tokenization_utils_base.py#L1838

pass


class AlpacaInstructTemplate(InstructTemplate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I misunderstood earlier, but I thought we're converting this to follow a similar structure to the chat formats i.e. tokenize instruction, input and response separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is kinda our middle ground for now. This keeps the convenience of string formatting as an option for instruct datasets then casts to messages on the dataset side. Tbh this is one of the kinks I still want to iron out here, we could go all the way on the Message format but I'm hedging a bit for now cause I know people like the string-formatting of prompts. But yeah this may change in the future

dialogue.append(Message(role=role, content=content))

return dialogue
masked = (role != "assistant") and train_on_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe I'm misunderstanding this, but shouldn;t this be

masked = (role != "assistant") and not train_on_input

i.e. we don't mask if train_on_input is TRUE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good catch. I got thrown off by the test_slimorca_dataset test, which was always asserting that the final token of labels was equal to eos_id, even when max_seq_len was so short that we were not including the assistant message and train_on_input was False. So I think the test was doing the wrong thing in that case (if train_on_input=False and we only have inputs, then everything should be masked, even the EOS token). Lmk if this makes sense to you

)
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice indeed, but please add comment for mere mortals like me who need to take more than a min to understand this :)

convert_to_dialogue=convert_to_dialogue,
template=_get_template(template),
convert_to_messages=convert_to_messages,
chat_format=chat_format,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will leave here and duck - is the one layer of component instantiation coming in the way of getting stuff done?

Comment on lines +47 to +49
self.encodes_whitespace = any(
[self.spm_model.encode(c) for c in WHITESPACE_CHARS]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea what's going on here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably add a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yeah, discussed the purpose of this variable in the second point here

underlying sentencepiece tokenization. Sentencepiece normally prepends
whitespace to any tokenized text, which can cause differences where
encode(s1) + encode(s2) != encode(s1 + s2) due to leading whitespace
added to s2. Default: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when would I set this to True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to describe in the comment on L138-L141. First, we only run into mismatches due to leading whitespace on the nth split of a string, where n>1. This is because sentencepiece automatically prepends whitespace to a string prior to encoding it, so when we do encode(s1+s2) and encode(s1) + encode(s2), s1 has been prepended with whitespace in both cases, and so the behavior is the same (this is point (a)).

Second, we also only see a mismatch if the tokenizer actually explicitly encodes whitespace, which not all tokenizers do. E.g. if spm_test corresponds to our tokenizer in tests/assets/m.model, and spm_llama corresponds to the usual Llama2 tokenizer, we can actually see that

>>> spm_test.encode(" ", add_bos=False, add_eos=False)
[]
>>> spm_test.encode("\n", add_bos=False, add_eos=False)
[]
>>> spm_llama.encode(" ", add_bos=False, add_eos=False)
[259]
>>> spm_llama.encode("\n", add_bos=False, add_eos=False)
[29871, 13],

so our test tokenizer does not return any tokens when it only sees whitespace. Because of this, we get different behavior in both cases when splitting a single string.

# When the tokenizer doesn't tokenize whitespace, the results match
>>> spm_test.encode("hi\nthere", add_bos=False, add_eos=False)
[476, 70]
>>> spm_test.encode("hi\n", add_bos=False, add_eos=False) + spm_test.encode("there", add_bos=False, add_eos=False)
[476, 70]

# On the regular Llama2 tokenizer, the results do not match
>>> spm_llama.encode("hi\nthere", add_bos=False, add_eos=False)
[7251, 13, 12711]
>>> spm_llama.encode("hi\n", add_bos=False, add_eos=False) + spm_llama.encode("there", add_bos=False, add_eos=False)
[7251, 13, 727]

Finally, there are some prompts that end with " " (e.g. our GrammarErrorCorrectionTemplate). In this case, the concatenated string actually contains a space, so we do not want to strip it from s2. But we do want to remove it from the preceding string to ensure we don't double-count it. This is the reason for the .rstrip(" ") in L150.

So these are the three conditions we need to check to determine whether to do this special handling, and they are all checked in tokenize_messages below.

mask = []
for message in messages:
# If assistant message, this is the end of a turn
end_of_turn = message.role == "assistant"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm grossly misunderstanding, but how are both start_of_turn and end_of_turn both True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They shouldn't be. End of turn is just on an assistant message. We set that at the beginning of the iteration, then at the end of the iteration, we set end of turn to False and start of turn to True, since the next iteration will be the start of a turn. Lmk if I'm misunderstanding your question here though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start of turn and end of turn can both be true here if the list of messages the user passes in consists of only a single assistant message. This would not be a valid dialogue, but the BOS and EOS should still be appended correctly. I wonder if we should validate that the dialogue is well-formed here or outside this method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. Imo it's not up to the tokenizer to decide this. Let's consider adding a separate utility e.g. validate_messages that can be called from the dataset class or elsewhere as a follow-up.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of minor comments, otherwise looks good to go! This was a tough one, thanks for pushing this through. Let's make sure we keep track of the follow-up items that were discussed.

@@ -31,8 +31,11 @@
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
}

# Inherit from tokenizer class to reuse its tokenize_messages method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you just say inherit 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best way to reduce copy-pasted code 😉. But for this case I do think it's the right thing, since it still allows us to call tokenize_messages (e.g. when testing on datasets) with the stripped-down tokenizer but using equivalent logic. The only usage currently is in test_slimorca.py, but I do kinda like that test, so keeping it like this for now.

Comment on lines +146 to +148
raise ValueError(
"System prompts are not supported in MistralChatFormat"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistral does not support the system role, so if a user passes in a message with a system role, we need to either error out or raise a warning that it will be ignored. See for context: vllm-project/vllm#2080 (comment)

Agreed on setting system to None instead of an empty string.


class ChatMLFormat(ChatFormat):
"""
OpenAI's Chat Markup Language used by their chat models:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they use this format with TikToken: https://community.openai.com/t/how-does-chatml-do-the-exact-formatting/80751

The main motivation to add this template is that it is default in HF if the model has no custom template: https://github.com/huggingface/transformers/blob/096f304695f7e7b169b031f7814352e900ad71c4/src/transformers/tokenization_utils_base.py#L1838

convert_to_dialogue=convert_to_dialogue,
template=_get_template(template),
convert_to_messages=convert_to_messages,
chat_format=chat_format,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the string mapping is very simple and uses our existing tools without requiring nested instantiation or impacting UX. Will do a follow-up and we can discuss there.

mask = []
for message in messages:
# If assistant message, this is the end of a turn
end_of_turn = message.role == "assistant"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start of turn and end of turn can both be true here if the list of messages the user passes in consists of only a single assistant message. This would not be a valid dialogue, but the BOS and EOS should still be appended correctly. I wonder if we should validate that the dialogue is well-formed here or outside this method

@ebsmothers ebsmothers merged commit 8183b42 into main Apr 2, 2024
20 checks passed
@ebsmothers ebsmothers deleted the datasets-refactor branch April 2, 2024 15:36
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants